コード例 #1
0
    def forward(self, decoder_state, source_hids, src_lengths):
        assert self.decoder_hidden_state_dim == self.context_dim
        max_src_len = source_hids.size()[0]
        assert max_src_len == src_lengths.data.max()
        batch_size = source_hids.size()[1]

        src_mask = (attention_utils.create_src_lengths_mask(
            batch_size, src_lengths).type_as(source_hids).t().unsqueeze(2))

        if self.pool_type == "mean":
            # need to make src_lengths a 3-D tensor to normalize masked_hiddens
            denom = src_lengths.view(1, batch_size, 1).type_as(source_hids)
            masked_hiddens = source_hids * src_mask
            context = (masked_hiddens / denom).sum(dim=0)
        elif self.pool_type == "max":
            masked_hiddens = source_hids - 10e6 * (1 - src_mask)
            context = masked_hiddens.max(dim=0)[0]
        else:
            raise ValueError(
                f"Pooling type {self.pool_type} is not supported.")
        attn_scores = Variable(
            torch.ones(src_mask.shape[1],
                       src_mask.shape[0]).type_as(source_hids.data),
            requires_grad=False,
        ).t()

        return context, attn_scores
コード例 #2
0
    def forward(self, decoder_state, source_hids, src_lengths):
        # decoder_state: bsz x context_dim
        if self.input_proj is not None:
            decoder_state = self.input_proj(decoder_state)
        # compute attention
        attn_scores = (source_hids * decoder_state.unsqueeze(0)).sum(dim=2).t()

        if self.src_length_masking:
            max_src_len = source_hids.size()[0]
            assert max_src_len == src_lengths.data.max()
            batch_size = source_hids.size()[1]
            src_mask = attention_utils.create_src_lengths_mask(
                batch_size,
                src_lengths,
            )
            masked_attn_scores = attn_scores.masked_fill(
                src_mask == 0, -np.inf)
            # Since input of varying lengths, need to make sure the attn_scores
            # for each sentence sum up to one
            attn_scores = F.softmax(masked_attn_scores, dim=-1)  # bsz x srclen
            score_denom = torch.sum(attn_scores,
                                    dim=1).unsqueeze(dim=1).expand(
                                        batch_size, max_src_len)
            normalized_masked_attn_scores = torch.div(attn_scores,
                                                      score_denom).t()
        else:
            normalized_masked_attn_scores = F.softmax(attn_scores, dim=-1).t()

        # sum weighted sources
        attn_weighted_context = (
            source_hids *
            normalized_masked_attn_scores.unsqueeze(2)).sum(dim=0)

        return attn_weighted_context, normalized_masked_attn_scores
コード例 #3
0
    def forward(self, decoder_state, source_hids, src_lengths, squeeze=True):
        """
        Computes MultiheadAttention with respect to either a vector
        or a tensor

        Inputs:
            decoder_state: (bsz x decoder_hidden_state_dim) or
                (bsz x T x decoder_hidden_state_dim)
            source_hids: srclen x bsz x context_dim
            src_lengths: bsz x 1, actual sequence lengths
            squeeze: Whether or not to squeeze on the time dimension.
                Even if decoder_state.dim() is 2 dimensional an
                explicit time step dimension will be unsqueezed.
        Outputs:
          [batch_size, max_src_len] if decoder_state.dim() == 2 & squeeze
            or
          [batch_size, 1, max_src_len] if decoder_state.dim() == 2 & !squeeze
            or
          [batch_size, T, max_src_len] if decoder_state.dim() == 3 & !squeeze
            or
          [batch_size, T, max_src_len] if decoder_state.dim() == 3 & squeeze & T != 1
            or
          [batch_size, max_src_len] if decoder_state.dim() == 3 & squeeze & T == 1
        """
        batch_size = decoder_state.shape[0]
        if decoder_state.dim() == 3:
            query = decoder_state
        elif decoder_state.dim() == 2:
            query = decoder_state.unsqueeze(1)
        else:
            raise ValueError("decoder state must be either 2 or 3 dimensional")
        query = query.transpose(0, 1)
        value = key = source_hids

        src_len_mask = None
        if src_lengths is not None and self.use_src_length_mask:
            # [batch_size, 1, seq_len]
            src_len_mask_int = attention_utils.create_src_lengths_mask(
                batch_size=batch_size, src_lengths=src_lengths)
            src_len_mask = src_len_mask_int != 1

        attn, attn_weights = self._fair_attn.forward(
            query,
            key,
            value,
            key_padding_mask=src_len_mask,
            need_weights=True)
        # attn.shape = T X bsz X embed_dim
        # attn_weights.shape = bsz X T X src_len

        attn_weights = attn_weights.transpose(0, 2)
        # attn_weights.shape = src_len X T X bsz

        if squeeze:
            attn = attn.squeeze(0)
            # attn.shape = squeeze(T) X bsz X embed_dim
            attn_weights = attn_weights.squeeze(1)
            # attn_weights.shape = src_len X squeeze(T) X bsz
            return attn, attn_weights
        return attn, attn_weights
コード例 #4
0
def apply_masks(scores, batch_size, unseen_mask, src_lengths):
    seq_len = scores.shape[-1]

    # [1, batch_size, seq_len]
    sequence_mask = torch.ones(batch_size, seq_len).unsqueeze(0).int()

    if src_lengths is not None:
        # [batch_size, 1, seq_len]
        sequence_mask = attention_utils.create_src_lengths_mask(
            batch_size=batch_size, src_lengths=src_lengths
        ).unsqueeze(-2)

    # [batch_size, 1, seq_len, seq_len]
    sequence_mask = sequence_mask.unsqueeze(1)

    scores = scores.masked_fill(sequence_mask == 0, -np.inf)
    return scores
コード例 #5
0
    def forward(self, decoder_state, source_hids, src_lengths):
        batch_size = decoder_state.shape[0]
        query = decoder_state.unsqueeze(1).transpose(0, 1)
        value = key = source_hids

        src_len_mask = None
        if src_lengths is not None and self.use_src_length_mask:
            # [batch_size, 1, seq_len]
            src_len_mask_int = attention_utils.create_src_lengths_mask(
                batch_size=batch_size, src_lengths=src_lengths)
            src_len_mask = src_len_mask_int != 1

        attn, attn_weights = self._fair_attn.forward(
            query,
            key,
            value,
            key_padding_mask=src_len_mask,
            need_weights=True)

        # attn.shape = tgt_len X bsz X embed_dim
        # attn_weights.shape = src_len X tgt_len X bsz
        return attn.squeeze(0), attn_weights.transpose(0, 2).squeeze(1)