def beam_generate(self,h,c,tembs,vembs,gembs,nerd,beamsz,k): #h,c,tembs,vembs,gembs,rembs = self.encode_inputs(title,entities,graph) #h,c,tembs,vembs,gembs = self.encode_inputs(title,entities,graph) embs = [x for x in [(self.t_attn,tembs),(self.g_attn,gembs),(self.e_attn,vembs)] if x[1] is not None] outp = torch.LongTensor(vembs[0].size(0),1).fill_(self.starttok).cuda() last = h.transpose(0,1) outputs = [] beam = None for i in range(self.maxlen): outp = self.emb_w_vertex(outp.clone(),nerd) enc = self.Embedding(outp) decin = torch.cat((enc,last),2) decout,(h,c) = self.dlstm(decin,(h,c)) last, vweight, _ = self.hierattn(decout,embs) scalar = torch.sigmoid(self.switch(h)) outs = torch.cat((decout,last),2) decoded = self.outlin(outs.contiguous().view(-1, self.args.hsz*2)) decoded = decoded.view(outs.size(0), outs.size(1), self.args.ntoks) decoded = torch.softmax(decoded,2) decoded[:,:,0].fill_(0) decoded[:,:,1].fill_(0) scalars = scalar.transpose(0,1) decoded = torch.mul(decoded,1-scalars.expand_as(decoded)) vweights = torch.mul(vweight,scalars.expand_as(vweight)) decoded = torch.cat([decoded,vweights],2) zero_vec = 1e-6*torch.ones_like(decoded) decoded += zero_vec decoded = decoded.log() scores, words = decoded.topk(dim=2,k=k) #scores = scores.transpose(0,1); words = words.transpose(0,1) if not beam: beam = Beam(words.squeeze(),scores.squeeze(),[h for i in range(beamsz)], [c for i in range(beamsz)],[last for i in range(beamsz)],beamsz,k) beam.endtok = self.endtok newembs = [] for a,x in embs: tmp = (x[0].repeat(len(beam.beam),1,1),x[1].repeat(len(beam.beam),1)) newembs.append((a,tmp)) embs = newembs else: if not beam.update(scores,words,h,c,last): break newembs = [] for a,x in embs: tmp = (x[0][:len(beam.beam),:,:],x[1][:len(beam.beam)]) newembs.append((a,tmp)) embs = newembs outp = beam.getwords() h = beam.geth() c = beam.getc() last = beam.getlast() return beam