def test_tail_mean_rank_stat_gather(rank_results): gather = stats.StatGather( [stats.ElementMeanRankStatTool(constants.TAIL_KEY, filtered=False)]) results = gather(rank_results) assert (7 + 6 + 8 + 5 + 4) / 5.0 == pytest.approx( results[stats.StatTool.gen_key(constants.TAIL_KEY, MEAN_RANK_FEATURE_KEY)])
def test_head_mean_rank_stat_gather(rank_results): gather = stats.StatGather( [stats.ElementMeanRankStatTool(constants.HEAD_KEY, filtered=False)]) results = gather(rank_results) assert (9 + 3 + 7 + 11 + 10) / 5.0 == pytest.approx( results[stats.StatTool.gen_key(constants.HEAD_KEY, MEAN_RANK_FEATURE_KEY)])
def test_head_filtered_mean_rank_stat_gather(rank_results): gather = stats.StatGather( [stats.ElementMeanRankStatTool(constants.HEAD_KEY, filtered=True)]) results = gather(rank_results) assert (9 + 2 + 6 + 8 + 9) / 5.0 == pytest.approx( results[stats.StatTool.gen_key(constants.HEAD_KEY, MEAN_FILTERED_RANK_FEATURE_KEY)])
def test_tail_filtered_hit_stat_gather(rank_results): hit_level = 5 gather = stats.StatGather([ stats.ElementHitStatTool(constants.TAIL_KEY, hit_level, filtered=True) ]) results = gather(rank_results) assert 3 / 5.0 == pytest.approx(results[stats.StatTool.gen_key( constants.TAIL_KEY, FILTERED_HITS_FEATURE_PREFIX, hit_level)])
def test_head_hit_stat_gather(rank_results): hit_level = 5 gather = stats.StatGather([ stats.ElementHitStatTool(constants.HEAD_KEY, hit_level, filtered=False) ]) results = gather(rank_results) assert 1 / 5.0 == pytest.approx(results[stats.StatTool.gen_key( constants.HEAD_KEY, HITS_FEATURE_PREFIX, hit_level)])
def test_tail_filtered_mean_rank_stat_gather(rank_results): hit_level = 6 gather = stats.StatGather( [stats.ElementMeanRankStatTool(constants.TAIL_KEY, filtered=True)]) results = gather(rank_results) assert (6 + 6 + 3 + 5 + 1) / 5.0 == pytest.approx( results[stats.StatTool.gen_key(constants.TAIL_KEY, MEAN_FILTERED_RANK_FEATURE_KEY)])
def build_stat_gather_from_config(config, drawer=None): """Builds StatGather based on config. It looks into report_dimension firstly. Then it looks into report keys. For hits, it looks into hits list. In the end, it checks the printing results and drawer. """ g = stats.StatGather() keys = [] if config.report_dimension & StatisticsDimension.SEPERATE_ENTITY: keys.append(constants.HEAD_KEY) keys.append(constants.TAIL_KEY) if config.report_dimension & StatisticsDimension.COMBINED_ENTITY: keys.append(constants.ENTITY_KEY) if config.report_dimension & StatisticsDimension.RELATION: keys.append(constants.RELATION_KEY) for key in keys: if key == constants.ENTITY_KEY: if config.report_features & LinkPredictionStatistics.MEAN_RECIPROCAL_RANK: g.add_stat(stats.CombinedEntityMeanReciprocalRankStatTool(key, filtered=False)) if config.report_features & LinkPredictionStatistics.MEAN_FILTERED_RECIPROCAL_RANK: g.add_stat(stats.CombinedEntityMeanReciprocalRankStatTool(key, filtered=True)) if config.report_features & LinkPredictionStatistics.MEAN_RANK: g.add_stat(stats.CombinedEntityMeanRankStatTool(key, filtered=False)) if config.report_features & LinkPredictionStatistics.MEAN_FILTERED_RANK: g.add_stat(stats.CombinedEntityMeanRankStatTool(key, filtered=True)) if config.report_features & LinkPredictionStatistics.HITS: for hit in config.report_hits.split(','): g.add_stat(stats.CombinedEntityHitStatTool(key, int(hit), filtered=False)) if config.report_features & LinkPredictionStatistics.HITS_FILTERED: for hit in config.report_hits_filtered.split(','): g.add_stat(stats.CombinedEntityHitStatTool(key, int(hit), filtered=True)) else: if config.report_features & LinkPredictionStatistics.MEAN_RECIPROCAL_RANK: g.add_stat(stats.ElementMeanReciprocalRankStatTool(key, filtered=False)) if config.report_features & LinkPredictionStatistics.MEAN_FILTERED_RECIPROCAL_RANK: g.add_stat(stats.ElementMeanReciprocalRankStatTool(key, filtered=True)) if config.report_features & LinkPredictionStatistics.MEAN_RANK: g.add_stat(stats.ElementMeanRankStatTool(key, filtered=False)) if config.report_features & LinkPredictionStatistics.MEAN_FILTERED_RANK: g.add_stat(stats.ElementMeanRankStatTool(key, filtered=True)) if config.report_features & LinkPredictionStatistics.HITS: for hit in config.report_hits.split(','): g.add_stat(stats.ElementHitStatTool(key, int(hit), filtered=False)) if config.report_features & LinkPredictionStatistics.HITS_FILTERED: for hit in config.report_hits_filtered.split(','): g.add_stat(stats.ElementHitStatTool(key, int(hit), filtered=True)) if config.print_stats: g.add_after_gather(stats.print_hook_after_stat_epoch()) if config.plot_graph and drawer is not None: g.add_after_gather(drawer.hook_after_stat_epoch()) return g
def test_tail_mean_reciprocal_rank_stat_gather(rank_results): gather = stats.StatGather([ stats.ElementMeanReciprocalRankStatTool(constants.TAIL_KEY, filtered=False) ]) results = gather(rank_results) assert (1 / 7.0 + 1 / 6.0 + 1 / 8.0 + 1 / 5.0 + 1 / 4.0) / 5.0 == pytest.approx(results[stats.StatTool.gen_key( constants.TAIL_KEY, MEAN_RECIPROCAL_RANK_FEATURE_KEY)])
def test_combined_filtered_mean_rank_stat_gather(rank_results): gather = stats.StatGather([ stats.CombinedEntityMeanRankStatTool(constants.ENTITY_KEY, filtered=True) ]) results = gather(rank_results) assert ((9 + 2 + 6 + 8 + 9) / 5.0 + (6 + 6 + 3 + 5 + 1) / 5.0) / 2.0 == pytest.approx( results[stats.StatTool.gen_key( constants.ENTITY_KEY, MEAN_FILTERED_RANK_FEATURE_KEY)])
def test_combined_mean_rank_stat_gather(rank_results): gather = stats.StatGather([ stats.CombinedEntityMeanRankStatTool(constants.ENTITY_KEY, filtered=False) ]) results = gather(rank_results) assert ((9 + 3 + 7 + 11 + 10) / 5.0 + (7 + 6 + 8 + 5 + 4) / 5.0) / 2.0 == pytest.approx( results[stats.StatTool.gen_key(constants.ENTITY_KEY, MEAN_RANK_FEATURE_KEY)])
def test_combined_mean_reciprocal_rank_stat_gather(rank_results): gather = stats.StatGather([ stats.CombinedEntityMeanReciprocalRankStatTool(constants.ENTITY_KEY, filtered=False) ]) results = gather(rank_results) assert ((1 / 9.0 + 1 / 3.0 + 1 / 7.0 + 1 / 11.0 + 1 / 10.0) / 5.0 + (1 / 7.0 + 1 / 6.0 + 1 / 8.0 + 1 / 5.0 + 1 / 4.0) / 5.0) / 2.0 == pytest.approx(results[stats.StatTool.gen_key( constants.ENTITY_KEY, MEAN_RECIPROCAL_RANK_FEATURE_KEY)])
def test_head_filtered_mean_reciprocal_rank_stat_gather(rank_results): gather = stats.StatGather([ stats.ElementMeanReciprocalRankStatTool(constants.HEAD_KEY, filtered=True) ]) results = gather(rank_results) assert (1 / 9.0 + 1 / 2.0 + 1 / 6.0 + 1 / 8.0 + 1 / 9.0) / 5.0 == pytest.approx(results[stats.StatTool.gen_key( constants.HEAD_KEY, MEAN_FILTERED_RECIPROCAL_RANK_FEATURE_KEY)])
def test_rel_hit_stat_gather(rank_results): hit_level = 9 gather = stats.StatGather([ stats.ElementHitStatTool(constants.RELATION_KEY, hit_level, filtered=False) ]) results = gather(rank_results) assert 2 / 3.0 == pytest.approx(results[stats.StatTool.gen_key( constants.RELATION_KEY, HITS_FEATURE_PREFIX, hit_level)])
def test_combined_hit_stat_gather(rank_results): hit_level = 6 gather = stats.StatGather([ stats.CombinedEntityHitStatTool(constants.ENTITY_KEY, hit_level, filtered=False) ]) results = gather(rank_results) assert (1 / 5.0 + 3 / 5.0) / 2.0 == pytest.approx( results[stats.StatTool.gen_key(constants.ENTITY_KEY, HITS_FEATURE_PREFIX, hit_level)])
def test_tail_filtered_mean_reciprocal_rank_stat_gather(rank_results): hit_level = 6 gather = stats.StatGather([ stats.ElementMeanReciprocalRankStatTool(constants.TAIL_KEY, filtered=True) ]) results = gather(rank_results) assert (1 / 6.0 + 1 / 6.0 + 1 / 3.0 + 1 / 5.0 + 1 / 1.0) / 5.0 == pytest.approx(results[stats.StatTool.gen_key( constants.TAIL_KEY, MEAN_FILTERED_RECIPROCAL_RANK_FEATURE_KEY)])
def combined_stat_gather(): return stats.StatGather([ stats.CombinedEntityHitStatTool(constants.ENTITY_KEY, 6, filtered=False), stats.CombinedEntityHitStatTool(constants.ENTITY_KEY, 6, filtered=True), stats.CombinedEntityMeanRankStatTool(constants.ENTITY_KEY, filtered=False), stats.CombinedEntityMeanRankStatTool(constants.ENTITY_KEY, filtered=True), stats.CombinedEntityMeanReciprocalRankStatTool(constants.ENTITY_KEY, filtered=False), stats.CombinedEntityMeanReciprocalRankStatTool(constants.ENTITY_KEY, filtered=True) ])
def sep_stat_gather(): return stats.StatGather([ stats.ElementHitStatTool(constants.HEAD_KEY, 5, filtered=False), stats.ElementHitStatTool(constants.HEAD_KEY, 8, filtered=True), stats.ElementHitStatTool(constants.TAIL_KEY, 5, filtered=False), stats.ElementHitStatTool(constants.TAIL_KEY, 5, filtered=True), stats.ElementHitStatTool(constants.RELATION_KEY, 9, filtered=False), stats.ElementHitStatTool(constants.RELATION_KEY, 5, filtered=True), stats.ElementMeanRankStatTool(constants.HEAD_KEY, filtered=False), stats.ElementMeanRankStatTool(constants.HEAD_KEY, filtered=True), stats.ElementMeanRankStatTool(constants.TAIL_KEY, filtered=False), stats.ElementMeanRankStatTool(constants.TAIL_KEY, filtered=True), stats.ElementMeanReciprocalRankStatTool(constants.HEAD_KEY, filtered=False), stats.ElementMeanReciprocalRankStatTool(constants.HEAD_KEY, filtered=True), stats.ElementMeanReciprocalRankStatTool(constants.TAIL_KEY, filtered=False), stats.ElementMeanReciprocalRankStatTool(constants.TAIL_KEY, filtered=True), ])