示例#1
0
def test_tail_mean_rank_stat_gather(rank_results):
    gather = stats.StatGather(
        [stats.ElementMeanRankStatTool(constants.TAIL_KEY, filtered=False)])
    results = gather(rank_results)
    assert (7 + 6 + 8 + 5 + 4) / 5.0 == pytest.approx(
        results[stats.StatTool.gen_key(constants.TAIL_KEY,
                                       MEAN_RANK_FEATURE_KEY)])
示例#2
0
def test_head_mean_rank_stat_gather(rank_results):
    gather = stats.StatGather(
        [stats.ElementMeanRankStatTool(constants.HEAD_KEY, filtered=False)])
    results = gather(rank_results)
    assert (9 + 3 + 7 + 11 + 10) / 5.0 == pytest.approx(
        results[stats.StatTool.gen_key(constants.HEAD_KEY,
                                       MEAN_RANK_FEATURE_KEY)])
示例#3
0
def test_head_filtered_mean_rank_stat_gather(rank_results):
    gather = stats.StatGather(
        [stats.ElementMeanRankStatTool(constants.HEAD_KEY, filtered=True)])
    results = gather(rank_results)
    assert (9 + 2 + 6 + 8 + 9) / 5.0 == pytest.approx(
        results[stats.StatTool.gen_key(constants.HEAD_KEY,
                                       MEAN_FILTERED_RANK_FEATURE_KEY)])
示例#4
0
def test_tail_filtered_hit_stat_gather(rank_results):
    hit_level = 5
    gather = stats.StatGather([
        stats.ElementHitStatTool(constants.TAIL_KEY, hit_level, filtered=True)
    ])
    results = gather(rank_results)
    assert 3 / 5.0 == pytest.approx(results[stats.StatTool.gen_key(
        constants.TAIL_KEY, FILTERED_HITS_FEATURE_PREFIX, hit_level)])
示例#5
0
def test_head_hit_stat_gather(rank_results):
    hit_level = 5
    gather = stats.StatGather([
        stats.ElementHitStatTool(constants.HEAD_KEY, hit_level, filtered=False)
    ])
    results = gather(rank_results)
    assert 1 / 5.0 == pytest.approx(results[stats.StatTool.gen_key(
        constants.HEAD_KEY, HITS_FEATURE_PREFIX, hit_level)])
示例#6
0
def test_tail_filtered_mean_rank_stat_gather(rank_results):
    hit_level = 6
    gather = stats.StatGather(
        [stats.ElementMeanRankStatTool(constants.TAIL_KEY, filtered=True)])
    results = gather(rank_results)
    assert (6 + 6 + 3 + 5 + 1) / 5.0 == pytest.approx(
        results[stats.StatTool.gen_key(constants.TAIL_KEY,
                                       MEAN_FILTERED_RANK_FEATURE_KEY)])
示例#7
0
def build_stat_gather_from_config(config, drawer=None):
  """Builds StatGather based on config.

  It looks into report_dimension firstly.
  Then it looks into report keys.
  For hits, it looks into hits list.

  In the end, it checks the printing results and drawer.
  """

  g = stats.StatGather()
  keys = []
  if config.report_dimension & StatisticsDimension.SEPERATE_ENTITY:
    keys.append(constants.HEAD_KEY)
    keys.append(constants.TAIL_KEY)
  if config.report_dimension & StatisticsDimension.COMBINED_ENTITY:
    keys.append(constants.ENTITY_KEY)
  if config.report_dimension & StatisticsDimension.RELATION:
    keys.append(constants.RELATION_KEY)

  for key in keys:
    if key == constants.ENTITY_KEY:
      if config.report_features & LinkPredictionStatistics.MEAN_RECIPROCAL_RANK:
        g.add_stat(stats.CombinedEntityMeanReciprocalRankStatTool(key, filtered=False))
      if config.report_features & LinkPredictionStatistics.MEAN_FILTERED_RECIPROCAL_RANK:
        g.add_stat(stats.CombinedEntityMeanReciprocalRankStatTool(key, filtered=True))
      if config.report_features & LinkPredictionStatistics.MEAN_RANK:
        g.add_stat(stats.CombinedEntityMeanRankStatTool(key, filtered=False))
      if config.report_features & LinkPredictionStatistics.MEAN_FILTERED_RANK:
        g.add_stat(stats.CombinedEntityMeanRankStatTool(key, filtered=True))
      if config.report_features & LinkPredictionStatistics.HITS:
        for hit in config.report_hits.split(','):
          g.add_stat(stats.CombinedEntityHitStatTool(key, int(hit), filtered=False))
      if config.report_features & LinkPredictionStatistics.HITS_FILTERED:
        for hit in config.report_hits_filtered.split(','):
          g.add_stat(stats.CombinedEntityHitStatTool(key, int(hit), filtered=True))
    else:
      if config.report_features & LinkPredictionStatistics.MEAN_RECIPROCAL_RANK:
        g.add_stat(stats.ElementMeanReciprocalRankStatTool(key, filtered=False))
      if config.report_features & LinkPredictionStatistics.MEAN_FILTERED_RECIPROCAL_RANK:
        g.add_stat(stats.ElementMeanReciprocalRankStatTool(key, filtered=True))
      if config.report_features & LinkPredictionStatistics.MEAN_RANK:
        g.add_stat(stats.ElementMeanRankStatTool(key, filtered=False))
      if config.report_features & LinkPredictionStatistics.MEAN_FILTERED_RANK:
        g.add_stat(stats.ElementMeanRankStatTool(key, filtered=True))
      if config.report_features & LinkPredictionStatistics.HITS:
        for hit in config.report_hits.split(','):
          g.add_stat(stats.ElementHitStatTool(key, int(hit), filtered=False))
      if config.report_features & LinkPredictionStatistics.HITS_FILTERED:
        for hit in config.report_hits_filtered.split(','):
          g.add_stat(stats.ElementHitStatTool(key, int(hit), filtered=True))

  if config.print_stats:
    g.add_after_gather(stats.print_hook_after_stat_epoch())
  if config.plot_graph and drawer is not None:
    g.add_after_gather(drawer.hook_after_stat_epoch())

  return g
示例#8
0
def test_tail_mean_reciprocal_rank_stat_gather(rank_results):
    gather = stats.StatGather([
        stats.ElementMeanReciprocalRankStatTool(constants.TAIL_KEY,
                                                filtered=False)
    ])
    results = gather(rank_results)
    assert (1 / 7.0 + 1 / 6.0 + 1 / 8.0 + 1 / 5.0 +
            1 / 4.0) / 5.0 == pytest.approx(results[stats.StatTool.gen_key(
                constants.TAIL_KEY, MEAN_RECIPROCAL_RANK_FEATURE_KEY)])
示例#9
0
def test_combined_filtered_mean_rank_stat_gather(rank_results):
    gather = stats.StatGather([
        stats.CombinedEntityMeanRankStatTool(constants.ENTITY_KEY,
                                             filtered=True)
    ])
    results = gather(rank_results)
    assert ((9 + 2 + 6 + 8 + 9) / 5.0 +
            (6 + 6 + 3 + 5 + 1) / 5.0) / 2.0 == pytest.approx(
                results[stats.StatTool.gen_key(
                    constants.ENTITY_KEY, MEAN_FILTERED_RANK_FEATURE_KEY)])
示例#10
0
def test_combined_mean_rank_stat_gather(rank_results):
    gather = stats.StatGather([
        stats.CombinedEntityMeanRankStatTool(constants.ENTITY_KEY,
                                             filtered=False)
    ])
    results = gather(rank_results)
    assert ((9 + 3 + 7 + 11 + 10) / 5.0 +
            (7 + 6 + 8 + 5 + 4) / 5.0) / 2.0 == pytest.approx(
                results[stats.StatTool.gen_key(constants.ENTITY_KEY,
                                               MEAN_RANK_FEATURE_KEY)])
示例#11
0
def test_combined_mean_reciprocal_rank_stat_gather(rank_results):
    gather = stats.StatGather([
        stats.CombinedEntityMeanReciprocalRankStatTool(constants.ENTITY_KEY,
                                                       filtered=False)
    ])
    results = gather(rank_results)
    assert ((1 / 9.0 + 1 / 3.0 + 1 / 7.0 + 1 / 11.0 + 1 / 10.0) / 5.0 +
            (1 / 7.0 + 1 / 6.0 + 1 / 8.0 + 1 / 5.0 + 1 / 4.0) /
            5.0) / 2.0 == pytest.approx(results[stats.StatTool.gen_key(
                constants.ENTITY_KEY, MEAN_RECIPROCAL_RANK_FEATURE_KEY)])
示例#12
0
def test_head_filtered_mean_reciprocal_rank_stat_gather(rank_results):
    gather = stats.StatGather([
        stats.ElementMeanReciprocalRankStatTool(constants.HEAD_KEY,
                                                filtered=True)
    ])
    results = gather(rank_results)
    assert (1 / 9.0 + 1 / 2.0 + 1 / 6.0 + 1 / 8.0 +
            1 / 9.0) / 5.0 == pytest.approx(results[stats.StatTool.gen_key(
                constants.HEAD_KEY,
                MEAN_FILTERED_RECIPROCAL_RANK_FEATURE_KEY)])
示例#13
0
def test_rel_hit_stat_gather(rank_results):
    hit_level = 9
    gather = stats.StatGather([
        stats.ElementHitStatTool(constants.RELATION_KEY,
                                 hit_level,
                                 filtered=False)
    ])
    results = gather(rank_results)
    assert 2 / 3.0 == pytest.approx(results[stats.StatTool.gen_key(
        constants.RELATION_KEY, HITS_FEATURE_PREFIX, hit_level)])
示例#14
0
def test_combined_hit_stat_gather(rank_results):
    hit_level = 6
    gather = stats.StatGather([
        stats.CombinedEntityHitStatTool(constants.ENTITY_KEY,
                                        hit_level,
                                        filtered=False)
    ])
    results = gather(rank_results)
    assert (1 / 5.0 + 3 / 5.0) / 2.0 == pytest.approx(
        results[stats.StatTool.gen_key(constants.ENTITY_KEY,
                                       HITS_FEATURE_PREFIX, hit_level)])
示例#15
0
def test_tail_filtered_mean_reciprocal_rank_stat_gather(rank_results):
    hit_level = 6
    gather = stats.StatGather([
        stats.ElementMeanReciprocalRankStatTool(constants.TAIL_KEY,
                                                filtered=True)
    ])
    results = gather(rank_results)
    assert (1 / 6.0 + 1 / 6.0 + 1 / 3.0 + 1 / 5.0 +
            1 / 1.0) / 5.0 == pytest.approx(results[stats.StatTool.gen_key(
                constants.TAIL_KEY,
                MEAN_FILTERED_RECIPROCAL_RANK_FEATURE_KEY)])
示例#16
0
def combined_stat_gather():
    return stats.StatGather([
        stats.CombinedEntityHitStatTool(constants.ENTITY_KEY,
                                        6,
                                        filtered=False),
        stats.CombinedEntityHitStatTool(constants.ENTITY_KEY, 6,
                                        filtered=True),
        stats.CombinedEntityMeanRankStatTool(constants.ENTITY_KEY,
                                             filtered=False),
        stats.CombinedEntityMeanRankStatTool(constants.ENTITY_KEY,
                                             filtered=True),
        stats.CombinedEntityMeanReciprocalRankStatTool(constants.ENTITY_KEY,
                                                       filtered=False),
        stats.CombinedEntityMeanReciprocalRankStatTool(constants.ENTITY_KEY,
                                                       filtered=True)
    ])
示例#17
0
def sep_stat_gather():
    return stats.StatGather([
        stats.ElementHitStatTool(constants.HEAD_KEY, 5, filtered=False),
        stats.ElementHitStatTool(constants.HEAD_KEY, 8, filtered=True),
        stats.ElementHitStatTool(constants.TAIL_KEY, 5, filtered=False),
        stats.ElementHitStatTool(constants.TAIL_KEY, 5, filtered=True),
        stats.ElementHitStatTool(constants.RELATION_KEY, 9, filtered=False),
        stats.ElementHitStatTool(constants.RELATION_KEY, 5, filtered=True),
        stats.ElementMeanRankStatTool(constants.HEAD_KEY, filtered=False),
        stats.ElementMeanRankStatTool(constants.HEAD_KEY, filtered=True),
        stats.ElementMeanRankStatTool(constants.TAIL_KEY, filtered=False),
        stats.ElementMeanRankStatTool(constants.TAIL_KEY, filtered=True),
        stats.ElementMeanReciprocalRankStatTool(constants.HEAD_KEY,
                                                filtered=False),
        stats.ElementMeanReciprocalRankStatTool(constants.HEAD_KEY,
                                                filtered=True),
        stats.ElementMeanReciprocalRankStatTool(constants.TAIL_KEY,
                                                filtered=False),
        stats.ElementMeanReciprocalRankStatTool(constants.TAIL_KEY,
                                                filtered=True),
    ])