Example #1
0
def run_ablation_studies(args):
    """
    Run the ablation study experiments reported in the paper.
    """
    def set_up_lf_for_inference(args):
        initialize_model_directory(args)
        lf = construct_model(args)
        lf.cuda()
        lf.batch_size = args.dev_batch_size
        lf.load_checkpoint(get_checkpoint_path(args))
        lf.eval()
        return lf

    def rel_change(metrics, ab_system, kg_portion):
        ab_system_metrics = metrics[ab_system][kg_portion]
        base_metrics = metrics['ours'][kg_portion]
        return int(np.round((ab_system_metrics - base_metrics) / base_metrics * 100))

    entity_index_path = os.path.join(args.data_dir, 'entity2id.txt')
    relation_index_path = os.path.join(args.data_dir, 'relation2id.txt')
    if 'NELL' in args.data_dir:
        adj_list_path = os.path.join(args.data_dir, 'adj_list.pkl')
        seen_entities = data_utils.load_seen_entities(adj_list_path, entity_index_path)
    else:
        seen_entities = set()
    dataset = os.path.basename(args.data_dir)
    dev_path = os.path.join(args.data_dir, 'dev.triples')
    dev_data = data_utils.load_triples(
        dev_path, entity_index_path, relation_index_path, seen_entities=seen_entities, verbose=False)
    to_m_rels, to_1_rels, (to_m_ratio, to_1_ratio) = data_utils.get_relations_by_type(args.data_dir, relation_index_path)
    relation_by_types = (to_m_rels, to_1_rels)
    to_m_ratio *= 100
    to_1_ratio *= 100
    seen_queries, (seen_ratio, unseen_ratio) = data_utils.get_seen_queries(args.data_dir, entity_index_path, relation_index_path)
    seen_ratio *= 100
    unseen_ratio *= 100

    systems = ['ours', '-ad', '-rs']
    mrrs, to_m_mrrs, to_1_mrrs, seen_mrrs, unseen_mrrs = {}, {}, {}, {}, {}
    for system in systems:
        print('** Evaluating {} system **'.format(system))
        if system == '-ad':
            args.action_dropout_rate = 0.0
            if dataset == 'umls':
                # adjust dropout hyperparameters
                args.emb_dropout_rate = 0.3
                args.ff_dropout_rate = 0.1
        elif system == '-rs':
            config_path = os.path.join('configs', '{}.sh'.format(dataset.lower()))
            args = parser.parse_args()
            args = data_utils.load_configs(args, config_path)
        
        lf = set_up_lf_for_inference(args)
        pred_scores = lf.forward(dev_data, verbose=False)
        _, _, _, _, mrr = src.eval.hits_and_ranks(dev_data, pred_scores, lf.kg.dev_objects, verbose=True)
        if to_1_ratio == 0:
            to_m_mrr = mrr
            to_1_mrr = -1
        else:
            to_m_mrr, to_1_mrr = src.eval.hits_and_ranks_by_relation_type(
                dev_data, pred_scores, lf.kg.dev_objects, relation_by_types, verbose=True)
        seen_mrr, unseen_mrr = src.eval.hits_and_ranks_by_seen_queries(
            dev_data, pred_scores, lf.kg.dev_objects, seen_queries, verbose=True)
        mrrs[system] = {'': mrr * 100}
        to_m_mrrs[system] = {'': to_m_mrr * 100}
        to_1_mrrs[system] = {'': to_1_mrr  * 100}
        seen_mrrs[system] = {'': seen_mrr * 100}
        unseen_mrrs[system] = {'': unseen_mrr * 100}
        _, _, _, _, mrr_full_kg = src.eval.hits_and_ranks(dev_data, pred_scores, lf.kg.all_objects, verbose=True)
        if to_1_ratio == 0:
            to_m_mrr_full_kg = mrr_full_kg
            to_1_mrr_full_kg = -1
        else:
            to_m_mrr_full_kg, to_1_mrr_full_kg = src.eval.hits_and_ranks_by_relation_type(
                dev_data, pred_scores, lf.kg.all_objects, relation_by_types, verbose=True)
        seen_mrr_full_kg, unseen_mrr_full_kg = src.eval.hits_and_ranks_by_seen_queries(
            dev_data, pred_scores, lf.kg.all_objects, seen_queries, verbose=True)
        mrrs[system]['full_kg'] = mrr_full_kg * 100
        to_m_mrrs[system]['full_kg'] = to_m_mrr_full_kg * 100
        to_1_mrrs[system]['full_kg'] = to_1_mrr_full_kg * 100
        seen_mrrs[system]['full_kg'] = seen_mrr_full_kg * 100
        unseen_mrrs[system]['full_kg'] = unseen_mrr_full_kg * 100

    # overall system comparison (table 3)
    print('Partial graph evaluation')
    print('--------------------------')
    print('Overall system performance')
    print('Ours(ConvE)\t-RS\t-AD')
    print('{:.1f}\t{:.1f}\t{:.1f}'.format(mrrs['ours'][''], mrrs['-rs'][''], mrrs['-ad']['']))
    print('--------------------------')
    # performance w.r.t. relation types (table 4, 6)
    print('Performance w.r.t. relation types')
    print('\tTo-many\t\t\t\tTo-one\t\t')
    print('%\tOurs\t-RS\t-AD\t%\tOurs\t-RS\t-AD')
    print('{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})\t{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})'.format(
        to_m_ratio, to_m_mrrs['ours'][''], to_m_mrrs['-rs'][''], rel_change(to_m_mrrs, '-rs', ''), to_m_mrrs['-ad'][''], rel_change(to_m_mrrs, '-ad', ''),
        to_1_ratio, to_1_mrrs['ours'][''], to_1_mrrs['-rs'][''], rel_change(to_1_mrrs, '-rs', ''), to_1_mrrs['-ad'][''], rel_change(to_1_mrrs, '-ad', '')))
    print('--------------------------')
    # performance w.r.t. seen queries (table 5, 7)
    print('Performance w.r.t. seen/unseen queries')
    print('\tSeen\t\t\t\tUnseen\t\t')
    print('%\tOurs\t-RS\t-AD\t%\tOurs\t-RS\t-AD')
    print('{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})\t{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})'.format(
        seen_ratio, seen_mrrs['ours'][''], seen_mrrs['-rs'][''], rel_change(seen_mrrs, '-rs', ''), seen_mrrs['-ad'][''], rel_change(seen_mrrs, '-ad', ''),
        unseen_ratio, unseen_mrrs['ours'][''], unseen_mrrs['-rs'][''], rel_change(unseen_mrrs, '-rs', ''), unseen_mrrs['-ad'][''], rel_change(unseen_mrrs, '-ad', '')))
    print()
    print('Full graph evaluation')
    print('--------------------------')
    print('Overall system performance')
    print('Ours(ConvE)\t-RS\t-AD')
    print('{:.1f}\t{:.1f}\t{:.1f}'.format(mrrs['ours']['full_kg'], mrrs['-rs']['full_kg'], mrrs['-ad']['full_kg']))
    print('--------------------------')
    print('Performance w.r.t. relation types')
    print('\tTo-many\t\t\t\tTo-one\t\t')
    print('%\tOurs\t-RS\t-AD\t%\tOurs\t-RS\t-AD')
    print('{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})\t{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})'.format(
        to_m_ratio, to_m_mrrs['ours']['full_kg'], to_m_mrrs['-rs']['full_kg'], rel_change(to_m_mrrs, '-rs', 'full_kg'), to_m_mrrs['-ad']['full_kg'], rel_change(to_m_mrrs, '-ad', 'full_kg'),
        to_1_ratio, to_1_mrrs['ours']['full_kg'], to_1_mrrs['-rs']['full_kg'], rel_change(to_1_mrrs, '-rs', 'full_kg'), to_1_mrrs['-ad']['full_kg'], rel_change(to_1_mrrs, '-ad', 'full_kg')))
    print('--------------------------')
    print('Performance w.r.t. seen/unseen queries')
    print('\tSeen\t\t\t\tUnseen\t\t')
    print('%\tOurs\t-RS\t-AD\t%\tOurs\t-RS\t-AD')
    print('{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})\t{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})'.format(
        seen_ratio, seen_mrrs['ours']['full_kg'], seen_mrrs['-rs']['full_kg'], rel_change(seen_mrrs, '-rs', 'full_kg'), seen_mrrs['-ad']['full_kg'], rel_change(seen_mrrs, '-ad', 'full_kg'),
        unseen_ratio, unseen_mrrs['ours']['full_kg'], unseen_mrrs['-rs']['full_kg'], rel_change(unseen_mrrs, '-rs', 'full_kg'), unseen_mrrs['-ad']['full_kg'], rel_change(unseen_mrrs, '-ad', 'full_kg')))
def run_ablation_studies(args):
    """
    Run the ablation study experiments reported in the paper.
    """
    def set_up_lf_for_inference(args):
        initialize_model_directory(args)
        lf = construct_model(args)
        lf.cuda()
        lf.batch_size = args.dev_batch_size
        lf.load_checkpoint(get_checkpoint_path(args))
        lf.eval()
        return lf

    def rel_change(metrics, ab_system, kg_portion):
        ab_system_metrics = metrics[ab_system][kg_portion]
        base_metrics = metrics["ours"][kg_portion]
        return int(
            np.round((ab_system_metrics - base_metrics) / base_metrics * 100))

    entity_index_path = os.path.join(args.data_dir, "entity2id.txt")
    relation_index_path = os.path.join(args.data_dir, "relation2id.txt")
    if "NELL" in args.data_dir:
        adj_list_path = os.path.join(args.data_dir, "adj_list.pkl")
        seen_entities = data_utils.load_seen_entities(adj_list_path,
                                                      entity_index_path)
    else:
        seen_entities = set()
    dataset = os.path.basename(args.data_dir)
    dev_path = os.path.join(args.data_dir, "dev.triples")
    dev_data = data_utils.load_triples(
        dev_path,
        entity_index_path,
        relation_index_path,
        seen_entities=seen_entities,
        verbose=False,
    )
    to_m_rels, to_1_rels, (to_m_ratio,
                           to_1_ratio) = data_utils.get_relations_by_type(
                               args.data_dir, relation_index_path)
    relation_by_types = (to_m_rels, to_1_rels)
    to_m_ratio *= 100
    to_1_ratio *= 100
    seen_queries, (seen_ratio, unseen_ratio) = data_utils.get_seen_queries(
        args.data_dir, entity_index_path, relation_index_path)
    seen_ratio *= 100
    unseen_ratio *= 100

    systems = ["ours", "-ad", "-rs"]
    mrrs, to_m_mrrs, to_1_mrrs, seen_mrrs, unseen_mrrs = {}, {}, {}, {}, {}
    for system in systems:
        print("** Evaluating {} system **".format(system))
        if system == "-ad":
            args.action_dropout_rate = 0.0
            if dataset == "umls":
                # adjust dropout hyperparameters
                args.emb_dropout_rate = 0.3
                args.ff_dropout_rate = 0.1
        elif system == "-rs":
            config_path = os.path.join("configs",
                                       "{}.sh".format(dataset.lower()))
            args = parser.parse_args()
            args = data_utils.load_configs(args, config_path)

        lf = set_up_lf_for_inference(args)
        pred_scores = lf.forward(dev_data, verbose=False)
        _, _, _, _, mrr = src.eval.hits_and_ranks(dev_data,
                                                  pred_scores,
                                                  lf.kg.dev_objects,
                                                  verbose=True)
        if to_1_ratio == 0:
            to_m_mrr = mrr
            to_1_mrr = -1
        else:
            to_m_mrr, to_1_mrr = src.eval.hits_and_ranks_by_relation_type(
                dev_data,
                pred_scores,
                lf.kg.dev_objects,
                relation_by_types,
                verbose=True,
            )
        seen_mrr, unseen_mrr = src.eval.hits_and_ranks_by_seen_queries(
            dev_data,
            pred_scores,
            lf.kg.dev_objects,
            seen_queries,
            verbose=True)
        mrrs[system] = {"": mrr * 100}
        to_m_mrrs[system] = {"": to_m_mrr * 100}
        to_1_mrrs[system] = {"": to_1_mrr * 100}
        seen_mrrs[system] = {"": seen_mrr * 100}
        unseen_mrrs[system] = {"": unseen_mrr * 100}
        _, _, _, _, mrr_full_kg = src.eval.hits_and_ranks(dev_data,
                                                          pred_scores,
                                                          lf.kg.all_objects,
                                                          verbose=True)
        if to_1_ratio == 0:
            to_m_mrr_full_kg = mrr_full_kg
            to_1_mrr_full_kg = -1
        else:
            (
                to_m_mrr_full_kg,
                to_1_mrr_full_kg,
            ) = src.eval.hits_and_ranks_by_relation_type(
                dev_data,
                pred_scores,
                lf.kg.all_objects,
                relation_by_types,
                verbose=True,
            )
        seen_mrr_full_kg, unseen_mrr_full_kg = src.eval.hits_and_ranks_by_seen_queries(
            dev_data,
            pred_scores,
            lf.kg.all_objects,
            seen_queries,
            verbose=True)
        mrrs[system]["full_kg"] = mrr_full_kg * 100
        to_m_mrrs[system]["full_kg"] = to_m_mrr_full_kg * 100
        to_1_mrrs[system]["full_kg"] = to_1_mrr_full_kg * 100
        seen_mrrs[system]["full_kg"] = seen_mrr_full_kg * 100
        unseen_mrrs[system]["full_kg"] = unseen_mrr_full_kg * 100

    # overall system comparison (table 3)
    print("Partial graph evaluation")
    print("--------------------------")
    print("Overall system performance")
    print("Ours(ConvE)\t-RS\t-AD")
    print("{:.1f}\t{:.1f}\t{:.1f}".format(mrrs["ours"][""], mrrs["-rs"][""],
                                          mrrs["-ad"][""]))
    print("--------------------------")
    # performance w.r.t. relation types (table 4, 6)
    print("Performance w.r.t. relation types")
    print("\tTo-many\t\t\t\tTo-one\t\t")
    print("%\tOurs\t-RS\t-AD\t%\tOurs\t-RS\t-AD")
    print(
        "{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})\t{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})"
        .format(
            to_m_ratio,
            to_m_mrrs["ours"][""],
            to_m_mrrs["-rs"][""],
            rel_change(to_m_mrrs, "-rs", ""),
            to_m_mrrs["-ad"][""],
            rel_change(to_m_mrrs, "-ad", ""),
            to_1_ratio,
            to_1_mrrs["ours"][""],
            to_1_mrrs["-rs"][""],
            rel_change(to_1_mrrs, "-rs", ""),
            to_1_mrrs["-ad"][""],
            rel_change(to_1_mrrs, "-ad", ""),
        ))
    print("--------------------------")
    # performance w.r.t. seen queries (table 5, 7)
    print("Performance w.r.t. seen/unseen queries")
    print("\tSeen\t\t\t\tUnseen\t\t")
    print("%\tOurs\t-RS\t-AD\t%\tOurs\t-RS\t-AD")
    print(
        "{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})\t{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})"
        .format(
            seen_ratio,
            seen_mrrs["ours"][""],
            seen_mrrs["-rs"][""],
            rel_change(seen_mrrs, "-rs", ""),
            seen_mrrs["-ad"][""],
            rel_change(seen_mrrs, "-ad", ""),
            unseen_ratio,
            unseen_mrrs["ours"][""],
            unseen_mrrs["-rs"][""],
            rel_change(unseen_mrrs, "-rs", ""),
            unseen_mrrs["-ad"][""],
            rel_change(unseen_mrrs, "-ad", ""),
        ))
    print()
    print("Full graph evaluation")
    print("--------------------------")
    print("Overall system performance")
    print("Ours(ConvE)\t-RS\t-AD")
    print("{:.1f}\t{:.1f}\t{:.1f}".format(mrrs["ours"]["full_kg"],
                                          mrrs["-rs"]["full_kg"],
                                          mrrs["-ad"]["full_kg"]))
    print("--------------------------")
    print("Performance w.r.t. relation types")
    print("\tTo-many\t\t\t\tTo-one\t\t")
    print("%\tOurs\t-RS\t-AD\t%\tOurs\t-RS\t-AD")
    print(
        "{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})\t{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})"
        .format(
            to_m_ratio,
            to_m_mrrs["ours"]["full_kg"],
            to_m_mrrs["-rs"]["full_kg"],
            rel_change(to_m_mrrs, "-rs", "full_kg"),
            to_m_mrrs["-ad"]["full_kg"],
            rel_change(to_m_mrrs, "-ad", "full_kg"),
            to_1_ratio,
            to_1_mrrs["ours"]["full_kg"],
            to_1_mrrs["-rs"]["full_kg"],
            rel_change(to_1_mrrs, "-rs", "full_kg"),
            to_1_mrrs["-ad"]["full_kg"],
            rel_change(to_1_mrrs, "-ad", "full_kg"),
        ))
    print("--------------------------")
    print("Performance w.r.t. seen/unseen queries")
    print("\tSeen\t\t\t\tUnseen\t\t")
    print("%\tOurs\t-RS\t-AD\t%\tOurs\t-RS\t-AD")
    print(
        "{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})\t{:.1f}\t{:.1f}\t{:.1f} ({:d})\t{:.1f} ({:d})"
        .format(
            seen_ratio,
            seen_mrrs["ours"]["full_kg"],
            seen_mrrs["-rs"]["full_kg"],
            rel_change(seen_mrrs, "-rs", "full_kg"),
            seen_mrrs["-ad"]["full_kg"],
            rel_change(seen_mrrs, "-ad", "full_kg"),
            unseen_ratio,
            unseen_mrrs["ours"]["full_kg"],
            unseen_mrrs["-rs"]["full_kg"],
            rel_change(unseen_mrrs, "-rs", "full_kg"),
            unseen_mrrs["-ad"]["full_kg"],
            rel_change(unseen_mrrs, "-ad", "full_kg"),
        ))