def beam_generate(self,b,beamsz,k): if self.args.title: tencs,_ = self.tenc(b.src) tmask = self.maskFromList(tencs.size(),b.src[1]).unsqueeze(1) ents = b.ent entlens = ents[2] ents = self.le(ents) if self.graph: gents,glob,grels = self.ge(b.rel[0],b.rel[1],(ents,entlens)) hx = glob #hx = ents.max(dim=1)[0] keys,mask = grels mask = mask==0 else: mask = self.maskFromList(ents.size(),entlens) hx = ents.max(dim=1)[0] keys =ents mask = mask.unsqueeze(1) if self.args.plan: planlogits = self.splan.plan_decode(hx,keys,mask.clone(),entlens) print(planlogits.size()) sorder = ' '.join([str(x) for x in planlogits.max(1)[1][0].tolist()]) print(sorder) sorder = [x.strip() for x in sorder.split("-1")] sorder = [[int(y) for y in x.strip().split(" ")] for x in sorder] mask.fill_(0) planplace = torch.zeros(hx.size(0)).long() for i,m in enumerate(sorder): mask[i][0][m[0]]=1 else: planlogits = None cx = torch.tensor(hx) a = self.attn(hx.unsqueeze(1),keys,mask=mask).squeeze(1) if self.args.title: a2 = self.attn2(hx.unsqueeze(1),tencs,mask=tmask).squeeze(1) a = torch.cat((a,a2),1) outputs = [] outp = torch.LongTensor(ents.size(0),1).fill_(self.starttok).cuda() beam = None for i in range(self.maxlen): op = self.emb_w_vertex(outp.clone(),b.nerd) if self.args.plan: schange = op==self.args.dottok if schange.nonzero().size(0)>0: print(schange, planplace, sorder) planplace[schange.nonzero().squeeze()]+=1 for j in schange.nonzero().squeeze(1): if planplace[j]<len(sorder[j]): mask[j] = 0 m = sorder[j][planplace[j]] mask[j][0][sorder[j][planplace[j]]]=1 op = self.emb(op).squeeze(1) prev = torch.cat((a,op),1) hx,cx = self.lstm(prev,(hx,cx)) a = self.attn(hx.unsqueeze(1),keys,mask=mask).squeeze(1) if self.args.title: a2 = self.attn2(hx.unsqueeze(1),tencs,mask=tmask).squeeze(1) #a = a + (self.mix(hx)*a2) a = torch.cat((a,a2),1) l = torch.cat((hx,a),1).unsqueeze(1) s = torch.sigmoid(self.switch(l)) o = self.out(l) o = torch.softmax(o,2) o = s*o #compute copy attn _, z = self.mattn(l,(ents,entlens)) #z = torch.softmax(z,2) z = (1-s)*z o = torch.cat((o,z),2) o[:,:,0].fill_(0) o[:,:,1].fill_(0) ''' if beam: for p,q in enumerate(beam.getPrevEnt()): o[p,:,q].fill_(0) for p,q in beam.getIsStart(): for r in q: o[p,:,r].fill_(0) ''' o = o+(1e-6*torch.ones_like(o)) decoded = o.log() scores, words = decoded.topk(dim=2,k=k) if not beam: beam = Beam(words.squeeze(),scores.squeeze(),[hx for i in range(beamsz)], [cx for i in range(beamsz)],[a for i in range(beamsz)],beamsz,k,self.args.ntoks) beam.endtok = self.endtok beam.eostok = self.eostok keys = keys.repeat(len(beam.beam),1,1) mask = mask.repeat(len(beam.beam),1,1) if self.args.title: tencs = tencs.repeat(len(beam.beam),1,1) tmask = tmask.repeat(len(beam.beam),1,1) if self.args.plan: planplace= planplace.unsqueeze(0).repeat(len(beam.beam),1) sorder = sorder*len(beam.beam) ents = ents.repeat(len(beam.beam),1,1) entlens = entlens.repeat(len(beam.beam)) else: if not beam.update(scores,words,hx,cx,a): break keys = keys[:len(beam.beam)] mask = mask[:len(beam.beam)] if self.args.title: tencs = tencs[:len(beam.beam)] tmask = tmask[:len(beam.beam)] if self.args.plan: planplace= planplace[:len(beam.beam)] sorder = sorder[0]*len(beam.beam) ents = ents[:len(beam.beam)] entlens = entlens[:len(beam.beam)] outp = beam.getwords() hx = beam.geth() cx = beam.getc() a = beam.getlast() return beam
def beam_generate(self, b, beamsz, k): if self.args.title: tencs, _ = self.tenc(b.src) # (1, title_len, 500) tmask = self.maskFromList(tencs.size(), b.src[1]).unsqueeze( 1) # (1, 1, title_len) ents = b.ent # tuple of (ent, phlens, elens) entlens = ents[2] ents = self.le(ents) # ents: (1, entity num, 500) / encoded hidden state of entities in b if self.graph: gents, glob, grels = self.ge(b.rel[0], b.rel[1], (ents, entlens)) hx = glob # hx = ents.max(dim=1)[0] keys, mask = grels mask = mask == 0 else: mask = self.maskFromList(ents.size(), entlens) hx = ents.max(dim=1)[0] keys = ents mask = mask.unsqueeze(1) if self.args.plan: planlogits = self.splan.plan_decode(hx, keys, mask.clone(), entlens) print(planlogits.size()) sorder = ' '.join( [str(x) for x in planlogits.max(1)[1][0].tolist()]) print(sorder) sorder = [x.strip() for x in sorder.split("-1")] sorder = [[int(y) for y in x.strip().split(" ")] for x in sorder] mask.fill_(0) planplace = torch.zeros(hx.size(0)).long() for i, m in enumerate(sorder): mask[i][0][m[0]] = 1 else: planlogits = None cx = torch.tensor(hx) # (beam size, 500) a = self.attn(hx.unsqueeze(1), keys, mask=mask).squeeze(1) # (1, 500) / c_g if self.args.title: a2 = self.attn2(hx.unsqueeze(1), tencs, mask=tmask).squeeze(1) # (1, 500) / c_s a = torch.cat((a, a2), 1) # (beam size, 1000) / c_t outp = torch.LongTensor(ents.size(0), 1).fill_( self.starttok).cuda() # initially, (1, 1) / start token beam = None for i in range(self.maxlen): op = self.emb_w_vertex(outp.clone(), b.nerd) # tag로부터 index를 없애는 작업 => train할 때도 not indexed => indexed 로 train 되었기 때문. if self.args.plan: schange = op == self.args.dottok if schange.nonzero().size(0) > 0: print(schange, planplace, sorder) planplace[schange.nonzero().squeeze()] += 1 for j in schange.nonzero().squeeze(1): if planplace[j] < len(sorder[j]): mask[j] = 0 m = sorder[j][planplace[j]] mask[j][0][sorder[j][planplace[j]]] = 1 op = self.emb(op).squeeze(1) # (beam size, 500) prev = torch.cat((a, op), 1) # (beam size, 1000) hx, cx = self.lstm(prev, (hx, cx)) a = self.attn(hx.unsqueeze(1), keys, mask=mask).squeeze(1) if self.args.title: a2 = self.attn2(hx.unsqueeze(1), tencs, mask=tmask).squeeze(1) a = torch.cat((a, a2), 1) # (beam size, 1000) l = torch.cat((hx, a), 1).unsqueeze(1) # (beam size, 1, 1500) s = torch.sigmoid(self.switch(l)) # (beam size, 1, 1) o = self.out(l) # (beam size, 1, target vocab size) o = torch.softmax(o, 2) o = s * o # compute copy attn _, z = self.mattn(l, (ents, entlens)) # z = torch.softmax(z,2) z = (1 - s) * z o = torch.cat((o, z), 2) # (beam size, 1, target vocab size + entity num) o[:, :, 0].fill_( 0) # remove probability for special tokens <unk>, <init> o[:, :, 1].fill_(0) ''' if beam: for p,q in enumerate(beam.getPrevEnt()): o[p,:,q].fill_(0) for p,q in beam.getIsStart(): for r in q: o[p,:,r].fill_(0) ''' o = o + (1e-6 * torch.ones_like(o)) decoded = o.log() scores, words = decoded.topk( dim=2, k=k) # (beam size, 1, k), (beam size, 1, k) if not beam: beam = Beam(words.squeeze(), scores.squeeze(), [hx for i in range(beamsz)], [cx for i in range(beamsz)], [a for i in range(beamsz)], beamsz, k, self.args.ntoks) beam.endtok = self.endtok beam.eostok = self.eostok keys = keys.repeat(len(beam.beam), 1, 1) # (beam size, adjacency matrix len, 500) mask = mask.repeat( len(beam.beam), 1, 1) # (beam size, 1, adjacency matrix len) => all 0? if self.args.title: tencs = tencs.repeat(len(beam.beam), 1, 1) # (beam size, title_len, 500) tmask = tmask.repeat(len(beam.beam), 1, 1) # (1, 1, title_len) if self.args.plan: planplace = planplace.unsqueeze(0).repeat( len(beam.beam), 1) sorder = sorder * len(beam.beam) ents = ents.repeat(len(beam.beam), 1, 1) # (beam size, entity num, 500) entlens = entlens.repeat(len(beam.beam)) # (beam size,) else: if not beam.update(scores, words, hx, cx, a): break # if beam size changes (i.e. any of beam ends), change size of weight matrices accordingly keys = keys[:len(beam.beam)] mask = mask[:len(beam.beam)] if self.args.title: tencs = tencs[:len(beam.beam)] tmask = tmask[:len(beam.beam)] if self.args.plan: planplace = planplace[:len(beam.beam)] sorder = sorder[0] * len(beam.beam) ents = ents[:len(beam.beam)] entlens = entlens[:len(beam.beam)] outp = beam.getwords() # (beam size,) / next word for each beam hx = beam.geth() # (beam size, 500) cx = beam.getc() # (beam size, 500) a = beam.getlast() # (beam size, 1000) return beam