コード例 #1
0
ファイル: evaluation.py プロジェクト: vishalbelsare/kgegrok
def build_stat_gather_from_config(config, drawer=None):
  """Builds StatGather based on config.

  It looks into report_dimension firstly.
  Then it looks into report keys.
  For hits, it looks into hits list.

  In the end, it checks the printing results and drawer.
  """

  g = stats.StatGather()
  keys = []
  if config.report_dimension & StatisticsDimension.SEPERATE_ENTITY:
    keys.append(constants.HEAD_KEY)
    keys.append(constants.TAIL_KEY)
  if config.report_dimension & StatisticsDimension.COMBINED_ENTITY:
    keys.append(constants.ENTITY_KEY)
  if config.report_dimension & StatisticsDimension.RELATION:
    keys.append(constants.RELATION_KEY)

  for key in keys:
    if key == constants.ENTITY_KEY:
      if config.report_features & LinkPredictionStatistics.MEAN_RECIPROCAL_RANK:
        g.add_stat(stats.CombinedEntityMeanReciprocalRankStatTool(key, filtered=False))
      if config.report_features & LinkPredictionStatistics.MEAN_FILTERED_RECIPROCAL_RANK:
        g.add_stat(stats.CombinedEntityMeanReciprocalRankStatTool(key, filtered=True))
      if config.report_features & LinkPredictionStatistics.MEAN_RANK:
        g.add_stat(stats.CombinedEntityMeanRankStatTool(key, filtered=False))
      if config.report_features & LinkPredictionStatistics.MEAN_FILTERED_RANK:
        g.add_stat(stats.CombinedEntityMeanRankStatTool(key, filtered=True))
      if config.report_features & LinkPredictionStatistics.HITS:
        for hit in config.report_hits.split(','):
          g.add_stat(stats.CombinedEntityHitStatTool(key, int(hit), filtered=False))
      if config.report_features & LinkPredictionStatistics.HITS_FILTERED:
        for hit in config.report_hits_filtered.split(','):
          g.add_stat(stats.CombinedEntityHitStatTool(key, int(hit), filtered=True))
    else:
      if config.report_features & LinkPredictionStatistics.MEAN_RECIPROCAL_RANK:
        g.add_stat(stats.ElementMeanReciprocalRankStatTool(key, filtered=False))
      if config.report_features & LinkPredictionStatistics.MEAN_FILTERED_RECIPROCAL_RANK:
        g.add_stat(stats.ElementMeanReciprocalRankStatTool(key, filtered=True))
      if config.report_features & LinkPredictionStatistics.MEAN_RANK:
        g.add_stat(stats.ElementMeanRankStatTool(key, filtered=False))
      if config.report_features & LinkPredictionStatistics.MEAN_FILTERED_RANK:
        g.add_stat(stats.ElementMeanRankStatTool(key, filtered=True))
      if config.report_features & LinkPredictionStatistics.HITS:
        for hit in config.report_hits.split(','):
          g.add_stat(stats.ElementHitStatTool(key, int(hit), filtered=False))
      if config.report_features & LinkPredictionStatistics.HITS_FILTERED:
        for hit in config.report_hits_filtered.split(','):
          g.add_stat(stats.ElementHitStatTool(key, int(hit), filtered=True))

  if config.print_stats:
    g.add_after_gather(stats.print_hook_after_stat_epoch())
  if config.plot_graph and drawer is not None:
    g.add_after_gather(drawer.hook_after_stat_epoch())

  return g
コード例 #2
0
def test_tail_filtered_hit_stat_gather(rank_results):
    hit_level = 5
    gather = stats.StatGather([
        stats.ElementHitStatTool(constants.TAIL_KEY, hit_level, filtered=True)
    ])
    results = gather(rank_results)
    assert 3 / 5.0 == pytest.approx(results[stats.StatTool.gen_key(
        constants.TAIL_KEY, FILTERED_HITS_FEATURE_PREFIX, hit_level)])
コード例 #3
0
def test_head_hit_stat_gather(rank_results):
    hit_level = 5
    gather = stats.StatGather([
        stats.ElementHitStatTool(constants.HEAD_KEY, hit_level, filtered=False)
    ])
    results = gather(rank_results)
    assert 1 / 5.0 == pytest.approx(results[stats.StatTool.gen_key(
        constants.HEAD_KEY, HITS_FEATURE_PREFIX, hit_level)])
コード例 #4
0
def test_rel_hit_stat_gather(rank_results):
    hit_level = 9
    gather = stats.StatGather([
        stats.ElementHitStatTool(constants.RELATION_KEY,
                                 hit_level,
                                 filtered=False)
    ])
    results = gather(rank_results)
    assert 2 / 3.0 == pytest.approx(results[stats.StatTool.gen_key(
        constants.RELATION_KEY, HITS_FEATURE_PREFIX, hit_level)])
コード例 #5
0
def sep_stat_gather():
    return stats.StatGather([
        stats.ElementHitStatTool(constants.HEAD_KEY, 5, filtered=False),
        stats.ElementHitStatTool(constants.HEAD_KEY, 8, filtered=True),
        stats.ElementHitStatTool(constants.TAIL_KEY, 5, filtered=False),
        stats.ElementHitStatTool(constants.TAIL_KEY, 5, filtered=True),
        stats.ElementHitStatTool(constants.RELATION_KEY, 9, filtered=False),
        stats.ElementHitStatTool(constants.RELATION_KEY, 5, filtered=True),
        stats.ElementMeanRankStatTool(constants.HEAD_KEY, filtered=False),
        stats.ElementMeanRankStatTool(constants.HEAD_KEY, filtered=True),
        stats.ElementMeanRankStatTool(constants.TAIL_KEY, filtered=False),
        stats.ElementMeanRankStatTool(constants.TAIL_KEY, filtered=True),
        stats.ElementMeanReciprocalRankStatTool(constants.HEAD_KEY,
                                                filtered=False),
        stats.ElementMeanReciprocalRankStatTool(constants.HEAD_KEY,
                                                filtered=True),
        stats.ElementMeanReciprocalRankStatTool(constants.TAIL_KEY,
                                                filtered=False),
        stats.ElementMeanReciprocalRankStatTool(constants.TAIL_KEY,
                                                filtered=True),
    ])