def training_loop(train_dict, val_dict, idx_dict, encoder, decoder, criterion,
                  optimizer, opts):
    """Runs the main training loop; evaluates the model on the val set every epoch.
        * Prints training and val loss each epoch.
        * Prints qualitative translation results each epoch using TEST_SENTENCE
        * Saves an attention map for TEST_WORD_ATTN each epoch

    Arguments:
        train_dict: The training word pairs, organized by source and target lengths.
        val_dict: The validation word pairs, organized by source and target lengths.
        idx_dict: Contains char-to-index and index-to-char mappings, and start & end token indexes.
        encoder: An encoder model to produce annotations for each step of the input sequence.
        decoder: A decoder model (with or without attention) to generate output tokens.
        criterion: Used to compute the CrossEntropyLoss for each decoder output.
        optimizer: Implements a step rule to update the parameters of the encoder and decoder.
        opts: The command-line arguments.
    """

    start_token = idx_dict['start_token']
    end_token = idx_dict['end_token']
    char_to_index = idx_dict['char_to_index']

    loss_log = open(os.path.join(opts.checkpoint_path, 'loss_log.txt'), 'w')

    best_val_loss = 1e6
    train_losses = []
    val_losses = []

    for epoch in range(opts.nepochs):

        optimizer.param_groups[0]['lr'] *= opts.lr_decay

        epoch_losses = []

        for key in train_dict:

            input_strings, target_strings = zip(*train_dict[key])
            input_tensors = [
                torch.LongTensor(
                    utils.string_to_index_list(s, char_to_index, end_token))
                for s in input_strings
            ]
            target_tensors = [
                torch.LongTensor(
                    utils.string_to_index_list(s, char_to_index, end_token))
                for s in target_strings
            ]

            num_tensors = len(input_tensors)
            num_batches = int(np.ceil(num_tensors / float(opts.batch_size)))

            for i in range(num_batches):

                start = i * opts.batch_size
                end = start + opts.batch_size

                inputs = utils.to_var(torch.stack(input_tensors[start:end]),
                                      opts.cuda)
                targets = utils.to_var(torch.stack(target_tensors[start:end]),
                                       opts.cuda)

                # The batch size may be different in each epoch
                BS = inputs.size(0)

                encoder_annotations, encoder_hidden = encoder(inputs)

                # The last hidden state of the encoder becomes the first hidden state of the decoder
                decoder_hidden = encoder_hidden

                start_vector = torch.ones(BS).long().unsqueeze(
                    1) * start_token  # BS x 1 --> 16x1  CHECKED
                decoder_input = utils.to_var(
                    start_vector, opts.cuda)  # BS x 1 --> 16x1  CHECKED

                loss = 0.0

                seq_len = targets.size(1)  # Gets seq_len from BS x seq_len

                use_teacher_forcing = np.random.rand(
                ) < opts.teacher_forcing_ratio

                for i in range(seq_len):
                    decoder_output, decoder_hidden, attention_weights = decoder(
                        decoder_input, decoder_hidden, encoder_annotations)

                    current_target = targets[:, i]
                    loss += criterion(
                        decoder_output, current_target
                    )  # cross entropy between the decoder distribution and GT
                    ni = F.softmax(decoder_output, dim=1).data.max(1)[1]

                    if use_teacher_forcing:
                        # With teacher forcing, use the ground-truth token to condition the next step
                        decoder_input = targets[:, i].unsqueeze(1)
                    else:
                        # Without teacher forcing, use the model's own predictions to condition the next step
                        decoder_input = utils.to_var(ni.unsqueeze(1),
                                                     opts.cuda)

                loss /= float(seq_len)
                epoch_losses.append(loss.item())

                # Zero gradients
                optimizer.zero_grad()

                # Compute gradients
                loss.backward()

                # Update the parameters of the encoder and decoder
                optimizer.step()

        train_loss = np.mean(epoch_losses)
        val_loss = evaluate(val_dict, encoder, decoder, idx_dict, criterion,
                            opts)

        if val_loss < best_val_loss:
            checkpoint(encoder, decoder, idx_dict, opts)

        if not opts.no_attention:
            # Save attention maps for the fixed word TEST_WORD_ATTN throughout training
            utils.visualize_attention(
                TEST_WORD_ATTN,
                encoder,
                decoder,
                idx_dict,
                opts,
                save=os.path.join(
                    opts.checkpoint_path,
                    'train_attns/attn-epoch-{}.png'.format(epoch)))

        gen_string = utils.translate_sentence(TEST_SENTENCE, encoder, decoder,
                                              idx_dict, opts)
        print(
            "Epoch: {:3d} | Train loss: {:.3f} | Val loss: {:.3f} | Gen: {:20s}"
            .format(epoch, train_loss, val_loss, gen_string))

        loss_log.write('{} {} {}\n'.format(epoch, train_loss, val_loss))
        loss_log.flush()

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        save_loss_plot(train_losses, val_losses, opts)
def evaluate(data_dict, encoder, decoder, idx_dict, criterion, opts):
    """Evaluates the model on a held-out validation or test set.

    Arguments:
        data_dict: The validation/test word pairs, organized by source and target lengths.
        encoder: An encoder model to produce annotations for each step of the input sequence.
        decoder: A decoder model (with or without attention) to generate output tokens.
        idx_dict: Contains char-to-index and index-to-char mappings, and start & end token indexes.
        criterion: Used to compute the CrossEntropyLoss for each decoder output.
        opts: The command-line arguments.

    Returns:
        mean_loss: The average loss over all batches from data_dict.
    """

    start_token = idx_dict['start_token']
    end_token = idx_dict['end_token']
    char_to_index = idx_dict['char_to_index']

    losses = []

    for key in data_dict:

        input_strings, target_strings = zip(*data_dict[key])
        input_tensors = [
            torch.LongTensor(
                utils.string_to_index_list(s, char_to_index, end_token))
            for s in input_strings
        ]
        target_tensors = [
            torch.LongTensor(
                utils.string_to_index_list(s, char_to_index, end_token))
            for s in target_strings
        ]

        num_tensors = len(input_tensors)
        num_batches = int(np.ceil(num_tensors / float(opts.batch_size)))

        for i in range(num_batches):

            start = i * opts.batch_size
            end = start + opts.batch_size

            inputs = utils.to_var(torch.stack(input_tensors[start:end]),
                                  opts.cuda)
            targets = utils.to_var(torch.stack(target_tensors[start:end]),
                                   opts.cuda)

            # The batch size may be different in each epoch
            BS = inputs.size(0)

            encoder_annotations, encoder_hidden = encoder(inputs)

            # The final hidden state of the encoder becomes the initial hidden state of the decoder
            decoder_hidden = encoder_hidden
            start_vector = torch.ones(BS).long().unsqueeze(
                1) * start_token  # BS x 1
            decoder_input = utils.to_var(start_vector, opts.cuda)  # BS x 1

            loss = 0.0

            seq_len = targets.size(1)  # Gets seq_len from BS x seq_len

            for i in range(seq_len):
                decoder_output, decoder_hidden, attention_weights = decoder(
                    decoder_input, decoder_hidden, encoder_annotations)

                current_target = targets[:, i]
                loss += criterion(
                    decoder_output, current_target
                )  # cross entropy between the decoder distribution and GT

                ni = F.softmax(decoder_output, dim=1).data.max(1)[1]

                decoder_input = targets[:, i].unsqueeze(1)

            loss /= float(seq_len)
            losses.append(loss.item())

    mean_loss = np.mean(losses)

    return mean_loss