def read_data(path, data_type, args): X = [] y = [] if data_type == 'deepwalk': word_vector = KeyedVectors.load_word2vec_format(path, binary=False) for word in word_vector.vocab: vector = word_vector[word] X.append(vector) y.append(int(word)) if data_type == "transe": from models.transe import TransE as Model model = Model.load_model(path) ent_vocab = Vocab.load(args.ent) rel_vocab = Vocab.load(args.rel) valid_rels = np.random.randint(0, len(rel_vocab), args.n_cls) with open(args.kb) as f: for line in f: sub, rel, obj = line.strip().split('\t') sub_emb = model.pick_ent(ent_vocab[sub]) obj_emb = model.pick_ent(ent_vocab[obj]) rel_id = rel_vocab[rel] if (args.n_cls >= len(rel_vocab)) or (rel_id in valid_rels): X.append(sub_emb - obj_emb) y.append(rel_id) else: word_vector = KeyedVectors.load_word2vec_format(path, binary=False) ent_vocab = Vocab.load(args.ent) rel_vocab = Vocab.load(args.rel) def pick_emb(eid): if str(eid) in word_vector: return word_vector[str(eid)] else: return np.array([0.0 for i in range(len(word_vector[0]))]) with open(args.kb) as f: for line in f: sub, rel, obj = line.strip().split('\t') sub_emb = pick_emb(ent_vocab[sub]) obj_emb = pick_emb(ent_vocab[obj]) rel_id = rel_vocab[rel] valid_rels = np.random.randint(0, len(rel_vocab), args.n_cls) if (args.n_cls >= len(rel_vocab)) or (rel_id in valid_rels): X.append(sub_emb - obj_emb) y.append(rel_id) return np.array(X), np.array(y)
def test(args): ent_vocab = Vocab.load(args.ent) rel_vocab = Vocab.load(args.rel) # preparing data if args.task == 'kbc': test_dat = TripletDataset.load(args.data, ent_vocab, rel_vocab) elif args.task == 'tc': test_dat = LabeledTripletDataset.load(args.data, ent_vocab, rel_vocab) else: raise ValueError('Invalid task: {}'.format(args.task)) print('loading model...') if args.method == 'transe': from models.transe import TransE as Model elif args.method == 'complex': from models.complex import ComplEx as Model elif args.method == 'analogy': from models.analogy import ANALOGY as Model else: raise NotImplementedError if args.filtered: print('loading whole graph...') from utils.graph import TensorTypeGraph graphall = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab, rel_vocab) # graphall = TensorTypeGraph.load(args.graphall) else: graphall = None model = Model.load_model(args.model) if args.metric == 'all': evaluator = Evaluator('all', None, args.filtered, False, graphall) if args.filtered: evaluator.prepare_valid(test_dat) all_res = evaluator.run_all_matric(model, test_dat) for metric in sorted(all_res.keys()): print('{:20s}: {}'.format(metric, all_res[metric])) else: evaluator = Evaluator(args.metric, None, False, True, None) res = evaluator.run(model, test_dat) print('{:20s}: {}'.format(args.metric, res))