def predict_and_save(dataset: GroundedScanDataset, model: nn.Module, output_file_path: str, max_decoding_steps: int, max_testing_examples=None, **kwargs): """ Predict all data in dataset with a model and write the predictions to output_file_path. :param dataset: a dataset with test examples :param model: a trained model from model.py :param output_file_path: a path where a .json file with predictions will be saved. :param max_decoding_steps: after how many steps to force quit decoding :param max_testing_examples: after how many examples to stop predicting, if None all examples will be evaluated """ cfg = locals().copy() with open(output_file_path, mode='w') as outfile: output = [] with torch.no_grad(): i = 0 for (input_sequence, derivation_spec, situation_spec, output_sequence, target_sequence, attention_weights_commands, attention_weights_situations, position_accuracy) in predict( dataset.get_data_iterator(batch_size=1), model=model, max_decoding_steps=max_decoding_steps, pad_idx=dataset.target_vocabulary.pad_idx, sos_idx=dataset.target_vocabulary.sos_idx, eos_idx=dataset.target_vocabulary.eos_idx): i += 1 accuracy = sequence_accuracy(output_sequence, target_sequence[0].tolist()[1:-1]) input_str_sequence = dataset.array_to_sentence( input_sequence[0].tolist(), vocabulary="input") input_str_sequence = input_str_sequence[ 1:-1] # Get rid of <SOS> and <EOS> target_str_sequence = dataset.array_to_sentence( target_sequence[0].tolist(), vocabulary="target") target_str_sequence = target_str_sequence[ 1:-1] # Get rid of <SOS> and <EOS> output_str_sequence = dataset.array_to_sentence( output_sequence, vocabulary="target") output.append({ "input": input_str_sequence, "prediction": output_str_sequence, "derivation": derivation_spec, "target": target_str_sequence, "situation": situation_spec, "attention_weights_input": attention_weights_commands, "attention_weights_situation": attention_weights_situations, "accuracy": accuracy, "exact_match": True if accuracy == 100 else False, "position_accuracy": position_accuracy }) logger.info("Wrote predictions for {} examples.".format(i)) json.dump(output, outfile, indent=4) return output_file_path
def predict_and_save(dataset: GroundedScanDataset, model: nn.Module, output_file_path: str, max_decoding_steps: int, max_testing_examples=None, **kwargs): """ :param dataset: :param model: :param output_file_path: :param max_decoding_steps: :param max_testing_examples: :param kwargs: :return: """ cfg = locals().copy() with open(output_file_path, mode='w') as outfile: output = [] with torch.no_grad(): i = 0 for (input_sequence, derivation_spec, situation_spec, output_sequence, target_sequence, attention_weights_commands, attention_weights_situations, _) in predict( dataset.get_data_iterator(batch_size=1), model=model, max_decoding_steps=max_decoding_steps, pad_idx=dataset.target_vocabulary.pad_idx, sos_idx=dataset.target_vocabulary.sos_idx, eos_idx=dataset.target_vocabulary.eos_idx): i += 1 accuracy = sequence_accuracy(output_sequence, target_sequence[0].tolist()[1:-1]) input_str_sequence = dataset.array_to_sentence(input_sequence[0].tolist(), vocabulary="input") input_str_sequence = input_str_sequence[1:-1] # Get rid of <SOS> and <EOS> target_str_sequence = dataset.array_to_sentence(target_sequence[0].tolist(), vocabulary="target") target_str_sequence = target_str_sequence[1:-1] # Get rid of <SOS> and <EOS> output_str_sequence = dataset.array_to_sentence(output_sequence, vocabulary="target") output.append({"input": input_str_sequence, "prediction": output_str_sequence, "derivation": derivation_spec, "target": target_str_sequence, "situation": situation_spec, "attention_weights_input": attention_weights_commands, "attention_weights_situation": attention_weights_situations, "accuracy": accuracy, "exact_match": True if accuracy == 100 else False}) logger.info("Wrote predictions for {} examples.".format(i)) json.dump(output, outfile, indent=4) return output_file_path
def train( data_path: str, data_directory: str, generate_vocabularies: bool, input_vocab_path: str, target_vocab_path: str, embedding_dimension: int, num_encoder_layers: int, encoder_dropout_p: float, encoder_bidirectional: bool, training_batch_size: int, test_batch_size: int, max_decoding_steps: int, num_decoder_layers: int, decoder_dropout_p: float, cnn_kernel_size: int, cnn_dropout_p: float, cnn_hidden_num_channels: int, simple_situation_representation: bool, decoder_hidden_size: int, encoder_hidden_size: int, learning_rate: float, adam_beta_1: float, adam_beta_2: float, lr_decay: float, lr_decay_steps: int, resume_from_file: str, max_training_iterations: int, output_directory: str, print_every: int, evaluate_every: int, conditional_attention: bool, auxiliary_task: bool, weight_target_loss: float, attention_type: str, k: int, max_training_examples, max_testing_examples, # SeqGAN params begin pretrain_gen_path, pretrain_gen_epochs, pretrain_disc_path, pretrain_disc_epochs, rollout_trails, rollout_update_rate, disc_emb_dim, disc_hid_dim, load_tensors_from_path, # SeqGAN params end seed=42, **kwargs): device = torch.device("cpu") cfg = locals().copy() torch.manual_seed(seed) logger.info("Loading Training set...") training_set = GroundedScanDataset( data_path, data_directory, split="train", input_vocabulary_file=input_vocab_path, target_vocabulary_file=target_vocab_path, generate_vocabulary=generate_vocabularies, k=k) training_set.read_dataset( max_examples=max_training_examples, simple_situation_representation=simple_situation_representation, load_tensors_from_path=load_tensors_from_path ) # set this to False if no pickle file available logger.info("Done Loading Training set.") logger.info(" Loaded {} training examples.".format( training_set.num_examples)) logger.info(" Input vocabulary size training set: {}".format( training_set.input_vocabulary_size)) logger.info(" Most common input words: {}".format( training_set.input_vocabulary.most_common(5))) logger.info(" Output vocabulary size training set: {}".format( training_set.target_vocabulary_size)) logger.info(" Most common target words: {}".format( training_set.target_vocabulary.most_common(5))) if generate_vocabularies: training_set.save_vocabularies(input_vocab_path, target_vocab_path) logger.info( "Saved vocabularies to {} for input and {} for target.".format( input_vocab_path, target_vocab_path)) # logger.info("Loading Dev. set...") # test_set = GroundedScanDataset(data_path, data_directory, split="dev", # input_vocabulary_file=input_vocab_path, # target_vocabulary_file=target_vocab_path, generate_vocabulary=False, k=0) # test_set.read_dataset(max_examples=max_testing_examples, # simple_situation_representation=simple_situation_representation) # # # Shuffle the test set to make sure that if we only evaluate max_testing_examples we get a random part of the set. # test_set.shuffle_data() # logger.info("Done Loading Dev. set.") generator = Model( input_vocabulary_size=training_set.input_vocabulary_size, target_vocabulary_size=training_set.target_vocabulary_size, num_cnn_channels=training_set.image_channels, input_padding_idx=training_set.input_vocabulary.pad_idx, target_pad_idx=training_set.target_vocabulary.pad_idx, target_eos_idx=training_set.target_vocabulary.eos_idx, **cfg) total_vocabulary = set( list(training_set.input_vocabulary._word_to_idx.keys()) + list(training_set.target_vocabulary._word_to_idx.keys())) total_vocabulary_size = len(total_vocabulary) discriminator = Discriminator(embedding_dim=disc_emb_dim, hidden_dim=disc_hid_dim, vocab_size=total_vocabulary_size, max_seq_len=max_decoding_steps) generator = generator.cuda() if use_cuda else generator discriminator = discriminator.cuda() if use_cuda else discriminator rollout = Rollout(generator, rollout_update_rate) log_parameters(generator) trainable_parameters = [ parameter for parameter in generator.parameters() if parameter.requires_grad ] optimizer = torch.optim.Adam(trainable_parameters, lr=learning_rate, betas=(adam_beta_1, adam_beta_2)) scheduler = LambdaLR(optimizer, lr_lambda=lambda t: lr_decay**(t / lr_decay_steps)) # Load model and vocabularies if resuming. start_iteration = 1 best_iteration = 1 best_accuracy = 0 best_exact_match = 0 best_loss = float('inf') if resume_from_file: assert os.path.isfile( resume_from_file), "No checkpoint found at {}".format( resume_from_file) logger.info( "Loading checkpoint from file at '{}'".format(resume_from_file)) optimizer_state_dict = generator.load_model(resume_from_file) optimizer.load_state_dict(optimizer_state_dict) start_iteration = generator.trained_iterations logger.info("Loaded checkpoint '{}' (iter {})".format( resume_from_file, start_iteration)) if pretrain_gen_path is None: print('Pretraining generator with MLE...') pre_train_generator(training_set, training_batch_size, generator, seed, pretrain_gen_epochs, name='pretrained_generator') else: print('Load pretrained generator weights') generator_weights = torch.load(pretrain_gen_path) generator.load_state_dict(generator_weights) if pretrain_disc_path is None: print('Pretraining Discriminator....') train_discriminator(training_set, discriminator, training_batch_size, generator, seed, pretrain_disc_epochs, name="pretrained_discriminator") else: print('Loading Discriminator....') discriminator_weights = torch.load(pretrain_disc_path) discriminator.load_state_dict(discriminator_weights) logger.info("Training starts..") training_iteration = start_iteration torch.autograd.set_detect_anomaly(True) while training_iteration < max_training_iterations: # Shuffle the dataset and loop over it. training_set.shuffle_data() for (input_batch, input_lengths, _, situation_batch, _, target_batch, target_lengths, agent_positions, target_positions) in \ training_set.get_data_iterator(batch_size=training_batch_size): is_best = False generator.train() # Forward pass. samples = generator.sample( batch_size=training_batch_size, max_seq_len=max(target_lengths).astype(int), commands_input=input_batch, commands_lengths=input_lengths, situations_input=situation_batch, target_batch=target_batch, sos_idx=training_set.input_vocabulary.sos_idx, eos_idx=training_set.input_vocabulary.eos_idx) rewards = rollout.get_reward(samples, rollout_trails, input_batch, input_lengths, situation_batch, target_batch, training_set.input_vocabulary.sos_idx, training_set.input_vocabulary.eos_idx, discriminator) assert samples.shape == rewards.shape # calculate rewards rewards = torch.exp(rewards).contiguous().view((-1, )) if use_cuda: rewards = rewards.cuda() # get generator scores for sequence target_scores = generator.get_normalized_logits( commands_input=input_batch, commands_lengths=input_lengths, situations_input=situation_batch, samples=samples, sample_lengths=target_lengths, sos_idx=training_set.input_vocabulary.sos_idx) del samples # calculate loss on the generated sequence given the rewards loss = generator.get_gan_loss(target_scores, target_batch, rewards) del rewards # Backward pass and update model parameters. loss.backward() optimizer.step() scheduler.step(training_iteration) optimizer.zero_grad() generator.update_state(is_best=is_best) # Print current metrics. if training_iteration % print_every == 0: # accuracy, exact_match = generator.get_metrics(target_scores, target_batch) learning_rate = scheduler.get_lr()[0] logger.info("Iteration %08d, loss %8.4f, learning_rate %.5f," % (training_iteration, loss, learning_rate)) # logger.info("Iteration %08d, loss %8.4f, accuracy %5.2f, exact match %5.2f, learning_rate %.5f," # % (training_iteration, loss, accuracy, exact_match, learning_rate)) del target_scores, target_batch # # Evaluate on test set. # if training_iteration % evaluate_every == 0: # with torch.no_grad(): # generator.eval() # logger.info("Evaluating..") # accuracy, exact_match, target_accuracy = evaluate( # test_set.get_data_iterator(batch_size=1), model=generator, # max_decoding_steps=max_decoding_steps, pad_idx=test_set.target_vocabulary.pad_idx, # sos_idx=test_set.target_vocabulary.sos_idx, # eos_idx=test_set.target_vocabulary.eos_idx, # max_examples_to_evaluate=kwargs["max_testing_examples"]) # logger.info(" Evaluation Accuracy: %5.2f Exact Match: %5.2f " # " Target Accuracy: %5.2f" % (accuracy, exact_match, target_accuracy)) # if exact_match > best_exact_match: # is_best = True # best_accuracy = accuracy # best_exact_match = exact_match # generator.update_state(accuracy=accuracy, exact_match=exact_match, is_best=is_best) # file_name = "checkpoint.pth.tar".format(str(training_iteration)) # if is_best: # generator.save_checkpoint(file_name=file_name, is_best=is_best, # optimizer_state_dict=optimizer.state_dict()) rollout.update_params() train_discriminator(training_set, discriminator, training_batch_size, generator, seed, epochs=1, name="training_discriminator") training_iteration += 1 if training_iteration > max_training_iterations: break del loss torch.save( generator.state_dict(), '{}/{}'.format(output_directory, 'gen_{}_{}.ckpt'.format(training_iteration, seed))) torch.save( discriminator.state_dict(), '{}/{}'.format(output_directory, 'dis_{}_{}.ckpt'.format(training_iteration, seed))) logger.info("Finished training.")
def train(data_path: str, data_directory: str, generate_vocabularies: bool, input_vocab_path: str, target_vocab_path: str, embedding_dimension: int, num_encoder_layers: int, encoder_dropout_p: float, encoder_bidirectional: bool, training_batch_size: int, test_batch_size: int, max_decoding_steps: int, num_decoder_layers: int, decoder_dropout_p: float, cnn_kernel_size: int, cnn_dropout_p: float, cnn_hidden_num_channels: int, simple_situation_representation: bool, decoder_hidden_size: int, encoder_hidden_size: int, learning_rate: float, adam_beta_1: float, adam_beta_2: float, lr_decay: float, lr_decay_steps: int, resume_from_file: str, max_training_iterations: int, output_directory: str, print_every: int, evaluate_every: int, conditional_attention: bool, auxiliary_task: bool, weight_target_loss: float, attention_type: str, max_training_examples=None, seed=42, **kwargs): device = torch.device(type='cuda') if use_cuda else torch.device( type='cpu') cfg = locals().copy() torch.manual_seed(seed) logger.info("Loading Training set...") training_set = GroundedScanDataset( data_path, data_directory, split="train", input_vocabulary_file=input_vocab_path, target_vocabulary_file=target_vocab_path, generate_vocabulary=generate_vocabularies) training_set.read_dataset( max_examples=max_training_examples, simple_situation_representation=simple_situation_representation) logger.info("Done Loading Training set.") logger.info(" Loaded {} training examples.".format( training_set.num_examples)) logger.info(" Input vocabulary size training set: {}".format( training_set.input_vocabulary_size)) logger.info(" Most common input words: {}".format( training_set.input_vocabulary.most_common(5))) logger.info(" Output vocabulary size training set: {}".format( training_set.target_vocabulary_size)) logger.info(" Most common target words: {}".format( training_set.target_vocabulary.most_common(5))) if generate_vocabularies: training_set.save_vocabularies(input_vocab_path, target_vocab_path) logger.info( "Saved vocabularies to {} for input and {} for target.".format( input_vocab_path, target_vocab_path)) logger.info("Loading Test set...") test_set = GroundedScanDataset( data_path, data_directory, split="test", # TODO: use dev set here input_vocabulary_file=input_vocab_path, target_vocabulary_file=target_vocab_path, generate_vocabulary=False) test_set.read_dataset( max_examples=None, simple_situation_representation=simple_situation_representation) # Shuffle the test set to make sure that if we only evaluate max_testing_examples we get a random part of the set. test_set.shuffle_data() logger.info("Done Loading Test set.") model = Model(input_vocabulary_size=training_set.input_vocabulary_size, target_vocabulary_size=training_set.target_vocabulary_size, num_cnn_channels=training_set.image_channels, input_padding_idx=training_set.input_vocabulary.pad_idx, target_pad_idx=training_set.target_vocabulary.pad_idx, target_eos_idx=training_set.target_vocabulary.eos_idx, **cfg) model = model.cuda() if use_cuda else model log_parameters(model) trainable_parameters = [ parameter for parameter in model.parameters() if parameter.requires_grad ] optimizer = torch.optim.Adam(trainable_parameters, lr=learning_rate, betas=(adam_beta_1, adam_beta_2)) scheduler = LambdaLR(optimizer, lr_lambda=lambda t: lr_decay**(t / lr_decay_steps)) # Load model and vocabularies if resuming. start_iteration = 1 best_iteration = 1 best_accuracy = 0 best_exact_match = 0 best_loss = float('inf') if resume_from_file: assert os.path.isfile( resume_from_file), "No checkpoint found at {}".format( resume_from_file) logger.info( "Loading checkpoint from file at '{}'".format(resume_from_file)) optimizer_state_dict = model.load_model(resume_from_file) optimizer.load_state_dict(optimizer_state_dict) start_iteration = model.trained_iterations logger.info("Loaded checkpoint '{}' (iter {})".format( resume_from_file, start_iteration)) logger.info("Training starts..") training_iteration = start_iteration while training_iteration < max_training_iterations: # Shuffle the dataset and loop over it. training_set.shuffle_data() for (input_batch, input_lengths, _, situation_batch, _, target_batch, target_lengths, agent_positions, target_positions) in training_set.get_data_iterator( batch_size=training_batch_size): is_best = False model.train() # Forward pass. target_scores, target_position_scores = model( commands_input=input_batch, commands_lengths=input_lengths, situations_input=situation_batch, target_batch=target_batch, target_lengths=target_lengths) loss = model.get_loss(target_scores, target_batch) if auxiliary_task: target_loss = model.get_auxiliary_loss(target_position_scores, target_positions) else: target_loss = 0 loss += weight_target_loss * target_loss # Backward pass and update model parameters. loss.backward() optimizer.step() scheduler.step() optimizer.zero_grad() model.update_state(is_best=is_best) # Print current metrics. if training_iteration % print_every == 0: accuracy, exact_match = model.get_metrics( target_scores, target_batch) if auxiliary_task: auxiliary_accuracy_target = model.get_auxiliary_accuracy( target_position_scores, target_positions) else: auxiliary_accuracy_target = 0. learning_rate = scheduler.get_lr()[0] logger.info( "Iteration %08d, loss %8.4f, accuracy %5.2f, exact match %5.2f, learning_rate %.5f," " aux. accuracy target pos %5.2f" % (training_iteration, loss, accuracy, exact_match, learning_rate, auxiliary_accuracy_target)) # Evaluate on test set. if training_iteration % evaluate_every == 0: with torch.no_grad(): model.eval() logger.info("Evaluating..") accuracy, exact_match, target_accuracy = evaluate( test_set.get_data_iterator(batch_size=1), model=model, max_decoding_steps=max_decoding_steps, pad_idx=test_set.target_vocabulary.pad_idx, sos_idx=test_set.target_vocabulary.sos_idx, eos_idx=test_set.target_vocabulary.eos_idx, max_examples_to_evaluate=kwargs["max_testing_examples"] ) logger.info( " Evaluation Accuracy: %5.2f Exact Match: %5.2f " " Target Accuracy: %5.2f" % (accuracy, exact_match, target_accuracy)) if exact_match > best_exact_match: is_best = True best_accuracy = accuracy best_exact_match = exact_match model.update_state(accuracy=accuracy, exact_match=exact_match, is_best=is_best) file_name = "checkpoint.pth.tar".format( str(training_iteration)) if is_best: model.save_checkpoint( file_name=file_name, is_best=is_best, optimizer_state_dict=optimizer.state_dict()) training_iteration += 1 if training_iteration > max_training_iterations: break logger.info("Finished training.")