def forward(self, inputs, listener_outputs, function=F.log_softmax, teacher_forcing_ratio=0.90, use_beam_search=False): batch_size = inputs.size(0) max_length = inputs.size(1) - 1 # minus the start of sequence symbol decode_results = list() use_teacher_forcing = True if random.random( ) < teacher_forcing_ratio else False hidden = self._init_state(batch_size) if use_beam_search: # TopK Decoding input_ = inputs[:, 0].unsqueeze(1) beam = Beam(k=self.k, decoder=self, batch_size=batch_size, max_length=max_length, function=function, device=self.device) logits = None y_hats = beam.search(input_, listener_outputs) else: if use_teacher_forcing: # if teacher_forcing, Infer all at once speller_inputs = inputs[inputs != self.eos_id].view( batch_size, -1) predicted_softmax, hidden = self.forward_step( input_=speller_inputs, hidden=hidden, listener_outputs=listener_outputs, function=function) for di in range(predicted_softmax.size(1)): step_output = predicted_softmax[:, di, :] decode_results.append(step_output) else: speller_input = inputs[:, 0].unsqueeze(1) for di in range(max_length): predicted_softmax, hidden = self.forward_step( input_=speller_input, hidden=hidden, listener_outputs=listener_outputs, function=function) step_output = predicted_softmax.squeeze(1) decode_results.append(step_output) speller_input = decode_results[-1].topk(1)[1] logits = torch.stack(decode_results, dim=1).to(self.device) y_hats = logits.max(-1)[1] return y_hats, logits
def forward(self, inputs, encoder_outputs, function=F.log_softmax, teacher_forcing_ratio=0.90, use_beam_search=False): y_hats, logits = None, None decode_results = [] batch_size = inputs.size(0) max_len = inputs.size(1) - 1 # minus the start of sequence symbol decoder_hidden = torch.FloatTensor(self.n_layers, batch_size, self.hidden_size).uniform_( -0.1, 0.1).to(self.device) use_teacher_forcing = True if random.random( ) < teacher_forcing_ratio else False if use_beam_search: """ Beam-Search Decoding """ inputs = inputs[:, 0].unsqueeze(1) beam = Beam(k=self.k, decoder_hidden=decoder_hidden, decoder=self, batch_size=batch_size, max_len=max_len, function=function, device=self.device) y_hats = beam.search(inputs, encoder_outputs) else: if use_teacher_forcing: """ if teacher_forcing, Infer all at once """ inputs = inputs[:, :-1] predicted_softmax = self._forward_step( input=inputs, decoder_hidden=decoder_hidden, encoder_outputs=encoder_outputs, function=function) for di in range(predicted_softmax.size(1)): step_output = predicted_softmax[:, di, :] decode_results.append(step_output) else: input = inputs[:, 0].unsqueeze(1) for di in range(max_len): predicted_softmax = self._forward_step( input=input, decoder_hidden=decoder_hidden, encoder_outputs=encoder_outputs, function=function) step_output = predicted_softmax.squeeze(1) decode_results.append(step_output) input = decode_results[-1].topk(1)[1] logits = torch.stack(decode_results, dim=1).to(self.device) y_hats = logits.max(-1)[1] return y_hats, logits
def beam_search(self, source_tensor, beam_size): """ Use beam search to generate summaries one by one. :param source_tensor: (src seq len, batch size), batch size need to be 1. :param beam_size: beam search size :return: same as forward """ batch_size = source_tensor.size(1) assert batch_size == 1 # run encoder encoder_init_hidden = torch.zeros(1, batch_size, self.params.hidden_layer_units, device=device) encoder_word_embeddings = self.embedding(source_tensor) encoder_outputs, encoder_hidden = self.encoder(encoder_word_embeddings, encoder_init_hidden) # build batch of beam size and initialize states encoder_outputs = encoder_outputs.expand(-1, beam_size, -1).contiguous() decoder_hidden_cur_step = encoder_hidden.expand(-1, beam_size, -1).contiguous() be = Beam(beam_size, self.special_tokens) step = 0 while step <= self.params.max_dec_steps: decoder_input_cur_step = be.states[-1] decoder_cur_word_embedding = self.embedding(decoder_input_cur_step) decoder_output_cur_step, decoder_hidden_cur_step = self.decoder( decoder_cur_word_embedding, decoder_hidden_cur_step, encoder_outputs) if be.advance(decoder_output_cur_step): break step += 1 result_tokens = be.trace(0) output_token_idx = torch.tensor(result_tokens, device=device, dtype=torch.long).unsqueeze(1).expand( -1, batch_size) return output_token_idx, torch.tensor(0., device=device)
def beamsearch(memory, model, device, beam_size=4, candidates=1, max_seq_length=128, sos_token=1, eos_token=2): # memory: Tx1xE model.eval() beam = Beam(beam_size=beam_size, min_length=0, n_top=candidates, ranker=None, start_token_id=sos_token, end_token_id=eos_token) with torch.no_grad(): # memory = memory.repeat(1, beam_size, 1) # TxNxE memory = model.expand_memory(memory, beam_size) for _ in range(max_seq_length): tgt_inp = beam.get_current_state().transpose(0, 1).to(device) # TxN decoder_outputs, memory = model.transformer.forward_decoder( tgt_inp, memory) log_prob = log_softmax(decoder_outputs[:, -1, :].squeeze(0), dim=-1) beam.advance(log_prob.cpu()) if beam.done(): break scores, ks = beam.sort_finished(minimum=1) hypothesises = [] for i, (times, k) in enumerate(ks[:candidates]): hypothesis = beam.get_hypothesis(times, k) hypothesises.append(hypothesis) return [1] + [int(i) for i in hypothesises[0][:-1]]
def beam_generate(self,b,beamsz,k): if self.args.title: tencs,_ = self.tenc(b.src) tmask = self.maskFromList(tencs.size(),b.src[1]).unsqueeze(1) ents = b.ent entlens = ents[2] ents = self.le(ents) if self.graph: gents,glob,grels = self.ge(b.rel[0],b.rel[1],(ents,entlens)) hx = glob #hx = ents.max(dim=1)[0] keys,mask = grels mask = mask==0 else: mask = self.maskFromList(ents.size(),entlens) hx = ents.max(dim=1)[0] keys =ents mask = mask.unsqueeze(1) if self.args.plan: planlogits = self.splan.plan_decode(hx,keys,mask.clone(),entlens) print(planlogits.size()) sorder = ' '.join([str(x) for x in planlogits.max(1)[1][0].tolist()]) print(sorder) sorder = [x.strip() for x in sorder.split("-1")] sorder = [[int(y) for y in x.strip().split(" ")] for x in sorder] mask.fill_(0) planplace = torch.zeros(hx.size(0)).long() for i,m in enumerate(sorder): mask[i][0][m[0]]=1 else: planlogits = None cx = torch.tensor(hx) a = self.attn(hx.unsqueeze(1),keys,mask=mask).squeeze(1) if self.args.title: a2 = self.attn2(hx.unsqueeze(1),tencs,mask=tmask).squeeze(1) a = torch.cat((a,a2),1) outputs = [] outp = torch.LongTensor(ents.size(0),1).fill_(self.starttok).cuda() beam = None for i in range(self.maxlen): op = self.emb_w_vertex(outp.clone(),b.nerd) if self.args.plan: schange = op==self.args.dottok if schange.nonzero().size(0)>0: print(schange, planplace, sorder) planplace[schange.nonzero().squeeze()]+=1 for j in schange.nonzero().squeeze(1): if planplace[j]<len(sorder[j]): mask[j] = 0 m = sorder[j][planplace[j]] mask[j][0][sorder[j][planplace[j]]]=1 op = self.emb(op).squeeze(1) prev = torch.cat((a,op),1) hx,cx = self.lstm(prev,(hx,cx)) a = self.attn(hx.unsqueeze(1),keys,mask=mask).squeeze(1) if self.args.title: a2 = self.attn2(hx.unsqueeze(1),tencs,mask=tmask).squeeze(1) #a = a + (self.mix(hx)*a2) a = torch.cat((a,a2),1) l = torch.cat((hx,a),1).unsqueeze(1) s = torch.sigmoid(self.switch(l)) o = self.out(l) o = torch.softmax(o,2) o = s*o #compute copy attn _, z = self.mattn(l,(ents,entlens)) #z = torch.softmax(z,2) z = (1-s)*z o = torch.cat((o,z),2) o[:,:,0].fill_(0) o[:,:,1].fill_(0) ''' if beam: for p,q in enumerate(beam.getPrevEnt()): o[p,:,q].fill_(0) for p,q in beam.getIsStart(): for r in q: o[p,:,r].fill_(0) ''' o = o+(1e-6*torch.ones_like(o)) decoded = o.log() scores, words = decoded.topk(dim=2,k=k) if not beam: beam = Beam(words.squeeze(),scores.squeeze(),[hx for i in range(beamsz)], [cx for i in range(beamsz)],[a for i in range(beamsz)],beamsz,k,self.args.ntoks) beam.endtok = self.endtok beam.eostok = self.eostok keys = keys.repeat(len(beam.beam),1,1) mask = mask.repeat(len(beam.beam),1,1) if self.args.title: tencs = tencs.repeat(len(beam.beam),1,1) tmask = tmask.repeat(len(beam.beam),1,1) if self.args.plan: planplace= planplace.unsqueeze(0).repeat(len(beam.beam),1) sorder = sorder*len(beam.beam) ents = ents.repeat(len(beam.beam),1,1) entlens = entlens.repeat(len(beam.beam)) else: if not beam.update(scores,words,hx,cx,a): break keys = keys[:len(beam.beam)] mask = mask[:len(beam.beam)] if self.args.title: tencs = tencs[:len(beam.beam)] tmask = tmask[:len(beam.beam)] if self.args.plan: planplace= planplace[:len(beam.beam)] sorder = sorder[0]*len(beam.beam) ents = ents[:len(beam.beam)] entlens = entlens[:len(beam.beam)] outp = beam.getwords() hx = beam.geth() cx = beam.getc() a = beam.getlast() return beam
def beam_generate(self, batch, beam_size, k): if self.args.title: title_encoding, _ = self.title_encoder(batch.title) # (1, title_len, 500) title_mask = self.create_mask(title_encoding.size(), batch.title[1]).unsqueeze(1) # (1, 1, title_len) ents = batch.ent ent_num_list = ents[2] ents = self.entity_encoder(ents) # ents: (1, entity num, 500) / encoded hidden state of entities in b glob, node_embeddings, node_mask = self.graph_encoder(batch.rel[0], batch.rel[1], (ents, ent_num_list)) hx = glob node_mask = node_mask == 0 node_mask = node_mask.unsqueeze(1) cx = hx.clone().detach().requires_grad_(True) # (beam size, 500) context = self.attention_graph(hx.unsqueeze(1), node_embeddings, mask=node_mask).squeeze(1) # (1, 500) / c_g if self.args.title: title_context = self.attention_title(hx.unsqueeze(1), title_encoding, mask=title_mask).squeeze(1) # (1, 500) / c_s context = torch.cat((context, title_context), 1) # (beam size, 1000) / cat of c_g, c_s recent_token = torch.LongTensor(ents.size(0), 1).fill_(self.starttok).to(self.args.device) # recent_token: initially, (1, 1) / start token beam = None for i in range(self.args.maxlen): op = self.trim_entity_index(recent_token.clone(), batch.nerd) # previous token을 만들기 위해 tag로부터 index를 없애는 작업 # 없애는 이유: train할 때도 output vocab (input) => target vocab (output) 으로 train 되었기 때문. op = self.embed(op).squeeze(1) # (beam size, 500) prev = torch.cat((context, op), 1) # (beam size, 1000) hx, cx = self.decoder(prev, (hx, cx)) context = self.attention_graph(hx.unsqueeze(1), node_embeddings, mask=node_mask).squeeze(1) if self.args.title: title_context = self.attention_title(hx.unsqueeze(1), title_encoding, mask=title_mask).squeeze(1) context = torch.cat((context, title_context), 1) # (beam size, 1000) total_context = torch.cat((hx, context), 1).unsqueeze(1) # (beam size, 1, 1500) sampling_prob = torch.sigmoid(self.switch(total_context)) # (beam size, 1, 1) out = self.out(total_context) # (beam size, 1, target vocab size) out = torch.softmax(out, 2) out = sampling_prob * out # compute copy attention z = self.mat_attention(total_context, (ents, ent_num_list)) # z: (1, max_abstract_len, max_entity_num) / entities attended on each abstract word, then softmaxed z = (1 - sampling_prob) * z out = torch.cat((out, z), 2) # (beam size, 1, target vocab size + entity num) out[:, :, 0].fill_(0) # remove probability for special tokens <unk>, <init> out[:, :, 1].fill_(0) out = out + (1e-6 * torch.ones_like(out)) decoded = out.log() scores, words = decoded.topk(dim=2, k=k) # (beam size, 1, k), (beam size, 1, k) if not beam: beam = Beam(words.squeeze(), scores.squeeze(), [hx] * beam_size, [cx] * beam_size, [context] * beam_size, beam_size, k, self.args.output_vocab_size, self.args.device) beam.endtok = self.endtok beam.eostok = self.eostok node_embeddings = node_embeddings.repeat(len(beam.beam), 1, 1) # (beam size, adjacency matrix len, 500) node_mask = node_mask.repeat(len(beam.beam), 1, 1) # (beam size, 1, adjacency matrix len) => all 0? if self.args.title: title_encoding = title_encoding.repeat(len(beam.beam), 1, 1) # (beam size, title_len, 500) title_mask = title_mask.repeat(len(beam.beam), 1, 1) # (1, 1, title_len) ents = ents.repeat(len(beam.beam), 1, 1) # (beam size, entity num, 500) ent_num_list = ent_num_list.repeat(len(beam.beam)) # (beam size,) else: # if all beam nodes have ended, stop generating if not beam.update(scores, words, hx, cx, context): break # if beam size changes (i.e. any of beam ends), change size of weight matrices accordingly node_embeddings = node_embeddings[:len(beam.beam)] node_mask = node_mask[:len(beam.beam)] if self.args.title: title_encoding = title_encoding[:len(beam.beam)] title_mask = title_mask[:len(beam.beam)] ents = ents[:len(beam.beam)] ent_num_list = ent_num_list[:len(beam.beam)] recent_token = beam.getwords() # (beam size,) / next word for each beam hx = beam.get_h() # (beam size, 500) cx = beam.get_c() # (beam size, 500) context = beam.get_context() # (beam size, 1000) return beam
def beam_sample(self, src, src_mask, segment_ids, src_len, beam_size=1, eval_=False): # (1) Run the encoder on the src. lengths, indices = torch.sort(src_len, dim=0, descending=True) _, ind = torch.sort(indices) src = torch.index_select(src, dim=0, index=indices) if self.bert_model is None: src = src.t() batch_size = src.size(1) else: batch_size = src.size(0) contexts, encState = self.encoder(src, src_mask, segment_ids, lengths.tolist()) # (1b) Initialize for the decoder. def var(a): return a.clone().detach().requires_grad_(False) def rvar(a): return var(a.repeat(1, beam_size, 1)) def bottle(m): return m.view(batch_size * beam_size, -1) def unbottle(m): return m.view(beam_size, batch_size, -1) if self.config.sgm.cell == 'lstm': decState = (rvar(encState[0]), rvar(encState[1])) else: decState = rvar(encState) beam = [ Beam(beam_size, n_best=1, cuda=self.use_cuda, length_norm=self.config.sgm.length_norm) for __ in range(batch_size) ] # if self.decoder.attention is not None: # self.decoder.attention.init_context(contexts) if self.decoder.attention is not None: if self.bert_model is None: # Repeat everything beam_size times. contexts = rvar(contexts) contexts = contexts.transpose(0, 1) contexts = var(contexts.repeat(beam_size, 1, 1)) self.decoder.attention.init_context(context=contexts) # (2) run the decoder to generate sentences, using beam search. for i in range(self.config.sgm.max_time_step): if all((b.done() for b in beam)): break # Construct batch x beam_size nxt words. # Get all the pending current beam words and arrange for forward. inp = var( torch.stack([b.getCurrentState() for b in beam]).t().contiguous().view(-1)) # if self.bert_model is None: # inp = var(torch.stack([b.getCurrentState() for b in beam]) # .t().contiguous().view(-1)) # else: # inp = var(torch.stack([b.getCurrentState() for b in beam]).contiguous().view(-1)) # Run one step. output, decState, attn = self.decoder(inp, decState) # decOut: beam x rnn_size # (b) Compute a vector of batch*beam word scores. output = unbottle(self.log_softmax(output)) attn = unbottle(attn) # beam x tgt_vocab # (c) Advance each beam. # update state for j, b in enumerate(beam): b.advance(output[:, j], attn[:, j]) if self.config.sgm.cell == 'lstm': b.beam_update(decState, j) else: b.beam_update_gru(decState, j) # (3) Package everything up. allHyps, allScores, allAttn = [], [], [] if eval_: allWeight = [] for j in ind: b = beam[j] n_best = 1 scores, ks = b.sortFinished(minimum=n_best) hyps, attn = [], [] if eval_: weight = [] for i, (times, k) in enumerate(ks[:n_best]): hyp, att = b.getHyp(times, k) hyps.append(hyp) attn.append(att.max(1)[1]) if eval_: weight.append(att) allHyps.append(hyps[0]) allScores.append(scores[0]) allAttn.append(attn[0]) if eval_: allWeight.append(weight[0]) if eval_: return allHyps, allAttn, allWeight return allHyps, allAttn
def beam_generate(self, b, beamsz, k): if self.args.title: tencs, _ = self.tenc(b.src) # (1, title_len, 500) tmask = self.maskFromList(tencs.size(), b.src[1]).unsqueeze( 1) # (1, 1, title_len) ents = b.ent # tuple of (ent, phlens, elens) entlens = ents[2] ents = self.le(ents) # ents: (1, entity num, 500) / encoded hidden state of entities in b if self.graph: gents, glob, grels = self.ge(b.rel[0], b.rel[1], (ents, entlens)) hx = glob # hx = ents.max(dim=1)[0] keys, mask = grels mask = mask == 0 else: mask = self.maskFromList(ents.size(), entlens) hx = ents.max(dim=1)[0] keys = ents mask = mask.unsqueeze(1) if self.args.plan: planlogits = self.splan.plan_decode(hx, keys, mask.clone(), entlens) print(planlogits.size()) sorder = ' '.join( [str(x) for x in planlogits.max(1)[1][0].tolist()]) print(sorder) sorder = [x.strip() for x in sorder.split("-1")] sorder = [[int(y) for y in x.strip().split(" ")] for x in sorder] mask.fill_(0) planplace = torch.zeros(hx.size(0)).long() for i, m in enumerate(sorder): mask[i][0][m[0]] = 1 else: planlogits = None cx = torch.tensor(hx) # (beam size, 500) a = self.attn(hx.unsqueeze(1), keys, mask=mask).squeeze(1) # (1, 500) / c_g if self.args.title: a2 = self.attn2(hx.unsqueeze(1), tencs, mask=tmask).squeeze(1) # (1, 500) / c_s a = torch.cat((a, a2), 1) # (beam size, 1000) / c_t outp = torch.LongTensor(ents.size(0), 1).fill_( self.starttok).cuda() # initially, (1, 1) / start token beam = None for i in range(self.maxlen): op = self.emb_w_vertex(outp.clone(), b.nerd) # tag로부터 index를 없애는 작업 => train할 때도 not indexed => indexed 로 train 되었기 때문. if self.args.plan: schange = op == self.args.dottok if schange.nonzero().size(0) > 0: print(schange, planplace, sorder) planplace[schange.nonzero().squeeze()] += 1 for j in schange.nonzero().squeeze(1): if planplace[j] < len(sorder[j]): mask[j] = 0 m = sorder[j][planplace[j]] mask[j][0][sorder[j][planplace[j]]] = 1 op = self.emb(op).squeeze(1) # (beam size, 500) prev = torch.cat((a, op), 1) # (beam size, 1000) hx, cx = self.lstm(prev, (hx, cx)) a = self.attn(hx.unsqueeze(1), keys, mask=mask).squeeze(1) if self.args.title: a2 = self.attn2(hx.unsqueeze(1), tencs, mask=tmask).squeeze(1) a = torch.cat((a, a2), 1) # (beam size, 1000) l = torch.cat((hx, a), 1).unsqueeze(1) # (beam size, 1, 1500) s = torch.sigmoid(self.switch(l)) # (beam size, 1, 1) o = self.out(l) # (beam size, 1, target vocab size) o = torch.softmax(o, 2) o = s * o # compute copy attn _, z = self.mattn(l, (ents, entlens)) # z = torch.softmax(z,2) z = (1 - s) * z o = torch.cat((o, z), 2) # (beam size, 1, target vocab size + entity num) o[:, :, 0].fill_( 0) # remove probability for special tokens <unk>, <init> o[:, :, 1].fill_(0) ''' if beam: for p,q in enumerate(beam.getPrevEnt()): o[p,:,q].fill_(0) for p,q in beam.getIsStart(): for r in q: o[p,:,r].fill_(0) ''' o = o + (1e-6 * torch.ones_like(o)) decoded = o.log() scores, words = decoded.topk( dim=2, k=k) # (beam size, 1, k), (beam size, 1, k) if not beam: beam = Beam(words.squeeze(), scores.squeeze(), [hx for i in range(beamsz)], [cx for i in range(beamsz)], [a for i in range(beamsz)], beamsz, k, self.args.ntoks) beam.endtok = self.endtok beam.eostok = self.eostok keys = keys.repeat(len(beam.beam), 1, 1) # (beam size, adjacency matrix len, 500) mask = mask.repeat( len(beam.beam), 1, 1) # (beam size, 1, adjacency matrix len) => all 0? if self.args.title: tencs = tencs.repeat(len(beam.beam), 1, 1) # (beam size, title_len, 500) tmask = tmask.repeat(len(beam.beam), 1, 1) # (1, 1, title_len) if self.args.plan: planplace = planplace.unsqueeze(0).repeat( len(beam.beam), 1) sorder = sorder * len(beam.beam) ents = ents.repeat(len(beam.beam), 1, 1) # (beam size, entity num, 500) entlens = entlens.repeat(len(beam.beam)) # (beam size,) else: if not beam.update(scores, words, hx, cx, a): break # if beam size changes (i.e. any of beam ends), change size of weight matrices accordingly keys = keys[:len(beam.beam)] mask = mask[:len(beam.beam)] if self.args.title: tencs = tencs[:len(beam.beam)] tmask = tmask[:len(beam.beam)] if self.args.plan: planplace = planplace[:len(beam.beam)] sorder = sorder[0] * len(beam.beam) ents = ents[:len(beam.beam)] entlens = entlens[:len(beam.beam)] outp = beam.getwords() # (beam size,) / next word for each beam hx = beam.geth() # (beam size, 500) cx = beam.getc() # (beam size, 500) a = beam.getlast() # (beam size, 1000) return beam
def forward(self, inputs=None, listener_hidden=None, listener_outputs=None, function=F.log_softmax, teacher_forcing_ratio=0.99): y_hats, logit = None, None decode_results = [] # Validate Arguments batch_size = inputs.size(0) max_length = inputs.size(1) - 1 # minus the start of sequence symbol # Initiate Speller Hidden State to zeros : LxBxH speller_hidden = torch.FloatTensor(self.layer_size, batch_size, self.hidden_size).uniform_( -1.0, 1.0) #.cuda() # Decide Use Teacher Forcing or Not use_teacher_forcing = True if random.random( ) < teacher_forcing_ratio else False if self.use_beam_search: """Implementation of Beam-Search Decoding""" speller_input = inputs[:, 0].unsqueeze(1) beam = Beam(k=self.k, speller_hidden=speller_hidden, decoder=self, batch_size=batch_size, max_len=max_length, decode_func=function) y_hats = beam.search(speller_input, listener_outputs) else: # Manual unrolling is used to support random teacher forcing. # If teacher_forcing_ratio is True or False instead of a probability, the unrolling can be done in graph if use_teacher_forcing: speller_input = inputs[:, :-1] # except </s> """ if teacher_forcing, Infer all at once """ predicted_softmax = self._forward_step(speller_input, speller_hidden, listener_outputs, function=function) """Extract Output by Step""" for di in range(predicted_softmax.size(1)): step_output = predicted_softmax[:, di, :] decode_results.append(step_output) else: speller_input = inputs[:, 0].unsqueeze(1) for di in range(max_length): predicted_softmax = self._forward_step(speller_input, speller_hidden, listener_outputs, function=function) # (batch_size, classfication_num) step_output = predicted_softmax.squeeze(1) decode_results.append(step_output) speller_input = decode_results[-1].topk(1)[1] logit = torch.stack(decode_results, dim=1).to(self.device) y_hats = logit.max(-1)[1] print("Speller y_hats ====================") print(y_hats) return y_hats, logit if self.training else y_hats