Example #1
0
def get_bleu(encoder_decoder: EncoderDecoder, data_path, model_name,
             data_type):
    # test_file = open("data/copynet_test.txt", "r", encoding='utf-8')
    test_file = open(data_path + 'copynet_' + data_type + '.txt',
                     'r',
                     encoding='utf-8')
    if (data_type != 'dev'):
        out_file = open("results/" + model_name.split('/')[-1] + ".txt",
                        'w',
                        encoding='utf-8')
    total_score = 0.0
    num = 0.0
    for i, row in enumerate(tqdm(test_file)):
        sql = row.split('\t')[1]
        gold_nl = row.split('\t')[0]
        predicted = encoder_decoder.get_response(sql)
        predicted = predicted.replace('<SOS>', '')
        predicted = predicted.replace('<EOS>', '')
        predicted = predicted.rstrip()
        if (data_type != 'dev'):
            out_file.write(predicted + "\n")

        # score = sentence_bleu([gold_nl.split()], predicted.split(), smoothing_function=SmoothingFunction().method2)
        score = sentence_bleu([gold_nl.split()], predicted.split())
        # score = sentence_bleu(ref, pred)
        total_score += score
        num += 1
        '''
        if i == 10:
            break
        '''

    # del encoder_decoder
    test_file.close()
    if (data_type != 'dev'):
        out_file.close()
    final_score = total_score * 100 / num
    if (data_type == 'dev'):
        print("DEV set")
    else:
        print("TEST set")
    print("BLEU score is " + str(final_score))
    return final_score
Example #2
0
def train(encoder_decoder: EncoderDecoder, train_data_loader: DataLoader,
          model_name, val_data_loader: DataLoader, keep_prob,
          teacher_forcing_schedule, lr, max_length):

    global_step = 0
    loss_function = torch.nn.NLLLoss(ignore_index=0)
    optimizer = optim.Adam(encoder_decoder.parameters(), lr=lr)
    model_path = './model/' + model_name + '/'

    for epoch, teacher_forcing in enumerate(teacher_forcing_schedule):
        print('epoch %i' % epoch, flush=True)

        for batch_idx, (input_idxs, target_idxs, input_tokens,
                        target_tokens) in enumerate(tqdm(train_data_loader)):
            # input_idxs and target_idxs have dim (batch_size x max_len)
            # they are NOT sorted by length

            lengths = (input_idxs != 0).long().sum(dim=1)
            sorted_lengths, order = torch.sort(lengths, descending=True)

            input_variable = Variable(input_idxs[order, :][:, :max(lengths)])
            target_variable = Variable(target_idxs[order, :])

            optimizer.zero_grad()
            output_log_probs, output_seqs = encoder_decoder(
                input_variable,
                list(sorted_lengths),
                targets=target_variable,
                keep_prob=keep_prob,
                teacher_forcing=teacher_forcing)

            batch_size = input_variable.shape[0]

            flattened_outputs = output_log_probs.view(batch_size * max_length,
                                                      -1)

            batch_loss = loss_function(flattened_outputs,
                                       target_variable.contiguous().view(-1))
            batch_loss.backward()
            optimizer.step()

            batch_outputs = trim_seqs(output_seqs)

            batch_targets = [[list(seq[seq > 0])]
                             for seq in list(to_np(target_variable))]

            batch_bleu_score = corpus_bleu(
                batch_targets,
                batch_outputs,
                smoothing_function=SmoothingFunction().method1)

            if global_step < 10 or (global_step % 10 == 0
                                    and global_step < 100) or (
                                        global_step % 100 == 0 and epoch < 2):
                input_string = "Amy, Please schedule a meeting with Marcos on Tuesday April 3rd. Adam Kleczewski"
                output_string = encoder_decoder.get_response(input_string)
                writer.add_text('schedule',
                                output_string,
                                global_step=global_step)

                input_string = "Amy, Please cancel this meeting. Adam Kleczewski"
                output_string = encoder_decoder.get_response(input_string)
                writer.add_text('cancel',
                                output_string,
                                global_step=global_step)

            if global_step % 100 == 0:

                writer.add_scalar('train_batch_loss', batch_loss, global_step)
                writer.add_scalar('train_batch_bleu_score', batch_bleu_score,
                                  global_step)

                for tag, value in encoder_decoder.named_parameters():
                    tag = tag.replace('.', '/')
                    writer.add_histogram('weights/' + tag,
                                         value,
                                         global_step,
                                         bins='doane')
                    writer.add_histogram('grads/' + tag,
                                         to_np(value.grad),
                                         global_step,
                                         bins='doane')

            global_step += 1

        val_loss, val_bleu_score = evaluate(encoder_decoder, val_data_loader)

        writer.add_scalar('val_loss', val_loss, global_step=global_step)
        writer.add_scalar('val_bleu_score',
                          val_bleu_score,
                          global_step=global_step)

        encoder_embeddings = encoder_decoder.encoder.embedding.weight.data
        encoder_vocab = encoder_decoder.lang.tok_to_idx.keys()
        writer.add_embedding(encoder_embeddings,
                             metadata=encoder_vocab,
                             global_step=0,
                             tag='encoder_embeddings')

        decoder_embeddings = encoder_decoder.decoder.embedding.weight.data
        decoder_vocab = encoder_decoder.lang.tok_to_idx.keys()
        writer.add_embedding(decoder_embeddings,
                             metadata=decoder_vocab,
                             global_step=0,
                             tag='decoder_embeddings')

        input_string = "Amy, Please schedule a meeting with Marcos on Tuesday April 3rd. Adam Kleczewski"
        output_string = encoder_decoder.get_response(input_string)
        writer.add_text('schedule', output_string, global_step=global_step)

        input_string = "Amy, Please cancel this meeting. Adam Kleczewski"
        output_string = encoder_decoder.get_response(input_string)
        writer.add_text('cancel', output_string, global_step=global_step)

        print('val loss: %.5f, val BLEU score: %.5f' %
              (val_loss, val_bleu_score),
              flush=True)
        torch.save(encoder_decoder,
                   "%s%s_%i.pt" % (model_path, model_name, epoch))

        print('-' * 100, flush=True)