def predict_and_save(dataset: GroundedScanDataset,
                     model: nn.Module,
                     output_file_path: str,
                     max_decoding_steps: int,
                     max_testing_examples=None,
                     **kwargs):
    """
    Predict all data in dataset with a model and write the predictions to output_file_path.
    :param dataset: a dataset with test examples
    :param model: a trained model from model.py
    :param output_file_path: a path where a .json file with predictions will be saved.
    :param max_decoding_steps: after how many steps to force quit decoding
    :param max_testing_examples: after how many examples to stop predicting, if None all examples will be evaluated
    """
    cfg = locals().copy()

    with open(output_file_path, mode='w') as outfile:
        output = []
        with torch.no_grad():
            i = 0
            for (input_sequence, derivation_spec, situation_spec,
                 output_sequence, target_sequence, attention_weights_commands,
                 attention_weights_situations, position_accuracy) in predict(
                     dataset.get_data_iterator(batch_size=1),
                     model=model,
                     max_decoding_steps=max_decoding_steps,
                     pad_idx=dataset.target_vocabulary.pad_idx,
                     sos_idx=dataset.target_vocabulary.sos_idx,
                     eos_idx=dataset.target_vocabulary.eos_idx):
                i += 1
                accuracy = sequence_accuracy(output_sequence,
                                             target_sequence[0].tolist()[1:-1])
                input_str_sequence = dataset.array_to_sentence(
                    input_sequence[0].tolist(), vocabulary="input")
                input_str_sequence = input_str_sequence[
                    1:-1]  # Get rid of <SOS> and <EOS>
                target_str_sequence = dataset.array_to_sentence(
                    target_sequence[0].tolist(), vocabulary="target")
                target_str_sequence = target_str_sequence[
                    1:-1]  # Get rid of <SOS> and <EOS>
                output_str_sequence = dataset.array_to_sentence(
                    output_sequence, vocabulary="target")
                output.append({
                    "input": input_str_sequence,
                    "prediction": output_str_sequence,
                    "derivation": derivation_spec,
                    "target": target_str_sequence,
                    "situation": situation_spec,
                    "attention_weights_input": attention_weights_commands,
                    "attention_weights_situation":
                    attention_weights_situations,
                    "accuracy": accuracy,
                    "exact_match": True if accuracy == 100 else False,
                    "position_accuracy": position_accuracy
                })
        logger.info("Wrote predictions for {} examples.".format(i))
        json.dump(output, outfile, indent=4)
    return output_file_path
Esempio n. 2
0
 def load_training_set(self, data_path):
     logger.info("Loading Training set...")
     training_set = GroundedScanDataset(
         data_path,
         self.hparams.data_directory,
         split="train",
         input_vocabulary_file=self.hparams.input_vocab_path,
         target_vocabulary_file=self.hparams.target_vocab_path,
         generate_vocabulary=self.hparams.generate_vocabularies,
         k=self.hparams.k)
     training_set.read_dataset(
         max_examples=self.hparams.max_training_examples,
         simple_situation_representation=self.hparams.
         simple_situation_representation,
         load_tensors_from_path=self.hparams.load_tensors_from_path)
     logger.info("Done Loading Training set.")
     return training_set
Esempio n. 3
0
def predict_and_save(dataset: GroundedScanDataset, model: nn.Module, output_file_path: str, max_decoding_steps: int,
                     max_testing_examples=None, **kwargs):
    """

    :param dataset:
    :param model:
    :param output_file_path:
    :param max_decoding_steps:
    :param max_testing_examples:
    :param kwargs:
    :return:
    """
    cfg = locals().copy()

    with open(output_file_path, mode='w') as outfile:
        output = []
        with torch.no_grad():
            i = 0
            for (input_sequence, derivation_spec, situation_spec, output_sequence, target_sequence,
                 attention_weights_commands, attention_weights_situations, _) in predict(
                    dataset.get_data_iterator(batch_size=1), model=model, max_decoding_steps=max_decoding_steps,
                    pad_idx=dataset.target_vocabulary.pad_idx, sos_idx=dataset.target_vocabulary.sos_idx,
                    eos_idx=dataset.target_vocabulary.eos_idx):
                i += 1
                accuracy = sequence_accuracy(output_sequence, target_sequence[0].tolist()[1:-1])
                input_str_sequence = dataset.array_to_sentence(input_sequence[0].tolist(), vocabulary="input")
                input_str_sequence = input_str_sequence[1:-1]  # Get rid of <SOS> and <EOS>
                target_str_sequence = dataset.array_to_sentence(target_sequence[0].tolist(), vocabulary="target")
                target_str_sequence = target_str_sequence[1:-1]  # Get rid of <SOS> and <EOS>
                output_str_sequence = dataset.array_to_sentence(output_sequence, vocabulary="target")
                output.append({"input": input_str_sequence, "prediction": output_str_sequence,
                               "derivation": derivation_spec,
                               "target": target_str_sequence, "situation": situation_spec,
                               "attention_weights_input": attention_weights_commands,
                               "attention_weights_situation": attention_weights_situations,
                               "accuracy": accuracy,
                               "exact_match": True if accuracy == 100 else False})
        logger.info("Wrote predictions for {} examples.".format(i))
        json.dump(output, outfile, indent=4)
    return output_file_path
Esempio n. 4
0
def main(flags):
    for argument, value in flags.items():
        logger.info("{}: {}".format(argument, value))

    if not os.path.exists(flags["output_directory"]):
        os.mkdir(os.path.join(os.getcwd(), flags["output_directory"]))

    # Some checks on the flags
    if flags["generate_vocabularies"]:
        assert flags["input_vocab_path"] and flags[
            "target_vocab_path"], "Please specify paths to vocabularies to save."

    if flags["test_batch_size"] > 1:
        raise NotImplementedError(
            "Test batch size larger than 1 not implemented.")

    data_path = os.path.join(flags["data_directory"], "dataset.txt")
    if flags["mode"] == "train":
        train(data_path=data_path, **flags)
    elif flags["mode"] == "test":
        assert os.path.exists(os.path.join(flags["data_directory"], flags["input_vocab_path"])) and os.path.exists(
            os.path.join(flags["data_directory"], flags["target_vocab_path"])), \
            "No vocabs found at {} and {}".format(flags["input_vocab_path"], flags["target_vocab_path"])
        logger.info("Loading {} dataset split...".format(flags["split"]))
        test_set = GroundedScanDataset(
            data_path,
            flags["data_directory"],
            split=flags["split"],
            input_vocabulary_file=flags["input_vocab_path"],
            target_vocabulary_file=flags["target_vocab_path"],
            generate_vocabulary=False)
        test_set.read_dataset(max_examples=flags["max_testing_examples"],
                              simple_situation_representation=flags[
                                  "simple_situation_representation"])
        logger.info("Done Loading {} dataset split.".format(flags["split"]))
        logger.info("  Loaded {} examples.".format(test_set.num_examples))
        logger.info("  Input vocabulary size: {}".format(
            test_set.input_vocabulary_size))
        logger.info("  Most common input words: {}".format(
            test_set.input_vocabulary.most_common(5)))
        logger.info("  Output vocabulary size: {}".format(
            test_set.target_vocabulary_size))
        logger.info("  Most common target words: {}".format(
            test_set.target_vocabulary.most_common(5)))

        model = Model(input_vocabulary_size=test_set.input_vocabulary_size,
                      target_vocabulary_size=test_set.target_vocabulary_size,
                      num_cnn_channels=test_set.image_channels,
                      input_padding_idx=test_set.input_vocabulary.pad_idx,
                      target_pad_idx=test_set.target_vocabulary.pad_idx,
                      target_eos_idx=test_set.target_vocabulary.eos_idx,
                      **flags)
        model = model.cuda() if use_cuda else model

        # Load model and vocabularies if resuming.
        assert os.path.isfile(
            flags["resume_from_file"]), "No checkpoint found at {}".format(
                flags["resume_from_file"])
        logger.info("Loading checkpoint from file at '{}'".format(
            flags["resume_from_file"]))
        model.load_model(flags["resume_from_file"])
        start_iteration = model.trained_iterations
        logger.info("Loaded checkpoint '{}' (iter {})".format(
            flags["resume_from_file"], start_iteration))
        output_file_path = os.path.join(flags["output_directory"],
                                        flags["output_file_name"])
        output_file = predict_and_save(dataset=test_set,
                                       model=model,
                                       output_file_path=output_file_path,
                                       **flags)
        logger.info("Saved predictions to {}".format(output_file))
    elif flags["mode"] == "predict":
        raise NotImplementedError()
    else:
        raise ValueError("Wrong value for parameters --mode ({}).".format(
            flags["mode"]))
Esempio n. 5
0
def train(
        data_path: str,
        data_directory: str,
        generate_vocabularies: bool,
        input_vocab_path: str,
        target_vocab_path: str,
        embedding_dimension: int,
        num_encoder_layers: int,
        encoder_dropout_p: float,
        encoder_bidirectional: bool,
        training_batch_size: int,
        test_batch_size: int,
        max_decoding_steps: int,
        num_decoder_layers: int,
        decoder_dropout_p: float,
        cnn_kernel_size: int,
        cnn_dropout_p: float,
        cnn_hidden_num_channels: int,
        simple_situation_representation: bool,
        decoder_hidden_size: int,
        encoder_hidden_size: int,
        learning_rate: float,
        adam_beta_1: float,
        adam_beta_2: float,
        lr_decay: float,
        lr_decay_steps: int,
        resume_from_file: str,
        max_training_iterations: int,
        output_directory: str,
        print_every: int,
        evaluate_every: int,
        conditional_attention: bool,
        auxiliary_task: bool,
        weight_target_loss: float,
        attention_type: str,
        k: int,
        max_training_examples,
        max_testing_examples,
        # SeqGAN params begin
        pretrain_gen_path,
        pretrain_gen_epochs,
        pretrain_disc_path,
        pretrain_disc_epochs,
        rollout_trails,
        rollout_update_rate,
        disc_emb_dim,
        disc_hid_dim,
        load_tensors_from_path,
        # SeqGAN params end
        seed=42,
        **kwargs):
    device = torch.device("cpu")
    cfg = locals().copy()
    torch.manual_seed(seed)

    logger.info("Loading Training set...")

    training_set = GroundedScanDataset(
        data_path,
        data_directory,
        split="train",
        input_vocabulary_file=input_vocab_path,
        target_vocabulary_file=target_vocab_path,
        generate_vocabulary=generate_vocabularies,
        k=k)
    training_set.read_dataset(
        max_examples=max_training_examples,
        simple_situation_representation=simple_situation_representation,
        load_tensors_from_path=load_tensors_from_path
    )  # set this to False if no pickle file available

    logger.info("Done Loading Training set.")
    logger.info("  Loaded {} training examples.".format(
        training_set.num_examples))
    logger.info("  Input vocabulary size training set: {}".format(
        training_set.input_vocabulary_size))
    logger.info("  Most common input words: {}".format(
        training_set.input_vocabulary.most_common(5)))
    logger.info("  Output vocabulary size training set: {}".format(
        training_set.target_vocabulary_size))
    logger.info("  Most common target words: {}".format(
        training_set.target_vocabulary.most_common(5)))

    if generate_vocabularies:
        training_set.save_vocabularies(input_vocab_path, target_vocab_path)
        logger.info(
            "Saved vocabularies to {} for input and {} for target.".format(
                input_vocab_path, target_vocab_path))

    # logger.info("Loading Dev. set...")
    # test_set = GroundedScanDataset(data_path, data_directory, split="dev",
    #                                input_vocabulary_file=input_vocab_path,
    #                                target_vocabulary_file=target_vocab_path, generate_vocabulary=False, k=0)
    # test_set.read_dataset(max_examples=max_testing_examples,
    #                       simple_situation_representation=simple_situation_representation)
    #
    # # Shuffle the test set to make sure that if we only evaluate max_testing_examples we get a random part of the set.
    # test_set.shuffle_data()

    # logger.info("Done Loading Dev. set.")

    generator = Model(
        input_vocabulary_size=training_set.input_vocabulary_size,
        target_vocabulary_size=training_set.target_vocabulary_size,
        num_cnn_channels=training_set.image_channels,
        input_padding_idx=training_set.input_vocabulary.pad_idx,
        target_pad_idx=training_set.target_vocabulary.pad_idx,
        target_eos_idx=training_set.target_vocabulary.eos_idx,
        **cfg)
    total_vocabulary = set(
        list(training_set.input_vocabulary._word_to_idx.keys()) +
        list(training_set.target_vocabulary._word_to_idx.keys()))
    total_vocabulary_size = len(total_vocabulary)
    discriminator = Discriminator(embedding_dim=disc_emb_dim,
                                  hidden_dim=disc_hid_dim,
                                  vocab_size=total_vocabulary_size,
                                  max_seq_len=max_decoding_steps)

    generator = generator.cuda() if use_cuda else generator
    discriminator = discriminator.cuda() if use_cuda else discriminator
    rollout = Rollout(generator, rollout_update_rate)
    log_parameters(generator)
    trainable_parameters = [
        parameter for parameter in generator.parameters()
        if parameter.requires_grad
    ]
    optimizer = torch.optim.Adam(trainable_parameters,
                                 lr=learning_rate,
                                 betas=(adam_beta_1, adam_beta_2))
    scheduler = LambdaLR(optimizer,
                         lr_lambda=lambda t: lr_decay**(t / lr_decay_steps))

    # Load model and vocabularies if resuming.
    start_iteration = 1
    best_iteration = 1
    best_accuracy = 0
    best_exact_match = 0
    best_loss = float('inf')
    if resume_from_file:
        assert os.path.isfile(
            resume_from_file), "No checkpoint found at {}".format(
                resume_from_file)
        logger.info(
            "Loading checkpoint from file at '{}'".format(resume_from_file))
        optimizer_state_dict = generator.load_model(resume_from_file)
        optimizer.load_state_dict(optimizer_state_dict)
        start_iteration = generator.trained_iterations
        logger.info("Loaded checkpoint '{}' (iter {})".format(
            resume_from_file, start_iteration))

    if pretrain_gen_path is None:
        print('Pretraining generator with MLE...')
        pre_train_generator(training_set,
                            training_batch_size,
                            generator,
                            seed,
                            pretrain_gen_epochs,
                            name='pretrained_generator')
    else:
        print('Load pretrained generator weights')
        generator_weights = torch.load(pretrain_gen_path)
        generator.load_state_dict(generator_weights)

    if pretrain_disc_path is None:
        print('Pretraining Discriminator....')
        train_discriminator(training_set,
                            discriminator,
                            training_batch_size,
                            generator,
                            seed,
                            pretrain_disc_epochs,
                            name="pretrained_discriminator")
    else:
        print('Loading Discriminator....')
        discriminator_weights = torch.load(pretrain_disc_path)
        discriminator.load_state_dict(discriminator_weights)

    logger.info("Training starts..")
    training_iteration = start_iteration
    torch.autograd.set_detect_anomaly(True)
    while training_iteration < max_training_iterations:

        # Shuffle the dataset and loop over it.
        training_set.shuffle_data()

        for (input_batch, input_lengths, _, situation_batch, _, target_batch,
             target_lengths, agent_positions, target_positions) in \
                training_set.get_data_iterator(batch_size=training_batch_size):

            is_best = False
            generator.train()

            # Forward pass.
            samples = generator.sample(
                batch_size=training_batch_size,
                max_seq_len=max(target_lengths).astype(int),
                commands_input=input_batch,
                commands_lengths=input_lengths,
                situations_input=situation_batch,
                target_batch=target_batch,
                sos_idx=training_set.input_vocabulary.sos_idx,
                eos_idx=training_set.input_vocabulary.eos_idx)

            rewards = rollout.get_reward(samples, rollout_trails, input_batch,
                                         input_lengths, situation_batch,
                                         target_batch,
                                         training_set.input_vocabulary.sos_idx,
                                         training_set.input_vocabulary.eos_idx,
                                         discriminator)

            assert samples.shape == rewards.shape

            # calculate rewards
            rewards = torch.exp(rewards).contiguous().view((-1, ))
            if use_cuda:
                rewards = rewards.cuda()

            # get generator scores for sequence
            target_scores = generator.get_normalized_logits(
                commands_input=input_batch,
                commands_lengths=input_lengths,
                situations_input=situation_batch,
                samples=samples,
                sample_lengths=target_lengths,
                sos_idx=training_set.input_vocabulary.sos_idx)

            del samples

            # calculate loss on the generated sequence given the rewards
            loss = generator.get_gan_loss(target_scores, target_batch, rewards)

            del rewards

            # Backward pass and update model parameters.
            loss.backward()
            optimizer.step()
            scheduler.step(training_iteration)
            optimizer.zero_grad()
            generator.update_state(is_best=is_best)

            # Print current metrics.
            if training_iteration % print_every == 0:
                # accuracy, exact_match = generator.get_metrics(target_scores, target_batch)
                learning_rate = scheduler.get_lr()[0]
                logger.info("Iteration %08d, loss %8.4f, learning_rate %.5f," %
                            (training_iteration, loss, learning_rate))
                # logger.info("Iteration %08d, loss %8.4f, accuracy %5.2f, exact match %5.2f, learning_rate %.5f,"
                #             % (training_iteration, loss, accuracy, exact_match, learning_rate))
            del target_scores, target_batch

            # # Evaluate on test set.
            # if training_iteration % evaluate_every == 0:
            #     with torch.no_grad():
            #         generator.eval()
            #         logger.info("Evaluating..")
            #         accuracy, exact_match, target_accuracy = evaluate(
            #             test_set.get_data_iterator(batch_size=1), model=generator,
            #             max_decoding_steps=max_decoding_steps, pad_idx=test_set.target_vocabulary.pad_idx,
            #             sos_idx=test_set.target_vocabulary.sos_idx,
            #             eos_idx=test_set.target_vocabulary.eos_idx,
            #             max_examples_to_evaluate=kwargs["max_testing_examples"])
            #         logger.info("  Evaluation Accuracy: %5.2f Exact Match: %5.2f "
            #                     " Target Accuracy: %5.2f" % (accuracy, exact_match, target_accuracy))
            #         if exact_match > best_exact_match:
            #             is_best = True
            #             best_accuracy = accuracy
            #             best_exact_match = exact_match
            #             generator.update_state(accuracy=accuracy, exact_match=exact_match, is_best=is_best)
            #         file_name = "checkpoint.pth.tar".format(str(training_iteration))
            #         if is_best:
            #             generator.save_checkpoint(file_name=file_name, is_best=is_best,
            #                                       optimizer_state_dict=optimizer.state_dict())

            rollout.update_params()

            train_discriminator(training_set,
                                discriminator,
                                training_batch_size,
                                generator,
                                seed,
                                epochs=1,
                                name="training_discriminator")
            training_iteration += 1
            if training_iteration > max_training_iterations:
                break
            del loss

        torch.save(
            generator.state_dict(),
            '{}/{}'.format(output_directory,
                           'gen_{}_{}.ckpt'.format(training_iteration, seed)))
        torch.save(
            discriminator.state_dict(),
            '{}/{}'.format(output_directory,
                           'dis_{}_{}.ckpt'.format(training_iteration, seed)))

    logger.info("Finished training.")
Esempio n. 6
0
def train(data_path: str,
          data_directory: str,
          generate_vocabularies: bool,
          input_vocab_path: str,
          target_vocab_path: str,
          embedding_dimension: int,
          num_encoder_layers: int,
          encoder_dropout_p: float,
          encoder_bidirectional: bool,
          training_batch_size: int,
          test_batch_size: int,
          max_decoding_steps: int,
          num_decoder_layers: int,
          decoder_dropout_p: float,
          cnn_kernel_size: int,
          cnn_dropout_p: float,
          cnn_hidden_num_channels: int,
          simple_situation_representation: bool,
          decoder_hidden_size: int,
          encoder_hidden_size: int,
          learning_rate: float,
          adam_beta_1: float,
          adam_beta_2: float,
          lr_decay: float,
          lr_decay_steps: int,
          resume_from_file: str,
          max_training_iterations: int,
          output_directory: str,
          print_every: int,
          evaluate_every: int,
          conditional_attention: bool,
          auxiliary_task: bool,
          weight_target_loss: float,
          attention_type: str,
          max_training_examples=None,
          seed=42,
          **kwargs):
    device = torch.device(type='cuda') if use_cuda else torch.device(
        type='cpu')
    cfg = locals().copy()

    torch.manual_seed(seed)

    logger.info("Loading Training set...")
    training_set = GroundedScanDataset(
        data_path,
        data_directory,
        split="train",
        input_vocabulary_file=input_vocab_path,
        target_vocabulary_file=target_vocab_path,
        generate_vocabulary=generate_vocabularies)
    training_set.read_dataset(
        max_examples=max_training_examples,
        simple_situation_representation=simple_situation_representation)
    logger.info("Done Loading Training set.")
    logger.info("  Loaded {} training examples.".format(
        training_set.num_examples))
    logger.info("  Input vocabulary size training set: {}".format(
        training_set.input_vocabulary_size))
    logger.info("  Most common input words: {}".format(
        training_set.input_vocabulary.most_common(5)))
    logger.info("  Output vocabulary size training set: {}".format(
        training_set.target_vocabulary_size))
    logger.info("  Most common target words: {}".format(
        training_set.target_vocabulary.most_common(5)))

    if generate_vocabularies:
        training_set.save_vocabularies(input_vocab_path, target_vocab_path)
        logger.info(
            "Saved vocabularies to {} for input and {} for target.".format(
                input_vocab_path, target_vocab_path))

    logger.info("Loading Test set...")
    test_set = GroundedScanDataset(
        data_path,
        data_directory,
        split="test",  # TODO: use dev set here
        input_vocabulary_file=input_vocab_path,
        target_vocabulary_file=target_vocab_path,
        generate_vocabulary=False)
    test_set.read_dataset(
        max_examples=None,
        simple_situation_representation=simple_situation_representation)

    # Shuffle the test set to make sure that if we only evaluate max_testing_examples we get a random part of the set.
    test_set.shuffle_data()
    logger.info("Done Loading Test set.")

    model = Model(input_vocabulary_size=training_set.input_vocabulary_size,
                  target_vocabulary_size=training_set.target_vocabulary_size,
                  num_cnn_channels=training_set.image_channels,
                  input_padding_idx=training_set.input_vocabulary.pad_idx,
                  target_pad_idx=training_set.target_vocabulary.pad_idx,
                  target_eos_idx=training_set.target_vocabulary.eos_idx,
                  **cfg)
    model = model.cuda() if use_cuda else model
    log_parameters(model)
    trainable_parameters = [
        parameter for parameter in model.parameters()
        if parameter.requires_grad
    ]
    optimizer = torch.optim.Adam(trainable_parameters,
                                 lr=learning_rate,
                                 betas=(adam_beta_1, adam_beta_2))
    scheduler = LambdaLR(optimizer,
                         lr_lambda=lambda t: lr_decay**(t / lr_decay_steps))

    # Load model and vocabularies if resuming.
    start_iteration = 1
    best_iteration = 1
    best_accuracy = 0
    best_exact_match = 0
    best_loss = float('inf')
    if resume_from_file:
        assert os.path.isfile(
            resume_from_file), "No checkpoint found at {}".format(
                resume_from_file)
        logger.info(
            "Loading checkpoint from file at '{}'".format(resume_from_file))
        optimizer_state_dict = model.load_model(resume_from_file)
        optimizer.load_state_dict(optimizer_state_dict)
        start_iteration = model.trained_iterations
        logger.info("Loaded checkpoint '{}' (iter {})".format(
            resume_from_file, start_iteration))

    logger.info("Training starts..")
    training_iteration = start_iteration
    while training_iteration < max_training_iterations:

        # Shuffle the dataset and loop over it.
        training_set.shuffle_data()
        for (input_batch, input_lengths, _, situation_batch, _, target_batch,
             target_lengths, agent_positions,
             target_positions) in training_set.get_data_iterator(
                 batch_size=training_batch_size):
            is_best = False
            model.train()

            # Forward pass.
            target_scores, target_position_scores = model(
                commands_input=input_batch,
                commands_lengths=input_lengths,
                situations_input=situation_batch,
                target_batch=target_batch,
                target_lengths=target_lengths)
            loss = model.get_loss(target_scores, target_batch)
            if auxiliary_task:
                target_loss = model.get_auxiliary_loss(target_position_scores,
                                                       target_positions)
            else:
                target_loss = 0
            loss += weight_target_loss * target_loss

            # Backward pass and update model parameters.
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            model.update_state(is_best=is_best)

            # Print current metrics.
            if training_iteration % print_every == 0:
                accuracy, exact_match = model.get_metrics(
                    target_scores, target_batch)
                if auxiliary_task:
                    auxiliary_accuracy_target = model.get_auxiliary_accuracy(
                        target_position_scores, target_positions)
                else:
                    auxiliary_accuracy_target = 0.
                learning_rate = scheduler.get_lr()[0]
                logger.info(
                    "Iteration %08d, loss %8.4f, accuracy %5.2f, exact match %5.2f, learning_rate %.5f,"
                    " aux. accuracy target pos %5.2f" %
                    (training_iteration, loss, accuracy, exact_match,
                     learning_rate, auxiliary_accuracy_target))

            # Evaluate on test set.
            if training_iteration % evaluate_every == 0:
                with torch.no_grad():
                    model.eval()
                    logger.info("Evaluating..")
                    accuracy, exact_match, target_accuracy = evaluate(
                        test_set.get_data_iterator(batch_size=1),
                        model=model,
                        max_decoding_steps=max_decoding_steps,
                        pad_idx=test_set.target_vocabulary.pad_idx,
                        sos_idx=test_set.target_vocabulary.sos_idx,
                        eos_idx=test_set.target_vocabulary.eos_idx,
                        max_examples_to_evaluate=kwargs["max_testing_examples"]
                    )
                    logger.info(
                        "  Evaluation Accuracy: %5.2f Exact Match: %5.2f "
                        " Target Accuracy: %5.2f" %
                        (accuracy, exact_match, target_accuracy))
                    if exact_match > best_exact_match:
                        is_best = True
                        best_accuracy = accuracy
                        best_exact_match = exact_match
                        model.update_state(accuracy=accuracy,
                                           exact_match=exact_match,
                                           is_best=is_best)
                    file_name = "checkpoint.pth.tar".format(
                        str(training_iteration))
                    if is_best:
                        model.save_checkpoint(
                            file_name=file_name,
                            is_best=is_best,
                            optimizer_state_dict=optimizer.state_dict())

            training_iteration += 1
            if training_iteration > max_training_iterations:
                break
    logger.info("Finished training.")