def load_model(model_dir, model_type, step=None): args = saver.ArgsDict(model_dir=model_dir, model_type=model_type, step=step) saver.restore_args(args) arguments.backport_default_args(args) dataset.set_vocab(args) m = models.get_model(args) eval_dataset = dataset.get_eval_dataset(args, m) m.model.eval() the_executor = executor.get_executor(args)() return m, eval_dataset, the_executor
def train_start(args): print("\tModel type: %s\n\tModel path: %s" % (args.model_type, args.model_dir)) dataset.set_vocab(args) m = models.get_model(args) m.model.train() train_data = dataset.get_train_dataset(args, m, for_eval=False) dev_data = dataset.get_eval_dataset(args, m) dev_data.shuffle = True sampler = get_sampler(train_data, args) saver.save_args(args) return train_data, dev_data, m, sampler
def evaluate(args): print("Evaluation:") print("\tModel type: %s\n\tModel path: %s" % (args.model_type, args.model_dir)) saver.restore_args(args) arguments.backport_default_args(args) dataset.set_vocab(args) m = models.get_model(args) if args.eval_final: eval_dataset = dataset.get_eval_final_dataset(args, m) elif args.eval_train: eval_dataset = dataset.get_train_dataset(args, m, for_eval=True) else: eval_dataset = dataset.get_eval_dataset(args, m) if m.last_step == 0: raise ValueError('Attempting to evaluate on untrained model') m.model.eval() current_executor = executor.get_executor(args)() if args.example_id is not None: eval_dataset.data = [eval_dataset.task[args.example_id]] evaluation.run_eval( args.tag, eval_dataset, m.inference, current_executor.execute, not args.hide_example_info, args.report_path)
def infer(args): print("\tModel type: %s\n\tModel path: %s" % (args.model_type, args.model_dir)) saver.restore_args(args) arguments.backport_default_args(args) dataset.set_vocab(args) m = models.get_model(args) if m.last_step == 0: raise ValueError('Attempting to evaluate on untrained model') if args.eval_final: eval_dataset = dataset.get_eval_final_dataset(args, m) elif args.eval_train: eval_dataset = dataset.get_train_dataset(args, m, for_eval=True) else: eval_dataset = dataset.get_eval_dataset(args, m) m.model.eval() f = open(args.infer_output, 'w') index_f = open(args.infer_output + '.index', 'w') infer_counters = collections.Counter() num_outputs = 0 iterator = tqdm.tqdm(eval_dataset, dynamic_ncols=True) for batch in iterator: infer_results = m.inference(batch) infer_outputs = m.process_infer_results(batch, infer_results, infer_counters) for output in infer_outputs: offset = f.tell() pickle.dump(output, f, pickle.HIGHEST_PROTOCOL) index_f.write(struct.pack('<Q', offset)) num_outputs += 1 if args.infer_limit and num_outputs >= args.infer_limit: return iterator.set_postfix(**infer_counters)
def evaluate(args): print("Evaluation:") print("\tModel type: %s\n\tModel path: %s" % (args.model_type, args.model_dir)) saver.restore_args(args) arguments.backport_default_args(args) dataset.set_vocab(args) m = models.get_model(args) if args.eval_final: eval_dataset = dataset.get_eval_final_dataset(args, m) elif args.eval_train: eval_dataset = dataset.get_train_dataset(args, m, for_eval=True) else: eval_dataset = dataset.get_eval_dataset(args, m) if m.last_step == 0: raise ValueError('Attempting to evaluate on untrained model') m.model.eval() the_executor = executor.get_executor(args)() top_k_exact = np.zeros(args.max_beam_trees, dtype=int) top_k_sem = np.zeros(args.max_beam_trees, dtype=int) top_k_gen = np.zeros(args.max_beam_trees, dtype=int) exact_ranks = [] sem_ranks = [] gen_ranks = [] outputs = [] total = 0.0 iterator = tqdm.tqdm(eval_dataset, dynamic_ncols=True) for batch in iterator: sequences = m.inference(batch, filtered=False) for batch_idx, beams in enumerate(sequences): total += 1 orig_example = batch.orig_examples[batch_idx] exact_found, sem_found, gen_found = False, False, False exact_rank, sem_rank, gen_rank = None, None, None for rank, tokens in enumerate(beams): # Exact match ref_code = getattr(orig_example, 'code_sequence', getattr(orig_example, 'goal_code', None)) if not exact_found and tuple(tokens) == tuple(ref_code): top_k_exact[rank:] += 1 exact_found = True exact_rank = rank if not sem_found or not exact_found: # Semantic match (passes all input tests) input_tests_eval = executor.evaluate_code( tokens, None, orig_example.input_tests, the_executor.execute) sem_match = input_tests_eval['correct'] == len( orig_example.input_tests) if not sem_found and sem_match: top_k_sem[rank:] += 1 sem_found = True sem_rank = rank # Generalization (passes all input tests + other tests) tests_eval = executor.evaluate_code( tokens, None, orig_example.tests, the_executor.execute) gen = sem_match and tests_eval['correct'] == len( orig_example.tests) if not gen_found and gen: top_k_gen[rank:] += 1 gen_found = True gen_rank = rank if exact_found and sem_found and gen_found: break exact_ranks.append(exact_rank) sem_ranks.append(sem_rank) gen_ranks.append(gen_rank) if args.save_beam_outputs: outputs.append(beams) iterator.set_postfix(exact=top_k_exact[0] / total, sem=top_k_sem[0] / total, gen=top_k_gen[0] / total) with open(args.report_path, 'w') as f: json.dump( { # Total number of programs in this report. 'total': total, # List where the Nth entry contains the number of programs with # exact match among the top N beam search outputs. # Length = args.max_beam_trees. 'exact': top_k_exact.tolist(), # List where the Nth entry contains the number of programs with # semantic match (correct on all `input_tests`, given to the model) # among the top N beam search outputs. # Length = args.max_beam_trees. 'semantic': top_k_sem.tolist(), # List where the Nth entry contains the number of programs with # generalization (correct on all `input_tests` and `tests) among the # top N beam search outputs. # Length = args.max_beam_trees. 'generalization': top_k_gen.tolist(), # For each program, the rank at which the corresponding type of # match was found (None if not applicable to any rank). 'ranks': { 'exact': exact_ranks, 'semantic': sem_ranks, 'generalization': gen_ranks, }, # List of length `total` where each item is a list containing # `args.max_beam_trees` programs, as output by the beam search. # Can be empty if this output was not saved. 'beam_outputs': outputs }, f)