def hierarchical_attention(query, key, value, v, hierarchical_key, hierarchical_v, hierarchical_align, hierarchical_bias=None, bias=None, mem_mask=None, hierarchical_length=None, max_para_num=0): """ query[(Bs), B, D], key[B, T, D], value[B, T, D]""" if len(query.size()) == 2: score_node = badanau_attention_score(query, key, v, bias) score_para = badanau_attention_score(query, hierarchical_key, hierarchical_v, hierarchical_bias) hierarchical_mask = len_mask(hierarchical_length, score_para.device).unsqueeze(-2) norm_score_para = prob_normalize(score_para, hierarchical_mask) norm_score_node = prob_normalize(score_node, mem_mask) nq = score_para.size(1) hierarchical_align = hierarchical_align.unsqueeze(1).repeat(1, nq, 1) score_para_node = norm_score_para.gather(2, hierarchical_align) norm_score_node = torch.mul(norm_score_node, score_para_node) norm_score = norm_score_node / norm_score_node.sum( dim=-1).unsqueeze(-1) output = attention_aggregate(value, norm_score) elif len(query.size()) == 3: # for batch decoding score_node = badanau_attention_score(query, key, v, bias) score_para = badanau_attention_score(query, hierarchical_key, hierarchical_v, hierarchical_bias) hierarchical_mask = len_mask(hierarchical_length, score_para.device, max_num=max_para_num).unsqueeze( -2).unsqueeze(0).expand_as(score_para) norm_score_para = prob_normalize(score_para, hierarchical_mask) norm_score_node = prob_normalize( score_node, mem_mask.unsqueeze(0).expand_as(score_node)) nq = score_para.size(2) beam = score_para.size(0) hierarchical_align = hierarchical_align.unsqueeze(1).unsqueeze( 0).repeat(beam, 1, nq, 1) score_para_node = norm_score_para.gather(-1, hierarchical_align) norm_score_node = torch.mul(norm_score_node, score_para_node) norm_score = norm_score_node / norm_score_node.sum( dim=-1).unsqueeze(-1) output = attention_aggregate(value, norm_score) return output.squeeze(-2), norm_score.squeeze(-2)
def sample(self, article, art_lens, extend_art, extend_vsize, go, eos, unk, max_len): """ greedy decode support batching""" batch_size = len(art_lens) vsize = self._embedding.num_embeddings attention, init_dec_states = self.encode(article, art_lens) mask = len_mask(art_lens, attention.device).unsqueeze(-2) attention = (attention, mask, extend_art, extend_vsize) tok = torch.LongTensor([go]*batch_size).to(article.device) outputs = [] attns = [] states = init_dec_states seqLogProbs = [] for i in range(max_len): tok, states, attn_score, sampleProb = self._decoder.sample_step( tok, states, attention) #print('sample tok:', tok) if i == 0: unfinished = (tok != END) else: it = tok * unfinished.type_as(tok) unfinished = unfinished * (it != END) attns.append(attn_score.detach()) if i == 0: outputs.append(tok[:, 0].clone()) else: outputs.append(it[:, 0].clone()) tok = tok.masked_fill(tok >= vsize, unk) seqLogProbs.append(sampleProb) if unfinished.data.sum() == 0: break return outputs, attns, seqLogProbs
def forward(self, article, art_lens, abstract, extend_art, extend_vsize): print() attention, init_dec_states = self.encode(article, art_lens) mask = len_mask(art_lens, attention.device).unsqueeze(-2) logit = self._decoder( (attention, mask, extend_art, extend_vsize), init_dec_states, abstract ) return logit
def forward(self, nodes, mask=None, _input=None, _sents=None, sent_nums=None): if self._type == 'gold': assert mask is not None nodes = mask.unsqueeze(2) * nodes return nodes, mask elif self._type == 'none': nodes = nodes bs, ns, fs = nodes.size() mask = torch.ones(bs, ns) return nodes, mask elif self._type == 'soft': assert _input is not None mask = F.sigmoid(self._mask(_input)) # B * N * 1 nodes = nodes * mask return nodes, mask elif self._type == 'soft+sent': assert _input is not None and _sents is not None # bs, sn, ds = _sents.size() noden = _input.size(1) # sents = _sents.unsqueeze(1).repeat(1, noden, 1, 1) # _input = _input.unsqueeze(2).repeat(1, 1, sn, 1) # attention = self._bi_attn(_input, sents).squeeze(3) _nodes_feat = torch.matmul(_input, self._bi_attn.unsqueeze(0)) attention = torch.matmul(_nodes_feat, _sents.permute(0, 2, 1)) if sent_nums is not None: sent_mask = len_mask(sent_nums, _input.device).unsqueeze(1).repeat( 1, noden, 1) attention = attention.masked_fill(sent_mask == 0, -1e18) score = F.softmax(attention, dim=-1) #attention = sent_mask.unsqueeze(1) * attention else: score = F.softmax(attention, dim=-1) weights = torch.matmul(score, _sents) output = torch.matmul( weights, self._wc.unsqueeze(0)) + torch.matmul( _input, self._wh.unsqueeze(0)) # B * N * emb_dim mask = F.sigmoid( torch.matmul(F.tanh(output), self._v.unsqueeze(0).unsqueeze(2))) nodes = nodes * mask return nodes, mask else: raise Exception('Not Implemented yet')
def greedy(self, article, art_lens, extend_art, extend_vsize, go, eos, unk, max_len, min_len=0, tar_in=None): """ greedy decode support batching""" batch_size = len(art_lens) vsize = self._embedding.num_embeddings attention, init_dec_states = self.encode(article, art_lens) mask = len_mask(art_lens, attention.device).unsqueeze(-2) attention = (attention, mask, extend_art, extend_vsize) tok = torch.LongTensor([go] * batch_size).to(article.device) outputs = [] attns = [] states = init_dec_states for i in range(max_len): if i > min_len - 1: force_not_stop = False else: force_not_stop = True tok, states, attn_score = self._decoder.decode_step( tok, states, attention, force_not_stop=force_not_stop, eos=self._eos) #print('greedy tok:', tok.size()) if i == 0: unfinished = (tok != eos) #print('greedy tok:', tok) else: it = tok * unfinished.type_as(tok) unfinished = unfinished * (it != eos) attns.append(attn_score) if i == 0: outputs.append(tok[:, 0].clone()) else: outputs.append(it[:, 0].clone()) tok.masked_fill_(tok >= vsize, unk) if unfinished.data.sum() == 0: break return outputs, attns
def batch_decode(self, article, art_lens, extend_art, extend_vsize, go, eos, unk, max_len): """ greedy decode support batching""" batch_size = len(art_lens) vsize = self._embedding.num_embeddings attention, init_dec_states = self.encode(article, art_lens) mask = len_mask(art_lens, attention.device).unsqueeze(-2) attention = (attention, mask, extend_art, extend_vsize) tok = torch.LongTensor([go]*batch_size).to(article.device) outputs = [] attns = [] states = init_dec_states for i in range(max_len): tok, states, attn_score = self._decoder.decode_step( tok, states, attention) attns.append(attn_score) outputs.append(tok[:, 0].clone()) tok.masked_fill_(tok >= vsize, unk) return outputs, attns
def batched_beamsearch_cnn(self, article, art_lens, extend_art, extend_vsize, go, eos, unk, max_len, beam_size, diverse=1.0, min_len=35): batch_size = len(art_lens) vsize = self._embedding.num_embeddings attention, init_dec_states = self.encode(article, art_lens) mask = len_mask(art_lens, attention.device).unsqueeze(-2) all_attention = (attention, mask, extend_art, extend_vsize) attention = all_attention (h, c), prev = init_dec_states all_beams = [bs.init_beam(go, (h[:, i, :], c[:, i, :], prev[i])) for i in range(batch_size)] finished_beams = [[] for _ in range(batch_size)] outputs = [None for _ in range(batch_size)] for t in range(max_len): toks = [] all_states = [] for beam in filter(bool, all_beams): token, states = bs.pack_beam(beam, article.device) toks.append(token) all_states.append(states) token = torch.stack(toks, dim=1) states = ((torch.stack([h for (h, _), _ in all_states], dim=2), torch.stack([c for (_, c), _ in all_states], dim=2)), torch.stack([prev for _, prev in all_states], dim=1)) token.masked_fill_(token >= vsize, unk) if t < min_len: force_not_stop = True else: force_not_stop = False topk, lp, states, attn_score = self._decoder.topk_step( token, states, attention, beam_size, force_not_stop=force_not_stop) batch_i = 0 for i, (beam, finished) in enumerate(zip(all_beams, finished_beams)): if not beam: continue finished, new_beam = bs.next_search_beam_cnn( beam, beam_size, finished, eos, topk[:, batch_i, :], lp[:, batch_i, :], (states[0][0][:, :, batch_i, :], states[0][1][:, :, batch_i, :], states[1][:, batch_i, :]), attn_score[:, batch_i, :], diverse ) batch_i += 1 if len(finished) >= beam_size: all_beams[i] = [] outputs[i] = finished[:beam_size] # exclude finished inputs (attention, mask, extend_art, extend_vsize ) = all_attention masks = [mask[j] for j, o in enumerate(outputs) if o is None] ind = [j for j, o in enumerate(outputs) if o is None] ind = torch.LongTensor(ind).to(attention.device) attention, extend_art = map( lambda v: v.index_select(dim=0, index=ind), [attention, extend_art] ) if masks: mask = torch.stack(masks, dim=0) else: mask = None attention = ( attention, mask, extend_art, extend_vsize) else: all_beams[i] = new_beam finished_beams[i] = finished if all(outputs): break else: for i, (o, f, b) in enumerate(zip(outputs, finished_beams, all_beams)): if o is None: outputs[i] = (f+b)[:beam_size] return outputs
def forward(self, raw_article_sents): # for RL abs dec_args, id2word = self._prepro(raw_article_sents) article, art_lens, extend_art, extend_vsize, go, eos, unk, max_len = dec_args def argmax(arr, keys): return arr[max(range(len(arr)), key=lambda i: keys[i].item())] #-------gready decode------ with torch.no_grad(): self._net.eval() decs, attns = self._net.batch_decode(*dec_args) # gready decode dec_sents_greedy = [] for i, raw_words in enumerate(raw_article_sents): dec = [] for id_, attn in zip(decs, attns): if id_[i] == END: break elif id_[i] == UNK: dec.append(argmax(raw_words, attn[i])) else: dec.append(id2word[id_[i].item()]) dec_sents_greedy.append(dec) #--------batch sampling------ if self.training: self._net.train() batch_size = len(art_lens) vsize = self._net._embedding.num_embeddings attention, init_dec_states = self._net.encode(article, art_lens) mask = len_mask(art_lens, attention.device).unsqueeze(-2) attention = (attention, mask, extend_art, extend_vsize) lstm_in = torch.LongTensor([go] * batch_size).to(article.device) outputs = [] attns = [] dists = [] states = init_dec_states for i in range(max_len): tok, states, attn_score, logit = self._net._decoder.decode_step( lstm_in, states, attention, return_logit=True) prob = F.softmax(logit, dim=-1) out = torch.multinomial(prob, 1).detach() if i == 0: flag = (out != eos) else: flag = flag * (out != eos) dist = torch.log(prob.gather(1, out)) dists.append(dist) attns.append(attn_score) outputs.append(out[:, 0].clone()) if flag.sum().item() == 0: break lstm_in = out.masked_fill(out >= vsize, unk) output_word_batch = [] dist_batch = [] for i, raw_words in enumerate(raw_article_sents): words = [] diss = [] for id_, attn, dis in zip(outputs, attns, dists): diss.append(dis[i:i + 1]) if id_[i] == END: break elif id_[i] == UNK: words.append(argmax(raw_words, attn[i])) else: words.append(id2word[id_[i].item()]) output_word_batch.append(words) dist_batch.append(sum(diss)) if self.training: return dec_sents_greedy, output_word_batch, dist_batch else: return dec_sents_greedy