def build_stat_gather_from_config(config, drawer=None): """Builds StatGather based on config. It looks into report_dimension firstly. Then it looks into report keys. For hits, it looks into hits list. In the end, it checks the printing results and drawer. """ g = stats.StatGather() keys = [] if config.report_dimension & StatisticsDimension.SEPERATE_ENTITY: keys.append(constants.HEAD_KEY) keys.append(constants.TAIL_KEY) if config.report_dimension & StatisticsDimension.COMBINED_ENTITY: keys.append(constants.ENTITY_KEY) if config.report_dimension & StatisticsDimension.RELATION: keys.append(constants.RELATION_KEY) for key in keys: if key == constants.ENTITY_KEY: if config.report_features & LinkPredictionStatistics.MEAN_RECIPROCAL_RANK: g.add_stat(stats.CombinedEntityMeanReciprocalRankStatTool(key, filtered=False)) if config.report_features & LinkPredictionStatistics.MEAN_FILTERED_RECIPROCAL_RANK: g.add_stat(stats.CombinedEntityMeanReciprocalRankStatTool(key, filtered=True)) if config.report_features & LinkPredictionStatistics.MEAN_RANK: g.add_stat(stats.CombinedEntityMeanRankStatTool(key, filtered=False)) if config.report_features & LinkPredictionStatistics.MEAN_FILTERED_RANK: g.add_stat(stats.CombinedEntityMeanRankStatTool(key, filtered=True)) if config.report_features & LinkPredictionStatistics.HITS: for hit in config.report_hits.split(','): g.add_stat(stats.CombinedEntityHitStatTool(key, int(hit), filtered=False)) if config.report_features & LinkPredictionStatistics.HITS_FILTERED: for hit in config.report_hits_filtered.split(','): g.add_stat(stats.CombinedEntityHitStatTool(key, int(hit), filtered=True)) else: if config.report_features & LinkPredictionStatistics.MEAN_RECIPROCAL_RANK: g.add_stat(stats.ElementMeanReciprocalRankStatTool(key, filtered=False)) if config.report_features & LinkPredictionStatistics.MEAN_FILTERED_RECIPROCAL_RANK: g.add_stat(stats.ElementMeanReciprocalRankStatTool(key, filtered=True)) if config.report_features & LinkPredictionStatistics.MEAN_RANK: g.add_stat(stats.ElementMeanRankStatTool(key, filtered=False)) if config.report_features & LinkPredictionStatistics.MEAN_FILTERED_RANK: g.add_stat(stats.ElementMeanRankStatTool(key, filtered=True)) if config.report_features & LinkPredictionStatistics.HITS: for hit in config.report_hits.split(','): g.add_stat(stats.ElementHitStatTool(key, int(hit), filtered=False)) if config.report_features & LinkPredictionStatistics.HITS_FILTERED: for hit in config.report_hits_filtered.split(','): g.add_stat(stats.ElementHitStatTool(key, int(hit), filtered=True)) if config.print_stats: g.add_after_gather(stats.print_hook_after_stat_epoch()) if config.plot_graph and drawer is not None: g.add_after_gather(drawer.hook_after_stat_epoch()) return g
def test_tail_filtered_hit_stat_gather(rank_results): hit_level = 5 gather = stats.StatGather([ stats.ElementHitStatTool(constants.TAIL_KEY, hit_level, filtered=True) ]) results = gather(rank_results) assert 3 / 5.0 == pytest.approx(results[stats.StatTool.gen_key( constants.TAIL_KEY, FILTERED_HITS_FEATURE_PREFIX, hit_level)])
def test_head_hit_stat_gather(rank_results): hit_level = 5 gather = stats.StatGather([ stats.ElementHitStatTool(constants.HEAD_KEY, hit_level, filtered=False) ]) results = gather(rank_results) assert 1 / 5.0 == pytest.approx(results[stats.StatTool.gen_key( constants.HEAD_KEY, HITS_FEATURE_PREFIX, hit_level)])
def test_rel_hit_stat_gather(rank_results): hit_level = 9 gather = stats.StatGather([ stats.ElementHitStatTool(constants.RELATION_KEY, hit_level, filtered=False) ]) results = gather(rank_results) assert 2 / 3.0 == pytest.approx(results[stats.StatTool.gen_key( constants.RELATION_KEY, HITS_FEATURE_PREFIX, hit_level)])
def sep_stat_gather(): return stats.StatGather([ stats.ElementHitStatTool(constants.HEAD_KEY, 5, filtered=False), stats.ElementHitStatTool(constants.HEAD_KEY, 8, filtered=True), stats.ElementHitStatTool(constants.TAIL_KEY, 5, filtered=False), stats.ElementHitStatTool(constants.TAIL_KEY, 5, filtered=True), stats.ElementHitStatTool(constants.RELATION_KEY, 9, filtered=False), stats.ElementHitStatTool(constants.RELATION_KEY, 5, filtered=True), stats.ElementMeanRankStatTool(constants.HEAD_KEY, filtered=False), stats.ElementMeanRankStatTool(constants.HEAD_KEY, filtered=True), stats.ElementMeanRankStatTool(constants.TAIL_KEY, filtered=False), stats.ElementMeanRankStatTool(constants.TAIL_KEY, filtered=True), stats.ElementMeanReciprocalRankStatTool(constants.HEAD_KEY, filtered=False), stats.ElementMeanReciprocalRankStatTool(constants.HEAD_KEY, filtered=True), stats.ElementMeanReciprocalRankStatTool(constants.TAIL_KEY, filtered=False), stats.ElementMeanReciprocalRankStatTool(constants.TAIL_KEY, filtered=True), ])