Exemplo n.º 1
0
def eval_stats_to_thrift(template_stats, task_id):
    flags_order = [
        (Flags.GOOD_STABLE, 'GS'),
        (Flags.GOOD, 'G'),
        (Flags.BAD_STABLE, 'B'),
        (Flags.BAD, 'B'),
        (Flags.IMPOSSIBLE, 'IMP'),
    ]

    def find_flag_code(flags):
        for flag, code in flags_order:
            if flag in flags:
                return code

    thrift_eval_data = {}
    solutions_codes = []
    for tier in template_stats['status_counts']:
        not_solved = template_stats['status_counts'][tier][task_id][NOT_SOLVED]
        solved = template_stats['status_counts'][tier][task_id][SOLVED]
        if solved > 0:
            attempts = int(((not_solved + solved) / solved))
        else:
            attempts = -1
        thrift_eval_data[f'attempts_to_solve_{tier}'] = attempts
    for tier in template_stats['flags']:
        thrift_eval_data[f'flag_{tier}'] = find_flag_code(
            template_stats['flags'][tier][task_id])
    for tier in template_stats['solutions']:
        if template_stats['solutions'][tier].get(task_id):
            solutions_codes.append(TIER_TO_CODE[tier])
    for tier in template_stats['unstable_solutions']:
        if template_stats['unstable_solutions'][tier].get(task_id):
            solutions_codes.append(TIER_TO_CODE[tier] + 'U')
    thrift_eval_data['known_solutions'] = solutions_codes
    return task_if.EvalData(**thrift_eval_data)
Exemplo n.º 2
0
    def load_evaluation_data(self, task_id_pattern):
        known_task_ids = frozenset(self.task_cache)

        tasks_in_templates = collections.Counter(
            [task_id.split(':')[0] for task_id in known_task_ids])

        all_data = {}
        solved_in_template = collections.defaultdict(collections.Counter)
        for template_stats in self.eval_stats.values():
            for tier, tier_data in template_stats['flags'].items():
                for task_id, flags in tier_data.items():
                    if task_id not in known_task_ids:
                        continue
                    if Flags.GOOD_STABLE in flags:
                        solved_in_template[task_id.split(':')[0]][tier] += 1
                    if task_id not in all_data:
                        all_data[task_id] = eval_stats_to_thrift(
                            template_stats, task_id)

        for template_id, counts in solved_in_template.items():
            num_tasks = sum(
                task_id.startswith(template_id) for task_id in known_task_ids)

            def to_percent(x):
                return int(x * 100 / num_tasks)

            all_data[template_id + ':'] = task_if.EvalData(
                percent_ball=to_percent(counts['ball']),
                percent_two_balls=to_percent(counts['two_balls']),
                percent_ramp=to_percent(counts['ramp']),
                num_tasks=tasks_in_templates[template_id],
            )
        if task_id_pattern:
            all_data = {
                k: v
                for k, v in all_data.items()
                if k.startswith(task_id_pattern)
            }
        else:
            all_data = {k: v for k, v in all_data.items() if k.endswith(':')}
        return all_data