Exemplo n.º 1
0
def hierarchical_attention(query,
                           key,
                           value,
                           v,
                           hierarchical_key,
                           hierarchical_v,
                           hierarchical_align,
                           hierarchical_bias=None,
                           bias=None,
                           mem_mask=None,
                           hierarchical_length=None,
                           max_para_num=0):
    """ query[(Bs), B, D], key[B, T, D], value[B, T, D]"""
    if len(query.size()) == 2:
        score_node = badanau_attention_score(query, key, v, bias)
        score_para = badanau_attention_score(query, hierarchical_key,
                                             hierarchical_v, hierarchical_bias)
        hierarchical_mask = len_mask(hierarchical_length,
                                     score_para.device).unsqueeze(-2)
        norm_score_para = prob_normalize(score_para, hierarchical_mask)
        norm_score_node = prob_normalize(score_node, mem_mask)
        nq = score_para.size(1)
        hierarchical_align = hierarchical_align.unsqueeze(1).repeat(1, nq, 1)
        score_para_node = norm_score_para.gather(2, hierarchical_align)
        norm_score_node = torch.mul(norm_score_node, score_para_node)
        norm_score = norm_score_node / norm_score_node.sum(
            dim=-1).unsqueeze(-1)
        output = attention_aggregate(value, norm_score)
    elif len(query.size()) == 3:
        # for batch decoding
        score_node = badanau_attention_score(query, key, v, bias)
        score_para = badanau_attention_score(query, hierarchical_key,
                                             hierarchical_v, hierarchical_bias)
        hierarchical_mask = len_mask(hierarchical_length,
                                     score_para.device,
                                     max_num=max_para_num).unsqueeze(
                                         -2).unsqueeze(0).expand_as(score_para)
        norm_score_para = prob_normalize(score_para, hierarchical_mask)
        norm_score_node = prob_normalize(
            score_node,
            mem_mask.unsqueeze(0).expand_as(score_node))
        nq = score_para.size(2)
        beam = score_para.size(0)
        hierarchical_align = hierarchical_align.unsqueeze(1).unsqueeze(
            0).repeat(beam, 1, nq, 1)
        score_para_node = norm_score_para.gather(-1, hierarchical_align)

        norm_score_node = torch.mul(norm_score_node, score_para_node)
        norm_score = norm_score_node / norm_score_node.sum(
            dim=-1).unsqueeze(-1)
        output = attention_aggregate(value, norm_score)

    return output.squeeze(-2), norm_score.squeeze(-2)
Exemplo n.º 2
0
 def sample(self, article, art_lens, extend_art, extend_vsize,
                  go, eos, unk, max_len):
     """ greedy decode support batching"""
     batch_size = len(art_lens)
     vsize = self._embedding.num_embeddings
     attention, init_dec_states = self.encode(article, art_lens)
     mask = len_mask(art_lens, attention.device).unsqueeze(-2)
     attention = (attention, mask, extend_art, extend_vsize)
     tok = torch.LongTensor([go]*batch_size).to(article.device)
     outputs = []
     attns = []
     states = init_dec_states
     seqLogProbs = []
     for i in range(max_len):
         tok, states, attn_score, sampleProb = self._decoder.sample_step(
             tok, states, attention)
         #print('sample tok:', tok)
         if i == 0:
             unfinished = (tok != END)
         else:
             it = tok * unfinished.type_as(tok)
             unfinished = unfinished * (it != END)
         attns.append(attn_score.detach())
         if i == 0:
             outputs.append(tok[:, 0].clone())
         else:
             outputs.append(it[:, 0].clone())
         tok = tok.masked_fill(tok >= vsize, unk)
         seqLogProbs.append(sampleProb)
         if unfinished.data.sum() == 0:
             break
     return outputs, attns, seqLogProbs
Exemplo n.º 3
0
 def forward(self, article, art_lens, abstract, extend_art, extend_vsize):
     print()
     attention, init_dec_states = self.encode(article, art_lens)
     mask = len_mask(art_lens, attention.device).unsqueeze(-2)
     logit = self._decoder(
         (attention, mask, extend_art, extend_vsize),
         init_dec_states, abstract
     )
     return logit
Exemplo n.º 4
0
    def forward(self,
                nodes,
                mask=None,
                _input=None,
                _sents=None,
                sent_nums=None):
        if self._type == 'gold':
            assert mask is not None
            nodes = mask.unsqueeze(2) * nodes
            return nodes, mask
        elif self._type == 'none':
            nodes = nodes
            bs, ns, fs = nodes.size()
            mask = torch.ones(bs, ns)
            return nodes, mask
        elif self._type == 'soft':
            assert _input is not None
            mask = F.sigmoid(self._mask(_input))  # B * N * 1
            nodes = nodes * mask
            return nodes, mask
        elif self._type == 'soft+sent':
            assert _input is not None and _sents is not None
            # bs, sn, ds = _sents.size()
            noden = _input.size(1)
            # sents = _sents.unsqueeze(1).repeat(1, noden, 1, 1)
            # _input = _input.unsqueeze(2).repeat(1, 1, sn, 1)
            # attention = self._bi_attn(_input, sents).squeeze(3)

            _nodes_feat = torch.matmul(_input, self._bi_attn.unsqueeze(0))
            attention = torch.matmul(_nodes_feat, _sents.permute(0, 2, 1))

            if sent_nums is not None:
                sent_mask = len_mask(sent_nums,
                                     _input.device).unsqueeze(1).repeat(
                                         1, noden, 1)
                attention = attention.masked_fill(sent_mask == 0, -1e18)
                score = F.softmax(attention, dim=-1)
                #attention = sent_mask.unsqueeze(1) * attention
            else:
                score = F.softmax(attention, dim=-1)
            weights = torch.matmul(score, _sents)
            output = torch.matmul(
                weights, self._wc.unsqueeze(0)) + torch.matmul(
                    _input, self._wh.unsqueeze(0))  # B * N * emb_dim
            mask = F.sigmoid(
                torch.matmul(F.tanh(output),
                             self._v.unsqueeze(0).unsqueeze(2)))
            nodes = nodes * mask
            return nodes, mask
        else:
            raise Exception('Not Implemented yet')
Exemplo n.º 5
0
 def greedy(self,
            article,
            art_lens,
            extend_art,
            extend_vsize,
            go,
            eos,
            unk,
            max_len,
            min_len=0,
            tar_in=None):
     """ greedy decode support batching"""
     batch_size = len(art_lens)
     vsize = self._embedding.num_embeddings
     attention, init_dec_states = self.encode(article, art_lens)
     mask = len_mask(art_lens, attention.device).unsqueeze(-2)
     attention = (attention, mask, extend_art, extend_vsize)
     tok = torch.LongTensor([go] * batch_size).to(article.device)
     outputs = []
     attns = []
     states = init_dec_states
     for i in range(max_len):
         if i > min_len - 1:
             force_not_stop = False
         else:
             force_not_stop = True
         tok, states, attn_score = self._decoder.decode_step(
             tok,
             states,
             attention,
             force_not_stop=force_not_stop,
             eos=self._eos)
         #print('greedy tok:', tok.size())
         if i == 0:
             unfinished = (tok != eos)
             #print('greedy tok:', tok)
         else:
             it = tok * unfinished.type_as(tok)
             unfinished = unfinished * (it != eos)
         attns.append(attn_score)
         if i == 0:
             outputs.append(tok[:, 0].clone())
         else:
             outputs.append(it[:, 0].clone())
         tok.masked_fill_(tok >= vsize, unk)
         if unfinished.data.sum() == 0:
             break
     return outputs, attns
Exemplo n.º 6
0
 def batch_decode(self, article, art_lens, extend_art, extend_vsize,
                  go, eos, unk, max_len):
     """ greedy decode support batching"""
     batch_size = len(art_lens)
     vsize = self._embedding.num_embeddings
     attention, init_dec_states = self.encode(article, art_lens)
     mask = len_mask(art_lens, attention.device).unsqueeze(-2)
     attention = (attention, mask, extend_art, extend_vsize)
     tok = torch.LongTensor([go]*batch_size).to(article.device)
     outputs = []
     attns = []
     states = init_dec_states
     for i in range(max_len):
         tok, states, attn_score = self._decoder.decode_step(
             tok, states, attention)
         attns.append(attn_score)
         outputs.append(tok[:, 0].clone())
         tok.masked_fill_(tok >= vsize, unk)
     return outputs, attns
Exemplo n.º 7
0
    def batched_beamsearch_cnn(self, article, art_lens,
                           extend_art, extend_vsize,
                           go, eos, unk, max_len, beam_size, diverse=1.0, min_len=35):
        batch_size = len(art_lens)
        vsize = self._embedding.num_embeddings
        attention, init_dec_states = self.encode(article, art_lens)
        mask = len_mask(art_lens, attention.device).unsqueeze(-2)
        all_attention = (attention, mask, extend_art, extend_vsize)
        attention = all_attention
        (h, c), prev = init_dec_states
        all_beams = [bs.init_beam(go, (h[:, i, :], c[:, i, :], prev[i]))
                     for i in range(batch_size)]
        finished_beams = [[] for _ in range(batch_size)]
        outputs = [None for _ in range(batch_size)]
        for t in range(max_len):
            toks = []
            all_states = []
            for beam in filter(bool, all_beams):
                token, states = bs.pack_beam(beam, article.device)
                toks.append(token)
                all_states.append(states)
            token = torch.stack(toks, dim=1)
            states = ((torch.stack([h for (h, _), _ in all_states], dim=2),
                       torch.stack([c for (_, c), _ in all_states], dim=2)),
                      torch.stack([prev for _, prev in all_states], dim=1))
            token.masked_fill_(token >= vsize, unk)

            if t < min_len:
                force_not_stop = True
            else:
                force_not_stop = False
            topk, lp, states, attn_score = self._decoder.topk_step(
                token, states, attention, beam_size, force_not_stop=force_not_stop)

            batch_i = 0
            for i, (beam, finished) in enumerate(zip(all_beams,
                                                     finished_beams)):
                if not beam:
                    continue
                finished, new_beam = bs.next_search_beam_cnn(
                    beam, beam_size, finished, eos,
                    topk[:, batch_i, :], lp[:, batch_i, :],
                    (states[0][0][:, :, batch_i, :],
                     states[0][1][:, :, batch_i, :],
                     states[1][:, batch_i, :]),
                    attn_score[:, batch_i, :],
                    diverse
                )
                batch_i += 1
                if len(finished) >= beam_size:
                    all_beams[i] = []
                    outputs[i] = finished[:beam_size]
                    # exclude finished inputs
                    (attention, mask, extend_art, extend_vsize
                    ) = all_attention
                    masks = [mask[j] for j, o in enumerate(outputs)
                             if o is None]
                    ind = [j for j, o in enumerate(outputs) if o is None]
                    ind = torch.LongTensor(ind).to(attention.device)
                    attention, extend_art = map(
                        lambda v: v.index_select(dim=0, index=ind),
                        [attention, extend_art]
                    )
                    if masks:
                        mask = torch.stack(masks, dim=0)
                    else:
                        mask = None
                    attention = (
                        attention, mask, extend_art, extend_vsize)
                else:
                    all_beams[i] = new_beam
                    finished_beams[i] = finished
            if all(outputs):
                break
        else:
            for i, (o, f, b) in enumerate(zip(outputs,
                                              finished_beams, all_beams)):
                if o is None:
                    outputs[i] = (f+b)[:beam_size]
        return outputs
Exemplo n.º 8
0
    def forward(self, raw_article_sents):  # for RL abs
        dec_args, id2word = self._prepro(raw_article_sents)
        article, art_lens, extend_art, extend_vsize, go, eos, unk, max_len = dec_args

        def argmax(arr, keys):
            return arr[max(range(len(arr)), key=lambda i: keys[i].item())]

        #-------gready decode------
        with torch.no_grad():
            self._net.eval()
            decs, attns = self._net.batch_decode(*dec_args)  # gready decode
            dec_sents_greedy = []
            for i, raw_words in enumerate(raw_article_sents):
                dec = []
                for id_, attn in zip(decs, attns):
                    if id_[i] == END:
                        break
                    elif id_[i] == UNK:
                        dec.append(argmax(raw_words, attn[i]))
                    else:
                        dec.append(id2word[id_[i].item()])
                dec_sents_greedy.append(dec)

        #--------batch sampling------
        if self.training:
            self._net.train()
            batch_size = len(art_lens)
            vsize = self._net._embedding.num_embeddings
            attention, init_dec_states = self._net.encode(article, art_lens)
            mask = len_mask(art_lens, attention.device).unsqueeze(-2)
            attention = (attention, mask, extend_art, extend_vsize)
            lstm_in = torch.LongTensor([go] * batch_size).to(article.device)
            outputs = []
            attns = []
            dists = []
            states = init_dec_states
            for i in range(max_len):
                tok, states, attn_score, logit = self._net._decoder.decode_step(
                    lstm_in, states, attention, return_logit=True)
                prob = F.softmax(logit, dim=-1)
                out = torch.multinomial(prob, 1).detach()
                if i == 0:
                    flag = (out != eos)
                else:
                    flag = flag * (out != eos)
                dist = torch.log(prob.gather(1, out))
                dists.append(dist)
                attns.append(attn_score)
                outputs.append(out[:, 0].clone())
                if flag.sum().item() == 0:
                    break
                lstm_in = out.masked_fill(out >= vsize, unk)

            output_word_batch = []
            dist_batch = []
            for i, raw_words in enumerate(raw_article_sents):
                words = []
                diss = []
                for id_, attn, dis in zip(outputs, attns, dists):
                    diss.append(dis[i:i + 1])
                    if id_[i] == END:
                        break
                    elif id_[i] == UNK:
                        words.append(argmax(raw_words, attn[i]))
                    else:
                        words.append(id2word[id_[i].item()])
                output_word_batch.append(words)
                dist_batch.append(sum(diss))

        if self.training:
            return dec_sents_greedy, output_word_batch, dist_batch
        else:
            return dec_sents_greedy