示例#1
0
def evaluate_model(model_weights, vocabulary_path, dataset_path, nb_ios,
                   nb_samples, use_grammar, output_path, beam_size, top_k,
                   batch_size, use_cuda, dump_programs):
    all_outputs_path = []
    all_semantic_output_path = []
    all_syntax_output_path = []
    all_generalize_output_path = []
    res_dir = os.path.dirname(output_path)
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)
    for k in range(top_k):

        new_term = "exactmatch_top%d.txt" % (k + 1)
        new_semantic_term = "semantic_top%d.txt" % (k + 1)
        new_syntax_term = "syntax_top%d.txt" % (k + 1)
        new_generalize_term = "fullgeneralize_top%d.txt" % (k + 1)

        new_file_name = output_path + new_term
        new_semantic_file_name = output_path + new_semantic_term
        new_syntax_file_name = output_path + new_syntax_term
        new_generalize_file_name = output_path + new_generalize_term

        all_outputs_path.append(new_file_name)
        all_semantic_output_path.append(new_semantic_file_name)
        all_syntax_output_path.append(new_syntax_file_name)
        all_generalize_output_path.append(new_generalize_file_name)
    program_dump_path = os.path.join(res_dir, "generated")

    if os.path.exists(all_outputs_path[0]):
        with open(all_outputs_path[0], "r") as out_file:
            out_file_content = out_file.read()
            print("Using cached result from {}".format(all_outputs_path[0]))
            print("Greedy select accuracy: {}".format(out_file_content))
            return

    # Load the vocabulary of the trained model
    dataset, vocab = load_input_file(dataset_path, vocabulary_path)
    tgt_start = vocab["tkn2idx"]["<s>"]
    tgt_end = vocab["tkn2idx"]["m)"]
    tgt_pad = vocab["tkn2idx"]["<pad>"]

    simulator = Simulator(vocab["idx2tkn"])
    # Load the model
    if not use_cuda:
        # https://discuss.pytorch.org/t/on-a-cpu-device-how-to-load-checkpoint-saved-on-gpu-device/349/8
        # Is it failing?
        model = torch.load(model_weights,
                           map_location=lambda storage, loc: storage)
    else:
        model = torch.load(model_weights)
        model.cuda()
    # And put it into evaluation mode
    model.eval()

    syntax_checker = PySyntaxChecker(vocab["tkn2idx"], use_cuda)
    if use_grammar:
        model.set_syntax_checker(syntax_checker)

    if beam_size == 1:
        top_k = 1
    nb_correct = [0 for _ in range(top_k)]
    nb_semantic_correct = [0 for _ in range(top_k)]
    nb_syntax_correct = [0 for _ in range(top_k)]
    nb_generalize_correct = [0 for _ in range(top_k)]
    total_nb = 0

    dataset = shuffle_dataset(dataset, batch_size, randomize=False)
    for sp_idx in tqdm(range(0, len(dataset["sources"]), batch_size)):


        inp_grids, out_grids, \
        in_tgt_seq, in_tgt_seq_list, out_tgt_seq, \
        inp_worlds, out_worlds, \
        _, \
        inp_test_worlds, out_test_worlds = get_minibatch(dataset, sp_idx, batch_size,
                                                         tgt_start, tgt_end, tgt_pad,
                                                         nb_ios, shuffle=False, volatile_vars=True)

        max_len = out_tgt_seq.size(1) + 10
        if use_cuda:
            inp_grids, out_grids = inp_grids.cuda(), out_grids.cuda()
            in_tgt_seq, out_tgt_seq = in_tgt_seq.cuda(), out_tgt_seq.cuda()

        if dump_programs:
            import numpy as np
            decoder_logit, syntax_logit = model(inp_grids, out_grids,
                                                in_tgt_seq, in_tgt_seq_list)
            if syntax_logit is not None and model.decoder.learned_syntax_checker is not None:
                syntax_logit = syntax_logit.cpu().data.numpy()
                for n in range(in_tgt_seq.size(0)):
                    decoded_dump_dir = os.path.join(program_dump_path,
                                                    str(n + sp_idx))
                    if not os.path.exists(decoded_dump_dir):
                        os.makedirs(decoded_dump_dir)
                    seq = in_tgt_seq.cpu().data.numpy()[n].tolist()
                    seq_len = seq.index(0) if 0 in seq else len(seq)
                    file_name = str(n) + "_learned_syntax"
                    norm_logit = syntax_logit[n, :seq_len]
                    norm_logit = np.log(-norm_logit)
                    norm_logit = 1 / (1 + np.exp(-norm_logit))
                    np.save(os.path.join(decoded_dump_dir, file_name),
                            norm_logit)
                    ini_state = syntax_checker.get_initial_checker_state()
                    file_name = str(n) + "_manual_syntax"
                    mask = syntax_checker.get_sequence_mask(
                        ini_state, seq).squeeze().cpu().numpy()[:seq_len]
                    np.save(os.path.join(decoded_dump_dir, file_name), mask)
                    file_name = str(n) + "_diff"
                    diff = mask.astype(float) - norm_logit
                    diff = (diff + 1) / 2  # remap to [0,1]
                    np.save(os.path.join(decoded_dump_dir, file_name), diff)

        decoded = model.beam_sample(inp_grids, out_grids, tgt_start, tgt_end,
                                    max_len, beam_size, top_k)
        for batch_idx, (target, sp_decoded,
                        sp_input_worlds, sp_output_worlds,
                        sp_test_input_worlds, sp_test_output_worlds) in \
            enumerate(zip(out_tgt_seq.chunk(out_tgt_seq.size(0)), decoded,
                          inp_worlds, out_worlds,
                          inp_test_worlds, out_test_worlds)):

            total_nb += 1
            target = target.cpu().data.squeeze().numpy().tolist()
            target = [tkn_idx for tkn_idx in target if tkn_idx != tgt_pad]

            if dump_programs:
                decoded_dump_dir = os.path.join(program_dump_path,
                                                str(batch_idx + sp_idx))
                if not os.path.exists(decoded_dump_dir):
                    os.makedirs(decoded_dump_dir)
                write_program(os.path.join(decoded_dump_dir, "target"), target,
                              vocab["idx2tkn"])
                for rank, dec in enumerate(sp_decoded):
                    pred = dec[1]
                    ll = dec[0]
                    file_name = str(rank) + " - " + str(ll)
                    write_program(os.path.join(decoded_dump_dir, file_name),
                                  pred, vocab["idx2tkn"])

            # Exact matches
            for rank, dec in enumerate(sp_decoded):
                pred = dec[-1]
                if pred == target:
                    # This prediction is correct. This means that we score for
                    # all the following scores
                    for top_idx in range(rank, top_k):
                        nb_correct[top_idx] += 1
                    break

            # Semantic matches
            for rank, dec in enumerate(sp_decoded):
                pred = dec[-1]
                parse_success, cand_prog = simulator.get_prog_ast(pred)
                if (not parse_success):
                    continue
                semantically_correct = True
                for (input_world, output_world) in zip(sp_input_worlds,
                                                       sp_output_worlds):
                    res_emu = simulator.run_prog(cand_prog, input_world)
                    if (res_emu.status != 'OK') or res_emu.crashed or (
                            res_emu.outgrid != output_world):
                        # This prediction is semantically incorrect.
                        semantically_correct = False
                        break
                if semantically_correct:
                    # Score for all the following ranks
                    for top_idx in range(rank, top_k):
                        nb_semantic_correct[top_idx] += 1
                    break

            # Generalization
            for rank, dec in enumerate(sp_decoded):
                pred = dec[-1]
                parse_success, cand_prog = simulator.get_prog_ast(pred)
                if (not parse_success):
                    continue
                generalizes = True
                for (input_world, output_world) in zip(sp_input_worlds,
                                                       sp_output_worlds):
                    res_emu = simulator.run_prog(cand_prog, input_world)
                    if (res_emu.status != 'OK') or res_emu.crashed or (
                            res_emu.outgrid != output_world):
                        # This prediction is semantically incorrect.
                        generalizes = False
                        break
                for (input_world, output_world) in zip(sp_test_input_worlds,
                                                       sp_test_output_worlds):
                    res_emu = simulator.run_prog(cand_prog, input_world)
                    if (res_emu.status != 'OK') or res_emu.crashed or (
                            res_emu.outgrid != output_world):
                        # This prediction is semantically incorrect.
                        generalizes = False
                        break
                if generalizes:
                    # Score for all the following ranks
                    for top_idx in range(rank, top_k):
                        nb_generalize_correct[top_idx] += 1
                    break

            # Correct syntaxes
            for rank, dec in enumerate(sp_decoded):
                pred = dec[-1]
                parse_success, cand_prog = simulator.get_prog_ast(pred)
                if parse_success:
                    for top_idx in range(rank, top_k):
                        nb_syntax_correct[top_idx] += 1
                    break

    for k in range(top_k):
        with open(str(all_outputs_path[k]), "w") as res_file:
            res_file.write(str(100 * nb_correct[k] / total_nb))
        with open(str(all_semantic_output_path[k]), "w") as sem_res_file:
            sem_res_file.write(str(100 * nb_semantic_correct[k] / total_nb))
        with open(str(all_syntax_output_path[k]), "w") as stx_res_file:
            stx_res_file.write(str(100 * nb_syntax_correct[k] / total_nb))
        with open(str(all_generalize_output_path[k]), "w") as gen_res_file:
            gen_res_file.write(str(100 * nb_generalize_correct[k] / total_nb))

    semantic_at_one = 100 * nb_semantic_correct[0] / total_nb
    return semantic_at_one
示例#2
0
def train_seq2seq_model(
        # Optimization
        signal,
        nb_ios,
        nb_epochs,
        optim_alg,
        batch_size,
        learning_rate,
        use_grammar,
        beta,
        val_frequency,
        # Model
        kernel_size,
        conv_stack,
        fc_stack,
        tgt_embedding_size,
        lstm_hidden_size,
        nb_lstm_layers,
        learn_syntax,
        # RL specific options
        environment,
        reward_comb,
        nb_rollouts,
        rl_beam,
        rl_inner_batch,
        rl_use_ref,
        # What to train
        train_file,
        val_file,
        vocab_file,
        nb_samples,
        initialisation,
        # Where to write results
        result_folder,
        args_dict,
        # Run options
        use_cuda,
        log_frequency):

    #############################
    # Admin / Bookkeeping stuff #
    #############################
    # Creating the results directory
    result_dir = Path(result_folder)
    if not result_dir.exists():
        os.makedirs(str(result_dir))
    else:
        # The result directory exists. Let's check whether or not all of our
        # work has already been done.

        # The sign of all the works being done would be the model after the
        # last epoch, let's check if it's here
        last_epoch_model_path = result_dir / "Weights" / ("weights_%d.model" %
                                                          (nb_epochs - 1))
        if last_epoch_model_path.exists():
            print("{} already exists -- skipping this training".format(
                last_epoch_model_path))
            return

    # Dumping the arguments
    args_dump_path = result_dir / "args.json"
    with open(str(args_dump_path), "w") as args_dump_file:
        json.dump(args_dict, args_dump_file, indent=2)
    # Setting up the logs
    log_file = result_dir / "logs.txt"
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s',
                        filename=str(log_file),
                        filemode='w')
    train_loss_path = result_dir / "train_loss.json"
    models_dir = result_dir / "Weights"
    if not models_dir.exists():
        os.makedirs(str(models_dir))
        time.sleep(1)  # Let some time for the dir to be created

    #####################################
    # Load Model / Dataset / Vocabulary #
    #####################################
    # Load-up the dataset
    dataset, vocab = load_input_file(train_file, vocab_file)

    if use_grammar:
        syntax_checker = PySyntaxChecker(vocab["tkn2idx"], use_cuda)
    # Reduce the number of samples in the dataset, if needed
    if nb_samples > 0:
        # Randomize the dataset to shuffle it, because I'm not sure that there
        # is no meaning in the ordering of the samples
        random.seed(0)
        dataset = shuffle_dataset(dataset, batch_size)
        dataset = {
            'sources': dataset['sources'][:nb_samples],
            'targets': dataset['targets'][:nb_samples],
        }

    vocabulary_size = len(vocab["tkn2idx"])
    if initialisation is None:
        # Create the model
        model = IOs2Seq(kernel_size, conv_stack, fc_stack, vocabulary_size,
                        tgt_embedding_size, lstm_hidden_size, nb_lstm_layers,
                        learn_syntax)
        # Dump initial weights
        path_to_ini_weight_dump = models_dir / "ini_weights.model"
        with open(str(path_to_ini_weight_dump), "wb") as weight_file:
            torch.save(model, weight_file)
    else:
        model = torch.load(initialisation,
                           map_location=lambda storage, loc: storage)
    if use_grammar:
        model.set_syntax_checker(syntax_checker)
    tgt_start = vocab["tkn2idx"]["<s>"]
    tgt_end = vocab["tkn2idx"]["m)"]
    tgt_pad = vocab["tkn2idx"]["<pad>"]

    ############################################
    # Setup Loss / Optimizer / Eventual Critic #
    ############################################
    if signal == TrainSignal.SUPERVISED:
        # Create a mask to not penalize bad prediction on the padding
        weight_mask = torch.ones(vocabulary_size)
        weight_mask[tgt_pad] = 0
        # Setup the criterion
        loss_criterion = nn.CrossEntropyLoss(weight=weight_mask)
    elif signal == TrainSignal.RL or signal == TrainSignal.BEAM_RL:
        simulator = Simulator(vocab["idx2tkn"])

        if signal == TrainSignal.BEAM_RL:
            reward_comb_fun = RewardCombinationFun[reward_comb]
    else:
        raise Exception("Unknown TrainingSignal.")

    if use_cuda:
        model.cuda()
        if signal == TrainSignal.SUPERVISED:
            loss_criterion.cuda()

    # Setup the optimizers
    optimizer_cls = getattr(optim, optim_alg)
    optimizer = optimizer_cls(model.parameters(), lr=learning_rate)

    #####################
    # ################# #
    # # Training Loop # #
    # ################# #
    #####################
    losses = []
    recent_losses = []
    best_val_acc = np.NINF
    for epoch_idx in range(0, nb_epochs):
        nb_ios_for_epoch = nb_ios
        # This is definitely not the most efficient way to do it but oh well
        dataset = shuffle_dataset(dataset, batch_size)
        for sp_idx in tqdm(range(0, len(dataset["sources"]), batch_size)):

            batch_idx = int(sp_idx / batch_size)
            optimizer.zero_grad()

            if signal == TrainSignal.SUPERVISED:
                inp_grids, out_grids, \
                    in_tgt_seq, in_tgt_seq_list, out_tgt_seq, \
                    _, _, _, _, _ = get_minibatch(dataset, sp_idx, batch_size,
                                                  tgt_start, tgt_end, tgt_pad,
                                                  nb_ios_for_epoch)
                if use_cuda:
                    inp_grids, out_grids = inp_grids.cuda(), out_grids.cuda()
                    in_tgt_seq, out_tgt_seq = in_tgt_seq.cuda(
                    ), out_tgt_seq.cuda()
                if learn_syntax:
                    minibatch_loss = do_syntax_weighted_minibatch(
                        model, inp_grids, out_grids, in_tgt_seq,
                        in_tgt_seq_list, out_tgt_seq, loss_criterion, beta)
                else:
                    minibatch_loss = do_supervised_minibatch(
                        model, inp_grids, out_grids, in_tgt_seq,
                        in_tgt_seq_list, out_tgt_seq, loss_criterion)
                recent_losses.append(minibatch_loss)
            elif signal == TrainSignal.RL or signal == TrainSignal.BEAM_RL:
                inp_grids, out_grids, \
                    _, _, _, \
                    inp_worlds, out_worlds, \
                    targets, \
                    inp_test_worlds, out_test_worlds = get_minibatch(dataset, sp_idx, batch_size,
                                                                     tgt_start, tgt_end, tgt_pad,
                                                                     nb_ios_for_epoch)
                if use_cuda:
                    inp_grids, out_grids = inp_grids.cuda(), out_grids.cuda()
                # We use 1/nb_rollouts as the reward to normalize wrt the
                # size of the rollouts
                if signal == TrainSignal.RL:
                    reward_norm = 1 / float(nb_rollouts)
                elif signal == TrainSignal.BEAM_RL:
                    reward_norm = 1
                else:
                    raise NotImplementedError("Unknown training signal")

                lens = [len(target) for target in targets]
                max_len = max(lens) + 10
                env_cls = EnvironmentClasses[environment]
                if "Consistency" in environment:
                    envs = [
                        env_cls(reward_norm, trg_prog, sp_inp_worlds,
                                sp_out_worlds, simulator)
                        for trg_prog, sp_inp_worlds, sp_out_worlds in zip(
                            targets, inp_worlds, out_worlds)
                    ]
                elif "Generalization" or "Perf" in environment:
                    envs = [
                        env_cls(reward_norm, trg_prog, sp_inp_test_worlds,
                                sp_out_test_worlds, simulator)
                        for trg_prog, sp_inp_test_worlds, sp_out_test_worlds in
                        zip(targets, inp_test_worlds, out_test_worlds)
                    ]
                else:
                    raise NotImplementedError("Unknown environment type")

                if signal == TrainSignal.RL:
                    minibatch_reward = do_rl_minibatch(model, inp_grids,
                                                       out_grids, envs,
                                                       tgt_start, tgt_end,
                                                       max_len, nb_rollouts)
                    # minibatch_reward = do_rl_minibatch_two_steps(model,
                    #                                              inp_grids, out_grids,
                    #                                              envs,
                    #                                              tgt_start, tgt_end, tgt_pad,
                    #                                              max_len, nb_rollouts,
                    #                                              rl_inner_batch)
                elif signal == TrainSignal.BEAM_RL:
                    minibatch_reward = do_beam_rl(model, inp_grids, out_grids,
                                                  targets, envs,
                                                  reward_comb_fun, tgt_start,
                                                  tgt_end, tgt_pad, max_len,
                                                  rl_beam, rl_inner_batch,
                                                  rl_use_ref)
                else:
                    raise NotImplementedError("Unknown Environment type")
                recent_losses.append(minibatch_reward)
            else:
                raise NotImplementedError("Unknown Training method")
            optimizer.step()
            if (batch_idx % log_frequency == log_frequency-1 and len(recent_losses) > 0) or \
               (len(dataset["sources"]) - sp_idx ) < batch_size:
                logging.info('Epoch : %d Minibatch : %d Loss : %.5f' %
                             (epoch_idx, batch_idx,
                              sum(recent_losses) / len(recent_losses)))
                losses.extend(recent_losses)
                recent_losses = []
                # Dump the training losses
                with open(str(train_loss_path), "w") as train_loss_file:
                    json.dump(losses, train_loss_file, indent=2)

                if signal == TrainSignal.BEAM_RL:
                    # RL is much slower so we dump more frequently
                    path_to_weight_dump = models_dir / ("weights_%d.model" %
                                                        epoch_idx)
                    with open(str(path_to_weight_dump), "wb") as weight_file:
                        # Needs to be in cpu mode to dump, otherwise will be annoying to load
                        if use_cuda:
                            model.cpu()
                        torch.save(model, weight_file)
                        if use_cuda:
                            model.cuda()

        # Dump the weights at the end of the epoch
        path_to_weight_dump = models_dir / ("weights_%d.model" % epoch_idx)
        with open(str(path_to_weight_dump), "wb") as weight_file:
            # Needs to be in cpu mode to dump, otherwise will be annoying to load
            if use_cuda:
                model.cpu()
            torch.save(model, weight_file)
            if use_cuda:
                model.cuda()
        previous_weight_dump = models_dir / ("weights_%d.model" %
                                             (epoch_idx - 1))
        if previous_weight_dump.exists():
            os.remove(str(previous_weight_dump))
        # Dump the training losses
        with open(str(train_loss_path), "w") as train_loss_file:
            json.dump(losses, train_loss_file, indent=2)

        logging.info("Done with epoch %d." % epoch_idx)

        if (epoch_idx + 1) % val_frequency == 0 or (epoch_idx +
                                                    1) == nb_epochs:
            # Evaluate the model on the validation set
            out_path = str(result_dir / ("eval/epoch_%d/val_.txt" % epoch_idx))
            val_acc = evaluate_model(str(path_to_weight_dump), vocab_file,
                                     val_file, 5, 0, use_grammar, out_path,
                                     100, 50, batch_size, use_cuda, False)
            logging.info("Epoch : %d ValidationAccuracy : %f." %
                         (epoch_idx, val_acc))
            if val_acc > best_val_acc:
                logging.info("Epoch : %d ValidationBest : %f." %
                             (epoch_idx, val_acc))
                best_val_acc = val_acc
                path_to_weight_dump = models_dir / "best.model"
                with open(str(path_to_weight_dump), "wb") as weight_file:
                    # Needs to be in cpu mode to dump, otherwise will be annoying to load
                    if use_cuda:
                        model.cpu()
                    torch.save(model, weight_file)
                    if use_cuda:
                        model.cuda()
def evaluate_model(model_weights,
                   vocabulary_path,
                   dataset_path,
                   nb_ios,
                   nb_samples,
                   use_grammar,
                   output_path,
                   beam_size,
                   top_k,
                   batch_size,
                   use_cuda,
                   dump_programs,
                   return_individual_results=False,
                   use_beam=True,
                   num_successful_targets_per_precursor=None,
                   preloaded_data=False,
                   minibatch_cache_path=None):
    if return_individual_results:
        print('beginning evaluate_model for data augmentation')
    all_outputs_path = []
    all_semantic_output_path = []
    all_syntax_output_path = []
    all_generalize_output_path = []
    pred_filter_path = 'pred_filter_top1_l%d.txt' % top_k
    pred_filter_path = output_path + pred_filter_path
    res_dir = os.path.dirname(output_path)
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)
    for k in range(top_k):

        new_term = "exactmatch_top%d.txt" % (k + 1)
        new_semantic_term = "semantic_top%d.txt" % (k + 1)
        new_syntax_term = "syntax_top%d.txt" % (k + 1)
        new_generalize_term = "fullgeneralize_top%d.txt" % (k + 1)

        new_file_name = output_path + new_term
        new_semantic_file_name = output_path + new_semantic_term
        new_syntax_file_name = output_path + new_syntax_term
        new_generalize_file_name = output_path + new_generalize_term

        all_outputs_path.append(new_file_name)
        all_semantic_output_path.append(new_semantic_file_name)
        all_syntax_output_path.append(new_syntax_file_name)
        all_generalize_output_path.append(new_generalize_file_name)
    program_dump_path = os.path.join(res_dir, "generated")

    # if os.path.exists(all_outputs_path[0]):
    #     with open(all_outputs_path[0], "r") as out_file:
    #         out_file_content = out_file.read()
    #         print("Using cached result from {}".format(all_outputs_path[0]))
    #         print("Greedy select accuracy: {}".format(out_file_content))
    #         return

    # Load the vocabulary of the trained model
    if preloaded_data:
        dataset, vocab = pickle.loads(dataset_path), pickle.loads(
            vocabulary_path)  # pre-loaded
    else:
        dataset, vocab = load_input_file(dataset_path, vocabulary_path)

    tgt_start = vocab["tkn2idx"]["<s>"]
    tgt_end = vocab["tkn2idx"]["m)"]
    tgt_pad = vocab["tkn2idx"]["<pad>"]

    simulator = Simulator(vocab["idx2tkn"])
    # Load the model
    if not use_cuda:
        # https://discuss.pytorch.org/t/on-a-cpu-device-how-to-load-checkpoint-saved-on-gpu-device/349/8
        # Is it failing?
        model = torch.load(model_weights,
                           map_location=lambda storage, loc: storage)
    else:
        model = torch.load(model_weights)
        model.cuda()
    # And put it into evaluation mode
    model.eval()

    syntax_checker = PySyntaxChecker(vocab["tkn2idx"], use_cuda)
    if use_grammar:
        model.set_syntax_checker(syntax_checker)

    if beam_size == 1:
        top_k = 1
    nb_correct = [0 for _ in range(top_k)]
    nb_semantic_correct = [0 for _ in range(top_k)]
    nb_syntax_correct = [0 for _ in range(top_k)]
    nb_generalize_correct = [0 for _ in range(top_k)]
    nb_pred_filter_correct = 0
    total_nb = 0

    with torch.no_grad():
        if minibatch_cache_path is None or not os.path.isdir(
                minibatch_cache_path):
            if minibatch_cache_path is None:
                minibatch_cache_path = output_path
            dataset, sort_idx = shuffle_dataset(dataset,
                                                batch_size,
                                                randomize=False,
                                                return_sort_idx=True)
            if not os.path.isdir(minibatch_cache_path):
                os.makedirs(minibatch_cache_path)
            with open(
                    os.path.join(minibatch_cache_path, 'dataset_sortidx.pkl'),
                    'wb') as f:
                pickle.dump((dataset, sort_idx), f)
            if not os.path.exists(os.path.dirname(output_path)):
                os.makedirs(os.path.dirname(output_path))
            print('create inputs for saving minibatches')
            save_minibatch_inputs = []
            for sp_idx in tqdm(range(0, len(dataset["sources"]), batch_size)):
                sources = dataset['sources'][sp_idx:sp_idx + batch_size]
                targets = dataset['targets'][sp_idx:sp_idx + batch_size]
                minibatch_output_path = make_minibatch_path(
                    minibatch_cache_path, sp_idx)
                save_minibatch_inputs.append(
                    (sources, targets, minibatch_output_path, batch_size,
                     tgt_start, tgt_end, tgt_pad, nb_ios, False, True))
            print('save minibatches')
            print(str(datetime.datetime.now()))
            for i in range(0, len(save_minibatch_inputs),
                           60):  # to reset ram periodically...
                pool = Pool(processes=30, maxtasksperchild=1)
                pool.map(save_minibatch_parallel,
                         save_minibatch_inputs[i:i + 60],
                         chunksize=1)
                pool.close()
            print(str(datetime.datetime.now()))
        else:
            with open(
                    os.path.join(minibatch_cache_path, 'dataset_sortidx.pkl'),
                    'rb') as f:
                dataset, sort_idx = pickle.load(f)

        print('sample decodings from model')
        all_decoded = []
        load_count = 0
        pool = Pool(processes=30, maxtasksperchild=1)
        for sp_idx in tqdm(range(0, len(dataset["sources"]), batch_size)):
            if load_count % PARALLEL_LOAD_MINIBATCHES == 0:
                minibatch_load_paths = []
                for i in range(PARALLEL_LOAD_MINIBATCHES):
                    mb_sp_idx = sp_idx + i * batch_size
                    if mb_sp_idx < len(dataset["sources"]):
                        minibatch_load_paths.append(
                            make_minibatch_path(minibatch_cache_path,
                                                mb_sp_idx))
                loaded_minibatches = None
                loaded_minibatches = pool.map(load_minibatch_parallel,
                                              minibatch_load_paths,
                                              chunksize=1)
            minibatch = loaded_minibatches[load_count %
                                           PARALLEL_LOAD_MINIBATCHES]
            load_count += 1

            inp_grids, out_grids, \
            in_tgt_seq, in_tgt_seq_list, out_tgt_seq, \
            inp_worlds, out_worlds, \
            _, \
            inp_test_worlds, out_test_worlds = minibatch

            if return_individual_results:
                with open(make_minibatch_path(minibatch_cache_path, sp_idx),
                          'wb') as f:
                    pickle.dump(minibatch, f, protocol=pickle.HIGHEST_PROTOCOL)

            max_len = out_tgt_seq.size(1) + 10
            if use_cuda:
                inp_grids, out_grids = inp_grids.cuda(), out_grids.cuda()
                in_tgt_seq, out_tgt_seq = in_tgt_seq.cuda(), out_tgt_seq.cuda()

            if dump_programs:
                decoder_logit, syntax_logit = model(inp_grids, out_grids,
                                                    in_tgt_seq,
                                                    in_tgt_seq_list)
                if syntax_logit is not None and model.decoder.learned_syntax_checker is not None:
                    syntax_logit = syntax_logit.cpu().data.numpy()
                    for n in range(in_tgt_seq.size(0)):
                        decoded_dump_dir = os.path.join(
                            program_dump_path, str(n + sp_idx))
                        if not os.path.exists(decoded_dump_dir):
                            os.makedirs(decoded_dump_dir)
                        seq = in_tgt_seq.cpu().data.numpy()[n].tolist()
                        seq_len = seq.index(0) if 0 in seq else len(seq)
                        file_name = str(n) + "_learned_syntax"
                        norm_logit = syntax_logit[n, :seq_len]
                        norm_logit = np.log(-norm_logit)
                        norm_logit = 1 / (1 + np.exp(-norm_logit))
                        np.save(os.path.join(decoded_dump_dir, file_name),
                                norm_logit)
                        ini_state = syntax_checker.get_initial_checker_state()
                        file_name = str(n) + "_manual_syntax"
                        mask = syntax_checker.get_sequence_mask(
                            ini_state, seq).squeeze().cpu().numpy()[:seq_len]
                        np.save(os.path.join(decoded_dump_dir, file_name),
                                mask)
                        file_name = str(n) + "_diff"
                        diff = mask.astype(float) - norm_logit
                        diff = (diff + 1) / 2  # remap to [0,1]
                        np.save(os.path.join(decoded_dump_dir, file_name),
                                diff)

            decoded = model.beam_sample(inp_grids, out_grids, tgt_start,
                                        tgt_end, max_len, beam_size, top_k,
                                        True, use_beam)
            all_decoded.append((sp_idx, decoded))

            if load_count % PARALLEL_LOAD_MINIBATCHES == PARALLEL_LOAD_MINIBATCHES - 1:
                # save decodings
                decoded_load_paths = []
                for d_sp_idx, decoded in all_decoded:
                    decoded_load_paths.append(
                        make_decoded_path(output_path, d_sp_idx))
                pool.map(save_decoded_parallel,
                         [(decoded, path) for decoded, path in zip(
                             [tup[1]
                              for tup in all_decoded], decoded_load_paths)],
                         chunksize=1)
                all_decoded = []
        if len(all_decoded) > 0:
            # save decodings
            decoded_load_paths = []
            for d_sp_idx, decoded in all_decoded:
                decoded_load_paths.append(
                    make_decoded_path(output_path, d_sp_idx))
            pool.map(save_decoded_parallel,
                     [(decoded, path) for decoded, path in zip(
                         [tup[1] for tup in all_decoded], decoded_load_paths)],
                     chunksize=1)
            all_decoded = []
        pool.close()

        if return_individual_results:
            print('done sampling from model, starting filter')
            generalization_inputs = []
            vocab_pickle = pickle.dumps(vocab,
                                        protocol=pickle.HIGHEST_PROTOCOL)
            for sp_idx in tqdm(range(0, len(dataset["sources"]), batch_size)):
                generalization_inputs.append(
                    (minibatch_cache_path, sp_idx, output_path, vocab_pickle,
                     tgt_pad, num_successful_targets_per_precursor, top_k))
            print('starting pool.map')
            pool = Pool(processes=30, maxtasksperchild=1)
            print(str(datetime.datetime.now()))
            all_batch_results = pool.map(find_generalizations_parallel,
                                         generalization_inputs,
                                         chunksize=1)
            print(str(datetime.datetime.now()))
            pool.close()
            full_results = []
            num_found_full = 0
            for batch_results, stats in all_batch_results:
                full_results += batch_results
                num_found_full += stats[0]
        else:
            for sp_idx in tqdm(range(0, len(dataset["sources"]), batch_size)):
                inp_grids, out_grids, \
                in_tgt_seq, in_tgt_seq_list, out_tgt_seq, \
                inp_worlds, out_worlds, \
                _, \
                inp_test_worlds, out_test_worlds = get_minibatch(dataset, sp_idx, batch_size,
                                                                tgt_start, tgt_end, tgt_pad,
                                                                nb_ios, shuffle=False, volatile_vars=True)

                max_len = out_tgt_seq.size(1) + 10
                # decoded = model.beam_sample(inp_grids, out_grids,
                #                             tgt_start, tgt_end, max_len,
                #                             beam_size, top_k)
                # decoded = all_decoded[int(sp_idx / batch_size)]
                with open(make_decoded_path(output_path, sp_idx), 'rb') as f:
                    decoded = pickle.load(f)
                for batch_idx, (target, sp_decoded,
                                sp_input_worlds, sp_output_worlds,
                                sp_test_input_worlds, sp_test_output_worlds) in \
                    enumerate(zip(out_tgt_seq.chunk(out_tgt_seq.size(0)), decoded,
                                inp_worlds, out_worlds,
                                inp_test_worlds, out_test_worlds)):

                    if return_individual_results:
                        current_results = []
                    total_nb += 1
                    target = target.data.squeeze().numpy().tolist()
                    target = [
                        tkn_idx for tkn_idx in target if tkn_idx != tgt_pad
                    ]

                    if dump_programs:
                        decoded_dump_dir = os.path.join(
                            program_dump_path, str(batch_idx + sp_idx))
                        if not os.path.exists(decoded_dump_dir):
                            os.makedirs(decoded_dump_dir)
                        write_program(os.path.join(decoded_dump_dir, "target"),
                                      target, vocab["idx2tkn"])
                        for rank, dec in enumerate(sp_decoded):
                            pred = dec[1]
                            ll = dec[0]
                            file_name = str(rank) + " - " + str(ll)
                            write_program(
                                os.path.join(decoded_dump_dir, file_name),
                                pred, vocab["idx2tkn"])

                    # Exact matches
                    for rank, dec in enumerate(sp_decoded):
                        pred = dec[-1]
                        if pred == target:
                            # This prediction is correct. This means that we score for
                            # all the following scores
                            for top_idx in range(rank, top_k):
                                nb_correct[top_idx] += 1
                            break

                    # Semantic matches
                    for rank, dec in enumerate(sp_decoded):
                        pred = dec[-1]
                        parse_success, cand_prog = simulator.get_prog_ast(pred)
                        if (not parse_success):
                            continue
                        semantically_correct = True
                        for (input_world,
                             output_world) in zip(sp_input_worlds,
                                                  sp_output_worlds):
                            res_emu = simulator.run_prog(
                                cand_prog, input_world)
                            if (res_emu.status != 'OK') or res_emu.crashed or (
                                    res_emu.outgrid != output_world):
                                # This prediction is semantically incorrect.
                                semantically_correct = False
                                break
                        if semantically_correct:
                            # Score for all the following ranks
                            for top_idx in range(rank, top_k):
                                nb_semantic_correct[top_idx] += 1
                            break

                    # Generalization
                    found_full = False
                    pred_filter_success = 0
                    for rank, dec in enumerate(sp_decoded):
                        scored = False
                        num_success = 0
                        pred = dec[-1]
                        parse_success, cand_prog = simulator.get_prog_ast(pred)
                        if (not parse_success):
                            continue
                        generalizes = True
                        for (input_world,
                             output_world) in zip(sp_input_worlds,
                                                  sp_output_worlds):
                            res_emu = simulator.run_prog(
                                cand_prog, input_world)
                            if (res_emu.status != 'OK') or res_emu.crashed or (
                                    res_emu.outgrid != output_world):
                                # This prediction is semantically incorrect.
                                generalizes = False
                                break
                        if generalizes:
                            for (input_world,
                                 output_world) in zip(sp_test_input_worlds,
                                                      sp_test_output_worlds):
                                res_emu = simulator.run_prog(
                                    cand_prog, input_world)
                                if (res_emu.status !=
                                        'OK') or res_emu.crashed or (
                                            res_emu.outgrid != output_world):
                                    # This prediction is semantically incorrect.
                                    generalizes = False
                                    pred_filter_success = -1  # we picked something that was syntactically correct and passed the inputs, but fails on test
                                    # NOTE: our setup is that we are allowed to filter as much as we want using the input test cases, but only get 1 try for the real "test" test cases.
                                    break
                        if return_individual_results and generalizes:
                            current_results.append(pred)
                            num_success += 1
                            if num_success >= num_successful_targets_per_precursor:
                                found_full = True
                                break
                        if generalizes and not scored:
                            if pred_filter_success != -1:
                                nb_pred_filter_correct += 1
                            # Score for all the following ranks
                            for top_idx in range(rank, top_k):
                                nb_generalize_correct[top_idx] += 1
                            if return_individual_results:
                                scored = True
                            else:
                                break
                    if return_individual_results:
                        full_results.append(current_results)
                        if found_full:
                            num_found_full += 1

                    # Correct syntaxes
                    for rank, dec in enumerate(sp_decoded):
                        pred = dec[-1]
                        parse_success, cand_prog = simulator.get_prog_ast(pred)
                        if parse_success:
                            for top_idx in range(rank, top_k):
                                nb_syntax_correct[top_idx] += 1
                            break

            with open(pred_filter_path, 'w') as pred_filter_file:
                pred_filter_file.write(
                    str(100 * nb_pred_filter_correct / total_nb))
            for k in range(top_k):
                with open(str(all_outputs_path[k]), "w") as res_file:
                    res_file.write(str(100 * nb_correct[k] / total_nb))
                with open(str(all_semantic_output_path[k]),
                          "w") as sem_res_file:
                    sem_res_file.write(
                        str(100 * nb_semantic_correct[k] / total_nb))
                with open(str(all_syntax_output_path[k]), "w") as stx_res_file:
                    stx_res_file.write(
                        str(100 * nb_syntax_correct[k] / total_nb))
                with open(str(all_generalize_output_path[k]),
                          "w") as gen_res_file:
                    gen_res_file.write(
                        str(100 * nb_generalize_correct[k] / total_nb))

        if return_individual_results:
            print(str(datetime.datetime.now()))
            print('found full set of new targets for ' + str(num_found_full) +
                  ' out of ' + str(len(full_results)))
            unsort_idx = [sort_idx.index(i) for i in range(len(sort_idx))]
            full_results = [full_results[u] for u in unsort_idx]
            return full_results, (num_found_full, len(full_results))
        else:
            semantic_at_one = 100 * nb_semantic_correct[0] / total_nb
            return semantic_at_one