예제 #1
0
    def generate(batch):
        # Greedy search
        vocab = tgt_field.vocab
        bos_id = vocab.stoi[tgt_field.init_token]
        batch_size = batch.src[0].size(1)
        max_length = batch.src[0].size(0) * 2

        src_words, src_length = batch.src
        annotations, encoder_rnn_state = model.encoder(words=src_words,
                                                       length=src_length)

        prev_pred = Variable(src_length.new(1, batch_size).fill_(bos_id))
        done = torch.zeros(batch_size).byte()
        hyps = []
        prev_state = DecoderState(rnn_state=encoder_rnn_state,
                                  input_feeding=model.input_feeding)
        for t in range(max_length):
            if done.all():
                break
            decoder_input = prev_pred
            logits, prev_state, attn_weights = model.decoder(
                annotations=annotations,
                annotations_length=src_length,
                state=prev_state,
                words=decoder_input)
            pred = logits.max(2)[1]
            prev_pred = pred
            hyps.append(pred.data)
        hyps = torch.cat(hyps, dim=0).transpose(0, 1).tolist()
        hyps = [hyps[i][:2 * length] for i, length in enumerate(src_length)]
        return hyps
예제 #2
0
 def test_lstm_decoder_train_complex(self):
     words = Variable(torch.arange(0, 12).view(4, 3).long())
     enc_states = Variable(torch.randn(5, 3, 6))
     enc_length = torch.LongTensor([5, 4, 2])
     enc_last_state = Variable(torch.randn(3, 3, 6))
     enc_last_state = (enc_last_state, enc_last_state)
     prev_state = DecoderState(rnn_state=enc_last_state, input_feeding=True)
     dec = decoders.RecurrentDecoder(rnn_type='lstm',
                                     num_words=20,
                                     word_dim=2,
                                     hidden_dim=6,
                                     annotation_dim=6,
                                     num_layers=3,
                                     attention_type='dot',
                                     input_feeding=True,
                                     dropout_prob=0.1)
     logits, decoder_state, attention_weights = dec.forward(
         annotations=enc_states,
         annotations_length=enc_length,
         state=prev_state,
         words=words)
     self.assertTupleEqual(tuple(logits.size()), (4, 3, 20))
     self.assertTupleEqual(tuple(decoder_state.attention.size()), (3, 6))
     self.assertTupleEqual(tuple(decoder_state.rnn[0].size()), (3, 3, 6))
     self.assertTupleEqual(tuple(decoder_state.rnn[1].size()), (3, 3, 6))
     self.assertTupleEqual(tuple(attention_weights.size()), (4, 3, 5))
예제 #3
0
 def test_gru_decoder_train_simple2(self):
     words = Variable(torch.arange(0, 12).view(4, 3).long())
     enc_states = Variable(torch.randn(5, 3, 12))
     enc_length = torch.LongTensor([5, 4, 2])
     enc_last_state = Variable(torch.randn(1, 3, 6))
     dec = decoders.RecurrentDecoder(rnn_type='gru',
                                     num_words=20,
                                     word_dim=2,
                                     hidden_dim=6,
                                     annotation_dim=12,
                                     num_layers=1,
                                     attention_type='mlp',
                                     input_feeding=False,
                                     dropout_prob=0.1)
     prev_state = DecoderState(rnn_state=enc_last_state,
                               input_feeding=False)
     logits, decoder_state, attention_weights = dec.forward(
         annotations=enc_states,
         annotations_length=enc_length,
         state=prev_state,
         words=words)
     self.assertTupleEqual(tuple(logits.size()), (4, 3, 20))
     self.assertFalse(decoder_state.input_feeding)
     self.assertTupleEqual(tuple(decoder_state.rnn.size()), (1, 3, 6))
     self.assertTupleEqual(tuple(attention_weights.size()), (4, 3, 5))
예제 #4
0
    def generate(batch):
        vocab = tgt_field.vocab
        pad_id = vocab.stoi[tgt_field.pad_token]
        bos_id = vocab.stoi[tgt_field.init_token]
        eos_id = vocab.stoi[tgt_field.eos_token]

        src_words, src_length = batch.src
        src_max_length, batch_size = src_words.size()
        src_length_sorted, sort_indices = src_length.sort(0, descending=True)
        orig_indices = sort_indices.sort()[1]
        src_words_sorted = src_words[:, sort_indices]

        beam = [
            Beam(size=args.beam_size,
                 n_best=1,
                 pad_id=pad_id,
                 bos_id=bos_id,
                 eos_id=eos_id,
                 device=args.gpu,
                 vocab=vocab,
                 global_scorer=None) for _ in range(batch_size)
        ]
        context, encoder_state = model.encoder(words=src_words_sorted,
                                               length=src_length_sorted)
        context = context[:, orig_indices, :]
        encoder_state = DecoderState.apply_to_rnn_state(
            lambda s: s[:, orig_indices], rnn_state=encoder_state)

        prev_state = DecoderState(rnn_state=encoder_state,
                                  input_feeding=model.input_feeding)

        context = context.repeat(1, args.beam_size, 1)
        src_length = src_length.repeat(args.beam_size)
        prev_state = prev_state.repeat(args.beam_size)

        for t in range(src_max_length * 2):
            if all((b.done() for b in beam)):
                break
            decoder_input = torch.stack([b.get_current_state() for b in beam],
                                        dim=1)
            # decoder_input: (1, beam_size * batch_size)
            decoder_input = decoder_input.view(1, -1)
            decoder_input = Variable(decoder_input, volatile=True)
            logits, prev_state, attn_weights = model.decoder(
                annotations=context,
                annotations_length=src_length,
                state=prev_state,
                words=decoder_input)
            log_probs = functional.log_softmax(logits.squeeze(0), dim=1)
            # log_probs: (beam_size, batch_size, num_words)
            log_probs = log_probs.view(args.beam_size, batch_size, -1)
            # attn_weights: (beam_size, batch_size, source_length)
            attn_weights = attn_weights.view(args.beam_size, batch_size, -1)
            for j, b in enumerate(beam):
                b.advance(word_lk=log_probs[:, j].data,
                          attn_out=attn_weights[:, j].data)
                # Update prev_state to point correct parents
                current_origin = b.get_current_origin()
                prev_state.beam_update(batch_index=j,
                                       beam_indices=current_origin,
                                       beam_size=args.beam_size)

        hyps = []
        for i, b in enumerate(beam):
            scores, ks = b.sort_finished(minimum=1)
            hyp, att = b.get_hyp(timestep=ks[0][0], k=ks[0][1])
            hyps.append(hyp[:src_length[i] * 2])
        return hyps