Example #1
0
def train(encoder_decoder: EncoderDecoder, train_data_loader: DataLoader,
          model_name, val_data_loader: DataLoader, keep_prob,
          teacher_forcing_schedule, lr, max_length, use_decay, data_path):

    global_step = 0
    loss_function = torch.nn.NLLLoss(ignore_index=0)
    optimizer = optim.Adam(encoder_decoder.parameters(), lr=lr)
    model_path = './saved/' + model_name + '/'

    if (use_decay == False):
        gamma = 1.0
    else:
        gamma = 0.5
    scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

    #val_loss, val_bleu_score = evaluate(encoder_decoder, val_data_loader)

    best_bleu = 0.0

    for epoch, teacher_forcing in enumerate(teacher_forcing_schedule):
        #scheduler.step()
        print('epoch %i' % (epoch), flush=True)
        print('lr: ' + str(scheduler.get_lr()))

        for batch_idx, (input_idxs, target_idxs, input_tokens,
                        target_tokens) in enumerate(tqdm(train_data_loader)):
            # input_idxs and target_idxs have dim (batch_size x max_len)
            # they are NOT sorted by length
            '''
            print(input_idxs[0])
            print(input_tokens[0])
            print(target_idxs[0])
            print(target_tokens[0])
            '''
            lengths = (input_idxs != 0).long().sum(dim=1)
            sorted_lengths, order = torch.sort(lengths, descending=True)

            input_variable = Variable(input_idxs[order, :][:, :max(lengths)])
            target_variable = Variable(target_idxs[order, :])

            optimizer.zero_grad()
            output_log_probs, output_seqs = encoder_decoder(
                input_variable,
                list(sorted_lengths),
                targets=target_variable,
                keep_prob=keep_prob,
                teacher_forcing=teacher_forcing)
            batch_size = input_variable.shape[0]

            flattened_outputs = output_log_probs.view(batch_size * max_length,
                                                      -1)

            batch_loss = loss_function(flattened_outputs,
                                       target_variable.contiguous().view(-1))

            batch_loss.backward()
            optimizer.step()
            batch_outputs = trim_seqs(output_seqs)

            batch_targets = [[list(seq[seq > 0])]
                             for seq in list(to_np(target_variable))]

            #batch_bleu_score = corpus_bleu(batch_targets, batch_outputs, smoothing_function=SmoothingFunction().method2)
            batch_bleu_score = corpus_bleu(batch_targets, batch_outputs)
            '''
            if global_step < 10 or (global_step % 10 == 0 and global_step < 100) or (global_step % 100 == 0 and epoch < 2):
                input_string = "Amy, Please schedule a meeting with Marcos on Tuesday April 3rd. Adam Kleczewski"
                output_string = encoder_decoder.get_response(input_string)
                writer.add_text('schedule', output_string, global_step=global_step)

                input_string = "Amy, Please cancel this meeting. Adam Kleczewski"
                output_string = encoder_decoder.get_response(input_string)
                writer.add_text('cancel', output_string, global_step=global_step)
            '''

            if global_step % 100 == 0:

                writer.add_scalar('train_batch_loss', batch_loss, global_step)
                writer.add_scalar('train_batch_bleu_score', batch_bleu_score,
                                  global_step)

                for tag, value in encoder_decoder.named_parameters():
                    tag = tag.replace('.', '/')
                    writer.add_histogram('weights/' + tag,
                                         value,
                                         global_step,
                                         bins='doane')
                    writer.add_histogram('grads/' + tag,
                                         to_np(value.grad),
                                         global_step,
                                         bins='doane')

            global_step += 1

            debug = False

            if (debug):
                if batch_idx == 5:
                    break

        val_loss, val_bleu_score = evaluate(encoder_decoder, val_data_loader)

        writer.add_scalar('val_loss', val_loss, global_step=global_step)
        writer.add_scalar('val_bleu_score',
                          val_bleu_score,
                          global_step=global_step)

        encoder_embeddings = encoder_decoder.encoder.embedding.weight.data
        encoder_vocab = encoder_decoder.lang.tok_to_idx.keys()
        writer.add_embedding(encoder_embeddings,
                             metadata=encoder_vocab,
                             global_step=0,
                             tag='encoder_embeddings')

        decoder_embeddings = encoder_decoder.decoder.embedding.weight.data
        decoder_vocab = encoder_decoder.lang.tok_to_idx.keys()
        writer.add_embedding(decoder_embeddings,
                             metadata=decoder_vocab,
                             global_step=0,
                             tag='decoder_embeddings')
        '''
        input_string = "Amy, Please schedule a meeting with Marcos on Tuesday April 3rd. Adam Kleczewski"
        output_string = encoder_decoder.get_response(input_string)
        writer.add_text('schedule', output_string, global_step=global_step)

        input_string = "Amy, Please cancel this meeting. Adam Kleczewski"
        output_string = encoder_decoder.get_response(input_string)
        writer.add_text('cancel', output_string, global_step=global_step)
        '''

        calc_bleu_score = get_bleu(encoder_decoder, data_path, None, 'dev')
        print('val loss: %.5f, val BLEU score: %.5f' %
              (val_loss, calc_bleu_score),
              flush=True)
        if (calc_bleu_score > best_bleu):
            print("Best BLEU score! Saving model...")
            best_bleu = calc_bleu_score
            torch.save(
                encoder_decoder, "%s%s_%i_%.3f.pt" %
                (model_path, model_name, epoch, calc_bleu_score))

        print('-' * 100, flush=True)

        scheduler.step()
Example #2
0
def train(encoder_decoder: EncoderDecoder, train_data_loader: DataLoader,
          model_name, val_data_loader: DataLoader, keep_prob,
          teacher_forcing_schedule, lr, max_length, device,
          test_data_loader: DataLoader):

    global_step = 0
    loss_function = torch.nn.NLLLoss(ignore_index=0)
    optimizer = optim.Adam(encoder_decoder.parameters(), lr=lr)
    model_path = './model/' + model_name + '/'
    trained_model = encoder_decoder

    for epoch, teacher_forcing in enumerate(teacher_forcing_schedule):
        print('epoch %i' % epoch, flush=True)
        correct_predictions = 0.0
        all_predictions = 0.0
        for batch_idx, (input_idxs, target_idxs, input_tokens,
                        target_tokens) in enumerate(tqdm(train_data_loader)):
            # Empty the cache at each batch
            torch.cuda.empty_cache()
            # input_idxs and target_idxs have dim (batch_size x max_len)
            # they are NOT sorted by length

            lengths = (input_idxs != 0).long().sum(dim=1)
            sorted_lengths, order = torch.sort(lengths, descending=True)

            input_variable = input_idxs[order, :][:, :max(lengths)]
            input_variable = input_variable.to(device)
            target_variable = target_idxs[order, :]
            target_variable = target_variable.to(device)

            optimizer.zero_grad()
            output_log_probs, output_seqs = encoder_decoder(
                input_variable,
                list(sorted_lengths),
                targets=target_variable,
                keep_prob=keep_prob,
                teacher_forcing=teacher_forcing)

            batch_size = input_variable.shape[0]

            output_sentences = output_seqs.squeeze(2)

            flattened_outputs = output_log_probs.view(batch_size * max_length,
                                                      -1)

            batch_loss = loss_function(flattened_outputs,
                                       target_variable.contiguous().view(-1))
            batch_outputs = trim_seqs(output_seqs)

            batch_inputs = [[list(seq[seq > 0])]
                            for seq in list(to_np(input_variable))]
            batch_targets = [[list(seq[seq > 0])]
                             for seq in list(to_np(target_variable))]

            for i in range(len(batch_outputs)):
                y_i = batch_outputs[i]
                tgt_i = batch_targets[i][0]

                if y_i == tgt_i:
                    correct_predictions += 1.0

                all_predictions += 1.0

            batch_loss.backward()
            optimizer.step()

            batch_bleu_score = corpus_bleu(
                batch_targets,
                batch_outputs,
                smoothing_function=SmoothingFunction().method1)

            if global_step % 100 == 0:

                writer.add_scalar('train_batch_loss', batch_loss, global_step)
                writer.add_scalar('train_batch_bleu_score', batch_bleu_score,
                                  global_step)

                for tag, value in encoder_decoder.named_parameters():
                    tag = tag.replace('.', '/')
                    writer.add_histogram('weights/' + tag,
                                         value,
                                         global_step,
                                         bins='doane')
                    writer.add_histogram('grads/' + tag,
                                         to_np(value.grad),
                                         global_step,
                                         bins='doane')

            global_step += 1

        encoder_embeddings = encoder_decoder.encoder.embedding.weight.data
        encoder_vocab = encoder_decoder.lang.tok_to_idx.keys()
        writer.add_embedding(encoder_embeddings,
                             metadata=encoder_vocab,
                             global_step=0,
                             tag='encoder_embeddings')

        decoder_embeddings = encoder_decoder.decoder.embedding.weight.data
        decoder_vocab = encoder_decoder.lang.tok_to_idx.keys()
        writer.add_embedding(decoder_embeddings,
                             metadata=decoder_vocab,
                             global_step=0,
                             tag='decoder_embeddings')

        print('training accuracy %.5f' %
              (100.0 * (correct_predictions / all_predictions)))
        torch.save(encoder_decoder,
                   "%s%s_%i.pt" % (model_path, model_name, epoch))
        trained_model = encoder_decoder

        print('-' * 100, flush=True)

    torch.save(encoder_decoder, "%s%s_final.pt" % (model_path, model_name))
    return trained_model