def beam_generate(self,h,c,tembs,vembs,gembs,nerd,beamsz,k): #h,c,tembs,vembs,gembs,rembs = self.encode_inputs(title,entities,graph) #h,c,tembs,vembs,gembs = self.encode_inputs(title,entities,graph) embs = [x for x in [(self.t_attn,tembs),(self.g_attn,gembs),(self.e_attn,vembs)] if x[1] is not None] outp = torch.LongTensor(vembs[0].size(0),1).fill_(self.starttok).cuda() last = h.transpose(0,1) outputs = [] beam = None for i in range(self.maxlen): outp = self.emb_w_vertex(outp.clone(),nerd) enc = self.Embedding(outp) decin = torch.cat((enc,last),2) decout,(h,c) = self.dlstm(decin,(h,c)) last, vweight, _ = self.hierattn(decout,embs) scalar = torch.sigmoid(self.switch(h)) outs = torch.cat((decout,last),2) decoded = self.outlin(outs.contiguous().view(-1, self.args.hsz*2)) decoded = decoded.view(outs.size(0), outs.size(1), self.args.ntoks) decoded = torch.softmax(decoded,2) decoded[:,:,0].fill_(0) decoded[:,:,1].fill_(0) scalars = scalar.transpose(0,1) decoded = torch.mul(decoded,1-scalars.expand_as(decoded)) vweights = torch.mul(vweight,scalars.expand_as(vweight)) decoded = torch.cat([decoded,vweights],2) zero_vec = 1e-6*torch.ones_like(decoded) decoded += zero_vec decoded = decoded.log() scores, words = decoded.topk(dim=2,k=k) #scores = scores.transpose(0,1); words = words.transpose(0,1) if not beam: beam = Beam(words.squeeze(),scores.squeeze(),[h for i in range(beamsz)], [c for i in range(beamsz)],[last for i in range(beamsz)],beamsz,k) beam.endtok = self.endtok newembs = [] for a,x in embs: tmp = (x[0].repeat(len(beam.beam),1,1),x[1].repeat(len(beam.beam),1)) newembs.append((a,tmp)) embs = newembs else: if not beam.update(scores,words,h,c,last): break newembs = [] for a,x in embs: tmp = (x[0][:len(beam.beam),:,:],x[1][:len(beam.beam)]) newembs.append((a,tmp)) embs = newembs outp = beam.getwords() h = beam.geth() c = beam.getc() last = beam.getlast() return beam
def beam_generate(self, batch, beam_size, k) : batch = batch.input encoder_output, context = self.encoder(batch[0], batch[1]) hidden = [] for i in range(len(context)) : each = context[i] hidden.append(torch.cat([each[0:each.size(0):2], each[1:each.size(0):2]], 2)) hx = hidden[0] cx = hidden[1] recent_token = torch.LongTensor(1, ).fill_(2).to(self.device) beam = None for i in range(1000) : embedded = self.decoder.embedding(recent_token.type(dtype = torch.long).to(self.device)) #(beam_size, embedding_size) embedded = embedded.unsqueeze(0).permute(1, 0, 2) output, (hx, cx) = self.decoder.rnn(embedded, (hx.contiguous(), cx.contiguous())) hx = hx.permute(1, 0, -1) cx = cx.permute(1, 0, -1) output = self.decoder.out(output.contiguous()) #(beam_size, 1, target_vocab_size) output = self.softmax(output) output[:, :, 0].fill_(0) output[:, :, 1].fill_(0) output[:, :, 2].fill_(0) decoded = output.log().to(self.device) scores, words = decoded.topk(dim = -1, k = k) #(beam_size, 1, k) (beam_size, 1, k) scores.to(self.device) words.to(self.device) if not beam : beam = Beam(words.squeeze(), scores.squeeze(), [hx] * beam_size, [cx] * beam_size, beam_size, k, self.decoder.output_vocab_size, self.device) beam.endtok = 5 beam.eostok = 3 else : if not beam.update(scores, words, hx, cx) : break recent_token = beam.getwords().view(-1) #(beam_size, ) hx = beam.get_h().permute(1, 0, -1) cx = beam.get_c().permute(1, 0, -1) #context = beam.get_context() return beam