def beam_decode(self, batch, max_len, oov_nums): bos_token = self.data_utils.bos beam_size = self.args.beam_size vocab_size = self.data_utils.vocab_size src = batch['src'].long() src_mask = batch['src_mask'] src_extended = batch['src_extended'].long() memory = self.model.encode(src, src_mask) batch_size = src.size(0) beam = Beam(self.data_utils.pad, bos_token, self.data_utils.eos, beam_size, batch_size, self.args.n_best, True, max_len) ys = torch.full((batch_size, 1), bos_token).type_as(src.data).cuda() log_prob = self.model.decode( memory, src_mask, Variable(ys), Variable( subsequent_mask(ys.size(1)).type_as(src.data).expand( (ys.size(0), ys.size(1), ys.size(1)))), src_extended, oov_nums) # log_prob = [batch_size, 1, voc_size] top_prob, top_indices = torch.topk(input=log_prob, k=beam_size, dim=-1) # print(top_indices) top_prob = top_prob.view(-1, 1) top_indices = top_indices.view(-1, 1) beam.update_prob(top_prob.detach().cpu(), top_indices.detach().cpu()) # [batch_size, 1, beam_size] ys = top_indices top_indices = None # print(ys.size()) ####### repeat var ####### src = torch.repeat_interleave(src, beam_size, dim=0) src_mask = torch.repeat_interleave(src_mask, beam_size, dim=0) #[batch_size, src_len, d_model] -> [batch_size*beam_size, src_len, d_model] memory = torch.repeat_interleave(memory, beam_size, dim=0) # print('max_len', max_len) for t in range(1, max_len): log_prob = self.model.decode( memory, src_mask, Variable(ys), Variable( subsequent_mask(ys.size(1)).type_as(src.data).expand( (ys.size(0), ys.size(1), ys.size(1)))), src) # print('log_prob', log_prob.size()) log_prob = log_prob[:, -1].unsqueeze(1) # print(beam.seq) real_top = beam.advance(log_prob.detach().cpu()) # print(real_top.size()) # print(ys.size()) # print(real_top.size()) ys = torch.cat((ys, real_top.view(-1, 1).cuda()), dim=-1) # print(ys.size()) # print(ys.size()) # print(beam.top_prob) # print(len(beam.seq)) return [beam.seq[0]]