Exemplo n.º 1
0
def main(
    relation=Relation.CE_V_MAX.value,
    outfile="analysis/ce_v_max_digikey_discrepancies.csv",
):
    # Compare our gold with Digikey's gold for the analysis set of 66 docs
    # (docs that both we and Digikey have gold labels for)
    # NOTE: We use our gold as gold for this comparison against Digikey
    dirname = os.path.dirname(os.path.abspath(__file__))
    outfile = os.path.join(dirname, outfile)

    # Us
    our_gold = os.path.join(dirname, "data/analysis/our_gold.csv")
    our_gold_set = get_gold_set(gold=[our_gold], attribute=relation)
    our_gold_dic = gold_set_to_dic(our_gold_set)

    # Digikey
    digikey_gold = os.path.join(dirname, "data/analysis/digikey_gold.csv")
    digikey_gold_set = get_gold_set(gold=[digikey_gold], attribute=relation)
    digikey_gold_dic = gold_set_to_dic(digikey_gold_set)

    # Score Digikey using our gold as metric
    score = entity_level_scores(digikey_gold_set,
                                metric=our_gold_set,
                                attribute=relation)
    print_score(
        score,
        description=f"Scoring on {digikey_gold.split('/')[-1]} " +
        f"against {our_gold.split('/')[-1]}.",
    )

    # Run final comparison using FN and FP
    compare_entities(
        set(score.FN),
        entity_dic=digikey_gold_dic,
        type="FN",
        gold_dic=our_gold_dic,
        outfile=outfile,
    )
    compare_entities(set(score.FP),
                     type="FP",
                     append=True,
                     gold_dic=our_gold_dic,
                     outfile=outfile)
Exemplo n.º 2
0
def main(
    num=100,
    relation=Relation.CE_V_MAX.value,
    devfile="ce_v_max_dev_probs.csv",
    testfile="ce_v_max_test_probs.csv",
    outfile="analysis/ce_v_max_analysis_discrepancies.csv",
    debug=False,
):
    # Define output
    dirname = os.path.dirname(os.path.abspath(__file__))
    discrepancy_file = os.path.join(dirname, outfile)

    # Analysis
    gold_file = os.path.join(dirname, "data/analysis/our_gold.csv")
    filenames_file = os.path.join(dirname, "data/analysis/filenames.csv")
    filenames = capitalize_filenames(get_filenames_from_file(filenames_file))
    # logger.info(f"Analysis dataset is {len(filenames)}" + " filenames long.")
    gold = filter_filenames(
        get_gold_set(gold=[gold_file], attribute=relation), filenames
    )
    # logger.info(f"Original gold set is {len(get_filenames(gold))} filenames long.")

    best_score = Score(0, 0, 0, [], [], [])
    best_b = 0
    best_entities = set()

    # Test
    test_file = os.path.join(dirname, testfile)
    test_filenames = capitalize_filenames(
        get_filenames_from_file(os.path.join(dirname, "data/test/filenames.csv"))
    )
    test_goldfile = os.path.join(dirname, "data/test/test_gold.csv")
    test_gold = filter_filenames(
        get_gold_set(gold=[test_goldfile], attribute=relation), test_filenames
    )

    best_test_score = Score(0, 0, 0, [], [], [])
    best_test_b = 0
    best_test_entities = set()

    # Dev
    dev_file = os.path.join(dirname, devfile)
    dev_filenames = capitalize_filenames(
        get_filenames_from_file(os.path.join(dirname, "data/dev/filenames.csv"))
    )
    dev_goldfile = os.path.join(dirname, "data/dev/dev_gold.csv")
    dev_gold = filter_filenames(
        get_gold_set(gold=[dev_goldfile], attribute=relation), dev_filenames
    )

    best_dev_score = Score(0, 0, 0, [], [], [])
    best_dev_b = 0
    best_dev_entities = set()

    # Iterate over `b` values
    logger.info(f"Determining best b...")
    parts_by_doc = load_parts_by_doc()
    for b in tqdm(np.linspace(0, 1, num=num)):
        # Dev and Test
        dev_entities = get_entity_set(dev_file, parts_by_doc, b=b)
        test_entities = get_entity_set(test_file, parts_by_doc, b=b)

        # Analysis (combo of dev and test)
        entities = filter_filenames(
            dev_entities.union(test_entities), get_filenames_from_file(filenames_file)
        )

        # Score entities against gold data and generate comparison CSV
        dev_score = entity_level_scores(
            dev_entities, attribute=relation, docs=dev_filenames
        )
        test_score = entity_level_scores(
            test_entities, attribute=relation, docs=test_filenames
        )
        score = entity_level_scores(entities, attribute=relation, docs=filenames)

        if dev_score.f1 > best_dev_score.f1:
            best_dev_score = dev_score
            best_dev_b = b
            best_dev_entities = dev_entities

        if test_score.f1 > best_test_score.f1:
            best_test_score = test_score
            best_test_b = b
            best_test_entities = test_entities

        if score.f1 > best_score.f1:
            best_score = score
            best_b = b
            best_entities = entities

    if debug:
        # Test
        logger.info("Scoring for test set...")
        logger.info(
            f"Entity set is {len(get_filenames(best_test_entities))} filenames long."
        )
        logger.info(f"Gold set is {len(get_filenames(test_gold))} filenames long.")
        print_score(
            best_test_score,
            description=f"Scoring on cands > {best_test_b:.3f} "
            + "against our gold labels.",
        )

        # Dev
        logger.info("Scoring for dev set...")
        logger.info(
            f"Entity set is {len(get_filenames(best_dev_entities))} filenames long."
        )
        logger.info(f"Gold set is {len(get_filenames(dev_gold))} filenames long.")
        print_score(
            best_dev_score,
            description=f"Scoring on cands > {best_dev_b:.3f} against our gold labels.",
        )

        logger.info("Scoring for analysis set...")
    # Analysis
    # logger.info(f"Entity set is {len(get_filenames(best_entities))} filenames long.")
    # logger.info(f"Gold set is {len(get_filenames(gold))} filenames long.")
    print_score(
        best_score,
        description=f"Scoring on cands > {best_b:.3f} against our gold labels.",
    )

    compare_entities(
        set(best_score.FP),
        attribute=relation,
        type="FP",
        outfile=discrepancy_file,
        gold_dic=gold_set_to_dic(gold),
    )
    compare_entities(
        set(best_score.FN),
        attribute=relation,
        type="FN",
        outfile=discrepancy_file,
        append=True,
        entity_dic=gold_set_to_dic(best_entities),
    )
Exemplo n.º 3
0
)

logger = logging.getLogger(__name__)

if __name__ == "__main__":

    # Compare our gold with Digikey's gold for the analysis set of 66 docs
    # (docs that both we and Digikey have gold labels for)
    # NOTE: We use our gold as gold for this comparison against Digikey
    attribute = Relation.CE_V_MAX.value
    dirname = os.path.dirname(__name__)
    outfile = os.path.join(dirname, "../analysis/gold_discrepancies.csv")

    # Us
    our_gold = os.path.join(dirname, "../analysis/our_gold.csv")
    our_gold_set = get_gold_set(gold=[our_gold], attribute=attribute)
    our_gold_dic = gold_set_to_dic(our_gold_set)

    # Digikey
    digikey_gold = os.path.join(dirname, "../analysis/digikey_gold.csv")
    digikey_gold_set = get_gold_set(gold=[digikey_gold], attribute=attribute)
    digikey_gold_dic = gold_set_to_dic(digikey_gold_set)

    # Score Digikey using our gold as metric
    score = entity_level_scores(digikey_gold_set,
                                metric=our_gold_set,
                                attribute=attribute)
    print_score(score, entities=digikey_gold, metric=our_gold)

    # Run final comparison using FN and FP
    compare_entities(