def beamsearch(self, src, src_mask, beam_size=10, normalize=False, max_len=None, min_len=None): max_len = src.size(1) * 3 if max_len is None else max_len min_len = src.size(1) / 2 if min_len is None else min_len enc_context, _ = self.encoder(src, src_mask) enc_context = enc_context.contiguous() avg_enc_context = enc_context.sum(1) enc_context_len = src_mask.sum(1).unsqueeze(-1).expand_as( avg_enc_context) avg_enc_context = avg_enc_context / enc_context_len attn_mask = src_mask.bool() hidden = torch.tanh(self.init_affine(avg_enc_context)) prev_beam = Beam(beam_size) prev_beam.candidates = [[self.dec_sos]] prev_beam.scores = [0] f_done = lambda x: x[-1] == self.dec_eos valid_size = beam_size hyp_list = [] for k in range(max_len): candidates = prev_beam.candidates input = src.new_tensor([cand[-1] for cand in candidates]) input = self.dec_emb_dp(self.emb(input)) output, hidden = self.decoder(input, hidden, attn_mask, enc_context) log_prob = F.log_softmax(self.affine(output), dim=1) if k < min_len: log_prob[:, self.dec_eos] = -float("inf") if k == max_len - 1: eos_prob = log_prob[:, self.dec_eos].clone() log_prob[:, :] = -float("inf") log_prob[:, self.dec_eos] = eos_prob next_beam = Beam(valid_size) done_list, remain_list = next_beam.step(-log_prob, prev_beam, f_done) hyp_list.extend(done_list) valid_size -= len(done_list) if valid_size == 0: break beam_remain_ix = src.new_tensor(remain_list) enc_context = enc_context.index_select(0, beam_remain_ix) attn_mask = attn_mask.index_select(0, beam_remain_ix) hidden = hidden.index_select(0, beam_remain_ix) prev_beam = next_beam score_list = [hyp[1] for hyp in hyp_list] hyp_list = [ hyp[0][1:hyp[0].index(self.dec_eos)] if self.dec_eos in hyp[0] else hyp[0][1:] for hyp in hyp_list ] if normalize: for k, (hyp, score) in enumerate(zip(hyp_list, score_list)): if len(hyp) > 0: score_list[k] = score_list[k] / len(hyp) score = hidden.new_tensor(score_list) sort_score, sort_ix = torch.sort(score) output = [] for ix in sort_ix.tolist(): output.append((hyp_list[ix], score[ix].item())) return output
def beamsearch(self, src, src_mask, beam_size=8, normalize=False, max_len=None, min_len=None): max_len = src.size(1) * 3 if max_len is None else max_len min_len = src.size(1) / 2 if min_len is None else min_len enc_context = self.encoder(src, src_mask) enc_context = enc_context.contiguous() avg_enc_context = enc_context.sum(1) enc_context_len = src_mask.sum(1).unsqueeze(-1).expand_as(avg_enc_context) avg_enc_context = avg_enc_context / enc_context_len attn_mask = src_mask.byte() init_hidden = F.tanh(self.eLN(self.e2d(avg_enc_context))) hidden = [init_hidden.clone() for _ in xrange(self.dec_nlayer)] prev_beam = Beam(beam_size) prev_beam.candidates = [[self.dec_sos]] prev_beam.scores = [0] f_done = (lambda x: x[-1] == self.dec_eos) valid_size = beam_size hyp_list = [] for k in xrange(max_len): candidates = prev_beam.candidates input = src.new_tensor(map(lambda cand: cand[-1], candidates)) input = self.dec_emb_dp(self.emb(input)) output, hidden = self.decoder(input, hidden, attn_mask, enc_context) logit = torch.matmul(output, self.emb.weight.t()) if self.tied_emb else self.proj(output) log_prob = F.log_softmax(logit, dim=1) if k < min_len: log_prob[:, self.dec_eos] = -float('inf') if k == max_len - 1: eos_prob = log_prob[:, self.dec_eos].clone() log_prob[:, :] = -float('inf') log_prob[:, self.dec_eos] = eos_prob next_beam = Beam(valid_size) done_list, remain_list = next_beam.step(-log_prob, prev_beam, f_done) hyp_list.extend(done_list) valid_size -= len(done_list) if valid_size == 0: break beam_remain_ix = src.new_tensor(remain_list) enc_context = enc_context.index_select(0, beam_remain_ix) attn_mask = attn_mask.index_select(0, beam_remain_ix) hidden = [h.index_select(0, beam_remain_ix) for h in hidden] prev_beam = next_beam score_list = [hyp[1] for hyp in hyp_list] hyp_list = [hyp[0][1: hyp[0].index(self.dec_eos)] if self.dec_eos in hyp[0] else hyp[0][1:] for hyp in hyp_list] if normalize: for k, (hyp, score) in enumerate(zip(hyp_list, score_list)): if len(hyp) > 0: score_list[k] = score_list[k] / len(hyp) score = hidden[0].new_tensor(score_list) sort_score, sort_ix = torch.sort(score) output = [] for ix in sort_ix.tolist(): output.append((hyp_list[ix], score[ix].item())) return output