Пример #1
0
def read_data(path, data_type, args):
    X = []
    y = []
    if data_type == 'deepwalk':
        word_vector = KeyedVectors.load_word2vec_format(path, binary=False)
        for word in word_vector.vocab:
            vector = word_vector[word]
            X.append(vector)
            y.append(int(word))
    if data_type == "transe":
        from models.transe import TransE as Model
        model = Model.load_model(path)
        ent_vocab = Vocab.load(args.ent)
        rel_vocab = Vocab.load(args.rel)

        valid_rels = np.random.randint(0, len(rel_vocab), args.n_cls)

        with open(args.kb) as f:
            for line in f:
                sub, rel, obj = line.strip().split('\t')
                sub_emb = model.pick_ent(ent_vocab[sub])
                obj_emb = model.pick_ent(ent_vocab[obj])
                rel_id = rel_vocab[rel]
                if (args.n_cls >= len(rel_vocab)) or (rel_id in valid_rels):
                    X.append(sub_emb - obj_emb)
                    y.append(rel_id)
    else:
        word_vector = KeyedVectors.load_word2vec_format(path, binary=False)
        ent_vocab = Vocab.load(args.ent)
        rel_vocab = Vocab.load(args.rel)

        def pick_emb(eid):
            if str(eid) in word_vector:
                return word_vector[str(eid)]
            else:
                return np.array([0.0 for i in range(len(word_vector[0]))])

        with open(args.kb) as f:
            for line in f:
                sub, rel, obj = line.strip().split('\t')
                sub_emb = pick_emb(ent_vocab[sub])
                obj_emb = pick_emb(ent_vocab[obj])
                rel_id = rel_vocab[rel]
                valid_rels = np.random.randint(0, len(rel_vocab), args.n_cls)
                if (args.n_cls >= len(rel_vocab)) or (rel_id in valid_rels):
                    X.append(sub_emb - obj_emb)
                    y.append(rel_id)

    return np.array(X), np.array(y)
Пример #2
0
def test(args):
    ent_vocab = Vocab.load(args.ent)
    rel_vocab = Vocab.load(args.rel)

    # preparing data
    if args.task == 'kbc':
        test_dat = TripletDataset.load(args.data, ent_vocab, rel_vocab)
    elif args.task == 'tc':
        test_dat = LabeledTripletDataset.load(args.data, ent_vocab, rel_vocab)
    else:
        raise ValueError('Invalid task: {}'.format(args.task))

    print('loading model...')
    if args.method == 'transe':
        from models.transe import TransE as Model
    elif args.method == 'complex':
        from models.complex import ComplEx as Model
    elif args.method == 'analogy':
        from models.analogy import ANALOGY as Model
    else:
        raise NotImplementedError

    if args.filtered:
        print('loading whole graph...')
        from utils.graph import TensorTypeGraph
        graphall = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab,
                                                 rel_vocab)
        # graphall = TensorTypeGraph.load(args.graphall)
    else:
        graphall = None

    model = Model.load_model(args.model)

    if args.metric == 'all':
        evaluator = Evaluator('all', None, args.filtered, False, graphall)
        if args.filtered:
            evaluator.prepare_valid(test_dat)

        all_res = evaluator.run_all_matric(model, test_dat)
        for metric in sorted(all_res.keys()):
            print('{:20s}: {}'.format(metric, all_res[metric]))
    else:
        evaluator = Evaluator(args.metric, None, False, True, None)
        res = evaluator.run(model, test_dat)
        print('{:20s}: {}'.format(args.metric, res))