def beam_generate(self, batch, beam_size, k) : batch = batch.input encoder_output, context = self.encoder(batch[0], batch[1]) hidden = [] for i in range(len(context)) : each = context[i] hidden.append(torch.cat([each[0:each.size(0):2], each[1:each.size(0):2]], 2)) hx = hidden[0] cx = hidden[1] recent_token = torch.LongTensor(1, ).fill_(2).to(self.device) beam = None for i in range(1000) : embedded = self.decoder.embedding(recent_token.type(dtype = torch.long).to(self.device)) #(beam_size, embedding_size) embedded = embedded.unsqueeze(0).permute(1, 0, 2) output, (hx, cx) = self.decoder.rnn(embedded, (hx.contiguous(), cx.contiguous())) hx = hx.permute(1, 0, -1) cx = cx.permute(1, 0, -1) output = self.decoder.out(output.contiguous()) #(beam_size, 1, target_vocab_size) output = self.softmax(output) output[:, :, 0].fill_(0) output[:, :, 1].fill_(0) output[:, :, 2].fill_(0) decoded = output.log().to(self.device) scores, words = decoded.topk(dim = -1, k = k) #(beam_size, 1, k) (beam_size, 1, k) scores.to(self.device) words.to(self.device) if not beam : beam = Beam(words.squeeze(), scores.squeeze(), [hx] * beam_size, [cx] * beam_size, beam_size, k, self.decoder.output_vocab_size, self.device) beam.endtok = 5 beam.eostok = 3 else : if not beam.update(scores, words, hx, cx) : break recent_token = beam.getwords().view(-1) #(beam_size, ) hx = beam.get_h().permute(1, 0, -1) cx = beam.get_c().permute(1, 0, -1) #context = beam.get_context() return beam