def beam_generate(self, batch, beam_size, k): if self.args.title: title_encoding, _ = self.title_encoder(batch.title) # (1, title_len, 500) title_mask = self.create_mask(title_encoding.size(), batch.title[1]).unsqueeze(1) # (1, 1, title_len) ents = batch.ent ent_num_list = ents[2] ents = self.entity_encoder(ents) # ents: (1, entity num, 500) / encoded hidden state of entities in b glob, node_embeddings, node_mask = self.graph_encoder(batch.rel[0], batch.rel[1], (ents, ent_num_list)) hx = glob node_mask = node_mask == 0 node_mask = node_mask.unsqueeze(1) cx = hx.clone().detach().requires_grad_(True) # (beam size, 500) context = self.attention_graph(hx.unsqueeze(1), node_embeddings, mask=node_mask).squeeze(1) # (1, 500) / c_g if self.args.title: title_context = self.attention_title(hx.unsqueeze(1), title_encoding, mask=title_mask).squeeze(1) # (1, 500) / c_s context = torch.cat((context, title_context), 1) # (beam size, 1000) / cat of c_g, c_s recent_token = torch.LongTensor(ents.size(0), 1).fill_(self.starttok).to(self.args.device) # recent_token: initially, (1, 1) / start token beam = None for i in range(self.args.maxlen): op = self.trim_entity_index(recent_token.clone(), batch.nerd) # previous token을 만들기 위해 tag로부터 index를 없애는 작업 # 없애는 이유: train할 때도 output vocab (input) => target vocab (output) 으로 train 되었기 때문. op = self.embed(op).squeeze(1) # (beam size, 500) prev = torch.cat((context, op), 1) # (beam size, 1000) hx, cx = self.decoder(prev, (hx, cx)) context = self.attention_graph(hx.unsqueeze(1), node_embeddings, mask=node_mask).squeeze(1) if self.args.title: title_context = self.attention_title(hx.unsqueeze(1), title_encoding, mask=title_mask).squeeze(1) context = torch.cat((context, title_context), 1) # (beam size, 1000) total_context = torch.cat((hx, context), 1).unsqueeze(1) # (beam size, 1, 1500) sampling_prob = torch.sigmoid(self.switch(total_context)) # (beam size, 1, 1) out = self.out(total_context) # (beam size, 1, target vocab size) out = torch.softmax(out, 2) out = sampling_prob * out # compute copy attention z = self.mat_attention(total_context, (ents, ent_num_list)) # z: (1, max_abstract_len, max_entity_num) / entities attended on each abstract word, then softmaxed z = (1 - sampling_prob) * z out = torch.cat((out, z), 2) # (beam size, 1, target vocab size + entity num) out[:, :, 0].fill_(0) # remove probability for special tokens <unk>, <init> out[:, :, 1].fill_(0) out = out + (1e-6 * torch.ones_like(out)) decoded = out.log() scores, words = decoded.topk(dim=2, k=k) # (beam size, 1, k), (beam size, 1, k) if not beam: beam = Beam(words.squeeze(), scores.squeeze(), [hx] * beam_size, [cx] * beam_size, [context] * beam_size, beam_size, k, self.args.output_vocab_size, self.args.device) beam.endtok = self.endtok beam.eostok = self.eostok node_embeddings = node_embeddings.repeat(len(beam.beam), 1, 1) # (beam size, adjacency matrix len, 500) node_mask = node_mask.repeat(len(beam.beam), 1, 1) # (beam size, 1, adjacency matrix len) => all 0? if self.args.title: title_encoding = title_encoding.repeat(len(beam.beam), 1, 1) # (beam size, title_len, 500) title_mask = title_mask.repeat(len(beam.beam), 1, 1) # (1, 1, title_len) ents = ents.repeat(len(beam.beam), 1, 1) # (beam size, entity num, 500) ent_num_list = ent_num_list.repeat(len(beam.beam)) # (beam size,) else: # if all beam nodes have ended, stop generating if not beam.update(scores, words, hx, cx, context): break # if beam size changes (i.e. any of beam ends), change size of weight matrices accordingly node_embeddings = node_embeddings[:len(beam.beam)] node_mask = node_mask[:len(beam.beam)] if self.args.title: title_encoding = title_encoding[:len(beam.beam)] title_mask = title_mask[:len(beam.beam)] ents = ents[:len(beam.beam)] ent_num_list = ent_num_list[:len(beam.beam)] recent_token = beam.getwords() # (beam size,) / next word for each beam hx = beam.get_h() # (beam size, 500) cx = beam.get_c() # (beam size, 500) context = beam.get_context() # (beam size, 1000) return beam
def beam_generate(self,b,beamsz,k): if self.args.title: tencs,_ = self.tenc(b.src) tmask = self.maskFromList(tencs.size(),b.src[1]).unsqueeze(1) ents = b.ent entlens = ents[2] ents = self.le(ents) if self.graph: gents,glob,grels = self.ge(b.rel[0],b.rel[1],(ents,entlens)) hx = glob #hx = ents.max(dim=1)[0] keys,mask = grels mask = mask==0 else: mask = self.maskFromList(ents.size(),entlens) hx = ents.max(dim=1)[0] keys =ents mask = mask.unsqueeze(1) if self.args.plan: planlogits = self.splan.plan_decode(hx,keys,mask.clone(),entlens) print(planlogits.size()) sorder = ' '.join([str(x) for x in planlogits.max(1)[1][0].tolist()]) print(sorder) sorder = [x.strip() for x in sorder.split("-1")] sorder = [[int(y) for y in x.strip().split(" ")] for x in sorder] mask.fill_(0) planplace = torch.zeros(hx.size(0)).long() for i,m in enumerate(sorder): mask[i][0][m[0]]=1 else: planlogits = None cx = torch.tensor(hx) a = self.attn(hx.unsqueeze(1),keys,mask=mask).squeeze(1) if self.args.title: a2 = self.attn2(hx.unsqueeze(1),tencs,mask=tmask).squeeze(1) a = torch.cat((a,a2),1) outputs = [] outp = torch.LongTensor(ents.size(0),1).fill_(self.starttok).cuda() beam = None for i in range(self.maxlen): op = self.emb_w_vertex(outp.clone(),b.nerd) if self.args.plan: schange = op==self.args.dottok if schange.nonzero().size(0)>0: print(schange, planplace, sorder) planplace[schange.nonzero().squeeze()]+=1 for j in schange.nonzero().squeeze(1): if planplace[j]<len(sorder[j]): mask[j] = 0 m = sorder[j][planplace[j]] mask[j][0][sorder[j][planplace[j]]]=1 op = self.emb(op).squeeze(1) prev = torch.cat((a,op),1) hx,cx = self.lstm(prev,(hx,cx)) a = self.attn(hx.unsqueeze(1),keys,mask=mask).squeeze(1) if self.args.title: a2 = self.attn2(hx.unsqueeze(1),tencs,mask=tmask).squeeze(1) #a = a + (self.mix(hx)*a2) a = torch.cat((a,a2),1) l = torch.cat((hx,a),1).unsqueeze(1) s = torch.sigmoid(self.switch(l)) o = self.out(l) o = torch.softmax(o,2) o = s*o #compute copy attn _, z = self.mattn(l,(ents,entlens)) #z = torch.softmax(z,2) z = (1-s)*z o = torch.cat((o,z),2) o[:,:,0].fill_(0) o[:,:,1].fill_(0) ''' if beam: for p,q in enumerate(beam.getPrevEnt()): o[p,:,q].fill_(0) for p,q in beam.getIsStart(): for r in q: o[p,:,r].fill_(0) ''' o = o+(1e-6*torch.ones_like(o)) decoded = o.log() scores, words = decoded.topk(dim=2,k=k) if not beam: beam = Beam(words.squeeze(),scores.squeeze(),[hx for i in range(beamsz)], [cx for i in range(beamsz)],[a for i in range(beamsz)],beamsz,k,self.args.ntoks) beam.endtok = self.endtok beam.eostok = self.eostok keys = keys.repeat(len(beam.beam),1,1) mask = mask.repeat(len(beam.beam),1,1) if self.args.title: tencs = tencs.repeat(len(beam.beam),1,1) tmask = tmask.repeat(len(beam.beam),1,1) if self.args.plan: planplace= planplace.unsqueeze(0).repeat(len(beam.beam),1) sorder = sorder*len(beam.beam) ents = ents.repeat(len(beam.beam),1,1) entlens = entlens.repeat(len(beam.beam)) else: if not beam.update(scores,words,hx,cx,a): break keys = keys[:len(beam.beam)] mask = mask[:len(beam.beam)] if self.args.title: tencs = tencs[:len(beam.beam)] tmask = tmask[:len(beam.beam)] if self.args.plan: planplace= planplace[:len(beam.beam)] sorder = sorder[0]*len(beam.beam) ents = ents[:len(beam.beam)] entlens = entlens[:len(beam.beam)] outp = beam.getwords() hx = beam.geth() cx = beam.getc() a = beam.getlast() return beam
def beam_generate(self, b, beamsz, k): if self.args.title: tencs, _ = self.tenc(b.src) # (1, title_len, 500) tmask = self.maskFromList(tencs.size(), b.src[1]).unsqueeze( 1) # (1, 1, title_len) ents = b.ent # tuple of (ent, phlens, elens) entlens = ents[2] ents = self.le(ents) # ents: (1, entity num, 500) / encoded hidden state of entities in b if self.graph: gents, glob, grels = self.ge(b.rel[0], b.rel[1], (ents, entlens)) hx = glob # hx = ents.max(dim=1)[0] keys, mask = grels mask = mask == 0 else: mask = self.maskFromList(ents.size(), entlens) hx = ents.max(dim=1)[0] keys = ents mask = mask.unsqueeze(1) if self.args.plan: planlogits = self.splan.plan_decode(hx, keys, mask.clone(), entlens) print(planlogits.size()) sorder = ' '.join( [str(x) for x in planlogits.max(1)[1][0].tolist()]) print(sorder) sorder = [x.strip() for x in sorder.split("-1")] sorder = [[int(y) for y in x.strip().split(" ")] for x in sorder] mask.fill_(0) planplace = torch.zeros(hx.size(0)).long() for i, m in enumerate(sorder): mask[i][0][m[0]] = 1 else: planlogits = None cx = torch.tensor(hx) # (beam size, 500) a = self.attn(hx.unsqueeze(1), keys, mask=mask).squeeze(1) # (1, 500) / c_g if self.args.title: a2 = self.attn2(hx.unsqueeze(1), tencs, mask=tmask).squeeze(1) # (1, 500) / c_s a = torch.cat((a, a2), 1) # (beam size, 1000) / c_t outp = torch.LongTensor(ents.size(0), 1).fill_( self.starttok).cuda() # initially, (1, 1) / start token beam = None for i in range(self.maxlen): op = self.emb_w_vertex(outp.clone(), b.nerd) # tag로부터 index를 없애는 작업 => train할 때도 not indexed => indexed 로 train 되었기 때문. if self.args.plan: schange = op == self.args.dottok if schange.nonzero().size(0) > 0: print(schange, planplace, sorder) planplace[schange.nonzero().squeeze()] += 1 for j in schange.nonzero().squeeze(1): if planplace[j] < len(sorder[j]): mask[j] = 0 m = sorder[j][planplace[j]] mask[j][0][sorder[j][planplace[j]]] = 1 op = self.emb(op).squeeze(1) # (beam size, 500) prev = torch.cat((a, op), 1) # (beam size, 1000) hx, cx = self.lstm(prev, (hx, cx)) a = self.attn(hx.unsqueeze(1), keys, mask=mask).squeeze(1) if self.args.title: a2 = self.attn2(hx.unsqueeze(1), tencs, mask=tmask).squeeze(1) a = torch.cat((a, a2), 1) # (beam size, 1000) l = torch.cat((hx, a), 1).unsqueeze(1) # (beam size, 1, 1500) s = torch.sigmoid(self.switch(l)) # (beam size, 1, 1) o = self.out(l) # (beam size, 1, target vocab size) o = torch.softmax(o, 2) o = s * o # compute copy attn _, z = self.mattn(l, (ents, entlens)) # z = torch.softmax(z,2) z = (1 - s) * z o = torch.cat((o, z), 2) # (beam size, 1, target vocab size + entity num) o[:, :, 0].fill_( 0) # remove probability for special tokens <unk>, <init> o[:, :, 1].fill_(0) ''' if beam: for p,q in enumerate(beam.getPrevEnt()): o[p,:,q].fill_(0) for p,q in beam.getIsStart(): for r in q: o[p,:,r].fill_(0) ''' o = o + (1e-6 * torch.ones_like(o)) decoded = o.log() scores, words = decoded.topk( dim=2, k=k) # (beam size, 1, k), (beam size, 1, k) if not beam: beam = Beam(words.squeeze(), scores.squeeze(), [hx for i in range(beamsz)], [cx for i in range(beamsz)], [a for i in range(beamsz)], beamsz, k, self.args.ntoks) beam.endtok = self.endtok beam.eostok = self.eostok keys = keys.repeat(len(beam.beam), 1, 1) # (beam size, adjacency matrix len, 500) mask = mask.repeat( len(beam.beam), 1, 1) # (beam size, 1, adjacency matrix len) => all 0? if self.args.title: tencs = tencs.repeat(len(beam.beam), 1, 1) # (beam size, title_len, 500) tmask = tmask.repeat(len(beam.beam), 1, 1) # (1, 1, title_len) if self.args.plan: planplace = planplace.unsqueeze(0).repeat( len(beam.beam), 1) sorder = sorder * len(beam.beam) ents = ents.repeat(len(beam.beam), 1, 1) # (beam size, entity num, 500) entlens = entlens.repeat(len(beam.beam)) # (beam size,) else: if not beam.update(scores, words, hx, cx, a): break # if beam size changes (i.e. any of beam ends), change size of weight matrices accordingly keys = keys[:len(beam.beam)] mask = mask[:len(beam.beam)] if self.args.title: tencs = tencs[:len(beam.beam)] tmask = tmask[:len(beam.beam)] if self.args.plan: planplace = planplace[:len(beam.beam)] sorder = sorder[0] * len(beam.beam) ents = ents[:len(beam.beam)] entlens = entlens[:len(beam.beam)] outp = beam.getwords() # (beam size,) / next word for each beam hx = beam.geth() # (beam size, 500) cx = beam.getc() # (beam size, 500) a = beam.getlast() # (beam size, 1000) return beam