def evaluate_model(model_weights, vocabulary_path, dataset_path, nb_ios, nb_samples, use_grammar, output_path, beam_size, top_k, batch_size, use_cuda, dump_programs): all_outputs_path = [] all_semantic_output_path = [] all_syntax_output_path = [] all_generalize_output_path = [] res_dir = os.path.dirname(output_path) if not os.path.exists(res_dir): os.makedirs(res_dir) for k in range(top_k): new_term = "exactmatch_top%d.txt" % (k + 1) new_semantic_term = "semantic_top%d.txt" % (k + 1) new_syntax_term = "syntax_top%d.txt" % (k + 1) new_generalize_term = "fullgeneralize_top%d.txt" % (k + 1) new_file_name = output_path + new_term new_semantic_file_name = output_path + new_semantic_term new_syntax_file_name = output_path + new_syntax_term new_generalize_file_name = output_path + new_generalize_term all_outputs_path.append(new_file_name) all_semantic_output_path.append(new_semantic_file_name) all_syntax_output_path.append(new_syntax_file_name) all_generalize_output_path.append(new_generalize_file_name) program_dump_path = os.path.join(res_dir, "generated") if os.path.exists(all_outputs_path[0]): with open(all_outputs_path[0], "r") as out_file: out_file_content = out_file.read() print("Using cached result from {}".format(all_outputs_path[0])) print("Greedy select accuracy: {}".format(out_file_content)) return # Load the vocabulary of the trained model dataset, vocab = load_input_file(dataset_path, vocabulary_path) tgt_start = vocab["tkn2idx"]["<s>"] tgt_end = vocab["tkn2idx"]["m)"] tgt_pad = vocab["tkn2idx"]["<pad>"] simulator = Simulator(vocab["idx2tkn"]) # Load the model if not use_cuda: # https://discuss.pytorch.org/t/on-a-cpu-device-how-to-load-checkpoint-saved-on-gpu-device/349/8 # Is it failing? model = torch.load(model_weights, map_location=lambda storage, loc: storage) else: model = torch.load(model_weights) model.cuda() # And put it into evaluation mode model.eval() syntax_checker = PySyntaxChecker(vocab["tkn2idx"], use_cuda) if use_grammar: model.set_syntax_checker(syntax_checker) if beam_size == 1: top_k = 1 nb_correct = [0 for _ in range(top_k)] nb_semantic_correct = [0 for _ in range(top_k)] nb_syntax_correct = [0 for _ in range(top_k)] nb_generalize_correct = [0 for _ in range(top_k)] total_nb = 0 dataset = shuffle_dataset(dataset, batch_size, randomize=False) for sp_idx in tqdm(range(0, len(dataset["sources"]), batch_size)): inp_grids, out_grids, \ in_tgt_seq, in_tgt_seq_list, out_tgt_seq, \ inp_worlds, out_worlds, \ _, \ inp_test_worlds, out_test_worlds = get_minibatch(dataset, sp_idx, batch_size, tgt_start, tgt_end, tgt_pad, nb_ios, shuffle=False, volatile_vars=True) max_len = out_tgt_seq.size(1) + 10 if use_cuda: inp_grids, out_grids = inp_grids.cuda(), out_grids.cuda() in_tgt_seq, out_tgt_seq = in_tgt_seq.cuda(), out_tgt_seq.cuda() if dump_programs: import numpy as np decoder_logit, syntax_logit = model(inp_grids, out_grids, in_tgt_seq, in_tgt_seq_list) if syntax_logit is not None and model.decoder.learned_syntax_checker is not None: syntax_logit = syntax_logit.cpu().data.numpy() for n in range(in_tgt_seq.size(0)): decoded_dump_dir = os.path.join(program_dump_path, str(n + sp_idx)) if not os.path.exists(decoded_dump_dir): os.makedirs(decoded_dump_dir) seq = in_tgt_seq.cpu().data.numpy()[n].tolist() seq_len = seq.index(0) if 0 in seq else len(seq) file_name = str(n) + "_learned_syntax" norm_logit = syntax_logit[n, :seq_len] norm_logit = np.log(-norm_logit) norm_logit = 1 / (1 + np.exp(-norm_logit)) np.save(os.path.join(decoded_dump_dir, file_name), norm_logit) ini_state = syntax_checker.get_initial_checker_state() file_name = str(n) + "_manual_syntax" mask = syntax_checker.get_sequence_mask( ini_state, seq).squeeze().cpu().numpy()[:seq_len] np.save(os.path.join(decoded_dump_dir, file_name), mask) file_name = str(n) + "_diff" diff = mask.astype(float) - norm_logit diff = (diff + 1) / 2 # remap to [0,1] np.save(os.path.join(decoded_dump_dir, file_name), diff) decoded = model.beam_sample(inp_grids, out_grids, tgt_start, tgt_end, max_len, beam_size, top_k) for batch_idx, (target, sp_decoded, sp_input_worlds, sp_output_worlds, sp_test_input_worlds, sp_test_output_worlds) in \ enumerate(zip(out_tgt_seq.chunk(out_tgt_seq.size(0)), decoded, inp_worlds, out_worlds, inp_test_worlds, out_test_worlds)): total_nb += 1 target = target.cpu().data.squeeze().numpy().tolist() target = [tkn_idx for tkn_idx in target if tkn_idx != tgt_pad] if dump_programs: decoded_dump_dir = os.path.join(program_dump_path, str(batch_idx + sp_idx)) if not os.path.exists(decoded_dump_dir): os.makedirs(decoded_dump_dir) write_program(os.path.join(decoded_dump_dir, "target"), target, vocab["idx2tkn"]) for rank, dec in enumerate(sp_decoded): pred = dec[1] ll = dec[0] file_name = str(rank) + " - " + str(ll) write_program(os.path.join(decoded_dump_dir, file_name), pred, vocab["idx2tkn"]) # Exact matches for rank, dec in enumerate(sp_decoded): pred = dec[-1] if pred == target: # This prediction is correct. This means that we score for # all the following scores for top_idx in range(rank, top_k): nb_correct[top_idx] += 1 break # Semantic matches for rank, dec in enumerate(sp_decoded): pred = dec[-1] parse_success, cand_prog = simulator.get_prog_ast(pred) if (not parse_success): continue semantically_correct = True for (input_world, output_world) in zip(sp_input_worlds, sp_output_worlds): res_emu = simulator.run_prog(cand_prog, input_world) if (res_emu.status != 'OK') or res_emu.crashed or ( res_emu.outgrid != output_world): # This prediction is semantically incorrect. semantically_correct = False break if semantically_correct: # Score for all the following ranks for top_idx in range(rank, top_k): nb_semantic_correct[top_idx] += 1 break # Generalization for rank, dec in enumerate(sp_decoded): pred = dec[-1] parse_success, cand_prog = simulator.get_prog_ast(pred) if (not parse_success): continue generalizes = True for (input_world, output_world) in zip(sp_input_worlds, sp_output_worlds): res_emu = simulator.run_prog(cand_prog, input_world) if (res_emu.status != 'OK') or res_emu.crashed or ( res_emu.outgrid != output_world): # This prediction is semantically incorrect. generalizes = False break for (input_world, output_world) in zip(sp_test_input_worlds, sp_test_output_worlds): res_emu = simulator.run_prog(cand_prog, input_world) if (res_emu.status != 'OK') or res_emu.crashed or ( res_emu.outgrid != output_world): # This prediction is semantically incorrect. generalizes = False break if generalizes: # Score for all the following ranks for top_idx in range(rank, top_k): nb_generalize_correct[top_idx] += 1 break # Correct syntaxes for rank, dec in enumerate(sp_decoded): pred = dec[-1] parse_success, cand_prog = simulator.get_prog_ast(pred) if parse_success: for top_idx in range(rank, top_k): nb_syntax_correct[top_idx] += 1 break for k in range(top_k): with open(str(all_outputs_path[k]), "w") as res_file: res_file.write(str(100 * nb_correct[k] / total_nb)) with open(str(all_semantic_output_path[k]), "w") as sem_res_file: sem_res_file.write(str(100 * nb_semantic_correct[k] / total_nb)) with open(str(all_syntax_output_path[k]), "w") as stx_res_file: stx_res_file.write(str(100 * nb_syntax_correct[k] / total_nb)) with open(str(all_generalize_output_path[k]), "w") as gen_res_file: gen_res_file.write(str(100 * nb_generalize_correct[k] / total_nb)) semantic_at_one = 100 * nb_semantic_correct[0] / total_nb return semantic_at_one
def train_seq2seq_model( # Optimization signal, nb_ios, nb_epochs, optim_alg, batch_size, learning_rate, use_grammar, beta, val_frequency, # Model kernel_size, conv_stack, fc_stack, tgt_embedding_size, lstm_hidden_size, nb_lstm_layers, learn_syntax, # RL specific options environment, reward_comb, nb_rollouts, rl_beam, rl_inner_batch, rl_use_ref, # What to train train_file, val_file, vocab_file, nb_samples, initialisation, # Where to write results result_folder, args_dict, # Run options use_cuda, log_frequency): ############################# # Admin / Bookkeeping stuff # ############################# # Creating the results directory result_dir = Path(result_folder) if not result_dir.exists(): os.makedirs(str(result_dir)) else: # The result directory exists. Let's check whether or not all of our # work has already been done. # The sign of all the works being done would be the model after the # last epoch, let's check if it's here last_epoch_model_path = result_dir / "Weights" / ("weights_%d.model" % (nb_epochs - 1)) if last_epoch_model_path.exists(): print("{} already exists -- skipping this training".format( last_epoch_model_path)) return # Dumping the arguments args_dump_path = result_dir / "args.json" with open(str(args_dump_path), "w") as args_dump_file: json.dump(args_dict, args_dump_file, indent=2) # Setting up the logs log_file = result_dir / "logs.txt" logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filename=str(log_file), filemode='w') train_loss_path = result_dir / "train_loss.json" models_dir = result_dir / "Weights" if not models_dir.exists(): os.makedirs(str(models_dir)) time.sleep(1) # Let some time for the dir to be created ##################################### # Load Model / Dataset / Vocabulary # ##################################### # Load-up the dataset dataset, vocab = load_input_file(train_file, vocab_file) if use_grammar: syntax_checker = PySyntaxChecker(vocab["tkn2idx"], use_cuda) # Reduce the number of samples in the dataset, if needed if nb_samples > 0: # Randomize the dataset to shuffle it, because I'm not sure that there # is no meaning in the ordering of the samples random.seed(0) dataset = shuffle_dataset(dataset, batch_size) dataset = { 'sources': dataset['sources'][:nb_samples], 'targets': dataset['targets'][:nb_samples], } vocabulary_size = len(vocab["tkn2idx"]) if initialisation is None: # Create the model model = IOs2Seq(kernel_size, conv_stack, fc_stack, vocabulary_size, tgt_embedding_size, lstm_hidden_size, nb_lstm_layers, learn_syntax) # Dump initial weights path_to_ini_weight_dump = models_dir / "ini_weights.model" with open(str(path_to_ini_weight_dump), "wb") as weight_file: torch.save(model, weight_file) else: model = torch.load(initialisation, map_location=lambda storage, loc: storage) if use_grammar: model.set_syntax_checker(syntax_checker) tgt_start = vocab["tkn2idx"]["<s>"] tgt_end = vocab["tkn2idx"]["m)"] tgt_pad = vocab["tkn2idx"]["<pad>"] ############################################ # Setup Loss / Optimizer / Eventual Critic # ############################################ if signal == TrainSignal.SUPERVISED: # Create a mask to not penalize bad prediction on the padding weight_mask = torch.ones(vocabulary_size) weight_mask[tgt_pad] = 0 # Setup the criterion loss_criterion = nn.CrossEntropyLoss(weight=weight_mask) elif signal == TrainSignal.RL or signal == TrainSignal.BEAM_RL: simulator = Simulator(vocab["idx2tkn"]) if signal == TrainSignal.BEAM_RL: reward_comb_fun = RewardCombinationFun[reward_comb] else: raise Exception("Unknown TrainingSignal.") if use_cuda: model.cuda() if signal == TrainSignal.SUPERVISED: loss_criterion.cuda() # Setup the optimizers optimizer_cls = getattr(optim, optim_alg) optimizer = optimizer_cls(model.parameters(), lr=learning_rate) ##################### # ################# # # # Training Loop # # # ################# # ##################### losses = [] recent_losses = [] best_val_acc = np.NINF for epoch_idx in range(0, nb_epochs): nb_ios_for_epoch = nb_ios # This is definitely not the most efficient way to do it but oh well dataset = shuffle_dataset(dataset, batch_size) for sp_idx in tqdm(range(0, len(dataset["sources"]), batch_size)): batch_idx = int(sp_idx / batch_size) optimizer.zero_grad() if signal == TrainSignal.SUPERVISED: inp_grids, out_grids, \ in_tgt_seq, in_tgt_seq_list, out_tgt_seq, \ _, _, _, _, _ = get_minibatch(dataset, sp_idx, batch_size, tgt_start, tgt_end, tgt_pad, nb_ios_for_epoch) if use_cuda: inp_grids, out_grids = inp_grids.cuda(), out_grids.cuda() in_tgt_seq, out_tgt_seq = in_tgt_seq.cuda( ), out_tgt_seq.cuda() if learn_syntax: minibatch_loss = do_syntax_weighted_minibatch( model, inp_grids, out_grids, in_tgt_seq, in_tgt_seq_list, out_tgt_seq, loss_criterion, beta) else: minibatch_loss = do_supervised_minibatch( model, inp_grids, out_grids, in_tgt_seq, in_tgt_seq_list, out_tgt_seq, loss_criterion) recent_losses.append(minibatch_loss) elif signal == TrainSignal.RL or signal == TrainSignal.BEAM_RL: inp_grids, out_grids, \ _, _, _, \ inp_worlds, out_worlds, \ targets, \ inp_test_worlds, out_test_worlds = get_minibatch(dataset, sp_idx, batch_size, tgt_start, tgt_end, tgt_pad, nb_ios_for_epoch) if use_cuda: inp_grids, out_grids = inp_grids.cuda(), out_grids.cuda() # We use 1/nb_rollouts as the reward to normalize wrt the # size of the rollouts if signal == TrainSignal.RL: reward_norm = 1 / float(nb_rollouts) elif signal == TrainSignal.BEAM_RL: reward_norm = 1 else: raise NotImplementedError("Unknown training signal") lens = [len(target) for target in targets] max_len = max(lens) + 10 env_cls = EnvironmentClasses[environment] if "Consistency" in environment: envs = [ env_cls(reward_norm, trg_prog, sp_inp_worlds, sp_out_worlds, simulator) for trg_prog, sp_inp_worlds, sp_out_worlds in zip( targets, inp_worlds, out_worlds) ] elif "Generalization" or "Perf" in environment: envs = [ env_cls(reward_norm, trg_prog, sp_inp_test_worlds, sp_out_test_worlds, simulator) for trg_prog, sp_inp_test_worlds, sp_out_test_worlds in zip(targets, inp_test_worlds, out_test_worlds) ] else: raise NotImplementedError("Unknown environment type") if signal == TrainSignal.RL: minibatch_reward = do_rl_minibatch(model, inp_grids, out_grids, envs, tgt_start, tgt_end, max_len, nb_rollouts) # minibatch_reward = do_rl_minibatch_two_steps(model, # inp_grids, out_grids, # envs, # tgt_start, tgt_end, tgt_pad, # max_len, nb_rollouts, # rl_inner_batch) elif signal == TrainSignal.BEAM_RL: minibatch_reward = do_beam_rl(model, inp_grids, out_grids, targets, envs, reward_comb_fun, tgt_start, tgt_end, tgt_pad, max_len, rl_beam, rl_inner_batch, rl_use_ref) else: raise NotImplementedError("Unknown Environment type") recent_losses.append(minibatch_reward) else: raise NotImplementedError("Unknown Training method") optimizer.step() if (batch_idx % log_frequency == log_frequency-1 and len(recent_losses) > 0) or \ (len(dataset["sources"]) - sp_idx ) < batch_size: logging.info('Epoch : %d Minibatch : %d Loss : %.5f' % (epoch_idx, batch_idx, sum(recent_losses) / len(recent_losses))) losses.extend(recent_losses) recent_losses = [] # Dump the training losses with open(str(train_loss_path), "w") as train_loss_file: json.dump(losses, train_loss_file, indent=2) if signal == TrainSignal.BEAM_RL: # RL is much slower so we dump more frequently path_to_weight_dump = models_dir / ("weights_%d.model" % epoch_idx) with open(str(path_to_weight_dump), "wb") as weight_file: # Needs to be in cpu mode to dump, otherwise will be annoying to load if use_cuda: model.cpu() torch.save(model, weight_file) if use_cuda: model.cuda() # Dump the weights at the end of the epoch path_to_weight_dump = models_dir / ("weights_%d.model" % epoch_idx) with open(str(path_to_weight_dump), "wb") as weight_file: # Needs to be in cpu mode to dump, otherwise will be annoying to load if use_cuda: model.cpu() torch.save(model, weight_file) if use_cuda: model.cuda() previous_weight_dump = models_dir / ("weights_%d.model" % (epoch_idx - 1)) if previous_weight_dump.exists(): os.remove(str(previous_weight_dump)) # Dump the training losses with open(str(train_loss_path), "w") as train_loss_file: json.dump(losses, train_loss_file, indent=2) logging.info("Done with epoch %d." % epoch_idx) if (epoch_idx + 1) % val_frequency == 0 or (epoch_idx + 1) == nb_epochs: # Evaluate the model on the validation set out_path = str(result_dir / ("eval/epoch_%d/val_.txt" % epoch_idx)) val_acc = evaluate_model(str(path_to_weight_dump), vocab_file, val_file, 5, 0, use_grammar, out_path, 100, 50, batch_size, use_cuda, False) logging.info("Epoch : %d ValidationAccuracy : %f." % (epoch_idx, val_acc)) if val_acc > best_val_acc: logging.info("Epoch : %d ValidationBest : %f." % (epoch_idx, val_acc)) best_val_acc = val_acc path_to_weight_dump = models_dir / "best.model" with open(str(path_to_weight_dump), "wb") as weight_file: # Needs to be in cpu mode to dump, otherwise will be annoying to load if use_cuda: model.cpu() torch.save(model, weight_file) if use_cuda: model.cuda()
def evaluate_model(model_weights, vocabulary_path, dataset_path, nb_ios, nb_samples, use_grammar, output_path, beam_size, top_k, batch_size, use_cuda, dump_programs, return_individual_results=False, use_beam=True, num_successful_targets_per_precursor=None, preloaded_data=False, minibatch_cache_path=None): if return_individual_results: print('beginning evaluate_model for data augmentation') all_outputs_path = [] all_semantic_output_path = [] all_syntax_output_path = [] all_generalize_output_path = [] pred_filter_path = 'pred_filter_top1_l%d.txt' % top_k pred_filter_path = output_path + pred_filter_path res_dir = os.path.dirname(output_path) if not os.path.exists(res_dir): os.makedirs(res_dir) for k in range(top_k): new_term = "exactmatch_top%d.txt" % (k + 1) new_semantic_term = "semantic_top%d.txt" % (k + 1) new_syntax_term = "syntax_top%d.txt" % (k + 1) new_generalize_term = "fullgeneralize_top%d.txt" % (k + 1) new_file_name = output_path + new_term new_semantic_file_name = output_path + new_semantic_term new_syntax_file_name = output_path + new_syntax_term new_generalize_file_name = output_path + new_generalize_term all_outputs_path.append(new_file_name) all_semantic_output_path.append(new_semantic_file_name) all_syntax_output_path.append(new_syntax_file_name) all_generalize_output_path.append(new_generalize_file_name) program_dump_path = os.path.join(res_dir, "generated") # if os.path.exists(all_outputs_path[0]): # with open(all_outputs_path[0], "r") as out_file: # out_file_content = out_file.read() # print("Using cached result from {}".format(all_outputs_path[0])) # print("Greedy select accuracy: {}".format(out_file_content)) # return # Load the vocabulary of the trained model if preloaded_data: dataset, vocab = pickle.loads(dataset_path), pickle.loads( vocabulary_path) # pre-loaded else: dataset, vocab = load_input_file(dataset_path, vocabulary_path) tgt_start = vocab["tkn2idx"]["<s>"] tgt_end = vocab["tkn2idx"]["m)"] tgt_pad = vocab["tkn2idx"]["<pad>"] simulator = Simulator(vocab["idx2tkn"]) # Load the model if not use_cuda: # https://discuss.pytorch.org/t/on-a-cpu-device-how-to-load-checkpoint-saved-on-gpu-device/349/8 # Is it failing? model = torch.load(model_weights, map_location=lambda storage, loc: storage) else: model = torch.load(model_weights) model.cuda() # And put it into evaluation mode model.eval() syntax_checker = PySyntaxChecker(vocab["tkn2idx"], use_cuda) if use_grammar: model.set_syntax_checker(syntax_checker) if beam_size == 1: top_k = 1 nb_correct = [0 for _ in range(top_k)] nb_semantic_correct = [0 for _ in range(top_k)] nb_syntax_correct = [0 for _ in range(top_k)] nb_generalize_correct = [0 for _ in range(top_k)] nb_pred_filter_correct = 0 total_nb = 0 with torch.no_grad(): if minibatch_cache_path is None or not os.path.isdir( minibatch_cache_path): if minibatch_cache_path is None: minibatch_cache_path = output_path dataset, sort_idx = shuffle_dataset(dataset, batch_size, randomize=False, return_sort_idx=True) if not os.path.isdir(minibatch_cache_path): os.makedirs(minibatch_cache_path) with open( os.path.join(minibatch_cache_path, 'dataset_sortidx.pkl'), 'wb') as f: pickle.dump((dataset, sort_idx), f) if not os.path.exists(os.path.dirname(output_path)): os.makedirs(os.path.dirname(output_path)) print('create inputs for saving minibatches') save_minibatch_inputs = [] for sp_idx in tqdm(range(0, len(dataset["sources"]), batch_size)): sources = dataset['sources'][sp_idx:sp_idx + batch_size] targets = dataset['targets'][sp_idx:sp_idx + batch_size] minibatch_output_path = make_minibatch_path( minibatch_cache_path, sp_idx) save_minibatch_inputs.append( (sources, targets, minibatch_output_path, batch_size, tgt_start, tgt_end, tgt_pad, nb_ios, False, True)) print('save minibatches') print(str(datetime.datetime.now())) for i in range(0, len(save_minibatch_inputs), 60): # to reset ram periodically... pool = Pool(processes=30, maxtasksperchild=1) pool.map(save_minibatch_parallel, save_minibatch_inputs[i:i + 60], chunksize=1) pool.close() print(str(datetime.datetime.now())) else: with open( os.path.join(minibatch_cache_path, 'dataset_sortidx.pkl'), 'rb') as f: dataset, sort_idx = pickle.load(f) print('sample decodings from model') all_decoded = [] load_count = 0 pool = Pool(processes=30, maxtasksperchild=1) for sp_idx in tqdm(range(0, len(dataset["sources"]), batch_size)): if load_count % PARALLEL_LOAD_MINIBATCHES == 0: minibatch_load_paths = [] for i in range(PARALLEL_LOAD_MINIBATCHES): mb_sp_idx = sp_idx + i * batch_size if mb_sp_idx < len(dataset["sources"]): minibatch_load_paths.append( make_minibatch_path(minibatch_cache_path, mb_sp_idx)) loaded_minibatches = None loaded_minibatches = pool.map(load_minibatch_parallel, minibatch_load_paths, chunksize=1) minibatch = loaded_minibatches[load_count % PARALLEL_LOAD_MINIBATCHES] load_count += 1 inp_grids, out_grids, \ in_tgt_seq, in_tgt_seq_list, out_tgt_seq, \ inp_worlds, out_worlds, \ _, \ inp_test_worlds, out_test_worlds = minibatch if return_individual_results: with open(make_minibatch_path(minibatch_cache_path, sp_idx), 'wb') as f: pickle.dump(minibatch, f, protocol=pickle.HIGHEST_PROTOCOL) max_len = out_tgt_seq.size(1) + 10 if use_cuda: inp_grids, out_grids = inp_grids.cuda(), out_grids.cuda() in_tgt_seq, out_tgt_seq = in_tgt_seq.cuda(), out_tgt_seq.cuda() if dump_programs: decoder_logit, syntax_logit = model(inp_grids, out_grids, in_tgt_seq, in_tgt_seq_list) if syntax_logit is not None and model.decoder.learned_syntax_checker is not None: syntax_logit = syntax_logit.cpu().data.numpy() for n in range(in_tgt_seq.size(0)): decoded_dump_dir = os.path.join( program_dump_path, str(n + sp_idx)) if not os.path.exists(decoded_dump_dir): os.makedirs(decoded_dump_dir) seq = in_tgt_seq.cpu().data.numpy()[n].tolist() seq_len = seq.index(0) if 0 in seq else len(seq) file_name = str(n) + "_learned_syntax" norm_logit = syntax_logit[n, :seq_len] norm_logit = np.log(-norm_logit) norm_logit = 1 / (1 + np.exp(-norm_logit)) np.save(os.path.join(decoded_dump_dir, file_name), norm_logit) ini_state = syntax_checker.get_initial_checker_state() file_name = str(n) + "_manual_syntax" mask = syntax_checker.get_sequence_mask( ini_state, seq).squeeze().cpu().numpy()[:seq_len] np.save(os.path.join(decoded_dump_dir, file_name), mask) file_name = str(n) + "_diff" diff = mask.astype(float) - norm_logit diff = (diff + 1) / 2 # remap to [0,1] np.save(os.path.join(decoded_dump_dir, file_name), diff) decoded = model.beam_sample(inp_grids, out_grids, tgt_start, tgt_end, max_len, beam_size, top_k, True, use_beam) all_decoded.append((sp_idx, decoded)) if load_count % PARALLEL_LOAD_MINIBATCHES == PARALLEL_LOAD_MINIBATCHES - 1: # save decodings decoded_load_paths = [] for d_sp_idx, decoded in all_decoded: decoded_load_paths.append( make_decoded_path(output_path, d_sp_idx)) pool.map(save_decoded_parallel, [(decoded, path) for decoded, path in zip( [tup[1] for tup in all_decoded], decoded_load_paths)], chunksize=1) all_decoded = [] if len(all_decoded) > 0: # save decodings decoded_load_paths = [] for d_sp_idx, decoded in all_decoded: decoded_load_paths.append( make_decoded_path(output_path, d_sp_idx)) pool.map(save_decoded_parallel, [(decoded, path) for decoded, path in zip( [tup[1] for tup in all_decoded], decoded_load_paths)], chunksize=1) all_decoded = [] pool.close() if return_individual_results: print('done sampling from model, starting filter') generalization_inputs = [] vocab_pickle = pickle.dumps(vocab, protocol=pickle.HIGHEST_PROTOCOL) for sp_idx in tqdm(range(0, len(dataset["sources"]), batch_size)): generalization_inputs.append( (minibatch_cache_path, sp_idx, output_path, vocab_pickle, tgt_pad, num_successful_targets_per_precursor, top_k)) print('starting pool.map') pool = Pool(processes=30, maxtasksperchild=1) print(str(datetime.datetime.now())) all_batch_results = pool.map(find_generalizations_parallel, generalization_inputs, chunksize=1) print(str(datetime.datetime.now())) pool.close() full_results = [] num_found_full = 0 for batch_results, stats in all_batch_results: full_results += batch_results num_found_full += stats[0] else: for sp_idx in tqdm(range(0, len(dataset["sources"]), batch_size)): inp_grids, out_grids, \ in_tgt_seq, in_tgt_seq_list, out_tgt_seq, \ inp_worlds, out_worlds, \ _, \ inp_test_worlds, out_test_worlds = get_minibatch(dataset, sp_idx, batch_size, tgt_start, tgt_end, tgt_pad, nb_ios, shuffle=False, volatile_vars=True) max_len = out_tgt_seq.size(1) + 10 # decoded = model.beam_sample(inp_grids, out_grids, # tgt_start, tgt_end, max_len, # beam_size, top_k) # decoded = all_decoded[int(sp_idx / batch_size)] with open(make_decoded_path(output_path, sp_idx), 'rb') as f: decoded = pickle.load(f) for batch_idx, (target, sp_decoded, sp_input_worlds, sp_output_worlds, sp_test_input_worlds, sp_test_output_worlds) in \ enumerate(zip(out_tgt_seq.chunk(out_tgt_seq.size(0)), decoded, inp_worlds, out_worlds, inp_test_worlds, out_test_worlds)): if return_individual_results: current_results = [] total_nb += 1 target = target.data.squeeze().numpy().tolist() target = [ tkn_idx for tkn_idx in target if tkn_idx != tgt_pad ] if dump_programs: decoded_dump_dir = os.path.join( program_dump_path, str(batch_idx + sp_idx)) if not os.path.exists(decoded_dump_dir): os.makedirs(decoded_dump_dir) write_program(os.path.join(decoded_dump_dir, "target"), target, vocab["idx2tkn"]) for rank, dec in enumerate(sp_decoded): pred = dec[1] ll = dec[0] file_name = str(rank) + " - " + str(ll) write_program( os.path.join(decoded_dump_dir, file_name), pred, vocab["idx2tkn"]) # Exact matches for rank, dec in enumerate(sp_decoded): pred = dec[-1] if pred == target: # This prediction is correct. This means that we score for # all the following scores for top_idx in range(rank, top_k): nb_correct[top_idx] += 1 break # Semantic matches for rank, dec in enumerate(sp_decoded): pred = dec[-1] parse_success, cand_prog = simulator.get_prog_ast(pred) if (not parse_success): continue semantically_correct = True for (input_world, output_world) in zip(sp_input_worlds, sp_output_worlds): res_emu = simulator.run_prog( cand_prog, input_world) if (res_emu.status != 'OK') or res_emu.crashed or ( res_emu.outgrid != output_world): # This prediction is semantically incorrect. semantically_correct = False break if semantically_correct: # Score for all the following ranks for top_idx in range(rank, top_k): nb_semantic_correct[top_idx] += 1 break # Generalization found_full = False pred_filter_success = 0 for rank, dec in enumerate(sp_decoded): scored = False num_success = 0 pred = dec[-1] parse_success, cand_prog = simulator.get_prog_ast(pred) if (not parse_success): continue generalizes = True for (input_world, output_world) in zip(sp_input_worlds, sp_output_worlds): res_emu = simulator.run_prog( cand_prog, input_world) if (res_emu.status != 'OK') or res_emu.crashed or ( res_emu.outgrid != output_world): # This prediction is semantically incorrect. generalizes = False break if generalizes: for (input_world, output_world) in zip(sp_test_input_worlds, sp_test_output_worlds): res_emu = simulator.run_prog( cand_prog, input_world) if (res_emu.status != 'OK') or res_emu.crashed or ( res_emu.outgrid != output_world): # This prediction is semantically incorrect. generalizes = False pred_filter_success = -1 # we picked something that was syntactically correct and passed the inputs, but fails on test # NOTE: our setup is that we are allowed to filter as much as we want using the input test cases, but only get 1 try for the real "test" test cases. break if return_individual_results and generalizes: current_results.append(pred) num_success += 1 if num_success >= num_successful_targets_per_precursor: found_full = True break if generalizes and not scored: if pred_filter_success != -1: nb_pred_filter_correct += 1 # Score for all the following ranks for top_idx in range(rank, top_k): nb_generalize_correct[top_idx] += 1 if return_individual_results: scored = True else: break if return_individual_results: full_results.append(current_results) if found_full: num_found_full += 1 # Correct syntaxes for rank, dec in enumerate(sp_decoded): pred = dec[-1] parse_success, cand_prog = simulator.get_prog_ast(pred) if parse_success: for top_idx in range(rank, top_k): nb_syntax_correct[top_idx] += 1 break with open(pred_filter_path, 'w') as pred_filter_file: pred_filter_file.write( str(100 * nb_pred_filter_correct / total_nb)) for k in range(top_k): with open(str(all_outputs_path[k]), "w") as res_file: res_file.write(str(100 * nb_correct[k] / total_nb)) with open(str(all_semantic_output_path[k]), "w") as sem_res_file: sem_res_file.write( str(100 * nb_semantic_correct[k] / total_nb)) with open(str(all_syntax_output_path[k]), "w") as stx_res_file: stx_res_file.write( str(100 * nb_syntax_correct[k] / total_nb)) with open(str(all_generalize_output_path[k]), "w") as gen_res_file: gen_res_file.write( str(100 * nb_generalize_correct[k] / total_nb)) if return_individual_results: print(str(datetime.datetime.now())) print('found full set of new targets for ' + str(num_found_full) + ' out of ' + str(len(full_results))) unsort_idx = [sort_idx.index(i) for i in range(len(sort_idx))] full_results = [full_results[u] for u in unsort_idx] return full_results, (num_found_full, len(full_results)) else: semantic_at_one = 100 * nb_semantic_correct[0] / total_nb return semantic_at_one