Ejemplo n.º 1
0
def load_model(model_dir, model_type, step=None):
    args = saver.ArgsDict(model_dir=model_dir, model_type=model_type, step=step)
    saver.restore_args(args)
    arguments.backport_default_args(args)
    dataset.set_vocab(args)
    m = models.get_model(args)
    eval_dataset = dataset.get_eval_dataset(args, m)
    m.model.eval()
    the_executor = executor.get_executor(args)()
    return m, eval_dataset, the_executor
Ejemplo n.º 2
0
def train_start(args):
    print("\tModel type: %s\n\tModel path: %s" %
          (args.model_type, args.model_dir))
    dataset.set_vocab(args)
    m = models.get_model(args)
    m.model.train()
    train_data = dataset.get_train_dataset(args, m, for_eval=False)
    dev_data = dataset.get_eval_dataset(args, m)
    dev_data.shuffle = True
    sampler = get_sampler(train_data, args)
    saver.save_args(args)
    return train_data, dev_data, m, sampler
Ejemplo n.º 3
0
def evaluate(args):
    print("Evaluation:")
    print("\tModel type: %s\n\tModel path: %s" % (args.model_type, args.model_dir))
    saver.restore_args(args)
    arguments.backport_default_args(args)
    dataset.set_vocab(args)
    m = models.get_model(args)
    if args.eval_final:
        eval_dataset = dataset.get_eval_final_dataset(args, m)
    elif args.eval_train:
        eval_dataset = dataset.get_train_dataset(args, m, for_eval=True)
    else:
        eval_dataset = dataset.get_eval_dataset(args, m)
    if m.last_step == 0:
        raise ValueError('Attempting to evaluate on untrained model')
    m.model.eval()
    current_executor = executor.get_executor(args)()
    if args.example_id is not None:
        eval_dataset.data = [eval_dataset.task[args.example_id]]

    evaluation.run_eval(
        args.tag, eval_dataset, m.inference,
        current_executor.execute, not args.hide_example_info,
        args.report_path)
Ejemplo n.º 4
0
def infer(args):
    print("\tModel type: %s\n\tModel path: %s" %
          (args.model_type, args.model_dir))
    saver.restore_args(args)
    arguments.backport_default_args(args)
    dataset.set_vocab(args)
    m = models.get_model(args)
    if m.last_step == 0:
        raise ValueError('Attempting to evaluate on untrained model')

    if args.eval_final:
        eval_dataset = dataset.get_eval_final_dataset(args, m)
    elif args.eval_train:
        eval_dataset = dataset.get_train_dataset(args, m, for_eval=True)
    else:
        eval_dataset = dataset.get_eval_dataset(args, m)
    m.model.eval()

    f = open(args.infer_output, 'w')
    index_f = open(args.infer_output + '.index', 'w')
    infer_counters = collections.Counter()
    num_outputs = 0
    iterator = tqdm.tqdm(eval_dataset, dynamic_ncols=True)
    for batch in iterator:
        infer_results = m.inference(batch)
        infer_outputs = m.process_infer_results(batch, infer_results,
                                                infer_counters)
        for output in infer_outputs:
            offset = f.tell()
            pickle.dump(output, f, pickle.HIGHEST_PROTOCOL)
            index_f.write(struct.pack('<Q', offset))
            num_outputs += 1
            if args.infer_limit and num_outputs >= args.infer_limit:
                return

        iterator.set_postfix(**infer_counters)
Ejemplo n.º 5
0
def evaluate(args):
    print("Evaluation:")
    print("\tModel type: %s\n\tModel path: %s" %
          (args.model_type, args.model_dir))
    saver.restore_args(args)
    arguments.backport_default_args(args)
    dataset.set_vocab(args)
    m = models.get_model(args)
    if args.eval_final:
        eval_dataset = dataset.get_eval_final_dataset(args, m)
    elif args.eval_train:
        eval_dataset = dataset.get_train_dataset(args, m, for_eval=True)
    else:
        eval_dataset = dataset.get_eval_dataset(args, m)
    if m.last_step == 0:
        raise ValueError('Attempting to evaluate on untrained model')
    m.model.eval()
    the_executor = executor.get_executor(args)()

    top_k_exact = np.zeros(args.max_beam_trees, dtype=int)
    top_k_sem = np.zeros(args.max_beam_trees, dtype=int)
    top_k_gen = np.zeros(args.max_beam_trees, dtype=int)
    exact_ranks = []
    sem_ranks = []
    gen_ranks = []
    outputs = []
    total = 0.0

    iterator = tqdm.tqdm(eval_dataset, dynamic_ncols=True)
    for batch in iterator:
        sequences = m.inference(batch, filtered=False)

        for batch_idx, beams in enumerate(sequences):
            total += 1
            orig_example = batch.orig_examples[batch_idx]
            exact_found, sem_found, gen_found = False, False, False
            exact_rank, sem_rank, gen_rank = None, None, None
            for rank, tokens in enumerate(beams):
                # Exact match
                ref_code = getattr(orig_example, 'code_sequence',
                                   getattr(orig_example, 'goal_code', None))
                if not exact_found and tuple(tokens) == tuple(ref_code):
                    top_k_exact[rank:] += 1
                    exact_found = True
                    exact_rank = rank

                if not sem_found or not exact_found:
                    # Semantic match (passes all input tests)
                    input_tests_eval = executor.evaluate_code(
                        tokens, None, orig_example.input_tests,
                        the_executor.execute)
                    sem_match = input_tests_eval['correct'] == len(
                        orig_example.input_tests)
                    if not sem_found and sem_match:
                        top_k_sem[rank:] += 1
                        sem_found = True
                        sem_rank = rank

                    # Generalization (passes all input tests + other tests)
                    tests_eval = executor.evaluate_code(
                        tokens, None, orig_example.tests, the_executor.execute)
                    gen = sem_match and tests_eval['correct'] == len(
                        orig_example.tests)
                    if not gen_found and gen:
                        top_k_gen[rank:] += 1
                        gen_found = True
                        gen_rank = rank

                if exact_found and sem_found and gen_found:
                    break

            exact_ranks.append(exact_rank)
            sem_ranks.append(sem_rank)
            gen_ranks.append(gen_rank)
            if args.save_beam_outputs:
                outputs.append(beams)

        iterator.set_postfix(exact=top_k_exact[0] / total,
                             sem=top_k_sem[0] / total,
                             gen=top_k_gen[0] / total)

    with open(args.report_path, 'w') as f:
        json.dump(
            {
                # Total number of programs in this report.
                'total': total,
                # List where the Nth entry contains the number of programs with
                # exact match among the top N beam search outputs.
                # Length = args.max_beam_trees.
                'exact': top_k_exact.tolist(),
                # List where the Nth entry contains the number of programs with
                # semantic match (correct on all `input_tests`, given to the model)
                # among the top N beam search outputs.
                # Length = args.max_beam_trees.
                'semantic': top_k_sem.tolist(),
                # List where the Nth entry contains the number of programs with
                # generalization (correct on all `input_tests` and `tests) among the
                # top N beam search outputs.
                # Length = args.max_beam_trees.
                'generalization': top_k_gen.tolist(),
                # For each program, the rank at which the corresponding type of
                # match was found (None if not applicable to any rank).
                'ranks': {
                    'exact': exact_ranks,
                    'semantic': sem_ranks,
                    'generalization': gen_ranks,
                },
                # List of length `total` where each item is a list containing
                # `args.max_beam_trees` programs, as output by the beam search.
                # Can be empty if this output was not saved.
                'beam_outputs': outputs
            },
            f)