Exemplo n.º 1
0
def train(lf):
    train_path = data_utils.get_train_path(args)
    dev_path = os.path.join(args.data_dir, 'dev.triples')
    test_path = os.path.join(args.data_dir, 'test.triples')
    entity_index_path = os.path.join(args.data_dir, 'entity2id.txt')
    relation_index_path = os.path.join(args.data_dir, 'relation2id.txt')
    train_data = data_utils.load_triples(
        train_path,
        entity_index_path,
        relation_index_path,
        group_examples_by_query=args.group_examples_by_query,
        add_reverse_relations=args.add_reversed_training_edges)
    if 'NELL' in args.data_dir:
        adj_list_path = os.path.join(args.data_dir, 'adj_list.pkl')
        seen_entities = data_utils.load_seen_entities(adj_list_path,
                                                      entity_index_path)
    else:
        seen_entities = set()
    dev_data = data_utils.load_triples(dev_path,
                                       entity_index_path,
                                       relation_index_path,
                                       seen_entities=seen_entities)
    test_data = data_utils.load_triples(test_path,
                                        entity_index_path,
                                        relation_index_path,
                                        seen_entities=seen_entities)
    if args.checkpoint_path is not None:
        lf.load_checkpoint(args.checkpoint_path)
    lf.run_train(train_data, dev_data, test_data)
Exemplo n.º 2
0
def train(lf):
    train_path = data_utils.get_train_path(args)
    dev_path = os.path.join(args.data_dir, 'dev.triples')
    entity_index_path = os.path.join(args.data_dir, 'entity2id.txt')
    relation_index_path = os.path.join(args.data_dir, 'relation2id.txt')
    if args.few_shot or args.adaptation:
        normal_train_data, few_train_data = data_utils.load_triples(
            train_path,
            entity_index_path,
            relation_index_path,
            group_examples_by_query=args.group_examples_by_query,
            add_reverse_relations=args.add_reversed_training_edges,
            few_shot=True,
            lf=lf)
    else:
        train_data = data_utils.load_triples(
            train_path,
            entity_index_path,
            relation_index_path,
            group_examples_by_query=args.group_examples_by_query,
            add_reverse_relations=args.add_reversed_training_edges)
    if 'NELL' in args.data_dir:
        adj_list_path = os.path.join(args.data_dir, 'adj_list.pkl')
        seen_entities = data_utils.load_seen_entities(adj_list_path,
                                                      entity_index_path)
    else:
        seen_entities = set()
    if args.few_shot or args.adaptation:
        normal_dev_data, few_dev_data = data_utils.load_triples(
            dev_path,
            entity_index_path,
            relation_index_path,
            seen_entities=seen_entities,
            few_shot=True,
            lf=lf)
    else:
        dev_data = data_utils.load_triples(dev_path,
                                           entity_index_path,
                                           relation_index_path,
                                           seen_entities=seen_entities)
    if args.adaptation:
        for x in few_train_data:
            if args.checkpoint_path is not None:
                lf.load_checkpoint(args.checkpoint_path, adaptation=True)
            lf.run_train(few_train_data[x],
                         few_dev_data,
                         adaptation=True,
                         adaptation_relation=x)
    elif args.few_shot:
        if args.checkpoint_path is not None:
            lf.load_checkpoint(args.checkpoint_path)
        lf.run_train(normal_train_data, normal_dev_data, few_shot=True)
    else:
        if args.checkpoint_path is not None:
            if args.emb_few:
                lf.load_checkpoint(args.checkpoint_path, emb_few=True)
            else:
                lf.load_checkpoint(args.checkpoint_path)
        lf.run_train(train_data, dev_data)
Exemplo n.º 3
0
def compute_fact_scores(lf):
    data_dir = args.data_dir
    train_path = os.path.join(data_dir, 'train.triples')
    dev_path = os.path.join(data_dir, 'dev.triples')
    test_path = os.path.join(data_dir, 'test.triples')
    entity_index_path = os.path.join(args.data_dir, 'entity2id.txt')
    relation_index_path = os.path.join(args.data_dir, 'relation2id.txt')
    train_data = data_utils.load_triples(train_path, entity_index_path, relation_index_path)
    dev_data = data_utils.load_triples(dev_path, entity_index_path, relation_index_path)
    test_data = data_utils.load_triples(test_path, entity_index_path, relation_index_path)
    lf.eval()
    lf.load_checkpoint(get_checkpoint_path(args))
    train_scores = lf.forward_fact(train_data)
    dev_scores = lf.forward_fact(dev_data)
    test_scores = lf.forward_fact(test_data)

    print('Train set average fact score: {}'.format(float(train_scores.mean())))
    print('Dev set average fact score: {}'.format(float(dev_scores.mean())))
    print('Test set average fact score: {}'.format(float(test_scores.mean())))
Exemplo n.º 4
0
def export_error_cases(lf):
    lf.load_checkpoint(get_checkpoint_path(args))
    lf.batch_size = args.dev_batch_size
    lf.eval()
    entity_index_path = os.path.join(args.data_dir, 'entity2id.txt')
    relation_index_path = os.path.join(args.data_dir, 'relation2id.txt')
    dev_path = os.path.join(args.data_dir, 'dev.triples')
    dev_data = data_utils.load_triples(dev_path, entity_index_path, relation_index_path)
    lf.load_checkpoint(get_checkpoint_path(args))
    print('Dev set performance:')
    pred_scores = lf.forward(dev_data, verbose=False)
    src.eval.hits_and_ranks(dev_data, pred_scores, lf.kg.dev_objects, verbose=True)
    src.eval.export_error_cases(dev_data, pred_scores, lf.kg.dev_objects, os.path.join(lf.model_dir, 'error_cases.pkl'))
Exemplo n.º 5
0
def inference(lf):
    lf.batch_size = args.dev_batch_size
    lf.eval()
    lf.load_checkpoint(get_checkpoint_path(args))
    entity_index_path = os.path.join(args.data_dir, 'entity2id.txt')
    relation_index_path = os.path.join(args.data_dir, 'relation2id.txt')
    seen_entities = set()

    eval_metrics = {
        'dev': {},
        'test': {}
    }
    if args.eval_by_length_type:

        print('Test set performance:')
        lenlist=['triple4.txt','triple5.txt','triple6.txt','triple7.txt']
        test_path_list=[]
        for i in range(len(lenlist)):
            str1=os.path.join(args.data_dir, 'test/'+lenlist[i])
            test_path_list.append(str1)
        for j in range(len(test_path_list)):
            test_data = data_utils.load_triples(
                test_path_list[j], entity_index_path, relation_index_path, seen_entities=seen_entities, verbose=False)
            pred_scores = lf.forward(test_data, verbose=False)
            test_metrics = src.eval.hits_and_ranks(test_data, pred_scores, lf.kg.all_objects, verbose=True)
            eval_metrics['test']['hits_at_1'] = test_metrics[0]
            eval_metrics['test']['hits_at_3'] = test_metrics[1]
            eval_metrics['test']['hits_at_5'] = test_metrics[2]
            eval_metrics['test']['hits_at_10'] = test_metrics[3]
            eval_metrics['test']['mrr'] = test_metrics[4]
            print(str(j)+' done!!!')
            with open(args.log_file+'eval','a') as logf:
                for value in test_metrics:
                    logf.write(str(value)+'\n')

    return eval_metrics
Exemplo n.º 6
0
def run_ablation_studies(args):
    """
    Run the ablation study experiments reported in the paper.
    """
    def set_up_lf_for_inference(args):
        initialize_model_directory(args)
        lf = construct_model(args)
        lf.cuda()
        lf.batch_size = args.dev_batch_size
        lf.load_checkpoint(get_checkpoint_path(args))
        lf.eval()
        return lf

    def rel_change(metrics, ab_system, kg_portion):
        ab_system_metrics = metrics[ab_system][kg_portion]
        base_metrics = metrics['ours'][kg_portion]
        return int(np.round((ab_system_metrics - base_metrics) / base_metrics * 100))

    entity_index_path = os.path.join(args.data_dir, 'entity2id.txt')
    relation_index_path = os.path.join(args.data_dir, 'relation2id.txt')
    if 'NELL' in args.data_dir:
        adj_list_path = os.path.join(args.data_dir, 'adj_list.pkl')
        seen_entities = data_utils.load_seen_entities(adj_list_path, entity_index_path)
    else:
        seen_entities = set()
    dataset = os.path.basename(args.data_dir)
    dev_path = os.path.join(args.data_dir, 'dev.triples')
    dev_data = data_utils.load_triples(
        dev_path, entity_index_path, relation_index_path, seen_entities=seen_entities, verbose=False)
    to_m_rels, to_1_rels, (to_m_ratio, to_1_ratio) = data_utils.get_relations_by_type(args.data_dir, relation_index_path)
    relation_by_types = (to_m_rels, to_1_rels)
    to_m_ratio *= 100
    to_1_ratio *= 100
    seen_queries, (seen_ratio, unseen_ratio) = data_utils.get_seen_queries(args.data_dir, entity_index_path, relation_index_path)
    seen_ratio *= 100
    unseen_ratio *= 100

    systems = ['ours', '-ad', '-rs']
    mrrs, to_m_mrrs, to_1_mrrs, seen_mrrs, unseen_mrrs = {}, {}, {}, {}, {}
    for system in systems:
        print('** Evaluating {} system **'.format(system))
        if system == '-ad':
            args.action_dropout_rate = 0.0
            if dataset == 'umls':
                # adjust dropout hyperparameters
                args.emb_dropout_rate = 0.3
                args.ff_dropout_rate = 0.1
        elif system == '-rs':
            config_path = os.path.join('configs', '{}.sh'.format(dataset.lower()))
            args = parser.parse_args()
            args = data_utils.load_configs(args, config_path)
        
        lf = set_up_lf_for_inference(args)
        pred_scores = lf.forward(dev_data, verbose=False)
        _, _, _, _, mrr = src.eval.hits_and_ranks(dev_data, pred_scores, lf.kg.dev_objects, verbose=True)
        if to_1_ratio == 0:
            to_m_mrr = mrr
            to_1_mrr = -1
        else:
            to_m_mrr, to_1_mrr = src.eval.hits_and_ranks_by_relation_type(
                dev_data, pred_scores, lf.kg.dev_objects, relation_by_types, verbose=True)
        seen_mrr, unseen_mrr = src.eval.hits_and_ranks_by_seen_queries(
            dev_data, pred_scores, lf.kg.dev_objects, seen_queries, verbose=True)
        mrrs[system] = {'': mrr * 100}
        to_m_mrrs[system] = {'': to_m_mrr * 100}
        to_1_mrrs[system] = {'': to_1_mrr  * 100}
        seen_mrrs[system] = {'': seen_mrr * 100}
        unseen_mrrs[system] = {'': unseen_mrr * 100}
        _, _, _, _, mrr_full_kg = src.eval.hits_and_ranks(dev_data, pred_scores, lf.kg.all_objects, verbose=True)
        if to_1_ratio == 0:
            to_m_mrr_full_kg = mrr_full_kg
            to_1_mrr_full_kg = -1
        else:
            to_m_mrr_full_kg, to_1_mrr_full_kg = src.eval.hits_and_ranks_by_relation_type(
                dev_data, pred_scores, lf.kg.all_objects, relation_by_types, verbose=True)
        seen_mrr_full_kg, unseen_mrr_full_kg = src.eval.hits_and_ranks_by_seen_queries(
            dev_data, pred_scores, lf.kg.all_objects, seen_queries, verbose=True)
        mrrs[system]['full_kg'] = mrr_full_kg * 100
        to_m_mrrs[system]['full_kg'] = to_m_mrr_full_kg * 100
        to_1_mrrs[system]['full_kg'] = to_1_mrr_full_kg * 100
        seen_mrrs[system]['full_kg'] = seen_mrr_full_kg * 100
        unseen_mrrs[system]['full_kg'] = unseen_mrr_full_kg * 100

    # overall system comparison (table 3)
    print('Partial graph evaluation')
    print('--------------------------')
    print('Overall system performance')
    print('Ours(ConvE)\t-RS\t-AD')
    print('{:.1f}\t{:.1f}\t{:.1f}'.format(mrrs['ours'][''], mrrs['-rs'][''], mrrs['-ad']['']))
    print('--------------------------')
    # performance w.r.t. relation types (table 4, 6)
    print('Performance w.r.t. relation types')
    print('\tTo-many\t\t\t\tTo-one\t\t')
    print('%\tOurs\t-RS\t-AD\t%\tOurs\t-RS\t-AD')
    print('{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})\t{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})'.format(
        to_m_ratio, to_m_mrrs['ours'][''], to_m_mrrs['-rs'][''], rel_change(to_m_mrrs, '-rs', ''), to_m_mrrs['-ad'][''], rel_change(to_m_mrrs, '-ad', ''),
        to_1_ratio, to_1_mrrs['ours'][''], to_1_mrrs['-rs'][''], rel_change(to_1_mrrs, '-rs', ''), to_1_mrrs['-ad'][''], rel_change(to_1_mrrs, '-ad', '')))
    print('--------------------------')
    # performance w.r.t. seen queries (table 5, 7)
    print('Performance w.r.t. seen/unseen queries')
    print('\tSeen\t\t\t\tUnseen\t\t')
    print('%\tOurs\t-RS\t-AD\t%\tOurs\t-RS\t-AD')
    print('{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})\t{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})'.format(
        seen_ratio, seen_mrrs['ours'][''], seen_mrrs['-rs'][''], rel_change(seen_mrrs, '-rs', ''), seen_mrrs['-ad'][''], rel_change(seen_mrrs, '-ad', ''),
        unseen_ratio, unseen_mrrs['ours'][''], unseen_mrrs['-rs'][''], rel_change(unseen_mrrs, '-rs', ''), unseen_mrrs['-ad'][''], rel_change(unseen_mrrs, '-ad', '')))
    print()
    print('Full graph evaluation')
    print('--------------------------')
    print('Overall system performance')
    print('Ours(ConvE)\t-RS\t-AD')
    print('{:.1f}\t{:.1f}\t{:.1f}'.format(mrrs['ours']['full_kg'], mrrs['-rs']['full_kg'], mrrs['-ad']['full_kg']))
    print('--------------------------')
    print('Performance w.r.t. relation types')
    print('\tTo-many\t\t\t\tTo-one\t\t')
    print('%\tOurs\t-RS\t-AD\t%\tOurs\t-RS\t-AD')
    print('{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})\t{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})'.format(
        to_m_ratio, to_m_mrrs['ours']['full_kg'], to_m_mrrs['-rs']['full_kg'], rel_change(to_m_mrrs, '-rs', 'full_kg'), to_m_mrrs['-ad']['full_kg'], rel_change(to_m_mrrs, '-ad', 'full_kg'),
        to_1_ratio, to_1_mrrs['ours']['full_kg'], to_1_mrrs['-rs']['full_kg'], rel_change(to_1_mrrs, '-rs', 'full_kg'), to_1_mrrs['-ad']['full_kg'], rel_change(to_1_mrrs, '-ad', 'full_kg')))
    print('--------------------------')
    print('Performance w.r.t. seen/unseen queries')
    print('\tSeen\t\t\t\tUnseen\t\t')
    print('%\tOurs\t-RS\t-AD\t%\tOurs\t-RS\t-AD')
    print('{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})\t{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})'.format(
        seen_ratio, seen_mrrs['ours']['full_kg'], seen_mrrs['-rs']['full_kg'], rel_change(seen_mrrs, '-rs', 'full_kg'), seen_mrrs['-ad']['full_kg'], rel_change(seen_mrrs, '-ad', 'full_kg'),
        unseen_ratio, unseen_mrrs['ours']['full_kg'], unseen_mrrs['-rs']['full_kg'], rel_change(unseen_mrrs, '-rs', 'full_kg'), unseen_mrrs['-ad']['full_kg'], rel_change(unseen_mrrs, '-ad', 'full_kg')))
Exemplo n.º 7
0
def inference(lf):
    lf.batch_size = args.dev_batch_size
    lf.eval()
    if args.model == 'hypere':
        conve_kg_state_dict = get_conve_kg_state_dict(torch.load(args.conve_state_dict_path))
        lf.kg.load_state_dict(conve_kg_state_dict)
        secondary_kg_state_dict = get_complex_kg_state_dict(torch.load(args.complex_state_dict_path))
        lf.secondary_kg.load_state_dict(secondary_kg_state_dict)
    elif args.model == 'triplee':
        conve_kg_state_dict = get_conve_kg_state_dict(torch.load(args.conve_state_dict_path))
        lf.kg.load_state_dict(conve_kg_state_dict)
        complex_kg_state_dict = get_complex_kg_state_dict(torch.load(args.complex_state_dict_path))
        lf.secondary_kg.load_state_dict(complex_kg_state_dict)
        distmult_kg_state_dict = get_distmult_kg_state_dict(torch.load(args.distmult_state_dict_path))
        lf.tertiary_kg.load_state_dict(distmult_kg_state_dict)
    else:
        lf.load_checkpoint(get_checkpoint_path(args))
    entity_index_path = os.path.join(args.data_dir, 'entity2id.txt')
    relation_index_path = os.path.join(args.data_dir, 'relation2id.txt')
    if 'NELL' in args.data_dir:
        adj_list_path = os.path.join(args.data_dir, 'adj_list.pkl')
        seen_entities = data_utils.load_seen_entities(adj_list_path, entity_index_path)
    else:
        seen_entities = set()

    eval_metrics = {
        'dev': {},
        'test': {}
    }

    if args.compute_map:
        relation_sets = [
            'concept:athletehomestadium',
            'concept:athleteplaysforteam',
            'concept:athleteplaysinleague',
            'concept:athleteplayssport',
            'concept:organizationheadquarteredincity',
            'concept:organizationhiredperson',
            'concept:personborninlocation',
            'concept:teamplayssport',
            'concept:worksfor'
        ]
        mps = []
        for r in relation_sets:
            print('* relation: {}'.format(r))
            test_path = os.path.join(args.data_dir, 'tasks', r, 'test.pairs')
            test_data, labels = data_utils.load_triples_with_label(
                test_path, r, entity_index_path, relation_index_path, seen_entities=seen_entities)
            pred_scores = lf.forward(test_data, verbose=False)
            mp = src.eval.link_MAP(test_data, pred_scores, labels, lf.kg.all_objects, verbose=True)
            mps.append(mp)
        map_ = np.mean(mps)
        print('Overall MAP = {}'.format(map_))
        eval_metrics['test']['avg_map'] = map
    elif args.eval_by_relation_type:
        dev_path = os.path.join(args.data_dir, 'dev.triples')
        dev_data = data_utils.load_triples(dev_path, entity_index_path, relation_index_path, seen_entities=seen_entities)
        pred_scores = lf.forward(dev_data, verbose=False)
        to_m_rels, to_1_rels, _ = data_utils.get_relations_by_type(args.data_dir, relation_index_path)
        relation_by_types = (to_m_rels, to_1_rels)
        print('Dev set evaluation by relation type (partial graph)')
        src.eval.hits_and_ranks_by_relation_type(
            dev_data, pred_scores, lf.kg.dev_objects, relation_by_types, verbose=True)
        print('Dev set evaluation by relation type (full graph)')
        src.eval.hits_and_ranks_by_relation_type(
            dev_data, pred_scores, lf.kg.all_objects, relation_by_types, verbose=True)
    elif args.eval_by_seen_queries:
        dev_path = os.path.join(args.data_dir, 'dev.triples')
        dev_data = data_utils.load_triples(dev_path, entity_index_path, relation_index_path, seen_entities=seen_entities)
        pred_scores = lf.forward(dev_data, verbose=False)
        seen_queries = data_utils.get_seen_queries(args.data_dir, entity_index_path, relation_index_path)
        print('Dev set evaluation by seen queries (partial graph)')
        src.eval.hits_and_ranks_by_seen_queries(
            dev_data, pred_scores, lf.kg.dev_objects, seen_queries, verbose=True)
        print('Dev set evaluation by seen queries (full graph)')
        src.eval.hits_and_ranks_by_seen_queries(
            dev_data, pred_scores, lf.kg.all_objects, seen_queries, verbose=True)
    else:
        dev_path = os.path.join(args.data_dir, 'dev.triples')
        test_path = os.path.join(args.data_dir, 'test.triples')
        dev_data = data_utils.load_triples(
            dev_path, entity_index_path, relation_index_path, seen_entities=seen_entities, verbose=False)
        test_data = data_utils.load_triples(
            test_path, entity_index_path, relation_index_path, seen_entities=seen_entities, verbose=False)
        print('Dev set performance:')
        pred_scores = lf.forward(dev_data, verbose=args.save_beam_search_paths)
        dev_metrics = src.eval.hits_and_ranks(dev_data, pred_scores, lf.kg.dev_objects, verbose=True)
        eval_metrics['dev'] = {}
        eval_metrics['dev']['hits_at_1'] = dev_metrics[0]
        eval_metrics['dev']['hits_at_3'] = dev_metrics[1]
        eval_metrics['dev']['hits_at_5'] = dev_metrics[2]
        eval_metrics['dev']['hits_at_10'] = dev_metrics[3]
        eval_metrics['dev']['mrr'] = dev_metrics[4]
        src.eval.hits_and_ranks(dev_data, pred_scores, lf.kg.all_objects, verbose=True)
        print('Test set performance:')
        pred_scores = lf.forward(test_data, verbose=False)
        test_metrics = src.eval.hits_and_ranks(test_data, pred_scores, lf.kg.all_objects, verbose=True)
        eval_metrics['test']['hits_at_1'] = test_metrics[0]
        eval_metrics['test']['hits_at_3'] = test_metrics[1]
        eval_metrics['test']['hits_at_5'] = test_metrics[2]
        eval_metrics['test']['hits_at_10'] = test_metrics[3]
        eval_metrics['test']['mrr'] = test_metrics[4]

    return eval_metrics
Exemplo n.º 8
0
if __name__ == '__main__':

    args = Args()

    with torch.enable_grad():
        initialize_model_directory(args)
        lf: LFramework = construct_model(args)
        to_cuda(lf)

        train_path = data_utils.get_train_path(args)
        dev_path = os.path.join(args.data_dir, 'dev.triples')
        entity_index_path = os.path.join(args.data_dir, 'entity2id.txt')
        relation_index_path = os.path.join(args.data_dir, 'relation2id.txt')
        train_data = data_utils.load_triples(
            train_path,
            entity_index_path,
            relation_index_path,
            group_examples_by_query=args.group_examples_by_query,
            add_reverse_relations=args.add_reversed_training_edges)

        seen_entities = set()
        dev_data = data_utils.load_triples(dev_path,
                                           entity_index_path,
                                           relation_index_path,
                                           seen_entities=seen_entities)
        if args.checkpoint_path is not None:
            lf.load_checkpoint(args.checkpoint_path)
        lf.run_train(train_data, dev_data)
'''

Epoch 18: average training loss = 0.09945450909435749
=> saving checkpoint to './model/umls-conve-RV-xavier-200-200-0.003-32-3-0.3-0.3-0.2-0.1/checkpoint-18.tar'
def run_ablation_studies(args):
    """
    Run the ablation study experiments reported in the paper.
    """
    def set_up_lf_for_inference(args):
        initialize_model_directory(args)
        lf = construct_model(args)
        lf.cuda()
        lf.batch_size = args.dev_batch_size
        lf.load_checkpoint(get_checkpoint_path(args))
        lf.eval()
        return lf

    def rel_change(metrics, ab_system, kg_portion):
        ab_system_metrics = metrics[ab_system][kg_portion]
        base_metrics = metrics["ours"][kg_portion]
        return int(
            np.round((ab_system_metrics - base_metrics) / base_metrics * 100))

    entity_index_path = os.path.join(args.data_dir, "entity2id.txt")
    relation_index_path = os.path.join(args.data_dir, "relation2id.txt")
    if "NELL" in args.data_dir:
        adj_list_path = os.path.join(args.data_dir, "adj_list.pkl")
        seen_entities = data_utils.load_seen_entities(adj_list_path,
                                                      entity_index_path)
    else:
        seen_entities = set()
    dataset = os.path.basename(args.data_dir)
    dev_path = os.path.join(args.data_dir, "dev.triples")
    dev_data = data_utils.load_triples(
        dev_path,
        entity_index_path,
        relation_index_path,
        seen_entities=seen_entities,
        verbose=False,
    )
    to_m_rels, to_1_rels, (to_m_ratio,
                           to_1_ratio) = data_utils.get_relations_by_type(
                               args.data_dir, relation_index_path)
    relation_by_types = (to_m_rels, to_1_rels)
    to_m_ratio *= 100
    to_1_ratio *= 100
    seen_queries, (seen_ratio, unseen_ratio) = data_utils.get_seen_queries(
        args.data_dir, entity_index_path, relation_index_path)
    seen_ratio *= 100
    unseen_ratio *= 100

    systems = ["ours", "-ad", "-rs"]
    mrrs, to_m_mrrs, to_1_mrrs, seen_mrrs, unseen_mrrs = {}, {}, {}, {}, {}
    for system in systems:
        print("** Evaluating {} system **".format(system))
        if system == "-ad":
            args.action_dropout_rate = 0.0
            if dataset == "umls":
                # adjust dropout hyperparameters
                args.emb_dropout_rate = 0.3
                args.ff_dropout_rate = 0.1
        elif system == "-rs":
            config_path = os.path.join("configs",
                                       "{}.sh".format(dataset.lower()))
            args = parser.parse_args()
            args = data_utils.load_configs(args, config_path)

        lf = set_up_lf_for_inference(args)
        pred_scores = lf.forward(dev_data, verbose=False)
        _, _, _, _, mrr = src.eval.hits_and_ranks(dev_data,
                                                  pred_scores,
                                                  lf.kg.dev_objects,
                                                  verbose=True)
        if to_1_ratio == 0:
            to_m_mrr = mrr
            to_1_mrr = -1
        else:
            to_m_mrr, to_1_mrr = src.eval.hits_and_ranks_by_relation_type(
                dev_data,
                pred_scores,
                lf.kg.dev_objects,
                relation_by_types,
                verbose=True,
            )
        seen_mrr, unseen_mrr = src.eval.hits_and_ranks_by_seen_queries(
            dev_data,
            pred_scores,
            lf.kg.dev_objects,
            seen_queries,
            verbose=True)
        mrrs[system] = {"": mrr * 100}
        to_m_mrrs[system] = {"": to_m_mrr * 100}
        to_1_mrrs[system] = {"": to_1_mrr * 100}
        seen_mrrs[system] = {"": seen_mrr * 100}
        unseen_mrrs[system] = {"": unseen_mrr * 100}
        _, _, _, _, mrr_full_kg = src.eval.hits_and_ranks(dev_data,
                                                          pred_scores,
                                                          lf.kg.all_objects,
                                                          verbose=True)
        if to_1_ratio == 0:
            to_m_mrr_full_kg = mrr_full_kg
            to_1_mrr_full_kg = -1
        else:
            (
                to_m_mrr_full_kg,
                to_1_mrr_full_kg,
            ) = src.eval.hits_and_ranks_by_relation_type(
                dev_data,
                pred_scores,
                lf.kg.all_objects,
                relation_by_types,
                verbose=True,
            )
        seen_mrr_full_kg, unseen_mrr_full_kg = src.eval.hits_and_ranks_by_seen_queries(
            dev_data,
            pred_scores,
            lf.kg.all_objects,
            seen_queries,
            verbose=True)
        mrrs[system]["full_kg"] = mrr_full_kg * 100
        to_m_mrrs[system]["full_kg"] = to_m_mrr_full_kg * 100
        to_1_mrrs[system]["full_kg"] = to_1_mrr_full_kg * 100
        seen_mrrs[system]["full_kg"] = seen_mrr_full_kg * 100
        unseen_mrrs[system]["full_kg"] = unseen_mrr_full_kg * 100

    # overall system comparison (table 3)
    print("Partial graph evaluation")
    print("--------------------------")
    print("Overall system performance")
    print("Ours(ConvE)\t-RS\t-AD")
    print("{:.1f}\t{:.1f}\t{:.1f}".format(mrrs["ours"][""], mrrs["-rs"][""],
                                          mrrs["-ad"][""]))
    print("--------------------------")
    # performance w.r.t. relation types (table 4, 6)
    print("Performance w.r.t. relation types")
    print("\tTo-many\t\t\t\tTo-one\t\t")
    print("%\tOurs\t-RS\t-AD\t%\tOurs\t-RS\t-AD")
    print(
        "{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})\t{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})"
        .format(
            to_m_ratio,
            to_m_mrrs["ours"][""],
            to_m_mrrs["-rs"][""],
            rel_change(to_m_mrrs, "-rs", ""),
            to_m_mrrs["-ad"][""],
            rel_change(to_m_mrrs, "-ad", ""),
            to_1_ratio,
            to_1_mrrs["ours"][""],
            to_1_mrrs["-rs"][""],
            rel_change(to_1_mrrs, "-rs", ""),
            to_1_mrrs["-ad"][""],
            rel_change(to_1_mrrs, "-ad", ""),
        ))
    print("--------------------------")
    # performance w.r.t. seen queries (table 5, 7)
    print("Performance w.r.t. seen/unseen queries")
    print("\tSeen\t\t\t\tUnseen\t\t")
    print("%\tOurs\t-RS\t-AD\t%\tOurs\t-RS\t-AD")
    print(
        "{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})\t{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})"
        .format(
            seen_ratio,
            seen_mrrs["ours"][""],
            seen_mrrs["-rs"][""],
            rel_change(seen_mrrs, "-rs", ""),
            seen_mrrs["-ad"][""],
            rel_change(seen_mrrs, "-ad", ""),
            unseen_ratio,
            unseen_mrrs["ours"][""],
            unseen_mrrs["-rs"][""],
            rel_change(unseen_mrrs, "-rs", ""),
            unseen_mrrs["-ad"][""],
            rel_change(unseen_mrrs, "-ad", ""),
        ))
    print()
    print("Full graph evaluation")
    print("--------------------------")
    print("Overall system performance")
    print("Ours(ConvE)\t-RS\t-AD")
    print("{:.1f}\t{:.1f}\t{:.1f}".format(mrrs["ours"]["full_kg"],
                                          mrrs["-rs"]["full_kg"],
                                          mrrs["-ad"]["full_kg"]))
    print("--------------------------")
    print("Performance w.r.t. relation types")
    print("\tTo-many\t\t\t\tTo-one\t\t")
    print("%\tOurs\t-RS\t-AD\t%\tOurs\t-RS\t-AD")
    print(
        "{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})\t{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})"
        .format(
            to_m_ratio,
            to_m_mrrs["ours"]["full_kg"],
            to_m_mrrs["-rs"]["full_kg"],
            rel_change(to_m_mrrs, "-rs", "full_kg"),
            to_m_mrrs["-ad"]["full_kg"],
            rel_change(to_m_mrrs, "-ad", "full_kg"),
            to_1_ratio,
            to_1_mrrs["ours"]["full_kg"],
            to_1_mrrs["-rs"]["full_kg"],
            rel_change(to_1_mrrs, "-rs", "full_kg"),
            to_1_mrrs["-ad"]["full_kg"],
            rel_change(to_1_mrrs, "-ad", "full_kg"),
        ))
    print("--------------------------")
    print("Performance w.r.t. seen/unseen queries")
    print("\tSeen\t\t\t\tUnseen\t\t")
    print("%\tOurs\t-RS\t-AD\t%\tOurs\t-RS\t-AD")
    print(
        "{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})\t{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})"
        .format(
            seen_ratio,
            seen_mrrs["ours"]["full_kg"],
            seen_mrrs["-rs"]["full_kg"],
            rel_change(seen_mrrs, "-rs", "full_kg"),
            seen_mrrs["-ad"]["full_kg"],
            rel_change(seen_mrrs, "-ad", "full_kg"),
            unseen_ratio,
            unseen_mrrs["ours"]["full_kg"],
            unseen_mrrs["-rs"]["full_kg"],
            rel_change(unseen_mrrs, "-rs", "full_kg"),
            unseen_mrrs["-ad"]["full_kg"],
            rel_change(unseen_mrrs, "-ad", "full_kg"),
        ))
Exemplo n.º 10
0
def inference(lf):
    lf.batch_size = args.dev_batch_size
    lf.eval()
    if args.model == "hypere":
        conve_kg_state_dict = get_conve_kg_state_dict(
            torch.load(args.conve_state_dict_path))
        lf.kg.load_state_dict(conve_kg_state_dict)
        secondary_kg_state_dict = get_complex_kg_state_dict(
            torch.load(args.complex_state_dict_path))
        lf.secondary_kg.load_state_dict(secondary_kg_state_dict)
    elif args.model == "triplee":
        conve_kg_state_dict = get_conve_kg_state_dict(
            torch.load(args.conve_state_dict_path))
        lf.kg.load_state_dict(conve_kg_state_dict)
        complex_kg_state_dict = get_complex_kg_state_dict(
            torch.load(args.complex_state_dict_path))
        lf.secondary_kg.load_state_dict(complex_kg_state_dict)
        distmult_kg_state_dict = get_distmult_kg_state_dict(
            torch.load(args.distmult_state_dict_path))
        lf.tertiary_kg.load_state_dict(distmult_kg_state_dict)
    else:
        lf.load_checkpoint(get_checkpoint_path(args))
    print(lf.kg.entity_embeddings)
    print(lf.kg.relation_embeddings)
    entity_index_path = os.path.join(args.data_dir, "entity2id.txt")
    relation_index_path = os.path.join(args.data_dir, "relation2id.txt")
    if "NELL" in args.data_dir:
        adj_list_path = os.path.join(args.data_dir, "adj_list.pkl")
        seen_entities = data_utils.load_seen_entities(adj_list_path,
                                                      entity_index_path)
    else:
        seen_entities = set()

    eval_metrics = {"dev": {}, "test": {}}

    if args.compute_map:
        relation_sets = [
            "concept:athletehomestadium",
            "concept:athleteplaysforteam",
            "concept:athleteplaysinleague",
            "concept:athleteplayssport",
            "concept:organizationheadquarteredincity",
            "concept:organizationhiredperson",
            "concept:personborninlocation",
            "concept:teamplayssport",
            "concept:worksfor",
        ]
        mps = []
        for r in relation_sets:
            print("* relation: {}".format(r))
            test_path = os.path.join(args.data_dir, "tasks", r, "test.pairs")
            test_data, labels = data_utils.load_triples_with_label(
                test_path,
                r,
                entity_index_path,
                relation_index_path,
                seen_entities=seen_entities,
            )
            pred_scores = lf.forward(test_data, verbose=False)
            mp = src.eval.link_MAP(test_data,
                                   pred_scores,
                                   labels,
                                   lf.kg.all_objects,
                                   verbose=True)
            mps.append(mp)
        import numpy as np

        map_ = np.mean(mps)
        print("Overall MAP = {}".format(map_))
        eval_metrics["test"]["avg_map"] = map
    elif args.eval_by_relation_type:
        dev_path = os.path.join(args.data_dir, "dev.triples")
        dev_data = data_utils.load_triples(
            dev_path,
            entity_index_path,
            relation_index_path,
            seen_entities=seen_entities,
        )
        pred_scores = lf.forward(dev_data, verbose=False)
        to_m_rels, to_1_rels, _ = data_utils.get_relations_by_type(
            args.data_dir, relation_index_path)
        relation_by_types = (to_m_rels, to_1_rels)
        print("Dev set evaluation by relation type (partial graph)")
        src.eval.hits_and_ranks_by_relation_type(dev_data,
                                                 pred_scores,
                                                 lf.kg.dev_objects,
                                                 relation_by_types,
                                                 verbose=True)
        print("Dev set evaluation by relation type (full graph)")
        src.eval.hits_and_ranks_by_relation_type(dev_data,
                                                 pred_scores,
                                                 lf.kg.all_objects,
                                                 relation_by_types,
                                                 verbose=True)
    elif args.eval_by_seen_queries:
        dev_path = os.path.join(args.data_dir, "dev.triples")
        dev_data = data_utils.load_triples(
            dev_path,
            entity_index_path,
            relation_index_path,
            seen_entities=seen_entities,
        )
        pred_scores = lf.forward(dev_data, verbose=False)
        seen_queries = data_utils.get_seen_queries(args.data_dir,
                                                   entity_index_path,
                                                   relation_index_path)
        print("Dev set evaluation by seen queries (partial graph)")
        src.eval.hits_and_ranks_by_seen_queries(dev_data,
                                                pred_scores,
                                                lf.kg.dev_objects,
                                                seen_queries,
                                                verbose=True)
        print("Dev set evaluation by seen queries (full graph)")
        src.eval.hits_and_ranks_by_seen_queries(dev_data,
                                                pred_scores,
                                                lf.kg.all_objects,
                                                seen_queries,
                                                verbose=True)
    else:
        if args.few_shot:
            dev_path = os.path.join(args.data_dir, "dev.triples")
            test_path = os.path.join(args.data_dir, "test.triples")
            _, dev_data = data_utils.load_triples(
                dev_path,
                entity_index_path,
                relation_index_path,
                seen_entities=seen_entities,
                verbose=False,
                few_shot=True,
                lf=lf,
            )
            _, test_data = data_utils.load_triples(
                test_path,
                entity_index_path,
                relation_index_path,
                seen_entities=seen_entities,
                verbose=False,
                few_shot=True,
                lf=lf,
            )
            num = 0
            hits_10 = 0.0
            hits_1 = 0.0
            hits_3 = 0.0
            hits_5 = 0.0
            mrr = 0.0
            for x in test_data:
                lf.load_checkpoint(
                    args.checkpoint_path.replace("[relation]", str(x)))
                print("Test set of relation {} performance:".format(x))
                pred_scores = lf.forward(test_data[x], verbose=False)
                test_metrics = src.eval.hits_and_ranks(test_data[x],
                                                       pred_scores,
                                                       lf.kg.all_objects,
                                                       verbose=True)
                eval_metrics["test"]["hits_at_1"] = test_metrics[0]
                eval_metrics["test"]["hits_at_3"] = test_metrics[1]
                eval_metrics["test"]["hits_at_5"] = test_metrics[2]
                eval_metrics["test"]["hits_at_10"] = test_metrics[3]
                eval_metrics["test"]["mrr"] = test_metrics[4]
                num += len(test_data[x])
                hits_1 += float(test_metrics[0]) * len(test_data[x])
                hits_3 += float(test_metrics[1]) * len(test_data[x])
                hits_5 += float(test_metrics[2]) * len(test_data[x])
                hits_10 += float(test_metrics[3]) * len(test_data[x])
                mrr += float(test_metrics[4]) * len(test_data[x])
            print("Hits@1 = {}".format(hits_1 / num))
            print("Hits@3 = {}".format(hits_3 / num))
            print("Hits@5 = {}".format(hits_5 / num))
            print("Hits@10 = {}".format(hits_10 / num))
            print("MRR = {}".format(mrr / num))
        else:
            test_path = os.path.join(args.data_dir, "test.triples")
            test_data = data_utils.load_triples(
                test_path,
                entity_index_path,
                relation_index_path,
                seen_entities=seen_entities,
                verbose=False,
            )
            print("Test set performance:")
            pred_scores = lf.forward(test_data, verbose=False)
            test_metrics = src.eval.hits_and_ranks(
                test_data,
                pred_scores,
                lf.kg.all_objects,
                verbose=True,
                output=True,
                kg=lf.kg,
                model_name=args.model,
            )
            eval_metrics["test"]["hits_at_1"] = test_metrics[0]
            eval_metrics["test"]["hits_at_3"] = test_metrics[1]
            eval_metrics["test"]["hits_at_5"] = test_metrics[2]
            eval_metrics["test"]["hits_at_10"] = test_metrics[3]
            eval_metrics["test"]["mrr"] = test_metrics[4]

    return eval_metrics
Exemplo n.º 11
0
def train(lf):
    train_path = data_utils.get_train_path(args)
    dev_path = os.path.join(args.data_dir, "dev.triples")
    entity_index_path = os.path.join(args.data_dir, "entity2id.txt")
    relation_index_path = os.path.join(args.data_dir, "relation2id.txt")
    if args.few_shot or args.adaptation:
        # NOTE: train_data: {"11": [(362, 11, 57), (246, 11, 42), ...], ...}
        normal_train_data, few_train_data = data_utils.load_triples(
            train_path,
            entity_index_path,
            relation_index_path,
            group_examples_by_query=args.
            group_examples_by_query,  # NOTE: False in meta-learning
            add_reverse_relations=args.add_reversed_training_edges,
            few_shot=True,
            lf=lf,
        )
    else:
        # NOTE: train_data: [(36221, [11], 57), (4203, [7, 8, 13, 15, 48], 3), ...]
        print("jxtu: load all train data...")

        train_data = data_utils.load_triples(
            train_path,
            entity_index_path,
            relation_index_path,
            group_examples_by_query=args.
            group_examples_by_query,  # NOTE: True in embedding training
            add_reverse_relations=args.add_reversed_training_edges,
        )
    if "NELL" in args.data_dir:
        adj_list_path = os.path.join(args.data_dir, "adj_list.pkl")
        seen_entities = data_utils.load_seen_entities(adj_list_path,
                                                      entity_index_path)
    else:
        seen_entities = set()
    if args.few_shot or args.adaptation:
        normal_dev_data, few_dev_data = data_utils.load_triples(
            dev_path,
            entity_index_path,
            relation_index_path,
            seen_entities=seen_entities,
            few_shot=True,
            lf=lf,
        )
    else:
        dev_data = data_utils.load_triples(
            dev_path,
            entity_index_path,
            relation_index_path,
            seen_entities=seen_entities,
        )
    if args.adaptation:
        for x in few_train_data:
            if args.checkpoint_path is not None:
                lf.load_checkpoint(args.checkpoint_path, adaptation=True)
            lf.run_train(few_train_data[x],
                         few_dev_data,
                         adaptation=True,
                         adaptation_relation=x)
    elif args.few_shot:
        if args.checkpoint_path is not None:
            lf.load_checkpoint(args.checkpoint_path)
        lf.run_train(normal_train_data, normal_dev_data, few_shot=True)
    else:
        if args.checkpoint_path is not None:
            if args.emb_few:
                lf.load_checkpoint(args.checkpoint_path, emb_few=True)
            else:
                lf.load_checkpoint(args.checkpoint_path)
        lf.run_train(train_data, dev_data)
Exemplo n.º 12
0
def inference(lf):
    lf.batch_size = args.dev_batch_size
    lf.eval()
    if args.model == 'hypere':
        conve_kg_state_dict = get_conve_kg_state_dict(
            torch.load(args.conve_state_dict_path))
        lf.kg.load_state_dict(conve_kg_state_dict)
        secondary_kg_state_dict = get_complex_kg_state_dict(
            torch.load(args.complex_state_dict_path))
        lf.secondary_kg.load_state_dict(secondary_kg_state_dict)
    elif args.model == 'triplee':
        conve_kg_state_dict = get_conve_kg_state_dict(
            torch.load(args.conve_state_dict_path))
        lf.kg.load_state_dict(conve_kg_state_dict)
        complex_kg_state_dict = get_complex_kg_state_dict(
            torch.load(args.complex_state_dict_path))
        lf.secondary_kg.load_state_dict(complex_kg_state_dict)
        distmult_kg_state_dict = get_distmult_kg_state_dict(
            torch.load(args.distmult_state_dict_path))
        lf.tertiary_kg.load_state_dict(distmult_kg_state_dict)
    else:
        lf.load_checkpoint(get_checkpoint_path(args))
    print(lf.kg.entity_embeddings)
    print(lf.kg.relation_embeddings)
    entity_index_path = os.path.join(args.data_dir, 'entity2id.txt')
    relation_index_path = os.path.join(args.data_dir, 'relation2id.txt')
    if 'NELL' in args.data_dir:
        adj_list_path = os.path.join(args.data_dir, 'adj_list.pkl')
        seen_entities = data_utils.load_seen_entities(adj_list_path,
                                                      entity_index_path)
    else:
        seen_entities = set()

    eval_metrics = {'dev': {}, 'test': {}}

    if args.compute_map:
        relation_sets = [
            'concept:athletehomestadium', 'concept:athleteplaysforteam',
            'concept:athleteplaysinleague', 'concept:athleteplayssport',
            'concept:organizationheadquarteredincity',
            'concept:organizationhiredperson', 'concept:personborninlocation',
            'concept:teamplayssport', 'concept:worksfor'
        ]
        mps = []
        for r in relation_sets:
            print('* relation: {}'.format(r))
            test_path = os.path.join(args.data_dir, 'tasks', r, 'test.pairs')
            test_data, labels = data_utils.load_triples_with_label(
                test_path,
                r,
                entity_index_path,
                relation_index_path,
                seen_entities=seen_entities)
            pred_scores = lf.forward(test_data, verbose=False)
            mp = src.eval.link_MAP(test_data,
                                   pred_scores,
                                   labels,
                                   lf.kg.all_objects,
                                   verbose=True)
            mps.append(mp)
        import numpy as np
        map_ = np.mean(mps)
        print('Overall MAP = {}'.format(map_))
        eval_metrics['test']['avg_map'] = map
    elif args.eval_by_relation_type:
        dev_path = os.path.join(args.data_dir, 'dev.triples')
        dev_data = data_utils.load_triples(dev_path,
                                           entity_index_path,
                                           relation_index_path,
                                           seen_entities=seen_entities)
        pred_scores = lf.forward(dev_data, verbose=False)
        to_m_rels, to_1_rels, _ = data_utils.get_relations_by_type(
            args.data_dir, relation_index_path)
        relation_by_types = (to_m_rels, to_1_rels)
        print('Dev set evaluation by relation type (partial graph)')
        src.eval.hits_and_ranks_by_relation_type(dev_data,
                                                 pred_scores,
                                                 lf.kg.dev_objects,
                                                 relation_by_types,
                                                 verbose=True)
        print('Dev set evaluation by relation type (full graph)')
        src.eval.hits_and_ranks_by_relation_type(dev_data,
                                                 pred_scores,
                                                 lf.kg.all_objects,
                                                 relation_by_types,
                                                 verbose=True)
    elif args.eval_by_seen_queries:
        dev_path = os.path.join(args.data_dir, 'dev.triples')
        dev_data = data_utils.load_triples(dev_path,
                                           entity_index_path,
                                           relation_index_path,
                                           seen_entities=seen_entities)
        pred_scores = lf.forward(dev_data, verbose=False)
        seen_queries = data_utils.get_seen_queries(args.data_dir,
                                                   entity_index_path,
                                                   relation_index_path)
        print('Dev set evaluation by seen queries (partial graph)')
        src.eval.hits_and_ranks_by_seen_queries(dev_data,
                                                pred_scores,
                                                lf.kg.dev_objects,
                                                seen_queries,
                                                verbose=True)
        print('Dev set evaluation by seen queries (full graph)')
        src.eval.hits_and_ranks_by_seen_queries(dev_data,
                                                pred_scores,
                                                lf.kg.all_objects,
                                                seen_queries,
                                                verbose=True)
    else:
        dev_path = os.path.join(args.data_dir, 'dev.triples')
        test_path = os.path.join(args.data_dir, 'test.triples')
        dev_data = data_utils.load_triples(dev_path,
                                           entity_index_path,
                                           relation_index_path,
                                           seen_entities=seen_entities,
                                           verbose=False)
        # _, test_data = data_utils.load_triples(
        #     test_path, entity_index_path, relation_index_path, seen_entities=seen_entities, verbose=False, few_shot=True, lf=lf)
        test_data = data_utils.load_triples(test_path,
                                            entity_index_path,
                                            relation_index_path,
                                            seen_entities=seen_entities,
                                            verbose=False)
        # print('Dev set performance:')
        # pred_scores = lf.forward(dev_data, verbose=False)
        # dev_metrics = src.eval.hits_and_ranks(dev_data, pred_scores, lf.kg.dev_objects, verbose=True)
        # eval_metrics['dev'] = {}
        # eval_metrics['dev']['hits_at_1'] = dev_metrics[0]
        # eval_metrics['dev']['hits_at_3'] = dev_metrics[1]
        # eval_metrics['dev']['hits_at_5'] = dev_metrics[2]
        # eval_metrics['dev']['hits_at_10'] = dev_metrics[3]
        # eval_metrics['dev']['mrr'] = dev_metrics[4]
        # src.eval.hits_and_ranks(dev_data, pred_scores, lf.kg.all_objects, verbose=True)
        print('Test set performance:')
        pred_scores = lf.forward(test_data, verbose=False)
        test_metrics = src.eval.hits_and_ranks(test_data,
                                               pred_scores,
                                               lf.kg.all_objects,
                                               verbose=True,
                                               output=False,
                                               kg=lf.kg,
                                               model_name=args.model,
                                               split_relation=False)
        eval_metrics['dev'] = {}
        eval_metrics['dev']['hits_at_1'] = test_metrics[0]
        eval_metrics['dev']['hits_at_3'] = test_metrics[1]
        eval_metrics['dev']['hits_at_5'] = test_metrics[2]
        eval_metrics['dev']['hits_at_10'] = test_metrics[3]
        eval_metrics['dev']['mrr'] = test_metrics[4]
        # num = 0; hits_10 = 0.0; hits_1 = 0.0; hits_3 = 0.0; hits_5 = 0.0
        # for x in test_data:
        #     print('Test set of relation {} performance:'.format(x))
        #     pred_scores = lf.forward(test_data[x], verbose=False)
        #     test_metrics = src.eval.hits_and_ranks(test_data[x], pred_scores, lf.kg.all_objects, verbose=True)
        #     eval_metrics['test']['hits_at_1'] = test_metrics[0]
        #     eval_metrics['test']['hits_at_3'] = test_metrics[1]
        #     eval_metrics['test']['hits_at_5'] = test_metrics[2]
        #     eval_metrics['test']['hits_at_10'] = test_metrics[3]
        #     eval_metrics['test']['mrr'] = test_metrics[4]
        #     num += len(test_data[x])
        #     hits_1 += float(test_metrics[0]) * len(test_data[x])
        #     hits_3 += float(test_metrics[1]) * len(test_data[x])
        #     hits_5 += float(test_metrics[2]) * len(test_data[x])
        #     hits_10 += float(test_metrics[3]) * len(test_data[x])
        # print('Hits@1 = {}'.format(hits_1 / num))
        # print('Hits@3 = {}'.format(hits_3 / num))
        # print('Hits@5 = {}'.format(hits_5 / num))
        # print('Hits@10 = {}'.format(hits_10 / num))

    return eval_metrics