def reorder_beam(self, beam_ids): if isinstance(self.hidden, tuple): hidden = (u.swap(self.hidden[0], 1, beam_ids), u.swap(self.hidden[1], 1, beam_ids)) else: hidden = u.swap(self.hidden, 1, beam_ids) self.hidden = hidden
def beam(self, width=5, seed_texts=None, max_seq_len=25, ignore_eos=False, bos=False, eos=False, **kwargs): """ Approximation to the highest probability output over the generated sequence using beam search. """ prev, hidden = self._seed(seed_texts, 1, bos, eos) eos = self.eos if not ignore_eos else None beam = Beam(width, prev.squeeze().data[0], eos=eos) while beam.active and len(beam) < max_seq_len: prev_data = beam.get_current_state().unsqueeze(0) prev = Variable(prev_data, volatile=True) outs, hidden, _ = self.model(prev, hidden=hidden, **kwargs) beam.advance(outs.data) if self.model.cell.startswith('LSTM'): hidden = (u.swap(hidden[0], 1, beam.get_source_beam()), u.swap(hidden[1], 1, beam.get_source_beam())) else: hidden = u.swap(hidden, 1, beam.get_source_beam()) scores, hyps = beam.decode(n=width) return scores, hyps
def translate_beam(self, src, max_decode_len=2, beam_width=5): """ Translate a single input sequence using beam search. Parameters: ----------- src: torch.LongTensor (seq_len x 1) """ pad = self.src_dict.get_pad() eos = self.src_dict.get_eos() bos = self.src_dict.get_bos() gpu = src.is_cuda # encode emb = self.src_embeddings(src) enc_outs, enc_hidden = self.encoder(emb, compute_mask=False, mask_symbol=pad) # decode enc_outs = enc_outs.repeat(1, beam_width, 1) if self.cell.startswith('LSTM'): enc_hidden = (enc_hidden[0].repeat(1, beam_width, 1), enc_hidden[1].repeat(1, beam_width, 1)) else: enc_hidden = enc_hidden.repeat(1, beam_width, 1) dec_hidden = self.decoder.init_hidden_for(enc_hidden) dec_out, enc_att = None, None if self.decoder.att_type == 'Bahdanau': enc_att = self.decoder.attn.project_enc_outs(enc_outs) beam = Beam(beam_width, bos, eos=eos, gpu=gpu) while beam.active and len(beam) < len(src) * max_decode_len: # add seq_len singleton dim (1 x width) prev_data = beam.get_current_state().unsqueeze(0) prev = Variable(prev_data, volatile=True) prev_emb = self.trg_embeddings(prev).squeeze(0) dec_out, dec_hidden, att_weights = self.decoder(prev_emb, dec_hidden, enc_outs, out=dec_out, enc_att=enc_att) # (width x vocab_size) outs = self.project(dec_out) beam.advance(outs.data) # TODO: this doesn't seem to affect the output :-s dec_out = u.swap(dec_out, 0, beam.get_source_beam()) if self.cell.startswith('LSTM'): dec_hidden = (u.swap(dec_hidden[0], 1, beam.get_source_beam()), u.swap(dec_hidden[1], 1, beam.get_source_beam())) else: dec_hidden = u.swap(dec_hidden, 1, beam.get_source_beam()) # decode beams scores, hyps = beam.decode(n=beam_width) return scores, hyps, None # TODO: return attention
def reorder_beam(self, beam_ids): """ Reorder state attributes to match the previously decoded beam order """ if self.input_feed is not None: self.input_feed = u.swap(self.input_feed, 0, beam_ids) if isinstance(self.hidden, tuple): hidden = (u.swap(self.hidden[0], 1, beam_ids), u.swap(self.hidden[1], 1, beam_ids)) else: hidden = u.swap(self.hidden, 1, beam_ids) self.hidden = hidden
def translate_beam(self, src, max_decode_len=2, beam_width=5): """ Translate a single input sequence using beam search. Parameters: ----------- src: torch.LongTensor (seq_len x 1) """ pad = self.src_dict.get_pad() eos = self.src_dict.get_eos() bos = self.src_dict.get_bos() gpu = src.is_cuda # encode emb = self.src_embeddings(src) enc_outs, enc_hidden = self.encoder( emb, compute_mask=False, mask_symbol=pad) # decode enc_outs = enc_outs.repeat(1, beam_width, 1) if self.cell.startswith('LSTM'): enc_hidden = (enc_hidden[0].repeat(1, beam_width, 1), enc_hidden[1].repeat(1, beam_width, 1)) else: enc_hidden = enc_hidden.repeat(1, beam_width, 1) dec_hidden = self.decoder.init_hidden_for(enc_hidden) dec_out, enc_att = None, None if self.decoder.att_type == 'Bahdanau': enc_att = self.decoder.attn.project_enc_outs(enc_outs) beam = Beam(beam_width, bos, eos=eos, gpu=gpu) while beam.active and len(beam) < len(src) * max_decode_len: # add seq_len singleton dim (1 x width) prev_data = beam.get_current_state().unsqueeze(0) prev = Variable(prev_data, volatile=True) prev_emb = self.trg_embeddings(prev).squeeze(0) dec_out, dec_hidden, att_weights = self.decoder( prev_emb, dec_hidden, enc_outs, out=dec_out, enc_att=enc_att) # (width x vocab_size) outs = self.project(dec_out) beam.advance(outs.data) # TODO: this doesn't seem to affect the output :-s dec_out = u.swap(dec_out, 0, beam.get_source_beam()) if self.cell.startswith('LSTM'): dec_hidden = (u.swap(dec_hidden[0], 1, beam.get_source_beam()), u.swap(dec_hidden[1], 1, beam.get_source_beam())) else: dec_hidden = u.swap(dec_hidden, 1, beam.get_source_beam()) # decode beams scores, hyps = beam.decode(n=beam_width) return scores, hyps, None # TODO: return attention
def beam(self, width=5, seed_text=None, max_seq_len=25, **kwargs): prev, hidden = self.seed(seed_text) beam = Beam(width, prev.squeeze().data[0], eos=self.eos) while beam.active and len(beam) < max_seq_len: prev_data = beam.get_current_state().unsqueeze(0) prev = Variable(prev_data, volatile=True) outs, hidden, _ = self.model(prev, hidden=hidden, **kwargs) beam.advance(outs.data) if self.model.cell.startswith('LSTM'): hidden = (u.swap(hidden[0], 1, beam.get_source_beam()), u.swap(hidden[1], 1, beam.get_source_beam())) else: hidden = u.swap(hidden, 1, beam.get_source_beam()) scores, hyps = beam.decode(n=width) return scores, hyps
def beam(self, width=5, seed_texts=None, max_seq_len=25, batch_size=1, ignore_eos=False, bos=False, **kwargs): if len(seed_text) > 1 or batch_size > 1: raise ValueError( "Currently beam search is limited to single item batches") prev, hidden = self._seed(seed_texts, batch_size, 'argmax', bos) eos = self.eos if not ignore_eos else None beam = Beam(width, prev.squeeze().data[0], eos=eos) while beam.active and len(beam) < max_seq_len: prev_data = beam.get_current_state().unsqueeze(0) prev = Variable(prev_data, volatile=True) outs, hidden, _ = self.model(prev, hidden=hidden, **kwargs) beam.advance(outs.data) if self.model.cell.startswith('LSTM'): hidden = (u.swap(hidden[0], 1, beam.get_source_beam()), u.swap(hidden[1], 1, beam.get_source_beam())) else: hidden = u.swap(hidden, 1, beam.get_source_beam()) scores, hyps = beam.decode(n=width) return scores, hyps
def translate_beam(self, src, max_decode_len=2, beam_width=5, conds=None): """ Translate a single input sequence using beam search. Parameters: ----------- src: torch.LongTensor (seq_len x 1) """ eos = self.src_dict.get_eos() bos = self.src_dict.get_bos() gpu = src.is_cuda # Encode emb = self.src_embeddings(src) enc_outs, enc_hidden = self.encoder(emb) enc_outs = enc_outs.repeat(1, beam_width, 1) if self.cell.startswith('LSTM'): enc_hidden = (enc_hidden[0].repeat(1, beam_width, 1), enc_hidden[1].repeat(1, beam_width, 1)) else: enc_hidden = enc_hidden.repeat(1, beam_width, 1) # Decode # (handler conditions) if self.cond_dim is not None: if conds is None: raise ValueError("Conditional decoder needs conds") conds = [emb(cond) for cond, emb in zip(conds, self.cond_embs)] # (batch_size x total emb dim) conds = torch.cat(conds, 1) conds = conds.repeat(beam_width, 1) dec_hidden = self.decoder.init_hidden_for(enc_hidden) dec_out, enc_att = None, None if self.decoder.att_type == 'Bahdanau': enc_att = self.decoder.attn.project_enc_outs(enc_outs) beam = Beam(beam_width, bos, eos=eos, gpu=gpu) while beam.active and len(beam) < len(src) * max_decode_len: # (width) -> (1 x width) prev = beam.get_current_state().unsqueeze(0) prev = Variable(prev, volatile=True) prev_emb = self.trg_embeddings(prev).squeeze(0) dec_out, dec_hidden, att_weights = self.decoder( prev_emb, dec_hidden, enc_outs, prev_out=dec_out, enc_att=enc_att, conds=conds) # (width x vocab_size) logprobs = self.project(dec_out) beam.advance(logprobs.data) # repackage according to source beam dec_out = u.swap(dec_out, 0, beam.get_source_beam()) if self.cell.startswith('LSTM'): dec_hidden = (u.swap(dec_hidden[0], 1, beam.get_source_beam()), u.swap(dec_hidden[1], 1, beam.get_source_beam())) else: dec_hidden = u.swap(dec_hidden, 1, beam.get_source_beam()) scores, hyps = beam.decode(n=beam_width) return scores, hyps, None