def __init__(self, opt_path=None, *args, **kwargs): Algo.__init__(self, *args, **kwargs) W2VOption.__init__(self, *args, **kwargs) Evaluable.__init__(self, *args, **kwargs) Serializable.__init__(self, *args, **kwargs) Optimizable.__init__(self, *args, **kwargs) if opt_path is None: opt_path = W2VOption().get_default_option() self.logger = log.get_logger('W2V') self.opt, self.opt_path = self.get_option(opt_path) self.obj = CyW2V() assert self.obj.init(bytes(self.opt_path, 'utf-8')), 'cannot parse option file: %s' % opt_path self.data = None data = kwargs.get('data') data_opt = self.opt.get('data_opt') data_opt = kwargs.get('data_opt', data_opt) if data_opt: self.data = buffalo.data.load(data_opt) assert self.data.data_type == 'stream' self.data.create() elif isinstance(data, Data): self.data = data self.logger.info('W2V(%s)' % json.dumps(self.opt, indent=2)) if self.data: self.logger.info(self.data.show_info()) assert self.data.data_type in ['stream'] self._vocab = aux.Option({'size': 0, 'index': None, 'inv_index': None, 'scale': None, 'dist': None, 'total_word_count': 0})
def test1_is_valid_option(self): opt = W2VOption().get_default_option() self.assertTrue(W2VOption().is_valid_option(opt)) opt['save_best'] = 1 self.assertRaises(RuntimeError, W2VOption().is_valid_option, opt) opt['save_best'] = False self.assertTrue(W2VOption().is_valid_option(opt))
def load_text8_model(self): if os.path.isfile('text8.w2v.bin'): w2v = W2V() w2v.load('text8.w2v.bin') return w2v set_log_level(3) opt = W2VOption().get_default_option() opt.num_workers = 12 opt.d = 40 opt.min_count = 4 opt.num_iters = 10 opt.model_path = 'text8.w2v.bin' data_opt = StreamOptions().get_default_option() data_opt.input.main = self.text8 + 'main' data_opt.data.path = './text8.h5py' data_opt.data.use_cache = True data_opt.data.validation = {} c = W2V(opt, data_opt=data_opt) c.initialize() c.train() c.save() return c
def test5_text8_accuracy(self): set_log_level(2) opt = W2VOption().get_default_option() opt.num_workers = 12 opt.d = 200 opt.num_iters = 15 opt.min_count = 4 data_opt = StreamOptions().get_default_option() data_opt.input.main = self.text8 + 'main' data_opt.data.path = './text8.h5py' data_opt.data.use_cache = True data_opt.data.validation = {} model_path = 'text8.accuracy.w2v.bin' w = W2V(opt, data_opt=data_opt) if os.path.isfile(model_path): w.load(model_path) else: w.initialize() w.train() w.build_itemid_map() with open('./ext/text8/questions-words.txt') as fin: questions = fin.read().strip().split('\n') met = {} target_class = ['capital-common-countries'] class_name = None for line in questions: if not line: continue if line.startswith(':'): _, class_name = line.split(' ', 1) if class_name in target_class and class_name not in met: met[class_name] = {'hit': 0, 'miss': 0, 'total': 0} else: if class_name not in target_class: continue a, b, c, answer = line.lower().strip().split() oov = any( [w.get_feature(t) is None for t in [a, b, c, answer]]) if oov: continue topk = w.most_similar( w.get_weighted_feature({ b: 1, c: 1, a: -1 })) for nn, _ in topk: if nn in [a, b, c]: continue if nn == answer: met[class_name]['hit'] += 1 else: met[class_name]['miss'] += 1 break # top-1 met[class_name]['total'] += 1 stat = met['capital-common-countries'] acc = float(stat['hit']) / stat['total'] print('Top1-Accuracy={:0.3f}'.format(acc)) self.assertTrue(acc > 0.7)
def test2_init_with_dict(self): set_log_level(3) opt = W2VOption().get_default_option() W2V(opt) self.assertTrue(True)
def test0_get_default_option(self): W2VOption().get_default_option() self.assertTrue(True)