Beispiel #1
0
 def reorder_beam(self, beam_ids):
     if isinstance(self.hidden, tuple):
         hidden = (u.swap(self.hidden[0], 1,
                          beam_ids), u.swap(self.hidden[1], 1, beam_ids))
     else:
         hidden = u.swap(self.hidden, 1, beam_ids)
     self.hidden = hidden
Beispiel #2
0
    def beam(self,
             width=5,
             seed_texts=None,
             max_seq_len=25,
             ignore_eos=False,
             bos=False,
             eos=False,
             **kwargs):
        """
        Approximation to the highest probability output over the generated
        sequence using beam search.
        """
        prev, hidden = self._seed(seed_texts, 1, bos, eos)
        eos = self.eos if not ignore_eos else None
        beam = Beam(width, prev.squeeze().data[0], eos=eos)

        while beam.active and len(beam) < max_seq_len:
            prev_data = beam.get_current_state().unsqueeze(0)
            prev = Variable(prev_data, volatile=True)
            outs, hidden, _ = self.model(prev, hidden=hidden, **kwargs)
            beam.advance(outs.data)
            if self.model.cell.startswith('LSTM'):
                hidden = (u.swap(hidden[0], 1, beam.get_source_beam()),
                          u.swap(hidden[1], 1, beam.get_source_beam()))
            else:
                hidden = u.swap(hidden, 1, beam.get_source_beam())

        scores, hyps = beam.decode(n=width)

        return scores, hyps
Beispiel #3
0
    def translate_beam(self, src, max_decode_len=2, beam_width=5):
        """
        Translate a single input sequence using beam search.

        Parameters:
        -----------

        src: torch.LongTensor (seq_len x 1)
        """
        pad = self.src_dict.get_pad()
        eos = self.src_dict.get_eos()
        bos = self.src_dict.get_bos()
        gpu = src.is_cuda
        # encode
        emb = self.src_embeddings(src)
        enc_outs, enc_hidden = self.encoder(emb,
                                            compute_mask=False,
                                            mask_symbol=pad)
        # decode
        enc_outs = enc_outs.repeat(1, beam_width, 1)
        if self.cell.startswith('LSTM'):
            enc_hidden = (enc_hidden[0].repeat(1, beam_width, 1),
                          enc_hidden[1].repeat(1, beam_width, 1))
        else:
            enc_hidden = enc_hidden.repeat(1, beam_width, 1)
        dec_hidden = self.decoder.init_hidden_for(enc_hidden)
        dec_out, enc_att = None, None
        if self.decoder.att_type == 'Bahdanau':
            enc_att = self.decoder.attn.project_enc_outs(enc_outs)
        beam = Beam(beam_width, bos, eos=eos, gpu=gpu)
        while beam.active and len(beam) < len(src) * max_decode_len:
            # add seq_len singleton dim (1 x width)
            prev_data = beam.get_current_state().unsqueeze(0)
            prev = Variable(prev_data, volatile=True)
            prev_emb = self.trg_embeddings(prev).squeeze(0)
            dec_out, dec_hidden, att_weights = self.decoder(prev_emb,
                                                            dec_hidden,
                                                            enc_outs,
                                                            out=dec_out,
                                                            enc_att=enc_att)
            # (width x vocab_size)
            outs = self.project(dec_out)
            beam.advance(outs.data)
            # TODO: this doesn't seem to affect the output :-s
            dec_out = u.swap(dec_out, 0, beam.get_source_beam())
            if self.cell.startswith('LSTM'):
                dec_hidden = (u.swap(dec_hidden[0], 1, beam.get_source_beam()),
                              u.swap(dec_hidden[1], 1, beam.get_source_beam()))
            else:
                dec_hidden = u.swap(dec_hidden, 1, beam.get_source_beam())
        # decode beams
        scores, hyps = beam.decode(n=beam_width)
        return scores, hyps, None  # TODO: return attention
Beispiel #4
0
 def reorder_beam(self, beam_ids):
     """
     Reorder state attributes to match the previously decoded beam order
     """
     if self.input_feed is not None:
         self.input_feed = u.swap(self.input_feed, 0, beam_ids)
     if isinstance(self.hidden, tuple):
         hidden = (u.swap(self.hidden[0], 1,
                          beam_ids), u.swap(self.hidden[1], 1, beam_ids))
     else:
         hidden = u.swap(self.hidden, 1, beam_ids)
     self.hidden = hidden
    def translate_beam(self, src, max_decode_len=2, beam_width=5):
        """
        Translate a single input sequence using beam search.

        Parameters:
        -----------

        src: torch.LongTensor (seq_len x 1)
        """
        pad = self.src_dict.get_pad()
        eos = self.src_dict.get_eos()
        bos = self.src_dict.get_bos()
        gpu = src.is_cuda
        # encode
        emb = self.src_embeddings(src)
        enc_outs, enc_hidden = self.encoder(
            emb, compute_mask=False, mask_symbol=pad)
        # decode
        enc_outs = enc_outs.repeat(1, beam_width, 1)
        if self.cell.startswith('LSTM'):
            enc_hidden = (enc_hidden[0].repeat(1, beam_width, 1),
                          enc_hidden[1].repeat(1, beam_width, 1))
        else:
            enc_hidden = enc_hidden.repeat(1, beam_width, 1)
        dec_hidden = self.decoder.init_hidden_for(enc_hidden)
        dec_out, enc_att = None, None
        if self.decoder.att_type == 'Bahdanau':
            enc_att = self.decoder.attn.project_enc_outs(enc_outs)
        beam = Beam(beam_width, bos, eos=eos, gpu=gpu)
        while beam.active and len(beam) < len(src) * max_decode_len:
            # add seq_len singleton dim (1 x width)
            prev_data = beam.get_current_state().unsqueeze(0)
            prev = Variable(prev_data, volatile=True)
            prev_emb = self.trg_embeddings(prev).squeeze(0)
            dec_out, dec_hidden, att_weights = self.decoder(
                prev_emb, dec_hidden, enc_outs, out=dec_out, enc_att=enc_att)
            # (width x vocab_size)
            outs = self.project(dec_out)
            beam.advance(outs.data)
            # TODO: this doesn't seem to affect the output :-s
            dec_out = u.swap(dec_out, 0, beam.get_source_beam())
            if self.cell.startswith('LSTM'):
                dec_hidden = (u.swap(dec_hidden[0], 1, beam.get_source_beam()),
                              u.swap(dec_hidden[1], 1, beam.get_source_beam()))
            else:
                dec_hidden = u.swap(dec_hidden, 1, beam.get_source_beam())
        # decode beams
        scores, hyps = beam.decode(n=beam_width)
        return scores, hyps, None  # TODO: return attention
Beispiel #6
0
 def beam(self, width=5, seed_text=None, max_seq_len=25, **kwargs):
     prev, hidden = self.seed(seed_text)
     beam = Beam(width, prev.squeeze().data[0], eos=self.eos)
     while beam.active and len(beam) < max_seq_len:
         prev_data = beam.get_current_state().unsqueeze(0)
         prev = Variable(prev_data, volatile=True)
         outs, hidden, _ = self.model(prev, hidden=hidden, **kwargs)
         beam.advance(outs.data)
         if self.model.cell.startswith('LSTM'):
             hidden = (u.swap(hidden[0], 1, beam.get_source_beam()),
                       u.swap(hidden[1], 1, beam.get_source_beam()))
         else:
             hidden = u.swap(hidden, 1, beam.get_source_beam())
     scores, hyps = beam.decode(n=width)
     return scores, hyps
Beispiel #7
0
 def beam(self, width=5, seed_texts=None, max_seq_len=25, batch_size=1,
          ignore_eos=False, bos=False, **kwargs):
     if len(seed_text) > 1 or batch_size > 1:
         raise ValueError(
             "Currently beam search is limited to single item batches")
     prev, hidden = self._seed(seed_texts, batch_size, 'argmax', bos)
     eos = self.eos if not ignore_eos else None
     beam = Beam(width, prev.squeeze().data[0], eos=eos)
     while beam.active and len(beam) < max_seq_len:
         prev_data = beam.get_current_state().unsqueeze(0)
         prev = Variable(prev_data, volatile=True)
         outs, hidden, _ = self.model(prev, hidden=hidden, **kwargs)
         beam.advance(outs.data)
         if self.model.cell.startswith('LSTM'):
             hidden = (u.swap(hidden[0], 1, beam.get_source_beam()),
                       u.swap(hidden[1], 1, beam.get_source_beam()))
         else:
             hidden = u.swap(hidden, 1, beam.get_source_beam())
     scores, hyps = beam.decode(n=width)
     return scores, hyps
Beispiel #8
0
    def translate_beam(self, src, max_decode_len=2, beam_width=5, conds=None):
        """
        Translate a single input sequence using beam search.

        Parameters:
        -----------

        src: torch.LongTensor (seq_len x 1)
        """
        eos = self.src_dict.get_eos()
        bos = self.src_dict.get_bos()
        gpu = src.is_cuda

        # Encode
        emb = self.src_embeddings(src)
        enc_outs, enc_hidden = self.encoder(emb)
        enc_outs = enc_outs.repeat(1, beam_width, 1)
        if self.cell.startswith('LSTM'):
            enc_hidden = (enc_hidden[0].repeat(1, beam_width, 1),
                          enc_hidden[1].repeat(1, beam_width, 1))
        else:
            enc_hidden = enc_hidden.repeat(1, beam_width, 1)

        # Decode
        # (handler conditions)
        if self.cond_dim is not None:
            if conds is None:
                raise ValueError("Conditional decoder needs conds")
            conds = [emb(cond) for cond, emb in zip(conds, self.cond_embs)]
            # (batch_size x total emb dim)
            conds = torch.cat(conds, 1)
            conds = conds.repeat(beam_width, 1)

        dec_hidden = self.decoder.init_hidden_for(enc_hidden)
        dec_out, enc_att = None, None
        if self.decoder.att_type == 'Bahdanau':
            enc_att = self.decoder.attn.project_enc_outs(enc_outs)

        beam = Beam(beam_width, bos, eos=eos, gpu=gpu)

        while beam.active and len(beam) < len(src) * max_decode_len:
            # (width) -> (1 x width)
            prev = beam.get_current_state().unsqueeze(0)
            prev = Variable(prev, volatile=True)
            prev_emb = self.trg_embeddings(prev).squeeze(0)

            dec_out, dec_hidden, att_weights = self.decoder(
                prev_emb, dec_hidden, enc_outs, prev_out=dec_out,
                enc_att=enc_att, conds=conds)
            # (width x vocab_size)
            logprobs = self.project(dec_out)
            beam.advance(logprobs.data)

            # repackage according to source beam
            dec_out = u.swap(dec_out, 0, beam.get_source_beam())
            if self.cell.startswith('LSTM'):
                dec_hidden = (u.swap(dec_hidden[0], 1, beam.get_source_beam()),
                              u.swap(dec_hidden[1], 1, beam.get_source_beam()))
            else:
                dec_hidden = u.swap(dec_hidden, 1, beam.get_source_beam())

        scores, hyps = beam.decode(n=beam_width)

        return scores, hyps, None