예제 #1
0
    def forward(self, source, memory_bank, memory_lengths=None,
                memory_turns = None, coverage=None):
        # here we implement a hierarchical attention
        if source.dim() == 2:
            source = source.unsqueeze(1)

        batch, source_tl, source_wl, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        # word level attention
        word_align = self.word_score(source, memory_bank.contiguous()
                           .view(batch, -1, dim))

        # transform align (b, 1, tl * wl) -> (b * tl, 1, wl)
        word_align = word_align.view(batch * source_tl, 1, source_wl)
        if memory_lengths is not None:
            word_mask = sequence_mask_herd(memory_lengths.view(-1), max_len=word_align.size(-1))
            word_mask = word_mask.unsqueeze(1)  # Make it broadcastable.
            word_align.masked_fill_(1 - word_mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            word_align_vectors = F.softmax(word_align.view(batch * source_tl, source_wl), -1)
        else:
            word_align_vectors = sparsemax(word_align.view(batch * source_tl, source_wl), -1)

        # mask the all padded sentences
        sent_pad_mask = memory_lengths.view(-1).eq(0).unsqueeze(1)
        word_align_vectors = torch.mul(word_align_vectors,
                                       (1.0 - sent_pad_mask).type_as(word_align_vectors))
        word_align_vectors = word_align_vectors.view(batch * source_tl, target_l, source_wl)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        cw = torch.bmm(word_align_vectors, memory_bank.view(batch * source_tl, source_wl, -1))
        cw = cw.view(batch, source_tl, -1)
        # concat_cw = torch.cat([cw, source.repeat(1, source_tl, 1)], 2).view(batch*source_tl, -1)
        # attn_hw = self.word_linear_out(concat_cw).view(batch, source_tl, -1)
        # attn_hw = torch.tanh(attn_hw)

        # turn level attention
        turn_align = self.turn_score(source, cw)

        if memory_turns is not None:
            turn_mask = sequence_mask(memory_turns, max_len=turn_align.size(-1))
            turn_mask = turn_mask.unsqueeze(1)  # Make it broadcastable.
            turn_align.masked_fill_(1 - turn_mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            turn_align_vectors = F.softmax(turn_align.view(batch * target_l, source_tl), -1)
        else:
            turn_align_vectors = sparsemax(turn_align.view(batch * target_l, source_tl), -1)
        turn_align_vectors = turn_align_vectors.view(batch, target_l, source_tl)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        ct = torch.bmm(turn_align_vectors, cw)

        return ct.squeeze(1), None
예제 #2
0
    def forward(self, source, memory_bank, memory_lengths=None, coverage=None):
        # here we do not need to calculate the align
        # because the answer vector is already averaged representations
        if source.dim() == 2:
            source = source.unsqueeze(1)

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)

        # compute attention scores, as in Luong et al.
        align = self.score(source, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.masked_fill_(1 - mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = F.softmax(align.view(batch * target_l, source_l), -1)
        else:
            align_vectors = sparsemax(align.view(batch * target_l, source_l), -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        return c.squeeze(1), align_vectors
예제 #3
0
    def forward(self, memory_bank, memory_lengths=None, coverage=None):
        """

        Args:
          source (FloatTensor): query vectors ``(batch, tgt_len, dim)``
          memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)``
          memory_lengths (LongTensor): the source context lengths ``(batch,)``
          coverage (FloatTensor): None (not supported yet)

        Returns:
          (FloatTensor, FloatTensor):

          * Computed vector ``(tgt_len, batch, dim)``
          * Attention distribtutions for each query
            ``(tgt_len, batch, src_len)``
        """
        batch, source_l, dim = memory_bank.size()
        source = self.source
        source = source.expand(batch, -1)
        source = source.unsqueeze(1)
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, source_l_ = coverage.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = torch.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        align = self.score(source, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.masked_fill_(~mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = F.softmax(align.view(batch*target_l, source_l), -1)
        else:
            align_vectors = sparsemax(align.view(batch*target_l, source_l), -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)  # batch, target_l, dim
        c = c.mean(dim=1)  # batch, dim

        # Check output sizes
        batch_, dim_ = c.size()
        aeq(batch, batch_)
        aeq(dim, dim_)

        return c
예제 #4
0
    def forward(self, source, memory_bank, memory_lengths=None, coverage=None):
        """
                Args:
                  source (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
                  memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
                  memory_lengths (`LongTensor`): the source context lengths `[batch]`
                  coverage (`FloatTensor`): None (not supported yet)

                Returns:
                  (`FloatTensor`, `FloatTensor`):

                  * Computed vector `[tgt_len x batch x dim]`
                  * Attention distribtutions for each query
                     `[tgt_len x batch x src_len]`
                """
        if source.dim() == 2:
            source = source.unsqueeze(1)

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)

        # compute attention scores, as in Luong et al.
        align = self.score(source, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.masked_fill_(1 - mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = F.softmax(align.view(batch * target_l, source_l), -1)
        else:
            align_vectors = sparsemax(align.view(batch * target_l, source_l), -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        # concatenate
        concat_c = torch.cat([c, source], 2).view(batch * target_l,
                                                  self.indim + self.outdim)
        attn_h = self.linear_out(concat_c).view(batch, target_l, self.outdim)

        return attn_h.squeeze(1), align_vectors.squeeze(1)
예제 #5
0
    def forward(cls, ctx, input, target):
        """
        input (FloatTensor): n x num_classes
        target (LongTensor): n, the indices of the target classes
        """
        assert_equal(input.shape[0], target.shape[0])

        p_star = sparsemax(input, 1)
        cls.p_star = p_star.clone().detach()
        loss = _omega_sparsemax(p_star)

        p_star.scatter_add_(1, target.unsqueeze(1),
                            torch.full_like(p_star, -1))
        loss += torch.einsum("ij,ij->i", p_star, input)

        ctx.save_for_backward(p_star)

        return loss
예제 #6
0
    def attn_map(self, Z):
        if self.attn_func == "softmax":
            return F.softmax(Z, -1)

        elif self.attn_func == "esoftmax":
            return esoftmax(Z, -1)

        elif self.attn_func == "sparsemax":
            return sparsemax(Z, -1)

        elif self.attn_func == "tsallis15":
            return tsallis15(Z, -1)

        elif self.attn_func == "tsallis":
            if self.attn_alpha == 2:  # slightly faster specialized impl
                return sparsemax_bisect(Z, self.bisect_iter)
            else:
                return tsallis_bisect(Z, self.attn_alpha, self.bisect_iter)

        raise ValueError("invalid combination of arguments")
예제 #7
0
    def forward(self,
                source,
                memory_bank,
                memory_lengths=None,
                coverage=None,
                sent_align_vectors=None,
                sent_nums=None):
        """
        Only one-step attention is supported now.
        Args:
          source (`FloatTensor`): query vectors `[batch x dim]`
          memory_bank (`FloatTensor`): word_memory_bank is `FloatTensor` with shape `[batch x s_num x s_len x dim]`
          sent_lens (`LongTensor`): for word_memory_bank, `[batch x s_num]`
          coverage (`FloatTensor`): None (not supported yet)
          sent_align_vectors (`FloatTensor`): the computed sentence align distribution, `[batch x s_num]`
          sent_nums (`LongTensor`): the sentence numbers of inputs, `[batch]`
          use_tanh (`bool`): True, whether use tanh activation function for `general` and 'dot' attention

        Returns:
          (`FloatTensor`, `FloatTensor`):
            * Computed word attentional vector `[batch x dim]`
            * Word Attention distribtutions for the query of word `[batch x s_num x s_len]`
        """

        # only one step input is supported
        assert source.dim(
        ) == 2, "Only one step input is supported for current attention."
        one_step = True
        # [batch, 1, dim]
        source = source.unsqueeze(1)
        batch, tgt_l, dim = source.size()

        # check the specification for word level attention
        assert sent_align_vectors is not None, "For word level attention, the 'sent_align' must be specified."
        assert sent_nums is not None, "For word level attention, the 'sent_nums' must be specified."
        assert memory_lengths is not None, "The lengths for the word memory bank are required."
        sent_lens = memory_lengths

        batch_1, s_num, s_len, dim_ = memory_bank.size()
        batch_2, s_num_ = sent_align_vectors.size()
        batch_3 = sent_nums.size(0)

        aeq(batch, batch_1, batch_2, batch_3)
        aeq(dim, dim_, self.dim)
        aeq(s_num, s_num_)

        # if coverage is not None:
        #     batch_, source_l_ = coverage.size()
        #     aeq(batch, batch_)
        #     aeq(source_l, source_l_)
        #
        # if coverage is not None:
        #     cover = coverage.view(-1).unsqueeze(1)
        #     memory_bank += self.linear_cover(cover).view_as(memory_bank)
        #     memory_bank = torch.tanh(memory_bank)

        # compute word attention scores, as in Luong et al.
        # [batch, s_num, s_len, dim] -> [batch, s_num * s_len, dim]
        memory_bank = memory_bank.view(batch, s_num * s_len, dim)
        # [batch, 1, s_num * s_len]
        word_align = self.score(source, memory_bank)
        # [batch, s_num * s_len]
        word_align = word_align.squeeze(1)
        # [batch, s_num, s_len]
        word_align = word_align.view(batch, s_num, s_len)

        # remove the empty sentences
        # [s_toal, s_len], [s_total]
        valid_word_align, valid_sent_lens = valid_src_compress(
            word_align, sent_nums=sent_nums, sent_lens=sent_lens)

        # [s_toal, s_len]
        word_mask = sequence_mask(valid_sent_lens,
                                  max_len=valid_word_align.size(-1))

        # word_mask = word_mask.view(batch, s_num, s_len)

        # # [batch, s_num]
        # sent_mask = sequence_mask(sent_nums, max_len=s_num)
        # # [batch, s_num, 1]
        # sent_mask = sent_mask.unsqueeze(2)
        # # [batch, s_num, s_len]
        # align_vectors.masked_fill_(1 - sent_mask, 0.0)

        # [s_total, s_len]
        valid_word_align.masked_fill_(1 - word_mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = F.softmax(valid_word_align, -1)
        else:
            align_vectors = sparsemax(valid_word_align, -1)

        # Recover the original shape by pad 0.s for empty sentence
        # [batch, s_num, s_len]
        align_vectors = recover_src(align_vectors, sent_nums)

        # # For the whole invalid sentence, we set all the word aligns to 0s.
        # # Since
        # # [batch, s_num]
        # sent_mask = sequence_mask(sent_nums, max_len=s_num)
        # # [batch, s_num, 1]
        # sent_mask = sent_mask.unsqueeze(2)
        # # [batch, s_num, s_len]
        # align_vectors.masked_fill_(1 - sent_mask, 0.0)

        # [batch, s_num, 1]
        sent_align_vectors = sent_align_vectors.unsqueeze(-1)
        # [batch, s_num, s_len]
        align_vectors = align_vectors * sent_align_vectors

        # each context vector c_t is the weighted average
        # over all the source hidden states
        # [batch, 1, s_num * s_len]
        align_vectors = align_vectors.view(batch, -1).unsqueeze(1)
        # [batch, 1, s_num * s_len] x [batch, s_num * s_len, dim] -> [batch, 1, dim]
        c = torch.bmm(align_vectors, memory_bank)
        # [batch, dim]
        c = c.squeeze(1)
        returned_vec = c

        # If output_attn_h == False, we put linear out layer into decoder part
        if self.output_attn_h:
            # concatenate
            # [batch, dim]
            source = source.squeeze(1)
            # [batch, 2*dim]
            concat_c = torch.cat([c, source], 1)
            # [batch, dim]
            attn_h = self.linear_out(concat_c)
            if self.attn_type in ["general", "dot"]:
                attn_h = torch.tanh(attn_h)
            returned_vec = attn_h

        align_vectors = align_vectors.squeeze(1).view(batch, s_num, s_len)
        # Check output sizes
        batch_, dim_ = returned_vec.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        # check
        batch_, s_num_, s_len_ = align_vectors.size()
        aeq(batch, batch_)
        aeq(s_num, s_num_)

        return returned_vec, align_vectors
예제 #8
0
    def forward(self, source, memory_bank, memory_lengths=None):
        """

        Args:
          source (`FloatTensor`): query vectors `[batch x tgt_len x tgt_enc_dim]`
          memory_bank (`FloatTensor`): source vectors `[batch x src_len x src_enc_dim]`
          memory_lengths (`LongTensor`): the source context lengths `[batch]`

        Returns:
          (`FloatTensor`):

          * Attention distribtutions for each query
             `[batch x tgt_len x src_len]`
        """

        # one step input
        if source.dim() == 2:
            one_step = True
            source = source.unsqueeze(1)
        else:
            one_step = False

        batch, source_l, src_enc_dim = memory_bank.size()
        batch_, target_l, tgt_enc_dim = source.size()
        aeq(batch, batch_)
        aeq(self.src_enc_dim, src_enc_dim)
        aeq(self.tgt_enc_dim, tgt_enc_dim)

        # compute attention scores, as in Luong et al.
        # (batch, t_len, s_len)
        align = self.score(source, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.masked_fill_(1 - mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = F.softmax(align.view(batch * target_l, source_l),
                                      -1)
        else:
            align_vectors = sparsemax(align.view(batch * target_l, source_l),
                                      -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)

        # # each context vector c_t is the weighted average
        # # over all the source hidden states
        # c = torch.bmm(align_vectors, memory_bank)
        #
        # # concatenate
        # concat_c = torch.cat([c, source], 2).view(batch*target_l, dim*2)
        # attn_h = self.linear_out(concat_c).view(batch, target_l, dim)
        # if self.attn_type in ["general", "dot"]:
        #     attn_h = torch.tanh(attn_h)

        if one_step:
            # attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            # batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            # aeq(dim, dim_)
            batch_, source_l_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        else:
            # attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.contiguous()
            # Check output sizes
            # target_l_, batch_, dim_ = attn_h.size()
            # aeq(target_l, target_l_)
            aeq(batch, batch_)
            # aeq(dim, dim_)
            batch_, target_l_, source_l_ = align_vectors.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        return align_vectors
예제 #9
0
    def forward(self, source, memory_bank, memory_lengths=None, coverage=None):
        """

        Args:
          source (`FloatTensor`): query vectors `[batch x tgt_len x dim]`           rnn output
          memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`     encoder out
          memory_lengths (`LongTensor`): the source context lengths `[batch]`       encoder out length
          coverage (`FloatTensor`): None (not supported yet)

        Returns:
          (`FloatTensor`, `FloatTensor`):

          * Computed vector `[tgt_len x batch x dim]`
          * Attention distribtutions for each query
             `[tgt_len x batch x src_len]`
        """

        # one step input
        if source.dim() == 2:
            one_step = True
            source = source.unsqueeze(1)
        else:
            one_step = False

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, source_l_ = coverage.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)  # attn 이랑 share하면 이까지 안들어감.
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = torch.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        align = self.score(source, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.masked_fill_(1 - mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = F.softmax(align.view(batch * target_l, source_l),
                                      -1)
        else:
            align_vectors = sparsemax(align.view(batch * target_l, source_l),
                                      -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        # concatenate
        concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2)
        attn_h = self.linear_out(concat_c).view(batch, target_l, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = torch.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, source_l_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()
            # Check output sizes
            target_l_, batch_, dim_ = attn_h.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            target_l_, batch_, source_l_ = align_vectors.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        return attn_h, align_vectors
예제 #10
0
    def forward(self, source, memory_bank, memory_lengths=None, coverage=None):
        """

        Args:
          source (FloatTensor): query vectors ``(batch, tgt_len, dim)``
          memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)``
          memory_lengths (LongTensor): the source context lengths ``(batch,)``
          coverage (FloatTensor): None (not supported yet)

        Returns:
          (FloatTensor, FloatTensor):

          * Computed vector ``(tgt_len, batch, dim)``
          * Attention distribtutions for each query
            ``(tgt_len, batch, src_len)``
        """

        # one step input
        if source.dim() == 2:
            one_step = True
            source = source.unsqueeze(1)
        else:
            one_step = False

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, source_l_ = coverage.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = torch.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        # print('source', source.size())
        # print('memory_bank', memory_bank.size())
        align = self.score(source, memory_bank)     # 对应公式  计算attention权重公式
        # print('align', align.size())               # align torch.Size([bz, 1, WxH])

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.masked_fill_(~mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = F.softmax(align.view(batch*target_l, source_l), -1)
        else:
            align_vectors = sparsemax(align.view(batch*target_l, source_l), -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)
        # c (5, 1, 512)
        # concatenate

        concat_c = torch.cat([c, source], 2).view(batch*target_l, dim*2)  #ot = tanh(Wc[ht; ct])

        attn_h = self.linear_out(concat_c).view(batch, target_l, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = torch.tanh(attn_h)

        if one_step:

            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)


            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, source_l_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()
            # Check output sizes
            target_l_, batch_, dim_ = attn_h.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            target_l_, batch_, source_l_ = align_vectors.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        return attn_h, align_vectors
예제 #11
0
    def forward(self, source, memory_bank, memory_lengths=None, coverage=None):
        """

        Args:
          source (FloatTensor): query vectors ``(batch, tgt_len, dim)``
          memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)``
          memory_lengths (LongTensor): the source context lengths ``(batch,)``
          coverage (FloatTensor): None (not supported yet)

        Returns:
          (FloatTensor, FloatTensor):

          * Computed vector ``(tgt_len, batch, dim)``
          * Attention distribtutions for each query
            ``(tgt_len, batch, src_len)``
        """

        # one step input
        if source.dim() == 2:
            one_step = True
            source = source.unsqueeze(1)
        else:
            one_step = False

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, source_l_ = coverage.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = torch.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        align = self.score(source, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.masked_fill_(~mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = torch.log_softmax(
                align.view(batch * target_l, source_l), -1)
        else:
            align_vectors = sparsemax(align.view(batch * target_l, source_l),
                                      -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)

        if one_step:
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, source_l_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        else:
            align_vectors = align_vectors.transpose(0, 1).contiguous()
            # Check output sizes
            target_l_, batch_, source_l_ = align_vectors.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        return align_vectors
예제 #12
0
    def forward(self, source, memory_bank, memory_lengths=None, coverage=None):
        """

        Args:
          source (FloatTensor): query vectors ``(batch, tgt_len, dim)``
          memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)``
          memory_lengths (LongTensor): the source context lengths ``(batch,)``
          coverage (FloatTensor): None (not supported yet)

        Returns:
          (FloatTensor, FloatTensor):

          * Computed vector ``(tgt_len, batch, dim)``
          * Attention distribtutions for each query
            ``(tgt_len, batch, src_len)``
        """

        # one step input
        if source.dim() == 2:
            one_step = True
            source = source.unsqueeze(1)
        else:
            one_step = False

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, source_l_ = coverage.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = torch.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        align = self.score(source, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.masked_fill_(1 - mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = F.softmax(align.view(batch*target_l, source_l), -1)
        else:
            align_vectors = sparsemax(align.view(batch*target_l, source_l), -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        # concatenate
        concat_c = torch.cat([c, source], 2).view(batch*target_l, dim*2)
        attn_h = self.linear_out(concat_c).view(batch, target_l, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = torch.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, source_l_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()
            # Check output sizes
            target_l_, batch_, dim_ = attn_h.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            target_l_, batch_, source_l_ = align_vectors.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        return attn_h, align_vectors
예제 #13
0
    def forward(self, source, memory_bank,memory_lengths=None, coverage=None):
        """

        Args:
          source (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
          memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
          memory_lengths (`LongTensor`): the source context lengths `[batch]`
          coverage (`FloatTensor`): None (not supported yet)

        Returns:
          (`FloatTensor`, `FloatTensor`):

          * Computed vector `[tgt_len x batch x dim]`
          * Attention distribtutions for each query
             `[tgt_len x batch x src_len]`
        """
        # print ('Source..',source.size())
        # print ('memory_bank..',memory_bank.size())
        # Source..torch.Size([16, 512])
        # memory_bank..torch.Size([16, 400, 512])


        # one step input
        if source.dim() == 2:
            one_step = True
            source = source.unsqueeze(1)
        else:
            one_step = False

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, source_l_ = coverage.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = torch.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        align = self.score(source, memory_bank)

        if memory_lengths is not None:
            #????
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.

            # fix for tensor version > 1.2
            # refer to https://github.com/OpenNMT/OpenNMT-py/pull/1527/commits/234f9a5f6fca989fe6804e44ea68b58786ed58b8
            align.masked_fill_(~mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = F.softmax(align.view(batch*target_l, source_l), -1)
        else:
            align_vectors = sparsemax(align.view(batch*target_l, source_l), -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        # concatenate
        concat_c = torch.cat([c, source], 2).view(batch*target_l, dim*2)
        attn_h = self.linear_out(concat_c).view(batch, target_l, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = torch.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, source_l_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()
            # Check output sizes
            target_l_, batch_, dim_ = attn_h.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            target_l_, batch_, source_l_ = align_vectors.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        # print ('Atten Hidden...',attn_h.size()) # torch.Size([16, 512])
        # print ('Align...',align_vectors.size()) # torch.Size([16, 400])
        return attn_h, align_vectors
예제 #14
0
    def forward(self, source, memory_bank, memory_lengths=None, coverage=None):
        #            BxT'xd     BxTxd         B
        """

        Args:
          source (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
          memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
          memory_lengths (`LongTensor`): the source context lengths `[batch]`
          coverage (`FloatTensor`): None (not supported yet)

        Returns:
          (`FloatTensor`, `FloatTensor`):

          * Computed vector `[tgt_len x batch x dim]`
          * Attention distribtutions for each query
             `[tgt_len x batch x src_len]`
        """
        # one step input
        if source.dim() == 2:
            one_step = True
            source = source.unsqueeze(1)
        else:
            one_step = False

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, source_l_ = coverage.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = torch.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        align = self.score(source, memory_bank)
        #  (B x T' x T)
        #  align[b,t',t] = score(w_t, w_t') in batch b

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            # {0,1}^{B x T}
            mask = mask.unsqueeze(
                1)  # Make it broadcastable: {0,1}^{B x 1 x T}

            align.masked_fill_(1 - mask, -float('inf'))
            #  align[b,t',t] = -inf           if t > len(b)

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            # (BT' x T): apply softmax on each column
            align_vectors = F.softmax(align.view(batch * target_l, source_l),
                                      -1)
        else:
            align_vectors = sparsemax(align.view(batch * target_l, source_l),
                                      -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)
        # B x T' x T: [b,t',:] is a distrib like [0.02, 0.02, ..., 0.0, 0.0]

        # each context vector c_t is the weighted average
        # over all the source hidden states
        #
        #              (B x T' x T) (B x T x d) ---------> (B x T' x d)
        c = torch.bmm(align_vectors, memory_bank)
        # c[b,t',:] = context vector for w_t' in batch b

        # context + query
        concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2)

        # shrink it: (BT' x 2d) ---> (B x T' x d)
        attn_h = self.linear_out(concat_c).view(batch, target_l, dim)

        if self.attn_type in ["general", "dot"]:
            attn_h = torch.tanh(attn_h)  # Eq. (5) in Luong (2015)

        # T' = 1 (e.g., input feeding)
        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, source_l_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        else:
            attn_h = attn_h.transpose(0, 1).contiguous()  # TRANSPOSED
            align_vectors = align_vectors.transpose(0, 1).contiguous()  # ''
            # Check output sizes
            target_l_, batch_, dim_ = attn_h.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            target_l_, batch_, source_l_ = align_vectors.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        # (T' x B x d)  (T' x B x T)
        return attn_h, align_vectors
예제 #15
0
    def forward(self, source, memory_bank, memory_lengths=None, coverage=None):
        """

        Args:
          source (`FloatTensor`): query vectors `[batch x dim]`
          memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
          memory_lengths (`LongTensor`): the source context lengths `[batch]`
          coverage (`FloatTensor`): None (not supported yet)

        Returns:
          (`FloatTensor`, `FloatTensor`):

          * Computed vector `[batch x dim]`
          * Attention distribtutions for each query
             `[batch x src_len]`
        """

        # one step input
        assert source.dim() == 2, "Only one step input is supported"
        #one_step = True
        source = source.unsqueeze(1)

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)

        # compute attention scores, as in Luong et al.
        # [batch x 1 x src_len]
        align = self.score(source, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.masked_fill_(1 - mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = F.softmax(align.view(batch * target_l, source_l),
                                      -1)
        else:
            align_vectors = sparsemax(align.view(batch * target_l, source_l),
                                      -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)

        if coverage is not None:
            batch_, source_l_ = coverage.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)
        if self.coverage_attn and coverage is not None:
            # [batch, src_len]
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            # [batch, src_len]
            coverage_reversed = -1 * coverage
            coverage_reversed.masked_fill_(1 - mask, -float('inf'))
            coverage_reversed = F.softmax(coverage_reversed, -1)
            coverage_reversed = coverage_reversed.unsqueeze(1)
            # we only use the coverage_reversed to rescale the current sent attention and do not backward the gradient
            coverage_reversed = coverage_reversed.detach()

            align_vectors = align_vectors * coverage_reversed
            norm_term = align_vectors.sum(dim=2, keepdim=True)
            align_vectors = align_vectors / norm_term

        # each context vector c_t is the weighted average
        # over all the source hidden states
        # [batch, target_l, dim]
        c = torch.bmm(align_vectors, memory_bank)
        # [batch, dim]
        returned_vec = c.squeeze(1)

        # # concatenate
        if self.output_attn_h:
            concat_c = torch.cat([c, source], 2).view(batch * target_l,
                                                      dim * 2)
            attn_h = self.linear_out(concat_c).view(batch, target_l, dim)
            if self.attn_type in ["general", "dot"]:
                attn_h = torch.tanh(attn_h)
            attn_h = attn_h.squeeze(1)
            returned_vec = attn_h

        align_vectors = align_vectors.squeeze(1)
        # Check output sizes
        batch_, dim_ = returned_vec.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        # Check output sizes
        batch_, source_l_ = align_vectors.size()
        aeq(batch, batch_)
        aeq(source_l, source_l_)

        return returned_vec, align_vectors
예제 #16
0
    def forward(self,
                source,
                memory_bank,
                memory_lengths=None,
                coverage=None,
                sent_align_vectors=None,
                sent_position_tuple=None,
                src_word_sent_ids=None):
        """
        Only one-step attention is supported now.
        Args:
          source (`FloatTensor`): query vectors `[batch x dim]`
          memory_bank (`FloatTensor`): word_memory_bank is `FloatTensor` with shape `[batch x src_len x dim]`
          memory_lengths (`LongTensor`): for word_memory_bank, `[batch]`
          coverage (`FloatTensor`): None (not supported yet)
          sent_align_vectors (`FloatTensor`): the computed sentence align distribution, `[batch x s_num]`
          sent_position_tuple (:obj: `tuple`): Only used for seqhr_enc (sent_p, sent_nums) with size
                `([batch_size, s_num, 2], [batch])`.
          src_word_sent_ids (:obj: `tuple'): (word_sent_ids, src_lengths) with size `([batch, src_len], [batch])'
          use_tanh (`bool`): True, whether use tanh activation function for `general` and 'dot' attention

        Returns:
          (`FloatTensor`, `FloatTensor`):
            * Computed word attentional vector `[batch x dim]`
            * Word Attention distribtutions for the query of word `[batch x src_len]`
        """

        # only one step input is supported
        assert source.dim(
        ) == 2, "Only one step input is supported for current attention."
        assert isinstance(sent_position_tuple, tuple)
        sent_position, sent_nums = sent_position_tuple
        one_step = True
        # [batch, 1, dim]
        source = source.unsqueeze(1)
        batch, tgt_l, dim = source.size()

        # check the specification for word level attention
        assert sent_align_vectors is not None, "For word level attention, the 'sent_align' must be specified."
        assert sent_position is not None, "For word level attention, the 'sent_position' must be specified."
        assert sent_nums is not None, "For word level attention, the 'sent_nums' must be specified."
        assert memory_lengths is not None, "The lengths for the word memory bank are required."
        sent_lens = memory_lengths

        batch_1, src_len, dim_ = memory_bank.size()
        batch_2, sent_num = sent_align_vectors.size()
        batch_3 = sent_nums.size(0)

        aeq(batch, batch_1, batch_2, batch_3)
        aeq(dim, dim_, self.dim)

        # if coverage is not None:
        #     batch_, source_l_ = coverage.size()
        #     aeq(batch, batch_)
        #     aeq(source_l, source_l_)
        #
        # if coverage is not None:
        #     cover = coverage.view(-1).unsqueeze(1)
        #     memory_bank += self.linear_cover(cover).view_as(memory_bank)
        #     memory_bank = torch.tanh(memory_bank)

        # compute word attention scores, as in Luong et al.
        # [batch, 1, src_len]
        word_align = self.score(source, memory_bank)
        # [batch, src_len]
        word_align = word_align.squeeze(1)

        # [batch, src_len]
        word_mask = sequence_mask(memory_lengths, max_len=word_align.size(-1))

        # [batch, src_len]
        word_align.masked_fill_(1 - word_mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = F.softmax(word_align, -1)
        else:
            align_vectors = sparsemax(word_align, -1)

        if self.seqHRE_attn_rescale:
            word_sent_ids, memory_lengths_ = src_word_sent_ids
            assert memory_lengths.eq(memory_lengths_).all(), \
                "The src lengths in src_word_sent_ids should be the same as the memory_lengths"

            # # attention score reweighting method 1
            # # broadcast the sent_align_vectors from [batch, sent_num] to [batch, src_len]
            # # according to sent_position [batch, sent_num, 2] and sent_nums [batch]
            # expand_sent_align_vectors = []
            # for b_idx in range(batch):
            #     one_ex_expand = []
            #     for sent_idx in range(sent_num):
            #         sent_token_num = sent_position[b_idx][sent_idx][0] - sent_position[b_idx][sent_idx][1] + 1
            #         if sent_token_num != 1:
            #             one_ex_expand.append(sent_align_vectors[b_idx][sent_idx].expand(sent_token_num))
            #         else:
            #             break
            #     one_ex_expand = torch.cat(one_ex_expand, dim=0)
            #     if one_ex_expand.size(0) < src_len:
            #         pad_vector = torch.zeros([src_len - one_ex_expand.size(0)],
            #                                  dtype=one_ex_expand.dtype, device=one_ex_expand.device)
            #         one_ex_expand = torch.cat([one_ex_expand, pad_vector], dim=0).contiguous()
            #     expand_sent_align_vectors.append(one_ex_expand)
            #
            # # [batch, src_len]
            # expand_sent_align_vectors = torch.stack(expand_sent_align_vectors, dim=0).contiguous()
            # # reweight and renormalize the word align_vectors
            # align_vectors = align_vectors * expand_sent_align_vectors
            # norm_term = align_vectors.sum(dim=1, keepdim=True)
            # align_vectors = align_vectors / norm_term

            # attention score reweighting method 2
            # word_sent_ids: [batch, src_len]
            # sent_align_vectors: [batch, sent_num]
            # expand_sent_align_vectors: [batch, src_len]
            expand_sent_align_vectors = sent_align_vectors.gather(
                dim=1, index=word_sent_ids)
            # # reweight and renormalize the word align_vectors
            # Although word_sent_ids are padded with 0s which will gather the attention score of the sentence 0
            # align_vectors are 0.0000 on these padded places.
            align_vectors = align_vectors * expand_sent_align_vectors
            norm_term = align_vectors.sum(dim=1, keepdim=True)
            align_vectors = align_vectors / norm_term

        # each context vector c_t is the weighted average
        # over all the source hidden states
        # [batch, 1, src_len]
        align_vectors = align_vectors.unsqueeze(1)
        # [batch, 1, src_len] x [batch, src_len, dim] -> [batch, 1, dim]
        c = torch.bmm(align_vectors, memory_bank)
        # [batch, dim]
        c = c.squeeze(1)
        returned_vec = c
        # If output_attn_h == False, we put linear out layer on decoder part
        if self.output_attn_h:
            # concatenate
            # [batch, dim]
            source = source.squeeze(1)
            # [batch, 2*dim]
            concat_c = torch.cat([c, source], 1)
            # [batch, dim]
            attn_h = self.linear_out(concat_c)
            if self.attn_type in ["general", "dot"]:
                attn_h = torch.tanh(attn_h)
            returned_vec = attn_h

        # [batch, src_len]
        align_vectors = align_vectors.squeeze(1)
        # Check output sizes
        batch_1, dim_ = returned_vec.size()
        batch_2, _ = align_vectors.size()
        aeq(batch, batch_1, batch_2)
        aeq(dim, dim_)

        return returned_vec, align_vectors
예제 #17
0
                rel2id[id2emotion[e]] for e in tgt_emotions[s_id:e_id]
            ])  # (batch, )
            # print(batch_rel_ids)
            # batch_score = -torch.norm(ent_emb[batch_concept_ids].unsqueeze(2) +
            #     rel_emb[batch_rel_ids].unsqueeze(1).unsqueeze(2) - ent_emb.unsqueeze(0).unsqueeze(0), dim=-1) # (batch, max_num, vocab), costly

            batch_score = torch.mm((ent_emb[batch_concept_ids] + rel_emb[batch_rel_ids].unsqueeze(1)).view(-1, ent_emb_dim), ent_emb.transpose(1,0))\
                .view(len(batch_concept_ids), max_num_concepts, -1) # (batch, max_num, vocab)

            # batch_score = -torch.norm(ent_emb[batch_concept_ids].unsqueeze(2) *
            #     rel_emb[batch_rel_ids].unsqueeze(1).unsqueeze(2) - ent_emb.unsqueeze(0).unsqueeze(0), p=1, dim=-1) # (batch, max_num, vocab), costly
            if s_id == 0:
                print(src_concepts[:3])
                print(tgt_emotions[:3])
            if use_sparsemax:
                batch_attn = sparsemax(sparsemax_temp * batch_score,
                                       -1)  # (batch, max_num, vocab)
                combined_emb = torch.mm(
                    batch_attn.view(-1, batch_attn.shape[-1]),
                    concept_embedding).view(-1, max_num_concepts, emb_dim)
            elif top_k != 0:
                top_k_scores, top_k_indices = batch_score.topk(
                    top_k,
                    dim=2)  # (batch, max_num, top_k), (batch, max_num, top_k)
                top_k_attn = torch.softmax(top_k_scores,
                                           dim=-1)  # (batch, max_num, top_k)
                if s_id == 0:
                    print("Top k probs: ", top_k_attn[:3, 0])
                # augment VAD
                if concept_VAD_strength_embedding is not None:
                    VAD_attn = torch.softmax(
                        concept_VAD_strength_temp *
    def forward(self, source, memory_bank, memory_lengths=None, coverage=None, modification_method=None):
        """

        Args:
          source (FloatTensor): query vectors ``(batch, tgt_len, dim)``
          memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)``
          memory_lengths (LongTensor): the source context lengths ``(batch,)``
          coverage (FloatTensor): None (not supported yet)

        Returns:
          (FloatTensor, FloatTensor):

          * Computed vector ``(tgt_len, batch, dim)``
          * Attention distribtutions for each query
            ``(tgt_len, batch, src_len)``
        """

        # one step input
        if source.dim() == 2:
            one_step = True
            source = source.unsqueeze(1)
        else:
            one_step = False

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, source_l_ = coverage.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = torch.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        # dimension: batch x target len x source len
        align = self.score(source, memory_bank)

        if modification_method is not None:
            align = align.detach() # Is this OK?

            top_indices = torch.argsort(align, descending=True)
            memory_lengths_vector = memory_lengths.cpu()

            for i in range(align.shape[0]):
                true_length = memory_lengths_vector[i]#true_length_vector[i] #memory_lengths[i]

                if modification_method == 'uniform':
                    align[i,:,:true_length] = 1
                    continue

                for j in range(align.shape[1]):
                    if modification_method == 'zero_out_max':
                        max_index = align[i][j][0:true_length].argmax()
                        align[i][j][max_index] = -float('inf')
                    elif modification_method == 'random_permute':
                       	rand_indices = torch.randperm(true_length, requires_grad=False)
                        #cloned = align[i,j,rand_indices].clone()
                        align[i,j,0:true_length] = align[i,j,rand_indices]#.clone()
                        #align[i,j,0:true_length] = cloned
                    elif modification_method == 'second_max':

                        #top_indices = torch.argsort(align[i][j][:true_length], descending=True)

                        if true_length <= 2:
                            continue


                        first_max = None
                        second_max = None
                        third_max = None

                        for el in top_indices[i][j].cpu():
                            if el >= true_length:#.cuda(el.get_device()):
                                continue

                            if first_max is None:
                                first_max = el
                            elif second_max is None:
                                second_max = el
                            elif third_max is None:
                                third_max = el
                            else:
                                break

                        align[i][j][first_max] = align[i][j][third_max]
                    else:
                        print(">>> shit (Nothing was selected for attetnion modification) <<< ")

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.masked_fill_(~mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = F.softmax(align.view(batch*target_l, source_l), -1)
        else:
            align_vectors = sparsemax(align.view(batch*target_l, source_l), -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        # concatenate
        concat_c = torch.cat([c, source], 2).view(batch*target_l, dim*2)
        attn_h = self.linear_out(concat_c).view(batch, target_l, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = torch.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, source_l_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()
            # Check output sizes
            target_l_, batch_, dim_ = attn_h.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            target_l_, batch_, source_l_ = align_vectors.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        return attn_h, align_vectors
예제 #19
0
def get_sparse_attention(src_ent, rel):
    src_ent_emb_ = ent_emb[ent2id[src_ent]]
    rel_emb_ = rel_emb[rel2id[rel]]
    return sparsemax(-torch.norm(src_ent_emb_ + rel_emb_ - ent_emb, dim=1), -1)
예제 #20
0
def test_sparsemax():
    for _ in range(10):
        x = 0.5 * torch.randn(10, 30000, dtype=torch.float32)
        p1 = sparsemax(x, 1)
        p2 = sparsemax_bisect(x)
        assert torch.sum((p1 - p2)**2) < 1e-7
    def forward(self, source, memory_bank, memory_lengths=None, coverage=None):
        '''
        added by zhengquan
        调用的时候是
        #rnn_output.size() = [3,100] =[batch_size, hidden_size]
        if self.attentional:
            #p_attn是什么呢?为什么每次的维度还不一样呢?p.attn.size()=[3,32]=[batch_size,32]32是什么?最可怕的是这个维度的大小还随着样例的不同而改变。
            #memory_bank.size() = [32,3,100] =[src_len,batch_size,rnn_size],可能上面的那个src也就是src_len
            #memory_lengths.size() = [3] ,memory_lengths= [32,32,31]
            decoder_output, p_attn = self.attn(
                rnn_output,
                memory_bank.transpose(0, 1),
                memory_lengths=memory_lengths)
            attns["std"].append(p_attn)
        '''
        """

        Args:
          source (FloatTensor): query vectors ``(batch, tgt_len, dim)``
          memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)``
          memory_lengths (LongTensor): the source context lengths ``(batch,)``
          coverage (FloatTensor): None (not supported yet)

        Returns:
          (FloatTensor, FloatTensor):

          * Computed vector ``(tgt_len, batch, dim)``
          * Attention distribtutions for each query
            ``(tgt_len, batch, src_len)``
        """

        # one step input
        if source.dim() == 2:
            one_step = True
            source = source.unsqueeze(1)  #[batch_size , 1 , hidden_size]
        else:
            one_step = False

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, source_l_ = coverage.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = torch.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        align = self.score(source, memory_bank)
        #[batch_size,tgt_l(=1),src_l]
        # print("I love you")
        # import pdb
        # pdb.set_trace()

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.masked_fill_(1 - mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = F.softmax(align.view(batch * target_l, source_l),
                                      -1)
        else:
            align_vectors = sparsemax(align.view(batch * target_l, source_l),
                                      -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        # concatenate
        concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2)
        attn_h = self.linear_out(concat_c).view(batch, target_l, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = torch.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, source_l_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()
            # Check output sizes
            target_l_, batch_, dim_ = attn_h.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            target_l_, batch_, source_l_ = align_vectors.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        return attn_h, align_vectors
    def forward(self, source, memory_bank, memory_lengths=None, coverage=None, experiment_type=None):
        """

        Args:
          source (FloatTensor): query vectors ``(batch, tgt_len, dim)``
          memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)``
          memory_lengths (LongTensor): the source context lengths ``(batch,)``
          coverage (FloatTensor): None (not supported yet)
          experiment_type : Type of experiment. Possible values: permute, zero_out, equal_weight, last_state
        Returns:
          (FloatTensor, FloatTensor):

          * Computed vector ``(tgt_len, batch, dim)``
          * Attention distribtutions for each query
            ``(tgt_len, batch, src_len)``
        """
        # one step input
        if source.dim() == 2:
            one_step = True
            source = source.unsqueeze(1)
        else:
            one_step = False

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, source_l_ = coverage.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = torch.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        align = self.score(source, memory_bank)

        #import math
        #print("size of align:  ")
        #print(align.size())
        #max_attention = -math.inf
        #max_index = 0
        #for i in range(align.size()[2]):
        #    if(align[0][0][i] >= max_attention):
        #        max_attention = align[0][0][i]
        #        max_index = i
        #    print(align[0][0][i])

        #print("Attended mostly to source position %d with attention value: %f" % (max_index, max_attention))
        #print("#"*20)

        # i is batch index, j is target token index, third dimension is for source token index

        if experiment_type is not None and experiment_type not in ['zero_out', 'keep_max_uniform_others']:
            assert (align.size()[1] == 1)
            new_align = align.clone().cpu().numpy()

            for i in range(align.size()[0]):
                length = memory_lengths[i] if memory_lengths is not None else new_align[i][j].size()[0]

                for j in range(align.size()[1]):
                    if experiment_type == 'permute':
                        max_index = new_align[i][j][0:length].argmax()

                        succeed = False
                        for _ in range(10):
                            random.shuffle(new_align[i][j][0:length])
                            if(new_align[i][j][0:length].argmax() != max_index):
                                succeed = True
                                break

                        if succeed is False:
                            print("Couldn't permute properly! Be careful")

                    elif experiment_type == 'uniform':
                        new_align[i][j][0:length] = 1
                    elif experiment_type == 'last_state':
                        new_align[i][j][0:length] = -float('inf')
                        new_align[i][j][length-1] = 1
                    elif experiment_type == 'only_max':

                        keep_k = 1

                        indices = new_align[i][j][0:length].argsort()[-keep_k:][::-1]

                        backup = np.copy(new_align[i][j][indices])

                        new_align[i][j][0:length] = -float('inf')

                        for k in range(keep_k):
                            new_align[i][j][indices[k]] = backup[k]

                    elif experiment_type == 'zero_out_max':
                        max_index = new_align[i][j][0:length].argmax()
                        new_align[i][j][max_index] = -float('inf')

                    else:
                        print(">>> non of them is True <<<")
                        assert False

            align = torch.from_numpy(new_align).cuda()


        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.masked_fill_(1 - mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = F.softmax(align.view(batch*target_l, source_l), -1)
        else:
            align_vectors = sparsemax(align.view(batch*target_l, source_l), -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)

        if experiment_type == 'keep_max_uniform_others':
            for i in range(align_vectors.size()[0]): # Batch
                assert (align_vectors.size()[1] == 1)

                length = memory_lengths[i] if memory_lengths is not None else align_vectors[i][0].size()[0]
                if(length == 1):
                    print("length of source is 1!")
                    continue

                max_index = align_vectors[i][0][0:length].argmax()
                max_val = align_vectors[i][0][max_index].item()

                assert (max_val > 0)
                assert (max_val <= 1)

                align_vectors[i][0][0:length] = (1 - max_val) * 1.0 / float(length - 1)
                align_vectors[i][0][max_index] = max_val


        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        # concatenate
        concat_c = torch.cat([c, source], 2).view(batch*target_l, dim*2)
        attn_h = self.linear_out(concat_c).view(batch, target_l, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = torch.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, source_l_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()
            # Check output sizes
            target_l_, batch_, dim_ = attn_h.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            target_l_, batch_, source_l_ = align_vectors.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        if experiment_type == 'zero_out':
            attn_h = torch.zeros(attn_h.size()).cuda()

        return attn_h, align_vectors
예제 #23
0
    def forward(self,
                source,
                memory_bank,
                n,
                latt=False,
                memory_lengths=None,
                coverage=None):
        """

        Args:
          source (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
          memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
          memory_lengths (`LongTensor`): the source context lengths `[batch]`
          coverage (`FloatTensor`): None (not supported yet)

        Returns:
          (`FloatTensor`, `FloatTensor`):

          * Computed vector `[tgt_len x batch x dim]`
          * Attention distribtutions for each query
             `[tgt_len x batch x src_len]`
        """

        # one step input
        if source.dim() == 2:
            one_step = True
            source = source.unsqueeze(1)
        else:
            one_step = False
        #  one_step = True     #latt for lattice use only

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, source_l_ = coverage.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = F.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        align = self.score(source, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.masked_fill_(1 - mask, -float('inf'))

    #   print('align', align)

    # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = F.softmax(align.view(batch * target_l, source_l),
                                      -1)  # latt comment   align vectors
        else:
            align_vectors = sparsemax(align.view(batch * target_l, source_l),
                                      -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        # concatenate context and source
        concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2)
        attn_h = self.linear_out(concat_c).view(batch, target_l, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = F.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, source_l_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        else:
            #         print('test else in global attn', attn_h.size()) #test
            #  attn_h = attn_h.transpose(0, 1).contiguous() # comment out for lattice
            # latt
            align_vectors = align_vectors.transpose(0, 1).contiguous()
            # Check output sizes
            #target_l_, batch_, dim_ = attn_h.size()
            batch_, target_l_, dim_ = attn_h.size()  #latt use only
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            target_l_, batch_, source_l_ = align_vectors.size()

            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        #print('global attn attn_h, align_vectors', attn_h.size(), align_vectors.size()) # latt
        #global attn attn_h, align_vectors torch.Size([1, 9, 256]) torch.Size([9, 1, 27])
        #                                          no. of sen x max length x dim
        #global attn attn_h, align_vectors torch.Size([2, 17, 256]) torch.Size([17, 2, 51])
        #print('attn_h', attn_h)   # latt
        torch.set_printoptions(profile="full")
        #print('align_vectors', align_vectors)    # latt

        logger.info('align_vectors')  # latt
        logger.info(align_vectors)  # latt
        align_vectors[align_vectors == 0.5] = 0
        align_sum = torch.sum(align_vectors, 1)
        # logger.info(align_sum)
        # latt

        torch.set_printoptions(profile="default")

        return attn_h, align_vectors, align_sum  #latt ignore
예제 #24
0
    def forward(
        self,
        source: torch.FloatTensor,  # [batch, tgt_len, dim]
        memory_bank_list: List[
            torch.FloatTensor],  # [num_srcs] x [batch, src_len, dim]
        memory_lengths_list: List[
            torch.FloatTensor] = None,  # [num_srcs] x [batch]
        coverage=None
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        assert coverage is None

        # one step input
        if source.dim() == 2:
            one_step = True
            source = source.unsqueeze(1)
        else:
            one_step = False
        # end if

        # Join memory bank
        memory_bank = torch.cat(memory_bank_list, dim=1)

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, source_l_ = coverage.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = torch.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        align = self.score(source, memory_bank)

        if memory_lengths_list is not None:
            mask = torch.cat([
                sequence_mask(memory_lengths,
                              max_len=memory_bank_list[src_i].size(1))
                for src_i, memory_lengths in enumerate(memory_lengths_list)
            ],
                             dim=1)
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.masked_fill_(1 - mask, -float('inf'))
        # end if

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = F.softmax(align.view(batch * target_l, source_l),
                                      -1)
        else:
            align_vectors = sparsemax(align.view(batch * target_l, source_l),
                                      -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        # concatenate
        concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2)
        attn_h = self.linear_out(concat_c).view(batch, target_l, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = torch.tanh(attn_h)
        # end if

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, source_l_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()
            # Check output sizes
            target_l_, batch_, dim_ = attn_h.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            target_l_, batch_, source_l_ = align_vectors.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(source_l, source_l_)
        # end if

        return attn_h, align_vectors
예제 #25
0
    def forward(self,
                source,
                memory_bank,
                memory_lengths=None,
                coverage=None,
                sent_align_vectors=None):
        """

        Args:
          source (`FloatTensor`): query vectors `[batch x dim]`
          memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
          memory_lengths (`LongTensor`): the source context lengths `[batch]`
          coverage (`FloatTensor`): None (not supported yet)
          sent_align_vectors (`FloatTensor`): sentence level attention cores `[batch x src_len]`

        Returns:
          (`FloatTensor`, `FloatTensor`):

          * Computed vector `[batch x dim]`
          * Attention distribtutions for each query
             `[batch x src_len]`
        """

        # one step input
        assert source.dim() == 2, "Only one step input is supported"
        #one_step = True
        source = source.unsqueeze(1)
        sent_align_vectors = sent_align_vectors.unsqueeze(1)

        batch, src_len, dim = memory_bank.size()
        batch1, tgt_len, dim1 = source.size()
        batch2, tgt_len2, src_len2 = sent_align_vectors.size()
        aeq(batch, batch1, batch2)
        aeq(self.dim, dim, dim1)
        aeq(src_len, src_len2)
        aeq(tgt_len, tgt_len2)

        # if coverage is not None:
        #     batch_, source_l_ = coverage.size()
        #     aeq(batch, batch_)
        #     aeq(source_l, source_l_)
        # if coverage is not None:
        #     cover = coverage.view(-1).unsqueeze(1)
        #     memory_bank += self.linear_cover(cover).view_as(memory_bank)
        #     memory_bank = torch.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        # [batch, tgt_len, src_len]
        align = self.score(source, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.masked_fill_(1 - mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = F.softmax(align.view(batch * tgt_len, src_len), -1)
        else:
            align_vectors = sparsemax(align.view(batch * tgt_len, src_len), -1)
        align_vectors = align_vectors.view(batch, tgt_len, src_len)

        # rescale the word attention scores using the sent_align_vec
        align_vectors = align_vectors * sent_align_vectors
        norm_vec = align_vectors.sum(dim=-1, keepdim=True)
        align_vectors = align_vectors / norm_vec

        # each context vector c_t is the weighted average
        # over all the source hidden states
        # [batch, tgt_len, dim]
        c = torch.bmm(align_vectors, memory_bank)
        # [batch, dim]
        returned_vec = c.squeeze(1)

        # # concatenate
        if self.output_attn_h:
            concat_c = torch.cat([c, source], 2).view(batch * tgt_len, dim * 2)
            attn_h = self.linear_out(concat_c).view(batch, tgt_len, dim)
            if self.attn_type in ["general", "dot"]:
                attn_h = torch.tanh(attn_h)
            attn_h = attn_h.squeeze(1)
            returned_vec = attn_h

        align_vectors = align_vectors.squeeze(1)
        # Check output sizes
        batch_, dim_ = returned_vec.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        # Check output sizes
        batch_, src_len_ = align_vectors.size()
        aeq(batch, batch_)
        aeq(src_len, src_len_)

        return returned_vec, align_vectors