Esempio n. 1
0
    def translate(self):
        """Translate the whole dataset."""
        trg_preds = []
        trg_gold = []
        output_res = open(self.output,'w')
        for j in xrange(
            0, len(self.src['data']),
            self.config['data']['batch_size']
        ):
            """Decode a single minibatch."""
            print('Decoding %d out of %d ' % (j, len(self.src['data'])))
            hypotheses, scores = decoder.decode_batch(j)
            all_hyp_inds = [[x[0] for x in hyp] for hyp in hypotheses]
            all_preds = [
                ' '.join([trg['id2word'][x] for x in hyp[:-1]])
                for hyp in all_hyp_inds
            ]



            # Get target minibatch
            input_lines_trg_gold, output_lines_trg_gold, lens_src, mask_src = (
                get_minibatch(
                    self.trg['data'], self.tgt_dict, j,
                    self.config['data']['batch_size'],
                    self.config['data']['max_trg_length'],
                    is_gui=False, add_start=True, add_end=True
                )
            )

            output_lines_trg_gold = output_lines_trg_gold.data.cpu().numpy()
            all_gold_inds = [[x for x in hyp] for hyp in output_lines_trg_gold]
            all_gold = [
                ' '.join([trg['id2word'][x] for x in hyp[:-1]])
                for hyp in all_gold_inds
            ]

            trg_preds += all_preds
            trg_gold += all_gold

        output_res.writelines('\n'.join(trg_preds))
        bleu_score = get_bleu(trg_preds, trg_gold)
        output_res.close()
        print('BLEU : %.5f ' % (bleu_score))
Esempio n. 2
0
    def translate(self):
        """Evaluate model."""
        preds = []
        ground_truths = []
        for j in xrange(0, len(self.src['data']),
                        self.config['data']['batch_size']):

            print('Decoding : %d out of %d ' % (j, len(self.src['data'])))
            # Get source minibatch
            input_lines_src, output_lines_src, lens_src, mask_src = (
                get_minibatch(self.src['data'],
                              self.src['word2id'],
                              j,
                              self.config['data']['batch_size'],
                              self.config['data']['max_src_length'],
                              add_start=True,
                              add_end=True))

            input_lines_src = Variable(input_lines_src.data, volatile=True)
            output_lines_src = Variable(output_lines_src.data, volatile=True)
            mask_src = Variable(mask_src.data, volatile=True)

            # Get target minibatch
            input_lines_trg_gold, output_lines_trg_gold, lens_src, mask_src = (
                get_minibatch(self.trg['data'],
                              self.trg['word2id'],
                              j,
                              self.config['data']['batch_size'],
                              self.config['data']['max_trg_length'],
                              add_start=True,
                              add_end=True))

            input_lines_trg_gold = Variable(input_lines_trg_gold.data,
                                            volatile=True)
            output_lines_trg_gold = Variable(output_lines_trg_gold.data,
                                             volatile=True)
            mask_src = Variable(mask_src.data, volatile=True)

            # Initialize target with <s> for every sentence
            input_lines_trg = Variable(torch.LongTensor(
                [[trg['word2id']['<s>']]
                 for i in range(input_lines_src.size(0))]),
                                       volatile=True).cuda()

            # Decode a minibatch greedily __TODO__ add beam search decoding
            input_lines_trg = self.decode_minibatch(input_lines_src,
                                                    input_lines_trg,
                                                    output_lines_trg_gold)

            # Copy minibatch outputs to cpu and convert ids to words
            input_lines_trg = input_lines_trg.data.cpu().numpy()
            input_lines_trg = [[self.trg['id2word'][x] for x in line]
                               for line in input_lines_trg]

            # Do the same for gold sentences
            output_lines_trg_gold = output_lines_trg_gold.data.cpu().numpy()
            output_lines_trg_gold = [[self.trg['id2word'][x] for x in line]
                                     for line in output_lines_trg_gold]

            # Process outputs
            for sentence_pred, sentence_real, sentence_real_src in zip(
                    input_lines_trg, output_lines_trg_gold, output_lines_src):
                if '</s>' in sentence_pred:
                    index = sentence_pred.index('</s>')
                else:
                    index = len(sentence_pred)
                preds.append(['<s>'] + sentence_pred[:index + 1])

                if '</s>' in sentence_real:
                    index = sentence_real.index('</s>')
                else:
                    index = len(sentence_real)

                ground_truths.append(['<s>'] + sentence_real[:index + 1])

        bleu_score = get_bleu(preds, ground_truths)
        print('BLEU score : %.5f ' % (bleu_score))
Esempio n. 3
0
# -*- coding: utf-8 -*-
from evaluate import get_bleu

preds = [
    list(
        "It is a guide to action that ensures that the military will forever heed Party commands."
        .split())
]
groud_truths = [
    list(
        "It is a guide to action which ensures that the military always <unk> the commands of the party."
        .split())
]
print(preds)
print(groud_truths)
print(get_bleu(preds, groud_truths))
"""BUGS
preds:  ['<s>', '<s>', 'I', 'I', 'I', 'the', 'the', 'the', 'the', 'the', 'of', 'the', 'the', 'the', 'the', 'the', 'of', 'the', 'the', 'the', 'the', 'the', 'of', 'the', 'the', 'the', 'the', 'the', 'of', 'the', 'the', 'the', 'the', 'the', 'of', 'the', 'the', 'the', 'the', 'the', 'of', 'the', 'the', 'the', 'the', 'the', 'of', 'the', 'the', 'the', 'the', 'the']
ground_truths:  ['<s>', '<unk>', 'will', 'they', '<unk>', '<unk>', '<unk>', 'knowledge', 'of', 'both', '<unk>', 'and', '<unk>', '<unk>', 'in', '<unk>', '</s>']
"""
Esempio n. 4
0
    def translate(self):
        """Evaluate model."""
        preds = []
        ground_truths = []
        out_put = open(self.output, 'w')

        for j in xrange(0, len(self.src['data']),
                        self.config['data']['batch_size']):

            print('Decoding : %d out of %d ' % (j, len(self.src['data'])))
            # Get source minibatch
            input_lines_src, output_lines_src, lens_src, mask_src = (
                get_minibatch(self.src['data'],
                              self.src['word2id'],
                              j,
                              self.config['data']['batch_size'],
                              self.config['data']['max_src_length'],
                              is_gui=False,
                              add_start=True,
                              add_end=True))
            #if input_lines_src.size(0) != self.config['data']['batch_size']:
            #    break

            input_lines_src = Variable(input_lines_src.data, volatile=True)
            output_lines_src = Variable(output_lines_src.data, volatile=True)
            mask_src = Variable(mask_src.data, volatile=True)

            # Get target minibatch
            input_lines_trg_gold, output_lines_trg_gold, lens_src, mask_src = (
                get_minibatch(self.trg['data'],
                              self.trg['word2id'],
                              j,
                              self.config['data']['batch_size'],
                              self.config['data']['max_trg_length'],
                              is_gui=False,
                              add_start=True,
                              add_end=True))

            input_lines_trg_gold = Variable(input_lines_trg_gold.data,
                                            volatile=True)
            output_lines_trg_gold = Variable(output_lines_trg_gold.data,
                                             volatile=True)
            mask_src = Variable(mask_src.data, volatile=True)

            input_lines_gui, output_lines_gui, lens_gui, mask_gui, input_type_gui = get_minibatch(
                self.gui['data'],
                self.gui['word2id'],
                j,
                self.config['data']['batch_size'],
                self.config['data']['max_gui_length'],
                is_gui=True,
                add_start=True,
                add_end=True,
                line_types=self.gui['type'])

            input_lines_gui = Variable(input_lines_gui.data, volatile=True)
            output_lines_gui = Variable(output_lines_gui.data, volatile=True)
            input_type_gui = Variable(input_type_gui.data, volatile=True)
            mask_gui = Variable(mask_gui.data, volatile=True)

            # Initialize target with <s> for every sentence
            input_lines_trg = Variable(torch.LongTensor(
                [[trg['word2id']['<s>']]
                 for i in xrange(input_lines_src.size(0))]),
                                       volatile=True).cuda()

            # Decode a minibatch greedily __TODO__ add beam search decoding
            input_lines_trg = self.decode_minibatch(input_lines_src,
                                                    input_lines_trg,
                                                    input_lines_gui,
                                                    input_type_gui,
                                                    output_lines_trg_gold)

            # Copy minibatch outputs to cpu and convert ids to words
            input_lines_trg = input_lines_trg.data.cpu().numpy()
            input_lines_trg = [[self.trg['id2word'][x] for x in line]
                               for line in input_lines_trg]

            # Do the same for gold sentences
            output_lines_trg_gold = output_lines_trg_gold.data.cpu().numpy()
            output_lines_trg_gold = [[self.trg['id2word'][x] for x in line]
                                     for line in output_lines_trg_gold]

            # Process outputs
            for sentence_pred, sentence_real, sentence_real_src in zip(
                    input_lines_trg, output_lines_trg_gold, output_lines_src):
                '''
                while True:
                    if '<unk>' in sentence_pred:
                        sentence_pred = sentence_pred.remove('<unk>')

                    else:
                        break
                '''
                if '</s>' in sentence_pred:
                    index = sentence_pred.index('</s>')
                else:
                    index = len(sentence_pred)

                preds.append(sentence_pred[:index + 1])

                out_put.writelines(' '.join(sentence_pred[1:index]) + '\n')
                print('Predicted : %s ' % (' '.join(sentence_pred[1:index])))

                if '</s>' in sentence_real:
                    index = sentence_real.index('</s>')
                else:
                    index = len(sentence_real)

                ground_truths.append(['<s>'] + sentence_real[:index + 1])

                print('-----------------------------------------------')
                print('Real : %s ' % (' '.join(sentence_real[:index])))

                print('===============================================')

            #print(preds)

        bleu_score = get_bleu(preds, ground_truths)
        print('BLEU score : %.5f ' % (bleu_score))
        out_put.close()
Esempio n. 5
0
def train(encoder_decoder: EncoderDecoder, train_data_loader: DataLoader,
          model_name, val_data_loader: DataLoader, keep_prob,
          teacher_forcing_schedule, lr, max_length, use_decay, data_path):

    global_step = 0
    loss_function = torch.nn.NLLLoss(ignore_index=0)
    optimizer = optim.Adam(encoder_decoder.parameters(), lr=lr)
    model_path = './saved/' + model_name + '/'

    if (use_decay == False):
        gamma = 1.0
    else:
        gamma = 0.5
    scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

    #val_loss, val_bleu_score = evaluate(encoder_decoder, val_data_loader)

    best_bleu = 0.0

    for epoch, teacher_forcing in enumerate(teacher_forcing_schedule):
        #scheduler.step()
        print('epoch %i' % (epoch), flush=True)
        print('lr: ' + str(scheduler.get_lr()))

        for batch_idx, (input_idxs, target_idxs, input_tokens,
                        target_tokens) in enumerate(tqdm(train_data_loader)):
            # input_idxs and target_idxs have dim (batch_size x max_len)
            # they are NOT sorted by length
            '''
            print(input_idxs[0])
            print(input_tokens[0])
            print(target_idxs[0])
            print(target_tokens[0])
            '''
            lengths = (input_idxs != 0).long().sum(dim=1)
            sorted_lengths, order = torch.sort(lengths, descending=True)

            input_variable = Variable(input_idxs[order, :][:, :max(lengths)])
            target_variable = Variable(target_idxs[order, :])

            optimizer.zero_grad()
            output_log_probs, output_seqs = encoder_decoder(
                input_variable,
                list(sorted_lengths),
                targets=target_variable,
                keep_prob=keep_prob,
                teacher_forcing=teacher_forcing)
            batch_size = input_variable.shape[0]

            flattened_outputs = output_log_probs.view(batch_size * max_length,
                                                      -1)

            batch_loss = loss_function(flattened_outputs,
                                       target_variable.contiguous().view(-1))

            batch_loss.backward()
            optimizer.step()
            batch_outputs = trim_seqs(output_seqs)

            batch_targets = [[list(seq[seq > 0])]
                             for seq in list(to_np(target_variable))]

            #batch_bleu_score = corpus_bleu(batch_targets, batch_outputs, smoothing_function=SmoothingFunction().method2)
            batch_bleu_score = corpus_bleu(batch_targets, batch_outputs)
            '''
            if global_step < 10 or (global_step % 10 == 0 and global_step < 100) or (global_step % 100 == 0 and epoch < 2):
                input_string = "Amy, Please schedule a meeting with Marcos on Tuesday April 3rd. Adam Kleczewski"
                output_string = encoder_decoder.get_response(input_string)
                writer.add_text('schedule', output_string, global_step=global_step)

                input_string = "Amy, Please cancel this meeting. Adam Kleczewski"
                output_string = encoder_decoder.get_response(input_string)
                writer.add_text('cancel', output_string, global_step=global_step)
            '''

            if global_step % 100 == 0:

                writer.add_scalar('train_batch_loss', batch_loss, global_step)
                writer.add_scalar('train_batch_bleu_score', batch_bleu_score,
                                  global_step)

                for tag, value in encoder_decoder.named_parameters():
                    tag = tag.replace('.', '/')
                    writer.add_histogram('weights/' + tag,
                                         value,
                                         global_step,
                                         bins='doane')
                    writer.add_histogram('grads/' + tag,
                                         to_np(value.grad),
                                         global_step,
                                         bins='doane')

            global_step += 1

            debug = False

            if (debug):
                if batch_idx == 5:
                    break

        val_loss, val_bleu_score = evaluate(encoder_decoder, val_data_loader)

        writer.add_scalar('val_loss', val_loss, global_step=global_step)
        writer.add_scalar('val_bleu_score',
                          val_bleu_score,
                          global_step=global_step)

        encoder_embeddings = encoder_decoder.encoder.embedding.weight.data
        encoder_vocab = encoder_decoder.lang.tok_to_idx.keys()
        writer.add_embedding(encoder_embeddings,
                             metadata=encoder_vocab,
                             global_step=0,
                             tag='encoder_embeddings')

        decoder_embeddings = encoder_decoder.decoder.embedding.weight.data
        decoder_vocab = encoder_decoder.lang.tok_to_idx.keys()
        writer.add_embedding(decoder_embeddings,
                             metadata=decoder_vocab,
                             global_step=0,
                             tag='decoder_embeddings')
        '''
        input_string = "Amy, Please schedule a meeting with Marcos on Tuesday April 3rd. Adam Kleczewski"
        output_string = encoder_decoder.get_response(input_string)
        writer.add_text('schedule', output_string, global_step=global_step)

        input_string = "Amy, Please cancel this meeting. Adam Kleczewski"
        output_string = encoder_decoder.get_response(input_string)
        writer.add_text('cancel', output_string, global_step=global_step)
        '''

        calc_bleu_score = get_bleu(encoder_decoder, data_path, None, 'dev')
        print('val loss: %.5f, val BLEU score: %.5f' %
              (val_loss, calc_bleu_score),
              flush=True)
        if (calc_bleu_score > best_bleu):
            print("Best BLEU score! Saving model...")
            best_bleu = calc_bleu_score
            torch.save(
                encoder_decoder, "%s%s_%i_%.3f.pt" %
                (model_path, model_name, epoch, calc_bleu_score))

        print('-' * 100, flush=True)

        scheduler.step()