def eval_stats_to_thrift(template_stats, task_id): flags_order = [ (Flags.GOOD_STABLE, 'GS'), (Flags.GOOD, 'G'), (Flags.BAD_STABLE, 'B'), (Flags.BAD, 'B'), (Flags.IMPOSSIBLE, 'IMP'), ] def find_flag_code(flags): for flag, code in flags_order: if flag in flags: return code thrift_eval_data = {} solutions_codes = [] for tier in template_stats['status_counts']: not_solved = template_stats['status_counts'][tier][task_id][NOT_SOLVED] solved = template_stats['status_counts'][tier][task_id][SOLVED] if solved > 0: attempts = int(((not_solved + solved) / solved)) else: attempts = -1 thrift_eval_data[f'attempts_to_solve_{tier}'] = attempts for tier in template_stats['flags']: thrift_eval_data[f'flag_{tier}'] = find_flag_code( template_stats['flags'][tier][task_id]) for tier in template_stats['solutions']: if template_stats['solutions'][tier].get(task_id): solutions_codes.append(TIER_TO_CODE[tier]) for tier in template_stats['unstable_solutions']: if template_stats['unstable_solutions'][tier].get(task_id): solutions_codes.append(TIER_TO_CODE[tier] + 'U') thrift_eval_data['known_solutions'] = solutions_codes return task_if.EvalData(**thrift_eval_data)
def load_evaluation_data(self, task_id_pattern): known_task_ids = frozenset(self.task_cache) tasks_in_templates = collections.Counter( [task_id.split(':')[0] for task_id in known_task_ids]) all_data = {} solved_in_template = collections.defaultdict(collections.Counter) for template_stats in self.eval_stats.values(): for tier, tier_data in template_stats['flags'].items(): for task_id, flags in tier_data.items(): if task_id not in known_task_ids: continue if Flags.GOOD_STABLE in flags: solved_in_template[task_id.split(':')[0]][tier] += 1 if task_id not in all_data: all_data[task_id] = eval_stats_to_thrift( template_stats, task_id) for template_id, counts in solved_in_template.items(): num_tasks = sum( task_id.startswith(template_id) for task_id in known_task_ids) def to_percent(x): return int(x * 100 / num_tasks) all_data[template_id + ':'] = task_if.EvalData( percent_ball=to_percent(counts['ball']), percent_two_balls=to_percent(counts['two_balls']), percent_ramp=to_percent(counts['ramp']), num_tasks=tasks_in_templates[template_id], ) if task_id_pattern: all_data = { k: v for k, v in all_data.items() if k.startswith(task_id_pattern) } else: all_data = {k: v for k, v in all_data.items() if k.endswith(':')} return all_data