Example #1
0
    def forward(self, query, key, value, key_padding_mask=None,
                need_weights=True, attn_mask=None, params=None):
        if params is None:
            params = OrderedDict(self.named_parameters())

        in_proj_weight = params.get('in_proj_weight', None)
        in_proj_bias = params.get('in_proj_bias', None)
        bias_k = params.get('bias_k', None)
        bias_v = params.get('bias_v', None)

        if not self._qkv_same_embed_dim:
            return F.multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                in_proj_weight, in_proj_bias,
                bias_k, bias_v, self.add_zero_attn,
                self.dropout, params['out_proj.weight'], params['out_proj.bias'],
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask, use_separate_proj_weight=True,
                q_proj_weight=params['q_proj_weight'],
                k_proj_weight=params['k_proj_weight'],
                v_proj_weight=params['v_proj_weight'])
        else:
            return F.multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                in_proj_weight, in_proj_bias,
                bias_k, bias_v, self.add_zero_attn,
                self.dropout, params['out_proj.weight'], params['out_proj.bias'],
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask)
Example #2
0
    def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None):
        if hasattr(self, '_qkv_same_embed_dim') and self._qkv_same_embed_dim is False:
            return F.multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask, use_separate_proj_weight=True,
                q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
                v_proj_weight=self.v_proj_weight)
        else:
            if not hasattr(self, '_qkv_same_embed_dim'):
                warnings.warn('A new version of MultiheadAttention module has been implemented. \
                    Please re-train your model with the new module',
                              UserWarning)

            return F.multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask)
Example #3
0
 def forward(self,
             query,
             key,
             value,
             key_padding_mask=None,
             need_weights=True,
             attn_mask=None):
     if not self._qkv_same_embed_dim:
         return F.multi_head_attention_forward(
             query,
             key,
             value,
             self.embed_dim,
             self.num_heads,
             self.in_proj_weight,
             self.in_proj_bias,
             self.bias_k,
             self.bias_v,
             self.add_zero_attn,
             self.dropout,
             self.out_proj.weight,
             self.out_proj.bias,
             training=self.training,
             key_padding_mask=key_padding_mask,
             need_weights=need_weights,
             attn_mask=attn_mask,
             use_separate_proj_weight=True,
             q_proj_weight=self.q_proj_weight,
             k_proj_weight=self.k_proj_weight,
             v_proj_weight=self.v_proj_weight)
     else:
         return F.multi_head_attention_forward(
             query,
             key,
             value,
             self.embed_dim,
             self.num_heads,
             self.in_proj_weight,
             self.in_proj_bias,
             self.bias_k,
             self.bias_v,
             self.add_zero_attn,
             self.dropout,
             self.out_proj.weight,
             self.out_proj.bias,
             training=self.training,
             key_padding_mask=key_padding_mask,
             need_weights=need_weights,
             attn_mask=attn_mask)
Example #4
0
File: impl.py Project: yt752/aps
 def torch_forward(self,
                   query: th.Tensor,
                   key: th.Tensor,
                   value: th.Tensor,
                   key_padding_mask: Optional[th.Tensor] = None,
                   attn_mask: Optional[th.Tensor] = None) -> MHSAReturnType:
     """
     Args:
         query (Tensor): L x N x E
         key (Tensor): S x N x E
         value (Tensor): S x N x E
         key_padding_mask (Tensor): N x S
         attn_mask (Tensor): L x S, additional mask
     Return:
         context (Tensor): L x N x E
         weight (Tensor): N x L x S
     """
     return tf.multi_head_attention_forward(
         query,
         key,
         value,
         self.embed_dim,
         self.num_heads,
         self.in_proj_weight,
         self.in_proj_bias,
         None,
         None,
         False,
         self.dropout.p,
         self.out_proj.weight,
         self.out_proj.bias,
         training=self.training,
         key_padding_mask=key_padding_mask,
         need_weights=True,
         attn_mask=attn_mask)
Example #5
0
    def forward(self, x):
        x = x.reshape(x.shape[0], x.shape[1],
                      x.shape[2] * x.shape[3]).permute(2, 0,
                                                       1)  # NCHW -> (HW)NC
        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
        x, _ = F.multi_head_attention_forward(
            query=x,
            key=x,
            value=x,
            embed_dim_to_check=x.shape[-1],
            num_heads=self.num_heads,
            q_proj_weight=self.q_proj.weight,
            k_proj_weight=self.k_proj.weight,
            v_proj_weight=self.v_proj.weight,
            in_proj_weight=None,
            in_proj_bias=torch.cat(
                [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
            bias_k=None,
            bias_v=None,
            add_zero_attn=False,
            dropout_p=0,
            out_proj_weight=self.c_proj.weight,
            out_proj_bias=self.c_proj.bias,
            use_separate_proj_weight=True,
            training=self.training,
            need_weights=False)

        return x[0]
Example #6
0
    def forward(self,
                inputs: pt.Tensor,
                previous_states: Optional[pt.Tensor] = None,
                mask: Optional[pt.Tensor] = None,
                **args) -> Tuple[pt.Tensor, pt.Tensor]:  # type: ignore
        """
        Computes multi-head attention on a set of inputs, serving as queries, keys, and values.
        If sequence lengths are provided, they will be used to mask the attention scores.
        A bias mask may also be used to mask the attention scores.
        May also use a cache of previously computed inputs.
        Returns a tensor of shape (max_length, batch, output_depth).

        :param inputs: Input Data. Shape: (length, batch, input_depth).
        :param previous_states: Optional list with two tensors - previous input's keys and values.
                                Shape: 2 * (batch, max_length+1, depth_att).
        :param mask: Optional attention mask. See DotAttentionCell for shape information.
        :return: tensor of shape (max_length, batch, output_depth).
        """
        if self.training:  # use fused multi-head attention op during training
            assert not self.kv_interleaved
            contexts, _ = F.multi_head_attention_forward(
                query=inputs,
                key=inputs,
                value=inputs,
                embed_dim_to_check=self.depth,
                num_heads=self.heads,
                in_proj_weight=self.ff_in.weight,
                in_proj_bias=None,
                bias_k=None,
                bias_v=None,
                add_zero_attn=False,
                dropout_p=self._drop_p,
                out_proj_weight=self.ff_out.weight,
                out_proj_bias=self.ff_out.bias,
                training=self.training,
                key_padding_mask=None,
                need_weights=False,
                attn_mask=mask,
                use_separate_proj_weight=False,
                q_proj_weight=None,
                k_proj_weight=None,
                v_proj_weight=None)
            return contexts, contexts  # dummy return
        else:  # during inference multi-head attention with interleaved key-value parameters is used
            proj = self.ff_in(inputs)
            queries, states = proj.split((self.depth_att, 2 * self.depth_att),
                                         dim=2)

            if previous_states is not None:
                states = pt.cat((previous_states, states), dim=0)

            return self._attend(queries=queries, key_values=states,
                                mask=mask), states
Example #7
0
    def forward(
        self,
        queries: pt.Tensor,
        key_values: pt.Tensor,
        mask: Optional[pt.Tensor] = None,
        projected_memory_kv: Optional[pt.Tensor] = None
    ) -> pt.Tensor:  # mypy: ignore
        """
        Computes multi-head attention for queries given a memory tensor.
        If sequence lengths are provided, they will be used to mask the attention scores.
        A bias mask may also be used to mask the attention scores.
        Returns an tensor of shape (max_length, batch, output_depth).

        :param queries: Query tensor. Shape: (queries_length, batch, input_depth).
        :param key_values: Memory data to attend to. Shape: (key_values_length, batch, input_depth).
        :param mask: Optional attention mask. See DotAttentionCell for shape information.
        :param projected_memory_kv: Optional previously projected memory keys and values.
        :return: tensor of shape (query_seq_len, batch, output_depth).
        """
        if self.training:  # use fused multi-head attention op during training
            assert not self.kv_interleaved
            assert projected_memory_kv is None, "caching not supported in training"
            contexts, _ = F.multi_head_attention_forward(
                query=queries,
                key=key_values,
                value=key_values,
                embed_dim_to_check=self.depth,
                num_heads=self.heads,
                in_proj_weight=None,
                in_proj_bias=None,
                bias_k=None,
                bias_v=None,
                add_zero_attn=False,
                dropout_p=self._drop_p,
                out_proj_weight=self.ff_out.weight,
                out_proj_bias=self.ff_out.bias,
                training=self.training,
                key_padding_mask=None,
                need_weights=False,
                attn_mask=mask,
                use_separate_proj_weight=True,
                q_proj_weight=self.ff_q.weight,
                k_proj_weight=self.ff_kv.weight[:self.depth, :],
                v_proj_weight=self.ff_kv.weight[self.depth:, :])
            return contexts
        else:  # during inference multi-head attention with interleaved key-value parameters is used
            queries = self.ff_q(queries)
            key_values = projected_memory_kv if projected_memory_kv is not None else self.ff_kv(
                key_values)
            return self._attend(queries=queries,
                                key_values=key_values,
                                mask=mask)
Example #8
0
    def forward(self, src):
        # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
        r""" No weights / masks, this is for imgs, and forced to do only self attention
    Args:
        src: this is query, key, and value. 
        need_weights: output attn_output_weights.

    Shape:
        - Inputs:
        - src [BS, D, H, W] (img format), we convert this to NLP style to use multihead attn

        - Outputs:
        - attn_output: :math:`(BS, D, H, W)` 
        - attn_output_weights: :math:`(BS, H*W, H*W)` 
        """

        # NLP likes format: [L, BS, D]. Convert our [BS, D, H, W] -> [H*W, BS, D]:
        BS, D, H, W = src.size()
        src = src.view(BS, D, H * W)
        src = src.permute(2, 0, 1).contiguous()

        # this is self attention, so all the same
        query, key, value = src, src, src

        out, attn_weights = F.multi_head_attention_forward(
            query,
            key,
            value,
            embed_dim_to_check=self.embed_dim,
            num_heads=self.num_heads,
            in_proj_weight=self.in_proj_weight,
            in_proj_bias=self.in_proj_bias,
            bias_k=None,
            bias_v=None,
            add_zero_attn=False,  # ??? don't understand
            dropout_p=self.dropout,
            out_proj_weight=self.out_proj.weight,
            out_proj_bias=self.out_proj.bias,
            training=self.training,
            key_padding_mask=None,
            need_weights=self.need_weights,
            attn_mask=None)

        # [H*W, BS, D] -> [BS, D, H*W]:
        out = out.permute(1, 2, 0).contiguous()
        out = out.view(BS, D, H, W).contiguous()
        # !!! will have to reshape attn_weights into a useable format
        if self.need_weights: return out, attn_weights
        else: return out
    def forward(self, query, key, value, key_padding_mask=None,
                need_weights=True, attn_mask=None):
        # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
        r"""
    Args:
        query, key, value: map a query and a set of key-value pairs to an output.
            See "Attention Is All You Need" for more details.
        key_padding_mask: if provided, specified padding elements in the key will
            be ignored by the attention. This is an binary mask. When the value is True,
            the corresponding value on the attention layer will be filled with -inf.
        need_weights: output attn_output_weights.
        attn_mask: mask that prevents attention to certain positions. This is an additive mask
            (i.e. the values will be added to the attention layer).

    Shape:
        - Inputs:
        - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
          the embedding dimension.
        - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
          the embedding dimension.
        - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
          the embedding dimension.
        - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
        - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length.

        - Outputs:
        - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
          E is the embedding dimension.
        - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
          L is the target sequence length, S is the source sequence length.
        """
        return F.multi_head_attention_forward(
            query, key, value, self.embed_dim, self.num_heads,
            self.in_proj_weight, self.in_proj_bias,
            self.bias_k, self.bias_v, self.add_zero_attn,
            self.dropout, self.out_proj.weight, self.out_proj.bias,
            training=self.training,
            key_padding_mask=key_padding_mask, need_weights=need_weights,
            attn_mask=attn_mask, use_separate_proj_weight=True,
            q_proj_weight=self.qk_proj_weight, k_proj_weight=self.qk_proj_weight,
            v_proj_weight=self.v_proj_weight)
Example #10
0
 def forward(self, x: torch.Tensor) -> torch.Tensor:
     x = x.flatten(-2).permute(2, 0, 1)  # NCHW -> (HW)NC
     x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (1+HW)NC
     x, _ = F.multi_head_attention_forward(
         query=x,
         key=x,
         value=x,
         embed_dim_to_check=x.shape[-1],
         num_heads=self.num_heads,
         q_proj_weight=self.q_proj,
         k_proj_weight=self.k_proj,
         v_proj_weight=self.v_proj,
         in_proj_weight=None,
         in_proj_bias=self.bias,
         bias_k=None,
         bias_v=None,
         add_zero_attn=False,
         dropout_p=0,
         out_proj_weight=self.c_proj.weight,
         out_proj_bias=self.c_proj.bias,
         use_separate_proj_weight=True,
         training=self.training,
         need_weights=False)
     return x[0]
Example #11
0
    def forward(self,
                src_user,
                src_loc,
                src_reg,
                src_time,
                src_square_mask,
                src_binary_mask,
                trg_loc,
                trg_reg,
                mem_mask,
                ds=None):
        loc_emb_src = self.emb_loc(src_loc)
        if self.extra_config.get("user_location_only", False):
            src = loc_emb_src
        else:
            user_emb_src = self.emb_user(src_user)
            # (L, N, LEN_QUADKEY, REG_DIM)
            reg_emb = self.emb_reg(src_reg)
            reg_emb = reg_emb.view(
                reg_emb.size(0) * reg_emb.size(1), reg_emb.size(2),
                reg_emb.size(3)).permute(1, 0, 2)
            # (LEN_QUADKEY, L * N, REG_DIM)

            reg_emb = self.region_pos_encoder(reg_emb)
            reg_emb = self.region_encoder(reg_emb)
            # avg pooling
            reg_emb = torch.mean(reg_emb, dim=0)

            # reg_emb, _ = self.region_gru_encoder(reg_emb, self.h_0.expand(4, reg_emb.size(1), -1).contiguous())
            # reg_emb = reg_emb[-1, :, :]

            # (L, N, REG_DIM)
            reg_emb = reg_emb.view(loc_emb_src.size(0), loc_emb_src.size(1),
                                   reg_emb.size(1))

            time_emb = self.emb_time(src_time)
            if self.extra_config.get("embedding_fusion",
                                     "multiply") == "multiply":
                if self.extra_config.get("user_embedding", False):
                    src = loc_emb_src * reg_emb * time_emb * user_emb_src
                else:
                    src = loc_emb_src * reg_emb * time_emb
            else:
                if self.extra_config.get("user_embedding", False):
                    src = torch.cat(
                        [user_emb_src, loc_emb_src, reg_emb, time_emb], dim=-1)
                else:
                    src = torch.cat([loc_emb_src, reg_emb], dim=-1)

        if self.extra_config.get("size_sqrt_regularize", True):
            src = src * math.sqrt(src.size(-1))

        src = self.pos_encoder(src)
        # shape: [L, N, ninp]
        src = self.encoder(src, mask=src_square_mask)

        # shape: [(1+K)*L, N, loc_dim]
        loc_emb_trg = self.emb_loc(trg_loc)

        reg_emb_trg = self.emb_reg(
            trg_reg)  # [(1+K)*L, N, LEN_QUADKEY, REG_DIM]
        # (LEN_QUADKEY, (1+K)*L * N, REG_DIM)
        reg_emb_trg = reg_emb_trg.view(
            reg_emb_trg.size(0) * reg_emb_trg.size(1), reg_emb_trg.size(2),
            reg_emb_trg.size(3)).permute(1, 0, 2)
        reg_emb_trg = self.region_pos_encoder(reg_emb_trg)
        reg_emb_trg = self.region_encoder(reg_emb_trg)
        reg_emb_trg = torch.mean(reg_emb_trg, dim=0)
        # [(1+K)*L, N, REG_DIM]
        reg_emb_trg = reg_emb_trg.view(loc_emb_trg.size(0),
                                       loc_emb_trg.size(1),
                                       reg_emb_trg.size(1))

        loc_emb_trg = torch.cat([loc_emb_trg, reg_emb_trg], dim=-1)
        if self.extra_config.get("use_attention_as_decoder", False):
            # multi-head attention
            output, _ = F.multi_head_attention_forward(
                query=loc_emb_trg,
                key=src,
                value=src,
                embed_dim_to_check=src.size(2),
                num_heads=1,
                in_proj_weight=None,
                in_proj_bias=None,
                bias_k=None,
                bias_v=None,
                add_zero_attn=None,
                dropout_p=0.0,
                out_proj_weight=self.ident_mat,
                out_proj_bias=None,
                training=self.training,
                key_padding_mask=src_binary_mask,
                need_weights=False,
                attn_mask=mem_mask,
                use_separate_proj_weight=True,
                q_proj_weight=self.ident_mat,
                k_proj_weight=self.ident_mat,
                v_proj_weight=self.ident_mat)

            if self.training:
                src = src.repeat(loc_emb_trg.size(0) // src.size(0), 1, 1)
            else:
                src = src[torch.tensor(ds) - 1, torch.arange(len(ds)), :]
                src = src.unsqueeze(0).repeat(loc_emb_trg.size(0), 1, 1)

            output += src
            output = self.layer_norm(output)
        else:
            # No attention
            if self.training:
                output = src.repeat(loc_emb_trg.size(0) // src.size(0), 1, 1)
            else:
                output = src[torch.tensor(ds) - 1, torch.arange(len(ds)), :]
                output = output.unsqueeze(0).repeat(loc_emb_trg.size(0), 1, 1)

        # shape: [(1+K)*L, N]
        output = torch.sum(output * loc_emb_trg, dim=-1)
        return output
Example #12
0
    def forward(self,
                query,
                key,
                value,
                key_padding_mask=None,
                need_weights=True,
                attn_mask=None):
        r"""
    Args:
        query, key, value: map a query and a set of key-value pairs to an output.
            See "Attention Is All You Need" for more details.
        key_padding_mask: if provided, specified padding elements in the key will
            be ignored by the attention. This is an binary mask. When the value is True,
            the corresponding value on the attention layer will be filled with -inf.
        need_weights: output attn_output_weights.
        attn_mask: mask that prevents attention to certain positions. This is an additive mask
            (i.e. the values will be added to the attention layer).
    Shape:
        - Inputs:
        - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
          the embedding dimension.
        - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
          the embedding dimension.
        - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
          the embedding dimension.
        - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
        - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
        - Outputs:
        - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
          E is the embedding dimension.
        - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
          L is the target sequence length, S is the source sequence length.
        """
        if hasattr(
                self,
                '_qkv_same_embed_dim') and self._qkv_same_embed_dim is False:
            return F.multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                self.in_proj_weight,
                self.in_proj_bias,
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout,
                self.out_proj.weight,
                self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask,
                need_weights=need_weights,
                attn_mask=attn_mask,
                use_separate_proj_weight=True,
                q_proj_weight=self.q_proj_weight,
                k_proj_weight=self.k_proj_weight,
                v_proj_weight=self.v_proj_weight)
        else:
            if not hasattr(self, '_qkv_same_embed_dim'):
                warnings.warn(
                    'A new version of MultiheadAttention module has been implemented. \
                    Please re-train your model with the new module',
                    UserWarning)

            return F.multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                self.in_proj_weight,
                self.in_proj_bias,
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout,
                self.out_proj.weight,
                self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask,
                need_weights=need_weights,
                attn_mask=attn_mask)
Example #13
0
    tsal.out_proj.weight = torch.nn.Parameter(torch.eye(Fv))

    X = torch.rand(128, 5, 512)
    Xt = X.transpose(0, 1)

    # res = F.multi_head_attention_forward()
    res1 = tsal(Xt, Xt, Xt)[0].transpose(0, 1)
    res2 = sal(X)

    resf = F.multi_head_attention_forward(
        Xt,
        Xt,
        Xt,
        Fin,
        nheads,
        sal.proj.weight,
        None,
        None,
        None,
        False,
        0.,
        torch.eye(512),
        None)[0].transpose(0, 1)

    # sparse reduction ops testing

    x = torch.rand(4, 10)
    xd = x.view(2, 2, 10)
    batch = torch.LongTensor([0, 0, 1, 1])

    # sparse self-attention layer testing
    def forward(
        self,
        query,
        key: Optional[Tensor],
        value: Optional[Tensor],
        key_padding_mask: Optional[Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str,
                                                   Optional[Tensor]]]] = None,
        need_weights: bool = True,
        static_kv: bool = False,
        attn_mask: Optional[Tensor] = None,
        before_softmax: bool = False,
        need_head_weights: bool = False,
    ) -> Tuple[Tensor, Optional[Tensor]]:
        """Input shape: Time x Batch x Channel

        Args:
            key_padding_mask (ByteTensor, optional): mask to exclude
                keys that are pads, of shape `(batch, src_len)`, where
                padding elements are indicated by 1s.
            need_weights (bool, optional): return the attention weights,
                averaged over heads (default: False).
            attn_mask (ByteTensor, optional): typically used to
                implement causal attention, where the mask prevents the
                attention from looking forward in time (default: None).
            before_softmax (bool, optional): return the raw attention
                weights and values before the attention softmax.
            need_head_weights (bool, optional): return the attention
                weights for each head. Implies *need_weights*. Default:
                return the average attention weights over all heads.
        """
        if need_head_weights:
            need_weights = True

        tgt_len, bsz, embed_dim = query.size()
        assert embed_dim == self.embed_dim

        if (self.enable_torch_version and incremental_state is None
                and not static_kv):
            assert key is not None and value is not None
            return F.multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                torch.empty([0]),
                torch.cat(
                    (self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout,
                self.out_proj.weight,
                self.out_proj.bias,
                self.training,
                key_padding_mask,
                need_weights,
                attn_mask,
                use_separate_proj_weight=True,
                q_proj_weight=self.q_proj.weight,
                k_proj_weight=self.k_proj.weight,
                v_proj_weight=self.v_proj.weight,
            )

        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
            if saved_state is not None and "prev_key" in saved_state:
                # previous time steps are cached - no need to recompute
                # key and value if they are static
                if static_kv:
                    assert self.encoder_decoder_attention and not self.self_attention
                    key = value = None
        else:
            saved_state = None

        if self.self_attention:
            q = self.q_proj(query)
            k = self.k_proj(query)
            v = self.v_proj(query)
        elif self.encoder_decoder_attention:
            # encoder-decoder attention
            q = self.q_proj(query)
            if key is None:
                assert value is None
                k = v = None
            else:
                k = self.k_proj(key)
                v = self.v_proj(key)

        else:
            assert key is not None and value is not None
            q = self.q_proj(query)
            k = self.k_proj(key)
            v = self.v_proj(value)
        q *= self.scaling

        if self.bias_k is not None:
            assert self.bias_v is not None
            k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat(
                    [
                        key_padding_mask,
                        key_padding_mask.new_zeros(key_padding_mask.size(0),
                                                   1),
                    ],
                    dim=1,
                )

        q = (q.contiguous().view(tgt_len, bsz * self.num_heads,
                                 self.head_dim).transpose(0, 1))
        if k is not None:
            k = (k.contiguous().view(-1, bsz * self.num_heads,
                                     self.head_dim).transpose(0, 1))
        if v is not None:
            v = (v.contiguous().view(-1, bsz * self.num_heads,
                                     self.head_dim).transpose(0, 1))

        if saved_state is not None:
            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
            if "prev_key" in saved_state:
                _prev_key = saved_state["prev_key"]
                assert _prev_key is not None
                prev_key = _prev_key.view(bsz * self.num_heads, -1,
                                          self.head_dim)
                if static_kv:
                    k = prev_key
                else:
                    assert k is not None
                    k = torch.cat([prev_key, k], dim=1)
            if "prev_value" in saved_state:
                _prev_value = saved_state["prev_value"]
                assert _prev_value is not None
                prev_value = _prev_value.view(bsz * self.num_heads, -1,
                                              self.head_dim)
                if static_kv:
                    v = prev_value
                else:
                    assert v is not None
                    v = torch.cat([prev_value, v], dim=1)
            prev_key_padding_mask: Optional[Tensor] = None
            if "prev_key_padding_mask" in saved_state:
                prev_key_padding_mask = saved_state["prev_key_padding_mask"]
            assert k is not None and v is not None
            key_padding_mask = RelativeMultiheadAttention._append_prev_key_padding_mask(
                key_padding_mask=key_padding_mask,
                prev_key_padding_mask=prev_key_padding_mask,
                batch_size=bsz,
                src_len=k.size(1),
                static_kv=static_kv,
            )

            saved_state["prev_key"] = k.view(bsz, self.num_heads, -1,
                                             self.head_dim)
            saved_state["prev_value"] = v.view(bsz, self.num_heads, -1,
                                               self.head_dim)
            saved_state["prev_key_padding_mask"] = key_padding_mask
            # In this branch incremental_state is never None
            assert incremental_state is not None
            incremental_state = self._set_input_buffer(incremental_state,
                                                       saved_state)
        assert k is not None
        src_len = k.size(1)
        # This is part of a workaround to get around fork/join parallelism
        # not supporting Optional types.
        if key_padding_mask is not None and key_padding_mask.dim() == 0:
            key_padding_mask = None

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len

        if self.add_zero_attn:
            assert v is not None
            src_len += 1
            k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])],
                          dim=1)
            v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])],
                          dim=1)
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat(
                    [
                        key_padding_mask,
                        torch.zeros(key_padding_mask.size(0),
                                    1).type_as(key_padding_mask),
                    ],
                    dim=1,
                )

        attn_weights = torch.bmm(q, k.transpose(
            1, 2))  # [bsz * head, tgt_len, tgt_len]

        # relative position
        if self.maximum_relative_position and self.self_attention:
            keys_length = k.size(1)
            relative_pos = self.relative_positions(
                keys_length, self.maximum_relative_position).to(k.device)
            relative_repr_keys = self.relative_position_keys(relative_pos)
            relative_repr_values = self.relative_position_values(relative_pos)
            attn_weights += self.matmul_with_relative_representations(
                q, relative_repr_keys)
        else:
            relative_repr_keys = relative_repr_values = None
        assert list(
            attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]

        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(0)
            attn_weights += attn_mask

        if key_padding_mask is not None:
            # don't attend to padding symbols
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
                                             src_len)
            attn_weights = attn_weights.masked_fill(
                key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
                float("-inf"))
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len,
                                             src_len)

        if before_softmax:
            return attn_weights, v

        attn_weights_float = F.softmax(attn_weights, dim=-1)
        attn_probs = F.dropout(
            attn_weights_float.type_as(attn_weights),
            p=self.dropout,
            training=self.training,
        )
        assert v is not None
        attn = torch.bmm(attn_probs, v)

        # relative position
        if relative_repr_values is not None:
            attn += self.matmul_with_relative_representations(
                attn_probs, relative_repr_values, transpose=False)

        assert list(
            attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
        attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
        attn = self.out_proj(attn)
        attn_weights: Optional[Tensor] = None
        if need_weights:
            attn_weights = attn_weights_float.view(bsz, self.num_heads,
                                                   tgt_len,
                                                   src_len).transpose(1, 0)
            if not need_head_weights:
                # average attention weights over heads
                attn_weights = attn_weights.mean(dim=0)

        return attn, attn_weights
    def forward(self, query, key, value, key_padding_mask=None,
                need_weights=True, attn_mask=None):
        # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
        r"""
    Args:
        query, key, value: map a query and a set of key-value pairs to an output.
            See "Attention Is All You Need" for more details.
        key_padding_mask: if provided, specified padding elements in the key will
            be ignored by the attention. When given a binary mask and a value is True,
            the corresponding value on the attention layer will be ignored. When given
            a byte mask and a value is non-zero, the corresponding value on the attention
            layer will be ignored
        need_weights: output attn_output_weights.
        attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
            the batches while a 3D mask allows to specify a different mask for the entries of each batch.

    Shape:
        - Inputs:
        - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
          the embedding dimension.
        - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
          the embedding dimension.
        - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
          the embedding dimension.
        - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
          If a ByteTensor is provided, the non-zero positions will be ignored while the position
          with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
          value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
        - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
          3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
          S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
          positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
          while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
          is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
          is provided, it will be added to the attention weight.

        - Outputs:
        - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
          E is the embedding dimension.
        - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
          L is the target sequence length, S is the source sequence length.
        """
        if not self._qkv_same_embed_dim:
            return F.multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask, use_separate_proj_weight=True,
                q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
                v_proj_weight=self.v_proj_weight)
        else:
            return my_multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask)
            # return self.multi_head_attention_forward(
            #     query, key, value, self.embed_dim, self.num_heads,
            #     self.in_proj_weight, self.in_proj_bias,
            #     self.bias_k, self.bias_v, self.add_zero_attn,
            #     self.dropout, self.out_proj.weight, self.out_proj.bias,
            #     training=self.training,
            #     key_padding_mask=key_padding_mask, need_weights=need_weights,
            #     attn_mask=attn_mask)

    # def multi_head_attention_forward(self,
    #         query: Tensor,
    #         key: Tensor,
    #         value: Tensor,
    #         embed_dim_to_check: int,
    #         num_heads: int,
    #         in_proj_weight: Tensor,
    #         in_proj_bias: Tensor,
    #         bias_k: Optional[Tensor],
    #         bias_v: Optional[Tensor],
    #         add_zero_attn: bool,
    #         dropout_p: float,
    #         out_proj_weight: Tensor,
    #         out_proj_bias: Tensor,
    #         training: bool = True,
    #         key_padding_mask: Optional[Tensor] = None,
    #         need_weights: bool = True,
    #         attn_mask: Optional[Tensor] = None,
    #         use_separate_proj_weight: bool = False,
    #         q_proj_weight: Optional[Tensor] = None,
    #         k_proj_weight: Optional[Tensor] = None,
    #         v_proj_weight: Optional[Tensor] = None,
    #         static_k: Optional[Tensor] = None,
    #         static_v: Optional[Tensor] = None,
    # ) -> Tuple[Tensor, Optional[Tensor]]:
    #     r"""
    #     Args:
    #         query, key, value: map a query and a set of key-value pairs to an output.
    #             See "Attention Is All You Need" for more details.
    #         embed_dim_to_check: total dimension of the model.
    #         num_heads: parallel attention heads.
    #         in_proj_weight, in_proj_bias: input projection weight and bias.
    #         bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
    #         add_zero_attn: add a new batch of zeros to the key and
    #                        value sequences at dim=1.
    #         dropout_p: probability of an element to be zeroed.
    #         out_proj_weight, out_proj_bias: the output projection weight and bias.
    #         training: apply dropout if is ``True``.
    #         key_padding_mask: if provided, specified padding elements in the key will
    #             be ignored by the attention. This is an binary mask. When the value is True,
    #             the corresponding value on the attention layer will be filled with -inf.
    #         need_weights: output attn_output_weights.
    #         attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
    #             the batches while a 3D mask allows to specify a different mask for the entries of each batch.
    #         use_separate_proj_weight: the function accept the proj. weights for query, key,
    #             and value in different forms. If false, in_proj_weight will be used, which is
    #             a combination of q_proj_weight, k_proj_weight, v_proj_weight.
    #         q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
    #         static_k, static_v: static key and value used for attention operators.
    #     Shape:
    #         Inputs:
    #         - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
    #           the embedding dimension.
    #         - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
    #           the embedding dimension.
    #         - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
    #           the embedding dimension.
    #         - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
    #           If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
    #           will be unchanged. If a BoolTensor is provided, the positions with the
    #           value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
    #         - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
    #           3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
    #           S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
    #           positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
    #           while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
    #           are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
    #           is provided, it will be added to the attention weight.
    #         - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
    #           N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
    #         - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
    #           N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
    #         Outputs:
    #         - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
    #           E is the embedding dimension.
    #         - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
    #           L is the target sequence length, S is the source sequence length.
    #     """
    #     tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
    #     if has_torch_function(tens_ops):
    #         return handle_torch_function(
    #             multi_head_attention_forward,
    #             tens_ops,
    #             query,
    #             key,
    #             value,
    #             embed_dim_to_check,
    #             num_heads,
    #             in_proj_weight,
    #             in_proj_bias,
    #             bias_k,
    #             bias_v,
    #             add_zero_attn,
    #             dropout_p,
    #             out_proj_weight,
    #             out_proj_bias,
    #             training=training,
    #             key_padding_mask=key_padding_mask,
    #             need_weights=need_weights,
    #             attn_mask=attn_mask,
    #             use_separate_proj_weight=use_separate_proj_weight,
    #             q_proj_weight=q_proj_weight,
    #             k_proj_weight=k_proj_weight,
    #             v_proj_weight=v_proj_weight,
    #             static_k=static_k,
    #             static_v=static_v,
    #         )
    #     tgt_len, bsz, embed_dim = query.size()
    #     assert embed_dim == embed_dim_to_check
    #     # allow MHA to have different sizes for the feature dimension
    #     assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
    #
    #     if isinstance(embed_dim, torch.Tensor):
    #         # embed_dim can be a tensor when JIT tracing
    #         head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
    #     else:
    #         head_dim = embed_dim // num_heads
    #     assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
    #     scaling = float(head_dim) ** -0.5
    #
    #     if not use_separate_proj_weight:
    #         if (query is key or torch.equal(query, key)) and (key is value or torch.equal(key, value)):
    #             # self-attention
    #             q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
    #
    #         elif key is value or torch.equal(key, value):
    #             # encoder-decoder attention
    #             # This is inline in_proj function with in_proj_weight and in_proj_bias
    #             _b = in_proj_bias
    #             _start = 0
    #             _end = embed_dim
    #             _w = in_proj_weight[_start:_end, :]
    #             if _b is not None:
    #                 _b = _b[_start:_end]
    #             q = linear(query, _w, _b)
    #
    #             if key is None:
    #                 assert value is None
    #                 k = None
    #                 v = None
    #             else:
    #
    #                 # This is inline in_proj function with in_proj_weight and in_proj_bias
    #                 _b = in_proj_bias
    #                 _start = embed_dim
    #                 _end = None
    #                 _w = in_proj_weight[_start:, :]
    #                 if _b is not None:
    #                     _b = _b[_start:]
    #                 k, v = linear(key, _w, _b).chunk(2, dim=-1)
    #
    #         else:
    #             # This is inline in_proj function with in_proj_weight and in_proj_bias
    #             _b = in_proj_bias
    #             _start = 0
    #             _end = embed_dim
    #             _w = in_proj_weight[_start:_end, :]
    #             if _b is not None:
    #                 _b = _b[_start:_end]
    #             q = linear(query, _w, _b)
    #
    #             # This is inline in_proj function with in_proj_weight and in_proj_bias
    #             _b = in_proj_bias
    #             _start = embed_dim
    #             _end = embed_dim * 2
    #             _w = in_proj_weight[_start:_end, :]
    #             if _b is not None:
    #                 _b = _b[_start:_end]
    #             k = linear(key, _w, _b)
    #
    #             # This is inline in_proj function with in_proj_weight and in_proj_bias
    #             _b = in_proj_bias
    #             _start = embed_dim * 2
    #             _end = None
    #             _w = in_proj_weight[_start:, :]
    #             if _b is not None:
    #                 _b = _b[_start:]
    #             v = linear(value, _w, _b)
    #     else:
    #         q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
    #         len1, len2 = q_proj_weight_non_opt.size()
    #         assert len1 == embed_dim and len2 == query.size(-1)
    #
    #         k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
    #         len1, len2 = k_proj_weight_non_opt.size()
    #         assert len1 == embed_dim and len2 == key.size(-1)
    #
    #         v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
    #         len1, len2 = v_proj_weight_non_opt.size()
    #         assert len1 == embed_dim and len2 == value.size(-1)
    #
    #         if in_proj_bias is not None:
    #             q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
    #             k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim: (embed_dim * 2)])
    #             v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
    #         else:
    #             q = linear(query, q_proj_weight_non_opt, in_proj_bias)
    #             k = linear(key, k_proj_weight_non_opt, in_proj_bias)
    #             v = linear(value, v_proj_weight_non_opt, in_proj_bias)
    #     q = q * scaling
    #
    #     if attn_mask is not None:
    #         assert (
    #                 attn_mask.dtype == torch.float32
    #                 or attn_mask.dtype == torch.float64
    #                 or attn_mask.dtype == torch.float16
    #                 or attn_mask.dtype == torch.uint8
    #                 or attn_mask.dtype == torch.bool
    #         ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(attn_mask.dtype)
    #         if attn_mask.dtype == torch.uint8:
    #             warnings.warn(
    #                 "Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
    #             attn_mask = attn_mask.to(torch.bool)
    #
    #         if attn_mask.dim() == 2:
    #             attn_mask = attn_mask.unsqueeze(0)
    #             if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
    #                 raise RuntimeError("The size of the 2D attn_mask is not correct.")
    #         elif attn_mask.dim() == 3:
    #             if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
    #                 raise RuntimeError("The size of the 3D attn_mask is not correct.")
    #         else:
    #             raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
    #         # attn_mask's dim is 3 now.
    #
    #     # convert ByteTensor key_padding_mask to bool
    #     if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
    #         warnings.warn(
    #             "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
    #         )
    #         key_padding_mask = key_padding_mask.to(torch.bool)
    #
    #     if bias_k is not None and bias_v is not None:
    #         if static_k is None and static_v is None:
    #             k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
    #             v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
    #             if attn_mask is not None:
    #                 attn_mask = pad(attn_mask, (0, 1))
    #             if key_padding_mask is not None:
    #                 key_padding_mask = pad(key_padding_mask, (0, 1))
    #         else:
    #             assert static_k is None, "bias cannot be added to static key."
    #             assert static_v is None, "bias cannot be added to static value."
    #     else:
    #         assert bias_k is None
    #         assert bias_v is None
    #
    #     q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    #     if k is not None:
    #         k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
    #     if v is not None:
    #         v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
    #
    #     if static_k is not None:
    #         assert static_k.size(0) == bsz * num_heads
    #         assert static_k.size(2) == head_dim
    #         k = static_k
    #
    #     if static_v is not None:
    #         assert static_v.size(0) == bsz * num_heads
    #         assert static_v.size(2) == head_dim
    #         v = static_v
    #
    #     src_len = k.size(1)
    #
    #     if key_padding_mask is not None:
    #         assert key_padding_mask.size(0) == bsz
    #         assert key_padding_mask.size(1) == src_len
    #
    #     if add_zero_attn:
    #         src_len += 1
    #         k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
    #         v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
    #         if attn_mask is not None:
    #             attn_mask = pad(attn_mask, (0, 1))
    #         if key_padding_mask is not None:
    #             key_padding_mask = pad(key_padding_mask, (0, 1))
    #
    #     attn_output_weights = torch.bmm(q, k.transpose(1, 2))
    #     assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
    #
    #     if attn_mask is not None:
    #         if attn_mask.dtype == torch.bool:
    #             attn_output_weights.masked_fill_(attn_mask, float("-inf"))
    #         else:
    #             attn_output_weights += attn_mask
    #
    #     if key_padding_mask is not None:
    #         attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
    #         attn_output_weights = attn_output_weights.masked_fill(
    #             key_padding_mask.unsqueeze(1).unsqueeze(2),
    #             float("-inf"),
    #         )
    #         attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
    #
    #     attn_output_weights = softmax(attn_output_weights, dim=-1)
    #     attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
    #
    #     attn_output = torch.bmm(attn_output_weights, v)
    #     assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
    #     attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
    #     attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
    #
    #     if need_weights:
    #         # average attention weights over heads
    #         attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
    #         return attn_output, attn_output_weights.sum(dim=1) / num_heads
    #     else:
    #         return attn_output, None
    def forward(self,
                query,
                key,
                value,
                key_padding_mask=None,
                need_weights=True,
                attn_mask=None):
        # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
        r"""
    Args:
        query, key, value: map a query and a set of key-value pairs to an output.
            See "Attention Is All You Need" for more details.
        key_padding_mask: if provided, specified padding elements in the key will
            be ignored by the attention. When given a binary mask and a value is True,
            the corresponding value on the attention layer will be ignored. When given
            a byte mask and a value is non-zero, the corresponding value on the attention
            layer will be ignored
        need_weights: output attn_output_weights.
        attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
            the batches while a 3D mask allows to specify a different mask for the entries of each batch.

    Shape:
        - Inputs:
        - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
          the embedding dimension.
        - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
          the embedding dimension.
        - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
          the embedding dimension.
        - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
          If a ByteTensor is provided, the non-zero positions will be ignored while the position
          with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
          value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
        - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
          3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
          S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
          positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
          while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
          is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
          is provided, it will be added to the attention weight.

        - Outputs:
        - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
          E is the embedding dimension.
        - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
          L is the target sequence length, S is the source sequence length.
        """
        if not self._qkv_same_embed_dim:
            return F.multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                self.in_proj_weight,
                self.in_proj_bias,
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout,
                self.out_proj.weight,
                self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask,
                need_weights=need_weights,
                attn_mask=attn_mask,
                use_separate_proj_weight=True,
                q_proj_weight=self.q_proj_weight,
                k_proj_weight=self.k_proj_weight,
                v_proj_weight=self.v_proj_weight)
        else:
            return F.multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                self.in_proj_weight,
                self.in_proj_bias,
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout,
                self.out_proj.weight,
                self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask,
                need_weights=need_weights,
                attn_mask=attn_mask)
    def forward(
        self,
        query,
        key: Optional[Tensor],
        value: Optional[Tensor],
        ngrams: Optional[int] = None,
        is_translate: Optional[bool] = False,
        is_cascade: Optional[bool] = False,
        offset: Optional[Tensor] = None,
        key_padding_mask: Optional[Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str,
                                                   Optional[Tensor]]]] = None,
        need_weights: bool = True,
        static_kv: bool = False,
        attn_mask: Optional[Tensor] = None,
        before_softmax: bool = False,
        need_head_weights: bool = False,
    ) -> Tuple[Tensor, Optional[Tensor]]:
        """Input shape: Time x Batch x Channel

        Args:
            key_padding_mask (ByteTensor, optional): mask to exclude
                keys that are pads, of shape `(batch, src_len)`, where
                padding elements are indicated by 1s.
            need_weights (bool, optional): return the attention weights,
                averaged over heads (default: False).
            attn_mask (ByteTensor, optional): typically used to
                implement causal attention, where the mask prevents the
                attention from looking forward in time (default: None).
            before_softmax (bool, optional): return the raw attention
                weights and values before the attention softmax.
            need_head_weights (bool, optional): return the attention
                weights for each head. Implies *need_weights*. Default:
                return the average attention weights over all heads.
        """
        #if attn_mask is not None:
        #    import pdb; pdb.set_trace()
        if need_head_weights:
            need_weights = True

        tgt_len, bsz, embed_dim = query.size()
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]

        if (self.enable_torch_version and not self.onnx_trace
                and incremental_state is None and not static_kv):
            assert key is not None and value is not None
            return F.multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                torch.empty([0]),
                torch.cat(
                    (self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout,
                self.out_proj.weight,
                self.out_proj.bias,
                self.training,
                key_padding_mask,
                need_weights,
                attn_mask,
                use_separate_proj_weight=True,
                q_proj_weight=self.q_proj.weight,
                k_proj_weight=self.k_proj.weight,
                v_proj_weight=self.v_proj.weight,
            )

        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
            if saved_state is not None and "prev_key" in saved_state:
                # previous time steps are cached - no need to recompute
                # key and value if they are static
                if static_kv:
                    assert self.encoder_decoder_attention and not self.self_attention
                    key = value = None
        else:
            saved_state = None

        if self.self_attention:
            q = self.q_proj(query)
            k = self.k_proj(query)
            v = self.v_proj(query)
        elif self.encoder_decoder_attention:
            # encoder-decoder attention
            q = self.q_proj(query)
            if key is None:
                assert value is None
                k = v = None
            else:
                k = self.k_proj(key)
                v = self.v_proj(key)

        else:
            assert key is not None and value is not None
            q = self.q_proj(query)
            k = self.k_proj(key)
            v = self.v_proj(value)
        q *= self.scaling

        if self.bias_k is not None:
            assert self.bias_v is not None
            k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat(
                    [
                        key_padding_mask,
                        key_padding_mask.new_zeros(key_padding_mask.size(0),
                                                   1),
                    ],
                    dim=1,
                )

        q = (q.contiguous().view(tgt_len, bsz * self.num_heads,
                                 self.head_dim).transpose(0, 1))
        flag_cascade = False
        if k is not None:
            kbsz = k.size(1)
            if kbsz != bsz:
                flag_cascade = True
                assert tgt_len == 1, tgt_len
                assert bsz % kbsz == 0, (kbsz, bsz)
                k = (
                    k.contiguous().view(-1, kbsz * self.num_heads,
                                        self.head_dim).transpose(
                                            0, 1)  # kbsz*num_heads, l, H
                )
            else:
                k = (k.contiguous().view(-1, bsz * self.num_heads,
                                         self.head_dim).transpose(0, 1))
        if v is not None:
            if flag_cascade:
                assert v.size(1) == kbsz, (v.size(), kbsz, bsz)
                v = (
                    v.contiguous().view(-1, kbsz * self.num_heads,
                                        self.head_dim).transpose(
                                            0, 1)  # kbsz*num_heads, l, H
                )
            else:
                v = (v.contiguous().view(-1, bsz * self.num_heads,
                                         self.head_dim).transpose(0, 1))

        if saved_state is not None:
            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
            if "prev_key" in saved_state:
                _prev_key = saved_state["prev_key"]
                assert _prev_key is not None
                kbsz = _prev_key.size(0)
                if kbsz != bsz:
                    flag_cascade = True
                if flag_cascade:
                    prev_key = _prev_key.view(kbsz * self.num_heads, -1,
                                              self.head_dim)
                else:
                    #if bsz * self.num_heads == 2560 and not self.self_attention:
                    #    import pdb; pdb.set_trace()
                    prev_key = _prev_key.view(bsz * self.num_heads, -1,
                                              self.head_dim)
                if static_kv:
                    k = prev_key
                else:
                    assert k is not None
                    k = torch.cat([prev_key, k], dim=1)
            if "prev_value" in saved_state:
                _prev_value = saved_state["prev_value"]
                assert _prev_value is not None
                if flag_cascade:
                    prev_value = _prev_value.view(kbsz * self.num_heads, -1,
                                                  self.head_dim)
                else:
                    prev_value = _prev_value.view(bsz * self.num_heads, -1,
                                                  self.head_dim)
                if static_kv:
                    v = prev_value
                else:
                    assert v is not None
                    v = torch.cat([prev_value, v], dim=1)
            prev_key_padding_mask: Optional[Tensor] = None
            if "prev_key_padding_mask" in saved_state:
                prev_key_padding_mask = saved_state["prev_key_padding_mask"]
            assert k is not None and v is not None
            if flag_cascade:
                key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
                    key_padding_mask=key_padding_mask,
                    prev_key_padding_mask=prev_key_padding_mask,
                    batch_size=kbsz,
                    src_len=k.size(1),
                    static_kv=static_kv,
                )
            else:
                key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
                    key_padding_mask=key_padding_mask,
                    prev_key_padding_mask=prev_key_padding_mask,
                    batch_size=bsz,
                    src_len=k.size(1),
                    static_kv=static_kv,
                )

            if flag_cascade:
                saved_state["prev_key"] = k.view(kbsz, self.num_heads, -1,
                                                 self.head_dim)
                saved_state["prev_value"] = v.view(kbsz, self.num_heads, -1,
                                                   self.head_dim)
            else:
                saved_state["prev_key"] = k.view(bsz, self.num_heads, -1,
                                                 self.head_dim)
                saved_state["prev_value"] = v.view(bsz, self.num_heads, -1,
                                                   self.head_dim)
            saved_state["prev_key_padding_mask"] = key_padding_mask
            # In this branch incremental_state is never None
            assert incremental_state is not None
            incremental_state = self._set_input_buffer(incremental_state,
                                                       saved_state)
        assert k is not None
        src_len = k.size(1)
        if flag_cascade:
            q = q.contiguous().view(
                kbsz, -1, self.num_heads,
                self.head_dim).transpose(1, 2).contiguous().view(
                    kbsz * self.num_heads, -1,
                    self.head_dim)  # kbsz*num_heads, f, H

        # This is part of a workaround to get around fork/join parallelism
        # not supporting Optional types.
        if key_padding_mask is not None and key_padding_mask.dim() == 0:
            key_padding_mask = None

        if key_padding_mask is not None:
            if flag_cascade:
                assert key_padding_mask.size(0) == kbsz
            else:
                assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len

        if self.add_zero_attn:
            assert v is not None
            src_len += 1
            k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])],
                          dim=1)
            v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])],
                          dim=1)
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat(
                    [
                        key_padding_mask,
                        torch.zeros(key_padding_mask.size(0),
                                    1).type_as(key_padding_mask),
                    ],
                    dim=1,
                )

        attn_weights = torch.bmm(q, k.transpose(1, 2))
        attn_weights = MultiheadAttention.apply_sparse_mask(
            attn_weights, tgt_len, src_len, bsz)

        if flag_cascade:
            assert list(attn_weights.size()) == [
                kbsz * self.num_heads, bsz // kbsz, src_len
            ]
        else:
            assert list(attn_weights.size()) == [
                bsz * self.num_heads, tgt_len, src_len
            ]

        if attn_mask is not None:
            assert not flag_cascade
            attn_mask = attn_mask.unsqueeze(0)
            if self.self_attention:
                #import pdb; pdb.set_trace()
                if is_cascade:  # TODO: fully batching in batch dimension
                    import pdb
                    pdb.set_trace()
                    NGRAM = ngram + 1  # validation
                    attn_mask = attn_mask.repeat(bsz, 1, 1)
                    assert attn_mask.size(-1) == attn_mask.size(
                        -2), attn_mask.size()
                    x, y = torch.meshgrid(
                        torch.arange(attn_mask.size(-1)).to(attn_mask.device),
                        torch.arange(attn_mask.size(-1)).to(attn_mask.device))
                    x = x.unsqueeze(0).expand(bsz, -1, -1).contiguous()
                    y = y.unsqueeze(0).expand(bsz, -1, -1).contiguous()

                    x_block_id = (x + NGRAM) // NGRAM
                    #x_id = (x + NGRAM).fmod(NGRAM)
                    y_block_id = (y + NGRAM) // NGRAM
                    #y_id = (y  + NGRAM).fmod(NGRAM)
                    attn_mask[x_block_id.ne(y_block_id)] = -float('inf')
                    #attn_mask[:, :, 0] = 0
                    attn_mask = attn_mask.unsqueeze(1).expand(
                        -1, self.num_heads, -1, -1).contiguous().view(
                            -1, attn_mask.size(-1),
                            attn_mask.size(-1))  # bsz, 1, 16, 16
                elif self.training:
                    NGRAM = 5
                    attn_mask = attn_mask.repeat(bsz, 1, 1)
                    assert attn_mask.size(-1) == attn_mask.size(
                        -2), attn_mask.size()
                    x, y = torch.meshgrid(
                        torch.arange(attn_mask.size(-1)).to(attn_mask.device),
                        torch.arange(attn_mask.size(-1)).to(attn_mask.device))
                    x = x.unsqueeze(0).expand(bsz, -1, -1).contiguous()
                    y = y.unsqueeze(0).expand(bsz, -1, -1).contiguous()
                    assert offset is not None, offset

                    x_block_id = (x - offset + NGRAM) // NGRAM
                    #x_id = (x - offset + NGRAM).fmod(NGRAM)
                    y_block_id = (y - offset + NGRAM) // NGRAM
                    #y_id = (y - offset + NGRAM).fmod(NGRAM)
                    attn_mask[x_block_id.ne(y_block_id)] = -float('inf')
                    #attn_mask[:, :, 0] = 0
                    attn_mask = attn_mask.unsqueeze(1).expand(
                        -1, self.num_heads, -1, -1).contiguous().view(
                            -1, attn_mask.size(-1),
                            attn_mask.size(-1))  # bsz, 1, 16, 16
                elif not is_translate:
                    NGRAM = ngrams  # validation TODO: use self.ngrams
                    attn_mask = attn_mask.repeat(bsz, 1, 1)
                    assert attn_mask.size(-1) == attn_mask.size(
                        -2), attn_mask.size()
                    x, y = torch.meshgrid(
                        torch.arange(attn_mask.size(-1)).to(attn_mask.device),
                        torch.arange(attn_mask.size(-1)).to(attn_mask.device))
                    x = x.unsqueeze(0).expand(bsz, -1, -1).contiguous()
                    y = y.unsqueeze(0).expand(bsz, -1, -1).contiguous()

                    x_block_id = (x + NGRAM) // NGRAM
                    #x_id = (x + NGRAM).fmod(NGRAM)
                    y_block_id = (y + NGRAM) // NGRAM
                    #y_id = (y  + NGRAM).fmod(NGRAM)
                    attn_mask[x_block_id.ne(y_block_id)] = -float('inf')
                    #attn_mask[:, :, 0] = 0
                    attn_mask = attn_mask.unsqueeze(1).expand(
                        -1, self.num_heads, -1, -1).contiguous().view(
                            -1, attn_mask.size(-1),
                            attn_mask.size(-1))  # bsz, 1, 16, 16
                else:
                    NGRAM = ngrams  # translation
                    attn_mask = attn_mask.repeat(bsz, 1, 1)
                    attn_mask.fill_(-float('inf'))
                    i = 1
                    while i <= tgt_len and i <= NGRAM:
                        if i == 1:
                            attn_mask[:, -i, -min(NGRAM, tgt_len):] = 0
                        else:
                            if i <= min(NGRAM, tgt_len):
                                attn_mask[:, -i,
                                          -min(NGRAM, tgt_len):(-(i - 1))] = 0
                        i += 1
                    #attn_mask[:, :, 0] = 0
                    attn_mask = attn_mask.unsqueeze(1).expand(
                        -1, self.num_heads, -1, -1).contiguous().view(
                            -1, attn_mask.size(-1),
                            attn_mask.size(-1))  # bsz, 1, 16, 16
                #import pdb; pdb.set_trace()
                #import pdb; pdb.set_trace()

            if self.onnx_trace:
                assert False, 'onnx'
                attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
            attn_weights[attn_mask.eq(-float('inf'))] = -float(
                'inf')  # += attn_mask

        if key_padding_mask is not None:
            # don't attend to padding symbols key_padding_mask: kbsz, 1, 1, src_len
            if flag_cascade:
                attn_weights = attn_weights.view(kbsz, self.num_heads,
                                                 bsz // kbsz, src_len)
                attn_weights = attn_weights.masked_fill(
                    key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
                    float("-inf"))
                attn_weights = attn_weights.view(kbsz * self.num_heads,
                                                 bsz // kbsz, src_len)
            else:
                attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
                                                 src_len)
                attn_weights = attn_weights.masked_fill(
                    key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
                    float("-inf"))
                attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len,
                                                 src_len)

        if before_softmax:
            return attn_weights, v

        attn_weights_float = utils.softmax(attn_weights,
                                           dim=-1,
                                           onnx_trace=self.onnx_trace)
        attn_weights = attn_weights_float.type_as(attn_weights)
        attn_probs = F.dropout(
            attn_weights_float.type_as(attn_weights),
            p=self.dropout,
            training=self.training,
        )
        assert v is not None
        v[v.ne(v)] = 0.
        attn = torch.bmm(attn_probs, v)  # kbsz*num_heads, bsz//kbsz, src_len
        #prev_value = _prev_value.view(kbsz*self.num_heads, src_len, self.head_dim)
        if flag_cascade:
            attn = attn.contiguous().view(
                kbsz, self.num_heads, bsz // kbsz, self.head_dim).transpose(
                    1, 2).contiguous().view(bsz * self.num_heads, 1,
                                            self.head_dim)
        assert list(
            attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
        if self.onnx_trace and attn.size(1) == 1:
            # when ONNX tracing a single decoder step (sequence length == 1)
            # the transpose is a no-op copy before view, thus unnecessary
            attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
        else:
            attn = attn.transpose(0,
                                  1).contiguous().view(tgt_len, bsz, embed_dim)
        attn = self.out_proj(attn)
        attn_weights: Optional[Tensor] = None
        if need_weights:
            attn_weights = attn_weights_float.view(bsz, self.num_heads,
                                                   tgt_len,
                                                   src_len).transpose(1, 0)
            if not need_head_weights:
                # average attention weights over heads
                attn_weights = attn_weights.mean(dim=0)

        return attn, attn_weights
Example #18
0
k_proj_weight = torch.randn(embed_size, embed_size)
v_proj_weight = torch.randn(embed_size, embed_size)

out_proj_weight = torch.randn(embed_size, embed_size)
out_proj_bias = torch.zeros(batch_size, embed_size)

multi_head_attention = MultiheadAttention(query.size(2),
                                          num_heads,
                                          dropout=0.0)
multi_head_attention.training = training
multi_head_attention.q_linear.weight = nn.Parameter(q_proj_weight.clone())
multi_head_attention.k_linear.weight = nn.Parameter(k_proj_weight.clone())
multi_head_attention.v_linear.weight = nn.Parameter(v_proj_weight.clone())

multi_head_attention.out_linear.weight = nn.Parameter(out_proj_weight.clone())
multi_head_attention.out_linear.bias = nn.Parameter(out_proj_bias.clone())

# for p, n in multi_head_attention.named_parameters():
#     print(p, n)

attn_output1, attn_output_weights1 = multi_head_attention.forward(
    query, key, value)

attn_output2, attn_output_weights2 = F.multi_head_attention_forward(
    query, key, value, query.size(2), num_heads, None, None, None, None, False,
    0.0, out_proj_weight, out_proj_bias, training, None, True, None, True,
    q_proj_weight, k_proj_weight, v_proj_weight)

assert torch.equal(attn_output1, attn_output2)
assert torch.equal(attn_output_weights1, attn_output_weights2)
Example #19
0
    def forward(
        self,
        query,
        key: Optional[Tensor],
        value: Optional[Tensor],
        key_padding_mask: Optional[Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str,
                                                   Optional[Tensor]]]] = None,
        need_weights: bool = True,
        static_kv: bool = False,
        attn_mask: Optional[Tensor] = None,
        before_softmax: bool = False,
        need_head_weights: bool = False,
    ) -> Tuple[Tensor, Optional[Tensor]]:
        """Input shape: Time x Batch x Channel

        Args:
            key_padding_mask (ByteTensor, optional): mask to exclude
                keys that are pads, of shape `(batch, src_len)`, where
                padding elements are indicated by 1s.
            need_weights (bool, optional): return the attention weights,
                averaged over heads (default: False).
            attn_mask (ByteTensor, optional): typically used to
                implement causal attention, where the mask prevents the
                attention from looking forward in time (default: None).
            before_softmax (bool, optional): return the raw attention
                weights and values before the attention softmax.
            need_head_weights (bool, optional): return the attention
                weights for each head. Implies *need_weights*. Default:
                return the average attention weights over all heads.
        """
        if need_head_weights:
            need_weights = True

        is_tpu = query.device.type == "xla"

        tgt_len, bsz, embed_dim = query.size()
        src_len = tgt_len
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]
        if key is not None:
            src_len, key_bsz, _ = key.size()
            if not torch.jit.is_scripting():
                assert key_bsz == bsz
                assert value is not None
                assert src_len, bsz == value.shape[:2]

        if (not self.onnx_trace
                and not is_tpu  # don't use PyTorch version on TPUs
                and incremental_state is None and not static_kv
                and not (self.normalized_attention
                         and self.encoder_decoder_attention)
                and not self.positional_embeddings_in_attention
                # A workaround for quantization to work. Otherwise JIT compilation
                # treats bias in linear module as method.
                and not torch.jit.is_scripting()):
            assert key is not None and value is not None
            return F.multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                torch.empty([0]),
                torch.cat(
                    (self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout_module.p,
                self.out_proj.weight,
                self.out_proj.bias,
                self.training or self.dropout_module.apply_during_inference,
                key_padding_mask,
                need_weights,
                attn_mask,
                use_separate_proj_weight=True,
                q_proj_weight=self.q_proj.weight,
                k_proj_weight=self.k_proj.weight,
                v_proj_weight=self.v_proj.weight,
            )

        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
            if saved_state is not None and "prev_key" in saved_state:
                # previous time steps are cached - no need to recompute
                # key and value if they are static
                if static_kv:
                    assert self.encoder_decoder_attention and not self.self_attention
                    key = value = None
        else:
            saved_state = None

        if self.self_attention:
            q = self.q_proj(query)
            k = self.k_proj(query)
            v = self.v_proj(query)
        elif self.encoder_decoder_attention:
            # encoder-decoder attention
            q = self.q_proj(query)
            if key is None:
                assert value is None
                k = v = None
            else:
                k = self.k_proj(key)
                v = self.v_proj(key)

        else:
            assert key is not None and value is not None
            q = self.q_proj(query)
            k = self.k_proj(key)
            v = self.v_proj(value)
        q *= self.scaling

        if self.bias_k is not None:
            assert self.bias_v is not None
            k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat(
                    [
                        key_padding_mask,
                        key_padding_mask.new_zeros(key_padding_mask.size(0),
                                                   1),
                    ],
                    dim=1,
                )

        q = (q.contiguous().view(tgt_len, bsz * self.num_heads,
                                 self.head_dim).transpose(0, 1))
        if k is not None:
            k = (k.contiguous().view(-1, bsz * self.num_heads,
                                     self.head_dim).transpose(0, 1))
        if v is not None:
            v = (v.contiguous().view(-1, bsz * self.num_heads,
                                     self.head_dim).transpose(0, 1))

        if saved_state is not None:
            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
            if "prev_key" in saved_state:
                _prev_key = saved_state["prev_key"]
                assert _prev_key is not None
                prev_key = _prev_key.view(bsz * self.num_heads, -1,
                                          self.head_dim)
                if static_kv:
                    k = prev_key
                else:
                    assert k is not None
                    k = torch.cat([prev_key, k], dim=1)
                src_len = k.size(1)
            if "prev_value" in saved_state:
                _prev_value = saved_state["prev_value"]
                assert _prev_value is not None
                prev_value = _prev_value.view(bsz * self.num_heads, -1,
                                              self.head_dim)
                if static_kv:
                    v = prev_value
                else:
                    assert v is not None
                    v = torch.cat([prev_value, v], dim=1)
            prev_key_padding_mask: Optional[Tensor] = None
            if "prev_key_padding_mask" in saved_state:
                prev_key_padding_mask = saved_state["prev_key_padding_mask"]
            assert k is not None and v is not None
            key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
                key_padding_mask=key_padding_mask,
                prev_key_padding_mask=prev_key_padding_mask,
                batch_size=bsz,
                src_len=k.size(1),
                static_kv=static_kv,
            )

            saved_state["prev_key"] = k.view(bsz, self.num_heads, -1,
                                             self.head_dim)
            saved_state["prev_value"] = v.view(bsz, self.num_heads, -1,
                                               self.head_dim)
            saved_state["prev_key_padding_mask"] = key_padding_mask
            # In this branch incremental_state is never None
            assert incremental_state is not None
            incremental_state = self._set_input_buffer(incremental_state,
                                                       saved_state)
        assert k is not None
        assert k.size(1) == src_len

        if self.positional_embeddings_in_attention:
            # todo: precompute
            pos_q_timestep = utils.get_incremental_state(
                self,
                incremental_state,
                'pos_q_timestep',
            )
            #print(incremental_state != None, pos_q_timestep.item() if pos_q_timestep is not None else None, end=" ")
            #if (incremental_state is not None) and (pos_q_timestep is None):
            if (pos_q_timestep is None):
                pos_q_timestep = torch.tensor(0,
                                              dtype=torch.int64,
                                              device=q.device)
            #print(pos_q_timestep.item() if pos_q_timestep is not None else None)

            pos_q = self.pos_q_proj(
                self.pos_embeddings(
                    q.new_ones([bsz, tgt_len]),
                    timestep=pos_q_timestep,
                    incremental_state=incremental_state)).transpose(
                        0, 1)  # tgt_len x bsz
            pos_k = self.pos_k_proj(
                self.pos_embeddings(q.new_ones([bsz, src_len]))).transpose(
                    0, 1)  # src_len x bsz
            pos_q *= self.scaling
            #if incremental_state is not None:
            #print(pos_q.shape, pos_k.shape)
            #print(saved_state.keys())
            #print(pos_q[0, 0, :10])
            #print("")
            pos_q = (pos_q.contiguous().view(tgt_len, bsz * self.num_heads,
                                             self.head_dim).transpose(0, 1))
            pos_k = (pos_k.contiguous().view(-1, bsz * self.num_heads,
                                             self.head_dim).transpose(0, 1))
            pos_attn_weights = torch.bmm(pos_q, pos_k.transpose(1, 2))
            pos_attn_weights = self.apply_sparse_mask(pos_attn_weights,
                                                      tgt_len, src_len, bsz)
            assert list(pos_attn_weights.size()) == [
                bsz * self.num_heads, tgt_len, src_len
            ]

            if True:  #if (incremental_state is not None):
                utils.set_incremental_state(
                    self,
                    incremental_state,
                    'pos_q_timestep',
                    pos_q_timestep + 1,
                )

        # This is part of a workaround to get around fork/join parallelism
        # not supporting Optional types.
        if key_padding_mask is not None and key_padding_mask.dim() == 0:
            key_padding_mask = None

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len

        if self.add_zero_attn:
            assert v is not None
            src_len += 1
            k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])],
                          dim=1)
            v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])],
                          dim=1)
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat(
                    [
                        key_padding_mask,
                        torch.zeros(key_padding_mask.size(0),
                                    1).type_as(key_padding_mask),
                    ],
                    dim=1,
                )

        attn_weights = torch.bmm(q, k.transpose(1, 2))
        attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len,
                                              bsz)

        assert list(
            attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]

        if self.positional_embeddings_in_attention:
            attn_weights = attn_weights + pos_attn_weights

        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(0)
            if self.onnx_trace:
                attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
            attn_weights += attn_mask

        if key_padding_mask is not None:
            # don't attend to padding symbols
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
                                             src_len)
            if not is_tpu:
                attn_weights = attn_weights.masked_fill(
                    key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
                    float("-inf"),
                )
            else:
                attn_weights = attn_weights.transpose(0, 2)
                attn_weights = attn_weights.masked_fill(
                    key_padding_mask, float("-inf"))
                attn_weights = attn_weights.transpose(0, 2)
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len,
                                             src_len)

        # here attn_weights with encoder_decoder_attention is (Batch * Heads) x QueryT x KeyT
        if self.normalized_attention and self.encoder_decoder_attention:
            # compute the standard deviation of the attention on the key time dimension while ignoring masked elements
            attn_final_mask = torch.isinf(attn_weights)
            attn_final_neg_mask = torch.logical_not(attn_final_mask)
            if self.normalized_attention_logsoftmax:
                attn_probs = utils.softmax(attn_weights,
                                           dim=-1,
                                           onnx_trace=self.onnx_trace)
                attn_weights = torch.log(attn_probs + 1e-7)
            attn_weights_zero_masked = attn_weights.masked_fill(
                attn_final_mask, 0.0)
            attn_denom = attn_final_neg_mask.sum(dim=-1, keepdim=True) + 1e-05
            if not self.normalized_attention_by_entropy:
                attn_weight_mean = attn_weights_zero_masked.sum(
                    dim=-1,
                    keepdim=True) / attn_denom  # (Batch * Heads) x QueryT x 1
                attn_weight_centered_squares_zero_masked = torch.square(
                    attn_weights_zero_masked - attn_weight_mean).masked_fill(
                        attn_final_mask, 0.0)
                attn_weight_std = torch.sqrt(
                    attn_weight_centered_squares_zero_masked.sum(
                        dim=-1, keepdim=True) /
                    attn_denom)  # (Batch * Heads) x QueryT x 1
                attn_weight_scale = attn_weight_std
            elif self.normalized_attention_by_entropy:
                attn_entropy = (attn_probs * attn_weights_zero_masked).sum(
                    dim=-1, keepdim=True)  # (Batch * Heads) x QueryT x 1
                attn_weight_scale = attn_entropy
            else:
                assert (False)

            # assume unmasked
            #attn_weight_std = attn_weights.std(dim=-1, keepdim=True)
            #attn_weight_std = torch.ones_like(attn_weights)

            # compute gain
            attn_gain = self.attention_gain(query)  # QueryT x Batch x Heads
            #attn_gain = F.sigmoid(attn_gain)
            attn_gain = attn_gain.view(
                tgt_len, bsz * self.num_heads).transpose(0, 1).unsqueeze(
                    -1)  # (Batch * Heads) x QueryT x 1
            #attn_gain = torch.ones_like(attn_weights)

            # rescale
            attn_weights = attn_weights.masked_fill(attn_final_mask, 0.0)
            attn_weights = attn_weights / (attn_weight_scale + 1e-05)
            attn_weights = attn_weights * attn_gain
            attn_weights = attn_weights.masked_fill(attn_final_mask,
                                                    float("-inf"))
            #print("attn_final_neg_mask: %s, attn_weight_std: %s, attn_gain: %s, attn_weights: %s" % (attn_final_neg_mask.shape, attn_weight_std.shape, attn_gain.shape, attn_weights.shape))
            #print("sum: %s, attn_weight_std: %s, attn_gain: %s" % (attn_weights.sum().item(), attn_weight_std.mean().item(), attn_gain.mean().item()))

        if before_softmax:
            return attn_weights, v

        attn_weights_float = utils.softmax(attn_weights,
                                           dim=-1,
                                           onnx_trace=self.onnx_trace)
        attn_weights = attn_weights_float.type_as(attn_weights)
        attn_probs = self.dropout_module(attn_weights)

        assert v is not None
        attn = torch.bmm(attn_probs, v)
        assert list(
            attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
        if self.onnx_trace and attn.size(1) == 1:
            # when ONNX tracing a single decoder step (sequence length == 1)
            # the transpose is a no-op copy before view, thus unnecessary
            attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
        else:
            attn = attn.transpose(0,
                                  1).contiguous().view(tgt_len, bsz, embed_dim)
        attn = self.out_proj(attn)
        attn_weights: Optional[Tensor] = None
        if need_weights:
            attn_weights = attn_weights_float.view(bsz, self.num_heads,
                                                   tgt_len,
                                                   src_len).transpose(1, 0)
            if not need_head_weights:
                # average attention weights over heads
                attn_weights = attn_weights.mean(dim=0)

        return attn, attn_weights
Example #20
0
    def forward(self,
                query,
                key,
                value,
                key_padding_mask=None,
                incremental_state=None,
                need_weights=True,
                static_kv=False,
                attn_mask=None):
        """Input shape: Time x Batch x Channel

        Timesteps can be masked by supplying a T x T mask in the
        `attn_mask` argument. Padding elements can be excluded from
        the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
        batch x src_len, where padding elements are indicated by 1s.
        """
        tgt_len, bsz, embed_dim = query.size()
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]

        if self.enable_torch_version and not self.onnx_trace and incremental_state is None and not static_kv:
            if self.qkv_same_dim:
                return F.multi_head_attention_forward(
                    query, key, value, self.embed_dim, self.num_heads,
                    self.in_proj_weight, self.in_proj_bias, self.bias_k,
                    self.bias_v, self.add_zero_attn, self.dropout,
                    self.out_proj.weight, self.out_proj.bias, self.training,
                    key_padding_mask, need_weights, attn_mask)
            else:
                return F.multi_head_attention_forward(
                    query,
                    key,
                    value,
                    self.embed_dim,
                    self.num_heads,
                    torch.empty([0]),
                    self.in_proj_bias,
                    self.bias_k,
                    self.bias_v,
                    self.add_zero_attn,
                    self.dropout,
                    self.out_proj.weight,
                    self.out_proj.bias,
                    self.training,
                    key_padding_mask,
                    need_weights,
                    attn_mask,
                    use_separate_proj_weight=True,
                    q_proj_weight=self.q_proj_weight,
                    k_proj_weight=self.k_proj_weight,
                    v_proj_weight=self.v_proj_weight)

        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
            if 'prev_key' in saved_state:
                # previous time steps are cached - no need to recompute
                # key and value if they are static
                if static_kv:
                    assert self.encoder_decoder_attention and not self.self_attention
                    key = value = None
        else:
            saved_state = None

        if self.self_attention:
            # self-attention
            q, k, v = self.in_proj_qkv(query)
        elif self.encoder_decoder_attention:
            # encoder-decoder attention
            q = self.in_proj_q(query)
            if key is None:
                assert value is None
                k = v = None
            else:
                k = self.in_proj_k(key)
                v = self.in_proj_v(key)

        else:
            q = self.in_proj_q(query)
            k = self.in_proj_k(key)
            v = self.in_proj_v(value)
        q *= self.scaling

        if self.bias_k is not None:
            assert self.bias_v is not None
            k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat([
                    key_padding_mask,
                    key_padding_mask.new_zeros(key_padding_mask.size(0), 1)
                ],
                                             dim=1)

        q = q.contiguous().view(tgt_len, bsz * self.num_heads,
                                self.head_dim).transpose(0, 1)
        if k is not None:
            k = k.contiguous().view(-1, bsz * self.num_heads,
                                    self.head_dim).transpose(0, 1)
        if v is not None:
            v = v.contiguous().view(-1, bsz * self.num_heads,
                                    self.head_dim).transpose(0, 1)

        if saved_state is not None:
            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
            if 'prev_key' in saved_state:
                prev_key = saved_state['prev_key'].view(
                    bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    k = prev_key
                else:
                    k = torch.cat((prev_key, k), dim=1)
            if 'prev_value' in saved_state:
                prev_value = saved_state['prev_value'].view(
                    bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    v = prev_value
                else:
                    v = torch.cat((prev_value, v), dim=1)
            saved_state['prev_key'] = k.view(bsz, self.num_heads, -1,
                                             self.head_dim)
            saved_state['prev_value'] = v.view(bsz, self.num_heads, -1,
                                               self.head_dim)

            self._set_input_buffer(incremental_state, saved_state)

        src_len = k.size(1)

        # This is part of a workaround to get around fork/join parallelism
        # not supporting Optional types.
        if key_padding_mask is not None and key_padding_mask.shape == torch.Size(
            []):
            key_padding_mask = None

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len

        if self.add_zero_attn:
            src_len += 1
            k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])],
                          dim=1)
            v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])],
                          dim=1)
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat([
                    key_padding_mask,
                    torch.zeros(key_padding_mask.size(0),
                                1).type_as(key_padding_mask)
                ],
                                             dim=1)

        attn_weights = torch.bmm(q, k.transpose(1, 2))
        attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len,
                                              bsz)

        assert list(
            attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]

        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(0)
            if self.onnx_trace:
                attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
            attn_weights += attn_mask

        if key_padding_mask is not None:
            # don't attend to padding symbols
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
                                             src_len)
            if self.onnx_trace:
                attn_weights = torch.where(
                    key_padding_mask.unsqueeze(1).unsqueeze(2),
                    torch.Tensor([float("-Inf")]),
                    attn_weights.float()).type_as(attn_weights)
            else:
                attn_weights = attn_weights.masked_fill(
                    key_padding_mask.unsqueeze(1).unsqueeze(2),
                    float('-inf'),
                )
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len,
                                             src_len)

        # topk attn_weights

        attn_weights = utils.softmax(
            attn_weights,
            dim=-1,
            onnx_trace=self.onnx_trace,
        ).type_as(attn_weights)
        attn_weights = F.dropout(attn_weights,
                                 p=self.dropout,
                                 training=self.training)

        attn = torch.bmm(attn_weights, v)
        assert list(
            attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
        if (self.onnx_trace and attn.size(1) == 1):
            # when ONNX tracing a single decoder step (sequence length == 1)
            # the transpose is a no-op copy before view, thus unnecessary
            attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
        else:
            attn = attn.transpose(0,
                                  1).contiguous().view(tgt_len, bsz, embed_dim)
        attn = self.out_proj(attn)

        if need_weights:
            # average attention weights over heads
            # attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            # attn_weights = attn_weights.sum(dim=1) / self.num_heads
            # learn from open-nmt, we use one attention
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
                                             src_len)[:, 0, :, :].contiguous()
        else:
            attn_weights = None
        return attn, attn_weights
in_proj = sparse_attn.qkv_in_proj.weight
in_proj_bias = sparse_attn.qkv_in_proj.bias
out_proj = sparse_attn.out_proj.weight
out_proj_bias = sparse_attn.out_proj.bias

ns = node_states.unsqueeze(1)
da_node_states, da_weights = F.multi_head_attention_forward(
    ns,
    ns,
    ns,
    200,
    8,
    in_proj,
    in_proj_bias,
    None,
    None,
    False,
    0.0,
    out_proj,
    out_proj_bias,
    training=False,
    key_padding_mask=None,
    need_weights=True,
    attn_mask=attn_mask)
da_sum = torch.sum(da_node_states)
da_weights = da_weights.squeeze()

print(da_sum)
# %%
print(da_weights.t())
Example #22
0
    def forward(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        key_padding_mask: Optional[Tensor] = None,
        need_weights: bool = True,
        attn_mask: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Optional[Tensor]]:
        r"""
        Args:
            query, key, value: map a query and a set of key-value pairs to an output.
                See "Attention Is All You Need" for more details.
            key_padding_mask: if provided, specified padding elements in the key will
                be ignored by the attention. When given a binary mask and a value is True,
                the corresponding value on the attention layer will be ignored. When given
                a byte mask and a value is non-zero, the corresponding value on the attention
                layer will be ignored
            need_weights: output attn_output_weights.
            attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
                the batches while a 3D mask allows to specify a different mask for the entries of each batch.

        Shapes for inputs:
            - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
              the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
            - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
              the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
            - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
              the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
            - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
              If a ByteTensor is provided, the non-zero positions will be ignored while the position
              with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
              value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
            - attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the
              source sequence length.

              If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence
              length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend
              the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
              while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
              is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
              is provided, it will be added to the attention weight.

        Shapes for outputs:
            - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
              E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
            - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
              L is the target sequence length, S is the source sequence length.
        """
        if self.batch_first:
            query, key, value = [
                x.transpose(1, 0) for x in (query, key, value)
            ]

        if not self._qkv_same_embed_dim:
            attn_output, attn_output_weights = F.multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                self.in_proj_weight,
                self.in_proj_bias,
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout,
                self.out_proj.weight,
                self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask,
                need_weights=need_weights,
                attn_mask=attn_mask,
                use_separate_proj_weight=True,
                q_proj_weight=self.q_proj_weight,
                k_proj_weight=self.k_proj_weight,
                v_proj_weight=self.v_proj_weight,
            )
        else:
            attn_output, attn_output_weights = F.multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                self.in_proj_weight,
                self.in_proj_bias,
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout,
                self.out_proj.weight,
                self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask,
                need_weights=need_weights,
                attn_mask=attn_mask,
            )

        # (N, L, S) -> (N, S/L)
        scores = torch.mean(attn_output_weights, dim=1)
        pruning_mask = F.sigmoid(
            (scores - self.soft_threshold) / self.temperature)
        attn_output = attn_output.transpose(1, 0)
        attn_output = pruning_mask[:, :, None] * attn_output

        if self.batch_first:
            return (
                attn_output,
                attn_output_weights,
                torch.sum(torch.norm(pruning_mask, p=1) / self.num_heads),
            )
        else:
            return (
                attn_output.transpose(1, 0),
                attn_output_weights,
                torch.sum(torch.norm(pruning_mask, p=1) / self.num_heads),
            )
Example #23
0
    def forward(
        self,
        query,
        key: Optional[Tensor],
        value: Optional[Tensor],
        key_padding_mask: Optional[Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str,
                                                   Optional[Tensor]]]] = None,
        need_weights: bool = True,
        static_kv: bool = False,
        attn_mask: Optional[Tensor] = None,
        before_softmax: bool = False,
        need_head_weights: bool = False,
        calc_head_importance=False
    ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
        """Input shape: Time x Batch x Channel

        Args:
            key_padding_mask (ByteTensor, optional): mask to exclude
                keys that are pads, of shape `(batch, src_len)`, where
                padding elements are indicated by 1s.
            need_weights (bool, optional): return the attention weights,
                averaged over heads (default: False).
            attn_mask (ByteTensor, optional): typically used to
                implement causal attention, where the mask prevents the
                attention from looking forward in time (default: None).
            before_softmax (bool, optional): return the raw attention
                weights and values before the attention softmax.
            need_head_weights (bool, optional): return the attention
                weights for each head. Implies *need_weights*. Default:
                return the average attention weights over all heads.
        """

        if need_head_weights:
            need_weights = True

        tgt_len, bsz, embed_dim = query.size()
        self.bsz = bsz
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]

        if (self.enable_torch_version and not self.onnx_trace
                and incremental_state is None and not static_kv
                # A workaround for quantization to work. Otherwise JIT compilation
                # treats bias in linear module as method.
                and not torch.jit.is_scripting()):
            assert key is not None and value is not None
            return F.multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                torch.empty([0]),
                torch.cat(
                    (self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout,
                self.out_proj.weight,
                self.out_proj.bias,
                self.training,
                key_padding_mask,
                need_weights,
                attn_mask,
                use_separate_proj_weight=True,
                q_proj_weight=self.q_proj.weight,
                k_proj_weight=self.k_proj.weight,
                v_proj_weight=self.v_proj.weight,
            )

        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
            if saved_state is not None and "prev_key" in saved_state:
                # previous time steps are cached - no need to recompute
                # key and value if they are static
                if static_kv:
                    assert self.encoder_decoder_attention and not self.self_attention
                    key = value = None
        else:
            saved_state = None

        if self.self_attention:
            q = self.q_proj(query)
            k = self.k_proj(query)
            v = self.v_proj(query)
        elif self.encoder_decoder_attention:
            # encoder-decoder attention
            q = self.q_proj(query)
            if key is None:
                assert value is None
                k = v = None
            else:
                k = self.k_proj(key)
                v = self.v_proj(key)
        else:

            assert key is not None and value is not None
            q = self.q_proj(query)
            k = self.k_proj(key)
            v = self.v_proj(value)
        q *= self.scaling

        if self.bias_k is not None:
            assert self.bias_v is not None
            k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat(
                    [
                        key_padding_mask,
                        key_padding_mask.new_zeros(key_padding_mask.size(0),
                                                   1),
                    ],
                    dim=1,
                )
        q = (q.contiguous().view(tgt_len, bsz * self.num_heads,
                                 self.head_dim).transpose(0, 1))
        if k is not None:
            k = (k.contiguous().view(-1, bsz * self.num_heads,
                                     self.head_dim).transpose(0, 1))
        if v is not None:
            v = (v.contiguous().view(-1, bsz * self.num_heads,
                                     self.head_dim).transpose(0, 1))
        if saved_state is not None:
            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
            if "prev_key" in saved_state:
                _prev_key = saved_state["prev_key"]
                assert _prev_key is not None
                prev_key = _prev_key.view(bsz * self.num_heads, -1,
                                          self.head_dim)
                if static_kv:
                    k = prev_key
                else:
                    assert k is not None
                    k = torch.cat([prev_key, k], dim=1)
            if "prev_value" in saved_state:
                _prev_value = saved_state["prev_value"]
                assert _prev_value is not None
                prev_value = _prev_value.view(bsz * self.num_heads, -1,
                                              self.head_dim)
                if static_kv:
                    v = prev_value
                else:
                    assert v is not None
                    v = torch.cat([prev_value, v], dim=1)
            prev_key_padding_mask: Optional[Tensor] = None
            if "prev_key_padding_mask" in saved_state:
                prev_key_padding_mask = saved_state["prev_key_padding_mask"]
            assert k is not None and v is not None
            key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
                key_padding_mask=key_padding_mask,
                prev_key_padding_mask=prev_key_padding_mask,
                batch_size=bsz,
                src_len=k.size(1),
                static_kv=static_kv,
            )

            saved_state["prev_key"] = k.view(bsz, self.num_heads, -1,
                                             self.head_dim)
            saved_state["prev_value"] = v.view(bsz, self.num_heads, -1,
                                               self.head_dim)
            saved_state["prev_key_padding_mask"] = key_padding_mask
            # In this branch incremental_state is never None
            assert incremental_state is not None
            incremental_state = self._set_input_buffer(incremental_state,
                                                       saved_state)

        assert k is not None
        src_len = k.size(1)

        # This is part of a workaround to get around fork/join parallelism
        # not supporting Optional types.
        if key_padding_mask is not None and key_padding_mask.dim() == 0:
            key_padding_mask = None

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len

        if self.add_zero_attn:
            assert v is not None
            src_len += 1
            k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])],
                          dim=1)
            v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])],
                          dim=1)
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat(
                    [
                        key_padding_mask,
                        torch.zeros(key_padding_mask.size(0),
                                    1).type_as(key_padding_mask),
                    ],
                    dim=1,
                )

        attn_weights = torch.bmm(q, k.transpose(1, 2))
        attn_weights = MultiheadAttention.apply_sparse_mask(
            attn_weights, tgt_len, src_len, bsz)

        assert list(
            attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(0)
            if self.onnx_trace:
                attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
            attn_weights += attn_mask

        if key_padding_mask is not None:
            # don't attend to padding symbols
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
                                             src_len)
            attn_weights = attn_weights.masked_fill(
                key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
                float("-inf"))
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len,
                                             src_len)
        if before_softmax:
            return attn_weights, v

        attn_weights_float = utils.softmax(attn_weights,
                                           dim=-1,
                                           onnx_trace=self.onnx_trace)

        if self.mask_head is not None:
            head_masking_vector = torch.ones(self.num_heads)
            head_masking_vector[self.mask_head] = 0
            head_masking_vector = head_masking_vector.view(
                1, self.num_heads, 1, 1).to(attn_weights_float.device)
            attn_weights_float = attn_weights_float.view(
                self.num_heads, bsz, tgt_len, src_len)

            attn_weights_float = attn_weights_float.view(
                bsz, self.num_heads, tgt_len, src_len) * head_masking_vector
            attn_weights_float = attn_weights_float.view(
                bsz * self.num_heads, tgt_len, src_len)
        attn_weights = attn_weights_float.type_as(attn_weights)

        save_attn_for_guy = attn_weights.clone().detach().contiguous().view(
            bsz, self.num_heads, tgt_len, src_len)

        conf = None

        ## computing confidence of all heads over bsz sentences

        ## heads is an np array of shape [head_nums+1] which contains confidence*bsz for each head and bsz:
        ## [conf_h_1*bsz,conf_h_2*bsz,...,conf_h_n*bsz,bsz]
        # Viota's confidence is based on:
        # Word attn confidence is an upgraded more delicate version of conf,
        # where
        if self.head_confidence_method is not None:

            if attn_weights is not None:
                if self.head_confidence_method == "base":
                    a = attn_weights.clone().view(bsz, self.num_heads, tgt_len,
                                                  src_len).transpose(1, 0)
                    a[:, :, -1, -1] = torch.zeros((self.num_heads, bsz))
                    heads = a[:, :, :, :].max(dim=3)
                    heads = heads[0].max(dim=2)
                    heads = heads[0].sum(dim=1) / bsz
                elif self.head_confidence_method == "advanced":
                    a = attn_weights.clone().view(bsz, self.num_heads, tgt_len,
                                                  src_len).transpose(1, 0)
                    a[:, :, -1, -1] = torch.zeros((self.num_heads, bsz))
                    heads = a[:, :, :, :].max(dim=2)
                    heads = heads[0].sum(dim=2) / (src_len - 1)
                    heads = heads.sum(dim=1) / bsz
                    heads = heads
                elif self.head_confidence_method == "pairwise":
                    a = attn_weights.clone().contiguous().view(
                        bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
                    a[:, :, -1, -1] = torch.zeros((self.num_heads, bsz))
                    a = a.contiguous().view(self.num_heads, bsz,
                                            tgt_len * src_len)
                    c_2 = torch.cdist(a.transpose(0, 1).contiguous(),
                                      a.transpose(0, 1).contiguous(),
                                      p=2)
                    c_2 = c_2.sum(dim=0) / bsz
                    c_2 = c_2.sum(dim=0)
                    heads = c_2
                elif self.head_confidence_method == "wasserstein":
                    a = attn_weights.clone().view(bsz, self.num_heads, tgt_len,
                                                  src_len).transpose(1, 0)
                    uniform_heads = torch.zeros(self.num_heads, bsz, tgt_len,
                                                src_len)
                    uniform_heads[:, :, :-1, :-1] = 1 / src_len
                    distances = np.zeros(self.num_heads)
                    for head in range(self.num_heads):
                        for batch in range(bsz):
                            for line in range(tgt_len - 1):
                                distances[head] += wasserstein_distance(
                                    a[head, batch, line],
                                    uniform_heads[head, batch, line])
                            distances[head] /= (tgt_len - 1)
                        distances[head] /= bsz

            # Take max for each source word, than average all
            # for j in range(self.num_heads):
            #    conf_temp = 0
            #    for batch in range(bsz):
            #        word_attn_sum = 0
            #        for tgt in range(tgt_len - 1):
            #            word_attn_sum += attn_weights.view(self.num_heads, bsz, tgt_len, src_len)[j, batch, tgt,
            #                             :-1].max()
            #        conf_temp += word_attn_sum / (tgt_len - 1)
            #    word_max["heads"].append(conf_temp)
            conf = heads

        self.head_conf = conf

        attn_probs = F.dropout(
            attn_weights_float.type_as(attn_weights),
            p=self.dropout,
            training=self.training,
        )

        assert v is not None

        ctx = torch.bmm(attn_probs,
                        v)  # Thats what I called 'Z' in my summary.
        save_ctx = ctx.view(bsz, self.num_heads, tgt_len, self.head_dim)

        ctx = save_ctx.view(bsz * self.num_heads, tgt_len, self.head_dim)

        z = ctx.contiguous().view(bsz, self.num_heads, tgt_len,
                                  self.head_dim).transpose(0, 1)

        b = z.contiguous().view(self.num_heads, tgt_len * bsz * self.head_dim)

        # pdist cosine sim
        test_cos = save_ctx.contiguous().view(bsz, self.num_heads,
                                              tgt_len * self.head_dim)
        test_cos = test_cos.permute((1, 2, 0))
        cos_sim_pairwise = F.cosine_similarity(test_cos,
                                               test_cos.unsqueeze(1),
                                               dim=-2)
        cos_sim_pairwise = cos_sim_pairwise.permute((2, 0, 1))

        cos_sim_pairwise += 1.0

        self.cosine_similarity_matrix = cos_sim_pairwise

        #cos_sim_pairwise = torch.sum(cos_sim_pairwise, axis=0)/bsz #used for mean on bsz

        cos_sim_pairwise = torch.flatten(cos_sim_pairwise)

        cos_sim_sum = (torch.sum(cos_sim_pairwise)) / (self.num_heads ^ 2)

        self.cosine_similarity_total = cos_sim_sum

        # pdist l2

        test_l2 = save_ctx.contiguous().view(bsz, self.num_heads,
                                             tgt_len * self.head_dim)

        pairwise_l2 = torch.cdist(test_l2.contiguous(),
                                  test_l2.contiguous(),
                                  p=2)

        self.l2_pdist_mat = pairwise_l2

        # alphas

        self.alphas.requires_grad = True
        #self.alphas_bias.requires_grad = True
        #b = torch.mm(self.alphas, b) + self.alphas_bias
        b = torch.mm(self.alphas, b)

        ctx = b.contiguous().view(self.num_heads, bsz, tgt_len,
                                  self.head_dim).transpose(0, 1)

        ctx = ctx.contiguous().view(bsz * self.num_heads, tgt_len,
                                    self.head_dim)

        assert list(
            ctx.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
        if self.onnx_trace and ctx.size(1) == 1:
            # when ONNX tracing a single decoder step (sequence length == 1)
            # the transpose is a no-op copy before view, thus unnecessary
            ctx = ctx.contiguous().view(tgt_len, bsz, embed_dim)
        else:
            ctx = ctx.transpose(0,
                                1).contiguous().view(tgt_len, bsz, embed_dim)
        attn = self.out_proj(ctx)
        attn_weights: Optional[Tensor] = None
        if calc_head_importance:
            attn_weights = attn_weights_float.view(bsz, self.num_heads,
                                                   tgt_len, src_len)
        else:
            if need_weights:
                attn_weights = attn_weights_float.view(
                    bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
                if not need_head_weights:
                    # average attention weights over heads
                    attn_weights = attn_weights.mean(dim=0)
        return attn, attn_weights, save_attn_for_guy
Example #24
0
    def forward(
        self,
        query,
        key: Optional[Tensor],
        value: Optional[Tensor],
        key_padding_mask: Optional[Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str,
                                                   Optional[Tensor]]]] = None,
        need_weights: bool = True,
        static_kv: bool = False,
        attn_mask: Optional[Tensor] = None,
        before_softmax: bool = False,
        need_head_weights: bool = False,
    ) -> Tuple[Tensor, Optional[Tensor]]:
        """Input shape: Time x Batch x Channel

        Args:
            key_padding_mask (ByteTensor, optional): mask to exclude
                keys that are pads, of shape `(batch, src_len)`, where
                padding elements are indicated by 1s.
            need_weights (bool, optional): return the attention weights,
                averaged over heads (default: False).
            attn_mask (ByteTensor, optional): typically used to
                implement causal attention, where the mask prevents the
                attention from looking forward in time (default: None).
            before_softmax (bool, optional): return the raw attention
                weights and values before the attention softmax.
            need_head_weights (bool, optional): return the attention
                weights for each head. Implies *need_weights*. Default:
                return the average attention weights over all heads.
        """
        if need_head_weights:
            need_weights = True

        is_tpu = query.device.type == "xla"

        tgt_len, bsz, embed_dim = query.size()
        src_len = tgt_len
        if not self.skip_embed_dim_check:
            assert (embed_dim == self.embed_dim
                    ), f"query dim {embed_dim} != {self.embed_dim}"
        assert list(query.size()) == [tgt_len, bsz, embed_dim]
        if key is not None:
            src_len, key_bsz, _ = key.size()
            if not torch.jit.is_scripting():
                assert value is not None
                assert src_len, key_bsz == value.shape[:2]

        if (not self.onnx_trace
                and not is_tpu  # don't use PyTorch version on TPUs
                and incremental_state is None and not static_kv
                # A workaround for quantization to work. Otherwise JIT compilation
                # treats bias in linear module as method.
                and not torch.jit.is_scripting()
                # The Multihead attention implemented in pytorch forces strong dimension check
                # for input embedding dimention and K,Q,V projection dimension.
                # Since pruning will break the dimension check and it is not easy to modify the pytorch API,
                # it is preferred to bypass the pytorch MHA when we need to skip embed_dim_check
                and not self.skip_embed_dim_check
                and self.positional_embedding is None):
            assert key is not None and value is not None

            if self.use_xformers:
                return self._xformers_attn_forward(query, key, value,
                                                   key_padding_mask,
                                                   need_weights, attn_mask)

            else:
                return F.multi_head_attention_forward(
                    query,
                    key,
                    value,
                    self.embed_dim,
                    self.num_heads,
                    torch.empty([0]),
                    torch.cat((self.q_proj.bias, self.k_proj.bias,
                               self.v_proj.bias)),
                    self.bias_k,
                    self.bias_v,
                    self.add_zero_attn,
                    self.dropout_module.p,
                    self.out_proj.weight,
                    self.out_proj.bias,
                    self.training
                    or self.dropout_module.apply_during_inference,
                    key_padding_mask,
                    need_weights,
                    attn_mask,
                    use_separate_proj_weight=True,
                    q_proj_weight=self.q_proj.weight,
                    k_proj_weight=self.k_proj.weight,
                    v_proj_weight=self.v_proj.weight,
                )

        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
            if saved_state is not None and "prev_key" in saved_state:
                # previous time steps are cached - no need to recompute
                # key and value if they are static
                if static_kv:
                    assert self.encoder_decoder_attention and not self.self_attention
                    key = value = None
        else:
            saved_state = None

        if self.self_attention:
            q = self.q_proj(query)
            k = self.k_proj(query)
            v = self.v_proj(query)
        elif self.encoder_decoder_attention:
            # encoder-decoder attention
            q = self.q_proj(query)
            if key is None:
                assert value is None
                k = v = None
            else:
                if self.beam_size > 1 and bsz == key.size(1):
                    # key is [T, bsz*beam_size, C], reduce to [T, bsz, C]
                    key = key.view(key.size(0), -1, self.beam_size,
                                   key.size(2))[:, :, 0, :]
                    if key_padding_mask is not None:
                        key_padding_mask = key_padding_mask.view(
                            -1, self.beam_size, key_padding_mask.size(1))[:,
                                                                          0, :]
                k = self.k_proj(key)
                v = self.v_proj(key)

        else:
            assert key is not None and value is not None
            q = self.q_proj(query)
            k = self.k_proj(key)
            v = self.v_proj(value)

        if self.positional_embedding is not None:
            if not self.positional_embedding.learnable:
                q_with_bias_v = (q + self.pos_bias_v) * self.scaling
                q_with_bias_v = (q_with_bias_v.contiguous().view(
                    tgt_len, bsz * self.num_heads,
                    self.head_dim).transpose(0, 1))
                q = q + self.pos_bias_u
        q *= self.scaling

        if self.bias_k is not None:
            assert self.bias_v is not None
            k, v, attn_mask, key_padding_mask = self._add_bias(
                k, v, attn_mask, key_padding_mask, bsz)

        q = (q.contiguous().view(tgt_len, bsz * self.num_heads,
                                 self.head_dim).transpose(0, 1))
        kv_bsz = bsz  # need default value for scripting
        if k is not None:
            kv_bsz = k.size(1)
            k = (k.contiguous().view(-1, kv_bsz * self.num_heads,
                                     self.head_dim).transpose(0, 1))
        if v is not None:
            v = (v.contiguous().view(-1, kv_bsz * self.num_heads,
                                     self.head_dim).transpose(0, 1))

        if saved_state is not None:
            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
            if "prev_key" in saved_state:
                _prev_key = saved_state["prev_key"]
                assert _prev_key is not None
                kv_bsz = _prev_key.size(0)
                prev_key = _prev_key.view(kv_bsz * self.num_heads, -1,
                                          self.head_dim)
                if static_kv:
                    k = prev_key
                else:
                    assert k is not None
                    k = torch.cat([prev_key, k], dim=1)
                src_len = k.size(1)
            if "prev_value" in saved_state:
                _prev_value = saved_state["prev_value"]
                assert _prev_value is not None
                assert kv_bsz == _prev_value.size(0)
                prev_value = _prev_value.view(kv_bsz * self.num_heads, -1,
                                              self.head_dim)
                if static_kv:
                    v = prev_value
                else:
                    assert v is not None
                    v = torch.cat([prev_value, v], dim=1)
            prev_key_padding_mask: Optional[Tensor] = None
            if "prev_key_padding_mask" in saved_state:
                prev_key_padding_mask = saved_state["prev_key_padding_mask"]
            assert k is not None and v is not None
            key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
                key_padding_mask=key_padding_mask,
                prev_key_padding_mask=prev_key_padding_mask,
                batch_size=kv_bsz,
                src_len=k.size(1),
                static_kv=static_kv,
            )

            saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1,
                                             self.head_dim)
            saved_state["prev_value"] = v.view(kv_bsz, self.num_heads, -1,
                                               self.head_dim)
            saved_state["prev_key_padding_mask"] = key_padding_mask
            # In this branch incremental_state is never None
            assert incremental_state is not None
            incremental_state = self._set_input_buffer(incremental_state,
                                                       saved_state)
        assert k is not None
        assert k.size(1) == src_len

        # This is part of a workaround to get around fork/join parallelism
        # not supporting Optional types.
        if key_padding_mask is not None and key_padding_mask.dim() == 0:
            key_padding_mask = None

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == kv_bsz
            assert key_padding_mask.size(1) == src_len

        if self.add_zero_attn:
            assert v is not None
            src_len += 1
            k, v, key_padding_mask, attn_mask = self._append_zero_attn(
                k=k,
                v=v,
                key_padding_mask=key_padding_mask,
                attn_mask=attn_mask)

        if self.encoder_decoder_attention and bsz != kv_bsz:
            attn_weights = torch.einsum(
                "bxhtd,bhsd->bxhts",
                q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]),
                k.view((kv_bsz, self.num_heads) + k.size()[1:]),
            )
            attn_weights = attn_weights.reshape((-1, ) +
                                                attn_weights.size()[-2:])
        else:
            attn_weights = torch.bmm(q, k.transpose(1, 2))
        attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len,
                                              bsz)

        if self.positional_embedding is not None:
            # compute `attn_weights` as described in https://arxiv.org/abs/1901.02860 Section 3.3
            assert (
                not self.encoder_decoder_attention
            ), "positional embedding is only applicable to self attention"
            assert bsz == kv_bsz, f"{bsz} != {kv_bsz}"
            assert src_len >= tgt_len, f"{src_len} vs {tgt_len}"
            if key_padding_mask is not None:
                pe = self.positional_embedding(
                    ~(key_padding_mask.bool())
                )  # bsz x (2*src_len-1) x embed_dim
            else:
                pe = self.positional_embedding(
                    k.new_ones([bsz, src_len], dtype=torch.bool))
            if not self.positional_embedding.learnable:
                pe = self.pos_proj(pe)
            pe = pe.view(bsz, -1, self.num_heads, self.head_dim).transpose(
                1, 2)  # bsz x num_heads x (2*src_len-1) x head_dim
            pe = pe.reshape(bsz * self.num_heads, -1, self.head_dim)
            positional_logits = torch.bmm(
                q_with_bias_v
                if not self.positional_embedding.learnable else q,
                pe.transpose(1, 2),
            )
            assert list(positional_logits.size()) == [
                bsz * self.num_heads,
                tgt_len,
                2 * src_len - 1,
            ]
            batch_head_stride, tgt_stride, src_stride = positional_logits.stride(
            )
            # assume src (key) and tgt (query) sequences are right-aligned
            positional_logits = positional_logits.as_strided(
                (bsz * self.num_heads, tgt_len, src_len),
                (batch_head_stride, tgt_stride - src_stride, src_stride),
                storage_offset=src_stride * (tgt_len - 1),
            )
            attn_weights = attn_weights + positional_logits

        assert list(
            attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]

        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(0)
            if self.onnx_trace:
                attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
            attn_weights += attn_mask

        if key_padding_mask is not None:
            # don't attend to padding symbols
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
                                             src_len)
            if not is_tpu:
                attn_weights = attn_weights.view(kv_bsz, -1, self.num_heads,
                                                 tgt_len, src_len)
                attn_weights = attn_weights.masked_fill(
                    key_padding_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(
                        torch.bool),
                    float("-inf"),
                )
            else:
                attn_weights = attn_weights.transpose(0, 2)
                attn_weights = attn_weights.masked_fill(
                    key_padding_mask, float("-inf"))
                attn_weights = attn_weights.transpose(0, 2)
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len,
                                             src_len)

        if before_softmax:
            return attn_weights, v

        attn_weights_float = utils.softmax(attn_weights,
                                           dim=-1,
                                           onnx_trace=self.onnx_trace)
        attn_weights = attn_weights_float.type_as(attn_weights)

        if self.training and self.relaxed_attention_weight > 0.0:
            attn_weights = (
                1.0 - self.relaxed_attention_weight
            ) * attn_weights + self.relaxed_attention_weight / src_len

        attn_probs = self.dropout_module(attn_weights)

        assert v is not None
        if self.encoder_decoder_attention and bsz != kv_bsz:
            attn = torch.einsum(
                "bxhts,bhsd->bxhtd",
                attn_probs.view((
                    kv_bsz,
                    -1,
                    self.num_heads,
                ) + attn_probs.size()[1:]),
                v.view((
                    kv_bsz,
                    self.num_heads,
                ) + v.size()[1:]),
            )
            attn = attn.reshape((-1, ) + attn.size()[-2:])
        else:
            attn = torch.bmm(attn_probs, v)
        assert list(
            attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
        if self.onnx_trace and attn.size(1) == 1:
            # when ONNX tracing a single decoder step (sequence length == 1)
            # the transpose is a no-op copy before view, thus unnecessary
            attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim)
        else:
            attn = attn.transpose(0,
                                  1).contiguous().view(tgt_len, bsz,
                                                       self.embed_dim)
        attn = self.out_proj(attn)
        attn_weights: Optional[Tensor] = None
        if need_weights:
            attn_weights = attn_weights_float.view(bsz, self.num_heads,
                                                   tgt_len,
                                                   src_len).transpose(1, 0)
            if not need_head_weights:
                # average attention weights over heads
                attn_weights = attn_weights.mean(dim=0)

        return attn, attn_weights
    def forward(
            self,
            query,
            key: Optional[Tensor],
            value: Optional[Tensor],
            key_padding_mask: Optional[Tensor] = None,
            incremental_state: Optional[Dict[str,
                                             Dict[str,
                                                  Optional[Tensor]]]] = None,
            need_weights: bool = True,
            static_kv: bool = False,
            attn_mask: Optional[Tensor] = None,
            before_softmax: bool = False,
            need_head_weights: bool = False,
            output_all_attentions: bool = False,  # added by Goro Kobayashi
            output_all_norms: bool = False,  # added by Goro Kobayashi
    ) -> Tuple[Tensor, Optional[Tensor]]:
        """Input shape: Time x Batch x Channel

        Args:
            key_padding_mask (ByteTensor, optional): mask to exclude
                keys that are pads, of shape `(batch, src_len)`, where
                padding elements are indicated by 1s.
            need_weights (bool, optional): return the attention weights,
                averaged over heads (default: False).
            attn_mask (ByteTensor, optional): typically used to
                implement causal attention, where the mask prevents the
                attention from looking forward in time (default: None).
            before_softmax (bool, optional): return the raw attention
                weights and values before the attention softmax.
            need_head_weights (bool, optional): return the attention
                weights for each head. Implies *need_weights*. Default:
                return the average attention weights over all heads.
            -----Comments below are added by Goro Kobayashi-----
            output_all_attentions (bool, optional): return the attention
                weights for all heads (default: False).
            output_all_norms (bool, optional): return the norms 
                (||f(x)||, ||αf(x)||, and ||Σαf(x)||, 
                detailed in https://arxiv.org/abs/2004.10102)
                for all heads (default: False).
        """
        if need_head_weights or output_all_attentions:  # Changed by Goro Kobayashi
            need_weights = True

        tgt_len, bsz, embed_dim = query.size()
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]

        if (not self.onnx_trace
                and not self.tpu  # don't use PyTorch version on TPUs
                and incremental_state is None and not static_kv
                # A workaround for quantization to work. Otherwise JIT compilation
                # treats bias in linear module as method.
                and not torch.jit.is_scripting()):
            assert key is not None and value is not None
            return F.multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                torch.empty([0]),
                torch.cat(
                    (self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout_module.p,
                self.out_proj.weight,
                self.out_proj.bias,
                self.training or self.dropout_module.apply_during_inference,
                key_padding_mask,
                need_weights,
                attn_mask,
                use_separate_proj_weight=True,
                q_proj_weight=self.q_proj.weight,
                k_proj_weight=self.k_proj.weight,
                v_proj_weight=self.v_proj.weight,
            )

        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
            if saved_state is not None and "prev_key" in saved_state:
                # previous time steps are cached - no need to recompute
                # key and value if they are static
                if static_kv:
                    assert self.encoder_decoder_attention and not self.self_attention
                    key = value = None
        else:
            saved_state = None

        if self.self_attention:
            q = self.q_proj(query)
            k = self.k_proj(query)
            v = self.v_proj(query)
        elif self.encoder_decoder_attention:
            # encoder-decoder attention
            q = self.q_proj(query)
            if key is None:
                assert value is None
                k = v = None
            else:
                k = self.k_proj(key)
                v = self.v_proj(key)

        else:
            assert key is not None and value is not None
            q = self.q_proj(query)
            k = self.k_proj(key)
            v = self.v_proj(value)
        q *= self.scaling

        if self.bias_k is not None:
            assert self.bias_v is not None
            k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat(
                    [
                        key_padding_mask,
                        key_padding_mask.new_zeros(key_padding_mask.size(0),
                                                   1),
                    ],
                    dim=1,
                )

        q = (q.contiguous().view(tgt_len, bsz * self.num_heads,
                                 self.head_dim).transpose(0, 1))
        if k is not None:
            k = (k.contiguous().view(-1, bsz * self.num_heads,
                                     self.head_dim).transpose(0, 1))
        if v is not None:
            v = (v.contiguous().view(-1, bsz * self.num_heads,
                                     self.head_dim).transpose(0, 1))

        if saved_state is not None:
            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
            if "prev_key" in saved_state:
                _prev_key = saved_state["prev_key"]
                assert _prev_key is not None
                prev_key = _prev_key.view(bsz * self.num_heads, -1,
                                          self.head_dim)
                if static_kv:
                    k = prev_key
                else:
                    assert k is not None
                    k = torch.cat([prev_key, k], dim=1)
            if "prev_value" in saved_state:
                _prev_value = saved_state["prev_value"]
                assert _prev_value is not None
                prev_value = _prev_value.view(bsz * self.num_heads, -1,
                                              self.head_dim)
                if static_kv:
                    v = prev_value
                else:
                    assert v is not None
                    v = torch.cat([prev_value, v], dim=1)
            prev_key_padding_mask: Optional[Tensor] = None
            if "prev_key_padding_mask" in saved_state:
                prev_key_padding_mask = saved_state["prev_key_padding_mask"]
            assert k is not None and v is not None
            key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
                key_padding_mask=key_padding_mask,
                prev_key_padding_mask=prev_key_padding_mask,
                batch_size=bsz,
                src_len=k.size(1),
                static_kv=static_kv,
            )

            saved_state["prev_key"] = k.view(bsz, self.num_heads, -1,
                                             self.head_dim)
            saved_state["prev_value"] = v.view(bsz, self.num_heads, -1,
                                               self.head_dim)
            saved_state["prev_key_padding_mask"] = key_padding_mask
            # In this branch incremental_state is never None
            assert incremental_state is not None
            incremental_state = self._set_input_buffer(incremental_state,
                                                       saved_state)
        assert k is not None
        src_len = k.size(1)

        # This is part of a workaround to get around fork/join parallelism
        # not supporting Optional types.
        if key_padding_mask is not None and key_padding_mask.dim() == 0:
            key_padding_mask = None

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len

        if self.add_zero_attn:
            assert v is not None
            src_len += 1
            k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])],
                          dim=1)
            v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])],
                          dim=1)
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat(
                    [
                        key_padding_mask,
                        torch.zeros(key_padding_mask.size(0),
                                    1).type_as(key_padding_mask),
                    ],
                    dim=1,
                )

        attn_weights = torch.bmm(q, k.transpose(1, 2))
        attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len,
                                              bsz)

        assert list(
            attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]

        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(0)
            if self.onnx_trace:
                attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
            attn_weights += attn_mask

        if key_padding_mask is not None:
            # don't attend to padding symbols
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
                                             src_len)
            if not self.tpu:
                attn_weights = attn_weights.masked_fill(
                    key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
                    float("-inf"),
                )
            else:
                attn_weights = attn_weights.transpose(0, 2)
                attn_weights = attn_weights.masked_fill(
                    key_padding_mask, float("-inf"))
                attn_weights = attn_weights.transpose(0, 2)
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len,
                                             src_len)

        if before_softmax:
            return attn_weights, v

        attn_weights_float = utils.softmax(attn_weights,
                                           dim=-1,
                                           onnx_trace=self.onnx_trace)
        attn_weights = attn_weights_float.type_as(attn_weights)
        attn_probs = self.dropout_module(attn_weights)

        assert v is not None
        attn = torch.bmm(attn_probs, v)
        assert list(
            attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
        if self.onnx_trace and attn.size(1) == 1:
            # when ONNX tracing a single decoder step (sequence length == 1)
            # the transpose is a no-op copy before view, thus unnecessary
            attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
        else:
            attn = attn.transpose(0,
                                  1).contiguous().view(tgt_len, bsz, embed_dim)
        attn = self.out_proj(attn)
        attn_weights: Optional[Tensor] = None
        if need_weights:
            attn_weights = attn_weights_float.view(bsz, self.num_heads,
                                                   tgt_len,
                                                   src_len).transpose(1, 0)
            if not need_head_weights and not output_all_attentions:  # Changed by Goro Kobayashi
                # average attention weights over heads
                attn_weights = attn_weights.mean(dim=0)

        # -----added below by Goro Kobayashi-----
        if output_all_norms:
            with torch.no_grad():
                # Reshape Value vectors into (bsz, src_len, num_heads, 1, head_dim)
                v = v.contiguous().view(bsz, self.num_heads, -1, 1,
                                        self.head_dim).transpose(1, 2)

                # Dense weights W^O: (embed_dim, embed_dim)
                dense = self.out_proj.weight

                # Reshape W^O into (num_heads, head_dim, embed_dim)
                dense = dense.view(embed_dim, self.num_heads,
                                   self.head_dim).permute(1, 2,
                                                          0).contiguous()

                # By matrix product, make transformed vectors f(x): (bsz, num_heads, src_len, embed_dim)
                transformed_vectors = v.matmul(dense).view(
                    bsz, -1, self.num_heads, embed_dim)
                transformed_vectors = transformed_vectors.permute(0, 2, 1, 3)

                # Calculate L2 norm ||f(x)||: (num_heads, bsz, src_len)
                transformed_vector_norm = torch.norm(transformed_vectors,
                                                     dim=-1).transpose(0, 1)

                # By element product, make weighted vectors αf(x): (bsz, num_heads, tgt_len, src_len, embed_dim)
                attn_probs = attn_probs.view(bsz, self.num_heads, tgt_len,
                                             src_len)
                weighted_vectors = torch.einsum('bhts,bhsd->bhtsd', attn_probs,
                                                transformed_vectors)

                # Calculate L2 norm ||αf(x)||: (num_heads, bsz, tgt_len, src_len)
                weighted_vector_norm = torch.norm(weighted_vectors,
                                                  dim=-1).transpose(0, 1)

                # Sum each αf(x) over all heads: (bsz, tgt_len, src_len, embed_dim)
                summed_weighted_vectors = weighted_vectors.sum(dim=1)

                # Calculate L2 norm of summed weighted vectors: (bsz, tgt_len, src_len)
                summed_weighted_vector_norm = torch.norm(
                    summed_weighted_vectors, dim=-1)
            return attn, attn_weights, (transformed_vector_norm,
                                        weighted_vector_norm,
                                        summed_weighted_vector_norm)
        # -----added above by Goro Kobayashi-----

        return attn, attn_weights
Example #26
0
    def forward_with_args(
        self,
        embed_dim: int_or_int_dict,
        num_heads: int,
        kdim: int_or_int_dict | None,
        vdim: int_or_int_dict | None,
        dropout: float,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_padding_mask: torch.Tensor | None = None,
        need_weights: bool = True,
        attn_mask: torch.Tensor | None = None
    ) -> tuple[torch.Tensor, torch.Tensor | None]:

        if any(isinstance(arg, dict) for arg in [num_heads, dropout]):
            raise ValueError(
                'num_heads, dropout do not support weighted sampling.')

        # by default, kdim, vdim can be none
        if kdim is None:
            kdim = embed_dim
        if vdim is None:
            vdim = embed_dim

        qkv_same_embed_dim = kdim == embed_dim and vdim == embed_dim

        if getattr(self, 'batch_first', False):
            # for backward compatibility: v1.7 doesn't have batch_first
            query, key, value = [
                x.transpose(1, 0) for x in (query, key, value)
            ]

        if isinstance(embed_dim, dict):
            used_embed_dim = self.embed_dim
        else:
            used_embed_dim = embed_dim

        embed_dim_ = _W(embed_dim)

        # in projection weights & biases has q, k, v weights concatenated together
        in_proj_bias: Tensor | None = None
        in_proj_weight: Tensor | None = None
        if self.in_proj_bias is not None:
            in_proj_bias = _S(cast(
                Tensor, self.in_proj_bias))[self._to_proj_slice(embed_dim_)]
        if self.in_proj_weight is not None:
            in_proj_weight = _S(cast(Tensor, self.in_proj_weight))[
                self._to_proj_slice(embed_dim_), :embed_dim_]

        bias_k = _S(cast(Tensor, self.bias_k)
                    )[:, :, :embed_dim_] if self.bias_k is not None else None
        bias_v = _S(cast(Tensor, self.bias_v)
                    )[:, :, :embed_dim_] if self.bias_v is not None else None
        out_proj_weight = _S(cast(
            Tensor, self.out_proj.weight))[:embed_dim_, :embed_dim_]
        out_proj_bias = _S(
            cast(Tensor, self.out_proj.bias
                 ))[:embed_dim_] if self.out_proj.bias is not None else None

        if not qkv_same_embed_dim:
            q_proj = _S(cast(Tensor,
                             self.q_proj_weight))[:embed_dim_, :embed_dim_]
            k_proj = _S(cast(Tensor, self.k_proj_weight))[:embed_dim_]
            k_proj = _S(k_proj)[:, :_W(kdim)]
            v_proj = _S(cast(Tensor, self.v_proj_weight))[:embed_dim_]
            v_proj = _S(v_proj)[:, :_W(vdim)]

            # The rest part is basically same as pytorch
            attn_output, attn_output_weights = F.multi_head_attention_forward(
                query,
                key,
                value,
                used_embed_dim,
                num_heads,
                cast(Tensor, in_proj_weight),
                cast(Tensor, in_proj_bias),
                bias_k,
                bias_v,
                self.add_zero_attn,
                dropout,
                out_proj_weight,
                cast(Tensor, out_proj_bias),
                training=self.training,
                key_padding_mask=key_padding_mask,
                need_weights=need_weights,
                attn_mask=attn_mask,
                use_separate_proj_weight=True,
                q_proj_weight=q_proj,
                k_proj_weight=k_proj,
                v_proj_weight=v_proj)
        else:
            # Cast tensor here because of a bug in pytorch stub
            attn_output, attn_output_weights = F.multi_head_attention_forward(
                query,
                key,
                value,
                used_embed_dim,
                num_heads,
                cast(Tensor, in_proj_weight),
                cast(Tensor, in_proj_bias),
                bias_k,
                bias_v,
                self.add_zero_attn,
                dropout,
                out_proj_weight,
                cast(Tensor, out_proj_bias),
                training=self.training,
                key_padding_mask=key_padding_mask,
                need_weights=need_weights,
                attn_mask=attn_mask)

        if getattr(self, 'batch_first', False):  # backward compatibility
            return attn_output.transpose(1, 0), attn_output_weights
        else:
            return attn_output, attn_output_weights
Example #27
0
    def forward(
        self,
        query,
        key: Optional[Tensor],
        value: Optional[Tensor],
        key_padding_mask: Optional[Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str,
                                                   Optional[Tensor]]]] = None,
        need_weights: bool = True,
        static_kv: bool = False,
        attn_mask: Optional[Tensor] = None,
        before_softmax: bool = False,
        need_head_weights: bool = False,
    ) -> Tuple[Tensor, Optional[Tensor]]:
        """Input shape: Time x Batch x Channel

        Args:
            key_padding_mask (ByteTensor, optional): mask to exclude
                keys that are pads, of shape `(batch, src_len)`, where
                padding elements are indicated by 1s.
            need_weights (bool, optional): return the attention weights,
                averaged over heads (default: False).
            attn_mask (ByteTensor, optional): typically used to
                implement causal attention, where the mask prevents the
                attention from looking forward in time (default: None).
            before_softmax (bool, optional): return the raw attention
                weights and values before the attention softmax.
            need_head_weights (bool, optional): return the attention
                weights for each head. Implies *need_weights*. Default:
                return the average attention weights over all heads.
        """
        if need_head_weights:
            need_weights = True

        is_tpu = query.device.type == "xla"
        is_hpu = query.device.type == "hpu"

        tgt_len, bsz, embed_dim = query.size()
        src_len = tgt_len
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]
        if key is not None:
            src_len, key_bsz, _ = key.size()
            if not torch.jit.is_scripting():
                assert key_bsz == bsz
                assert value is not None
                assert src_len, bsz == value.shape[:2]

        if (not self.onnx_trace
                and not is_tpu  # don't use PyTorch version on TPUs
                and not is_hpu and incremental_state is None and not static_kv
                # A workaround for quantization to work. Otherwise JIT compilation
                # treats bias in linear module as method.
                and not torch.jit.is_scripting()):
            assert key is not None and value is not None
            return F.multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                torch.empty([0]),
                torch.cat(
                    (self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout_module.p,
                self.out_proj.weight,
                self.out_proj.bias,
                self.training or self.dropout_module.apply_during_inference,
                key_padding_mask,
                need_weights,
                attn_mask,
                use_separate_proj_weight=True,
                q_proj_weight=self.q_proj.weight,
                k_proj_weight=self.k_proj.weight,
                v_proj_weight=self.v_proj.weight,
            )

        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
            if saved_state is not None and "prev_key" in saved_state:
                # previous time steps are cached - no need to recompute
                # key and value if they are static
                if static_kv:
                    assert self.encoder_decoder_attention and not self.self_attention
                    key = value = None
        else:
            saved_state = None

        if self.self_attention:
            q = self.q_proj(query)
            k = self.k_proj(query)
            v = self.v_proj(query)
        elif self.encoder_decoder_attention:
            # encoder-decoder attention
            q = self.q_proj(query)
            if key is None:
                assert value is None
                k = v = None
            else:
                k = self.k_proj(key)
                v = self.v_proj(key)

        else:
            assert key is not None and value is not None
            q = self.q_proj(query)
            k = self.k_proj(key)
            v = self.v_proj(value)
        q *= self.scaling

        if self.bias_k is not None:
            assert self.bias_v is not None
            k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat(
                    [
                        key_padding_mask,
                        key_padding_mask.new_zeros(key_padding_mask.size(0),
                                                   1),
                    ],
                    dim=1,
                )

        q = (q.contiguous().view(tgt_len, bsz * self.num_heads,
                                 self.head_dim).transpose(0, 1))
        if k is not None:
            k = (k.contiguous().view(-1, bsz * self.num_heads,
                                     self.head_dim).transpose(0, 1))
        if v is not None:
            v = (v.contiguous().view(-1, bsz * self.num_heads,
                                     self.head_dim).transpose(0, 1))

        if saved_state is not None:
            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
            if "prev_key" in saved_state:
                _prev_key = saved_state["prev_key"]
                assert _prev_key is not None
                prev_key = _prev_key.view(bsz * self.num_heads, -1,
                                          self.head_dim)
                if static_kv:
                    k = prev_key
                else:
                    assert k is not None
                    k = torch.cat([prev_key, k], dim=1)
                src_len = k.size(1)
            if "prev_value" in saved_state:
                _prev_value = saved_state["prev_value"]
                assert _prev_value is not None
                prev_value = _prev_value.view(bsz * self.num_heads, -1,
                                              self.head_dim)
                if static_kv:
                    v = prev_value
                else:
                    assert v is not None
                    v = torch.cat([prev_value, v], dim=1)
            prev_key_padding_mask: Optional[Tensor] = None
            if "prev_key_padding_mask" in saved_state:
                prev_key_padding_mask = saved_state["prev_key_padding_mask"]
            assert k is not None and v is not None
            key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
                key_padding_mask=key_padding_mask,
                prev_key_padding_mask=prev_key_padding_mask,
                batch_size=bsz,
                src_len=k.size(1),
                static_kv=static_kv,
            )

            saved_state["prev_key"] = k.view(bsz, self.num_heads, -1,
                                             self.head_dim)
            saved_state["prev_value"] = v.view(bsz, self.num_heads, -1,
                                               self.head_dim)
            saved_state["prev_key_padding_mask"] = key_padding_mask
            # In this branch incremental_state is never None
            assert incremental_state is not None
            incremental_state = self._set_input_buffer(incremental_state,
                                                       saved_state)
        assert k is not None
        assert k.size(1) == src_len

        # This is part of a workaround to get around fork/join parallelism
        # not supporting Optional types.
        if key_padding_mask is not None and key_padding_mask.dim() == 0:
            key_padding_mask = None

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len

        if self.add_zero_attn:
            assert v is not None
            src_len += 1
            k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])],
                          dim=1)
            v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])],
                          dim=1)
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat(
                    [
                        key_padding_mask,
                        torch.zeros(key_padding_mask.size(0),
                                    1).type_as(key_padding_mask),
                    ],
                    dim=1,
                )

        attn_weights = torch.bmm(q, k.transpose(1, 2))
        attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len,
                                              bsz)

        assert list(
            attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]

        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(0)
            if self.onnx_trace:
                attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
            attn_weights += attn_mask

        if key_padding_mask is not None:
            # don't attend to padding symbols
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
                                             src_len)
            if not is_tpu:
                attn_weights = attn_weights.masked_fill(
                    key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
                    float("-inf"),
                )
            else:
                attn_weights = attn_weights.transpose(0, 2)
                attn_weights = attn_weights.masked_fill(
                    key_padding_mask, float("-inf"))
                attn_weights = attn_weights.transpose(0, 2)
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len,
                                             src_len)

        if before_softmax:
            return attn_weights, v

        attn_weights_float = utils.softmax(attn_weights,
                                           dim=-1,
                                           onnx_trace=self.onnx_trace)
        attn_weights = attn_weights_float.type_as(attn_weights)
        attn_probs = self.dropout_module(attn_weights)

        assert v is not None
        attn = torch.bmm(attn_probs, v)
        assert list(
            attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
        if self.onnx_trace and attn.size(1) == 1:
            # when ONNX tracing a single decoder step (sequence length == 1)
            # the transpose is a no-op copy before view, thus unnecessary
            attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
        else:
            attn = attn.transpose(0,
                                  1).contiguous().view(tgt_len, bsz, embed_dim)
        attn = self.out_proj(attn)
        attn_weights: Optional[Tensor] = None
        if need_weights:
            attn_weights = attn_weights_float.view(bsz, self.num_heads,
                                                   tgt_len,
                                                   src_len).transpose(1, 0)
            if not need_head_weights:
                # average attention weights over heads
                attn_weights = attn_weights.mean(dim=0)

        return attn, attn_weights
Example #28
0
    def forward(self, src_user, src_loc, src_reg, src_time, src_square_mask, src_binary_mask, trg_loc, mem_mask, ds=None):
        loc_emb_src = self.emb_loc(src_loc)
        if self.extra_config.get("user_location_only", False):
            src = loc_emb_src
        else:
            user_emb_src = self.emb_user(src_user)
            reg_emb = self.emb_reg(src_reg)
            time_emb = self.emb_time(src_time)
            if self.extra_config.get("embedding_fusion", "multiply") == "multiply":
                if self.extra_config.get("user_embedding", False):
                    src = loc_emb_src * reg_emb * time_emb * user_emb_src
                else:
                    src = loc_emb_src * reg_emb * time_emb
            else:
                if self.extra_config.get("user_embedding", False):
                    src = torch.cat([user_emb_src, loc_emb_src, reg_emb, time_emb], dim=-1)
                else:
                    src = torch.cat([loc_emb_src, reg_emb, time_emb], dim=-1)
                src = self.lin(src)

        if self.extra_config.get("size_sqrt_regularize", True):
            src = src * math.sqrt(src.size(-1))

        src = self.pos_encoder(src)
        # shape: [L, N, ninp]
        src = self.encoder(src, mask=src_square_mask)
        # shape: [(1+K)*L, N, loc_dim]
        loc_emb_trg = self.emb_loc(trg_loc)

        if self.extra_config.get("use_attention_as_decoder", False):
            # multi-head attention
            output, _ = F.multi_head_attention_forward(
                query=loc_emb_trg,
                key=src,
                value=src,
                embed_dim_to_check=src.size(2),
                num_heads=1,
                in_proj_weight=None,
                in_proj_bias=None,
                bias_k=None,
                bias_v=None,
                add_zero_attn=None,
                dropout_p=0.0,
                out_proj_weight=self.ident_mat,
                out_proj_bias=None,
                training=self.training,
                key_padding_mask=src_binary_mask,
                need_weights=False,
                attn_mask=mem_mask,
                use_separate_proj_weight=True,
                q_proj_weight=self.ident_mat,
                k_proj_weight=self.ident_mat,
                v_proj_weight=self.ident_mat
            )
            
            if self.training:
                src = src.repeat(loc_emb_trg.size(0) // src.size(0), 1, 1)
            else:
                src = src[torch.tensor(ds) - 1, torch.arange(len(ds)), :]
                src = src.unsqueeze(0).repeat(loc_emb_trg.size(0), 1, 1) 

            output += src
            output = self.layer_norm(output)
        else:
            # No attention
            if self.training:
                output = src.repeat(loc_emb_trg.size(0) // src.size(0), 1, 1)
            else:
                output = src[torch.tensor(ds) - 1, torch.arange(len(ds)), :]
                output = output.unsqueeze(0).repeat(loc_emb_trg.size(0), 1, 1)

        # shape: [(1+K)*L, N]
        output = torch.sum(output * loc_emb_trg, dim=-1)
        return output
    def forward(
        self,
        query, key, value,
        key_padding_mask=None,
        incremental_state=None,
        need_weights=True,
        static_kv=False,
        attn_mask=None,
        before_softmax=False,
        need_head_weights=False,
    ):
        """Input shape: Time x Batch x Channel

        Args:
            key_padding_mask (ByteTensor, optional): mask to exclude
                keys that are pads, of shape `(batch, src_len)`, where
                padding elements are indicated by 1s.
            need_weights (bool, optional): return the attention weights,
                averaged over heads (default: False).
            attn_mask (ByteTensor, optional): typically used to
                implement causal attention, where the mask prevents the
                attention from looking forward in time (default: None).
            before_softmax (bool, optional): return the raw attention
                weights and values before the attention softmax.
            need_head_weights (bool, optional): return the attention
                weights for each head. Implies *need_weights*. Default:
                return the average attention weights over all heads.
        """
        if need_head_weights:
            need_weights = True

        tgt_len, bsz, embed_dim = query.size()
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]

        if self.enable_torch_version and not self.onnx_trace and incremental_state is None and not static_kv:
            if self.qkv_same_dim:
                return F.multi_head_attention_forward(query, key, value,
                                                      self.embed_dim, self.num_heads,
                                                      self.in_proj_weight,
                                                      self.in_proj_bias, self.bias_k, self.bias_v,
                                                      self.add_zero_attn, self.dropout,
                                                      self.out_proj.weight, self.out_proj.bias,
                                                      self.training, key_padding_mask, need_weights,
                                                      attn_mask)
            else:
                return F.multi_head_attention_forward(query, key, value,
                                                      self.embed_dim, self.num_heads,
                                                      torch.empty([0]),
                                                      self.in_proj_bias, self.bias_k, self.bias_v,
                                                      self.add_zero_attn, self.dropout,
                                                      self.out_proj.weight, self.out_proj.bias,
                                                      self.training, key_padding_mask, need_weights,
                                                      attn_mask, use_separate_proj_weight=True,
                                                      q_proj_weight=self.q_proj_weight,
                                                      k_proj_weight=self.k_proj_weight,
                                                      v_proj_weight=self.v_proj_weight)

        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
            if 'prev_key' in saved_state:
                # previous time steps are cached - no need to recompute
                # key and value if they are static
                if static_kv:
                    assert self.encoder_decoder_attention and not self.self_attention
                    key = value = None
        else:
            saved_state = None

        if self.self_attention:
            # self-attention
            q, k, v = self.in_proj_qkv(query)
        elif self.encoder_decoder_attention:
            # encoder-decoder attention
            q = self.in_proj_q(query)
            if key is None:
                assert value is None
                k = v = None
            else:
                k = self.in_proj_k(key)
                v = self.in_proj_v(key)

        else:
            q = self.in_proj_q(query)
            k = self.in_proj_k(key)
            v = self.in_proj_v(value)
        q *= self.scaling

        if self.bias_k is not None:
            assert self.bias_v is not None
            k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat(
                    [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)

        q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
        if k is not None:
            k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
        if v is not None:
            v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)

        if saved_state is not None:
            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
            if 'prev_key' in saved_state:
                prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    k = prev_key
                else:
                    k = torch.cat((prev_key, k), dim=1)
            if 'prev_value' in saved_state:
                prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    v = prev_value
                else:
                    v = torch.cat((prev_value, v), dim=1)
            if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None:
                prev_key_padding_mask = saved_state['prev_key_padding_mask']
                if static_kv:
                    key_padding_mask = prev_key_padding_mask
                else:
                    key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)
            saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
            saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
            saved_state['prev_key_padding_mask'] = key_padding_mask

            self._set_input_buffer(incremental_state, saved_state)

        src_len = k.size(1)

        # This is part of a workaround to get around fork/join parallelism
        # not supporting Optional types.
        if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
            key_padding_mask = None

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len

        if self.add_zero_attn:
            src_len += 1
            k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
            v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
            if attn_mask is not None:
                attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat(
                    [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)

        attn_weights = torch.bmm(q, k.transpose(1, 2))
        attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)

        assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]

        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(0)
            if self.onnx_trace:
                attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
            attn_weights += attn_mask

        if key_padding_mask is not None:
            # don't attend to padding symbols
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights.masked_fill(
                key_padding_mask.unsqueeze(1).unsqueeze(2),
                float('-inf'),
            )
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        if before_softmax:
            return attn_weights, v

        attn_weights_float = utils.softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace)
        attn_weights = attn_weights_float.type_as(attn_weights)
        attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)

        attn = torch.bmm(attn_probs, v)
        assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
        if (self.onnx_trace and attn.size(1) == 1):
            # when ONNX tracing a single decoder step (sequence length == 1)
            # the transpose is a no-op copy before view, thus unnecessary
            attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
        else:
            attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
        attn = self.out_proj(attn)

        if need_weights:
            attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
            if not need_head_weights:
                # average attention weights over heads
                attn_weights = attn_weights.mean(dim=0)
        else:
            attn_weights = None

        return attn, attn_weights
Example #30
0
    def forward(self,
                query,
                key,
                value,
                key_padding_mask=None,
                incremental_state=None,
                need_weights=True,
                static_kv=False,
                attn_mask=None):
        """Input shape: Time x Batch x Channel

        Timesteps can be masked by supplying a T x T mask in the
        `attn_mask` argument. Padding elements can be excluded from
        the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
        batch x src_len, where padding elements are indicated by 1s.
        """

        tgt_len, bsz, embed_dim = query.size()
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]

        if self.enable_torch_version and not self.onnx_trace and incremental_state is None and not static_kv:
            if self.qkv_same_dim:
                return F.multi_head_attention_forward(
                    query, key, value, self.embed_dim, self.num_heads,
                    self.in_proj_weight, self.in_proj_bias, self.bias_k,
                    self.bias_v, self.add_zero_attn, self.dropout,
                    self.out_proj.weight, self.out_proj.bias, self.training,
                    key_padding_mask, need_weights, attn_mask)
            else:
                return F.multi_head_attention_forward(
                    query,
                    key,
                    value,
                    self.embed_dim,
                    self.num_heads,
                    torch.empty([0]),
                    self.in_proj_bias,
                    self.bias_k,
                    self.bias_v,
                    self.add_zero_attn,
                    self.dropout,
                    self.out_proj.weight,
                    self.out_proj.bias,
                    self.training,
                    key_padding_mask,
                    need_weights,
                    attn_mask,
                    use_separate_proj_weight=True,
                    q_proj_weight=self.q_proj_weight,
                    k_proj_weight=self.k_proj_weight,
                    v_proj_weight=self.v_proj_weight)

        tmp_q, tmp_k = query, key
        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
            if 'prev_key' in saved_state:
                # previous time steps are cached - no need to recompute
                # key and value if they are static
                if static_kv:
                    assert self.encoder_decoder_attention and not self.self_attention
                    key = value = None
        else:
            saved_state = None

        focus, salient_g = None, None
        if self.self_attention:
            # self-attention
            q, k, v = self.in_proj_qkv(query)
            # _g = q.sum(0).unsqueeze(0).repeat(tgt_len, 1, 1)
            _g = torch.mean(q, dim=0, keepdim=True)
            tmp_tensor = torch.tanh(self.proj_p(q) + self.proj_g(_g))
            # tmp_tensor = tmp_tensor.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
            tmp_tensor = tmp_tensor.view(tgt_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) \
                .transpose(1, 2)
            # calculate tanh(w_q*q+w_g*g)
            # mu = self.proj_c(tmp_tensor)
            # sigma = self.proj_d(tmp_tensor)
            # bsz * num_heads * tgt_len * head_dim
            mu = tmp_tensor * self.weight_c
            # bsz * num_heads * tgt_len
            mu = mu.sum(3).squeeze()
            sigma = tmp_tensor * self.weight_d
            sigma = sigma.sum(3).squeeze()

            # norm for mu and sigma $m * sigmoid(mu)$, and for self-attention query=key=value, so tgt_len == key_len
            mu = tgt_len * torch.sigmoid(mu)  # size(tgt_len)
            sigma = tgt_len * torch.sigmoid(sigma)  # size(tgt_len)

            # mu = mu.repeat(1, 1, tgt_len)
            # sigma = sigma.repeat(1, 1, tgt_len)
            # abs_pos = torch.arange(start=0, end=tgt_len, dtype=mu.dtype).unsqueeze(0).repeat(tgt_len, 1).cuda()
            # 1 * 1 * 1 * tgt_len(key_len)
            abs_pos = torch.arange(
                start=0, end=tgt_len,
                dtype=mu.dtype).unsqueeze(0).unsqueeze(0).unsqueeze(0).cuda()
            mu = mu.unsqueeze(-1)
            sigma = sigma.unsqueeze(-1)

            focus = -2 * (sigma**(-2)) * (
                (abs_pos - mu)**2)  # -\frac{(p-u)^2}{sigma^2/2}
            focus = focus.view(bsz * self.num_heads, tgt_len, tgt_len)

        elif self.encoder_decoder_attention:
            # encoder-decoder attention
            q = self.in_proj_q(query)
            if key is None:
                assert value is None
                k = v = None
            else:
                k = self.in_proj_k(key)
                v = self.in_proj_v(key)
            # tmp_q = q

            tmp_q = self.proj_h(tmp_q)
            tmp_q = tmp_q.contiguous().view(tgt_len, bsz * self.num_heads,
                                            self.head_dim).transpose(0, 1)
            tmp_k = self.proj_s(
                tmp_k
            )  # F.linear(k, self.s_proj_weight, bias=False)# self.proj_s(k)
            tmp_k = tmp_k.contiguous().view(-1, bsz * self.num_heads,
                                            self.head_dim).transpose(
                                                0, 1).transpose(1, 2)
            salient_g = torch.sigmoid(torch.bmm(tmp_q, tmp_k))
            # print(salient_g.size())
        else:
            q = self.in_proj_q(query)
            k = self.in_proj_k(key)
            v = self.in_proj_v(value)
        q *= self.scaling

        if self.bias_k is not None:
            assert self.bias_v is not None
            k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat([
                    key_padding_mask,
                    key_padding_mask.new_zeros(key_padding_mask.size(0), 1)
                ],
                                             dim=1)

        q = q.contiguous().view(tgt_len, bsz * self.num_heads,
                                self.head_dim).transpose(0, 1)
        if k is not None:
            k = k.contiguous().view(-1, bsz * self.num_heads,
                                    self.head_dim).transpose(0, 1)
        if v is not None:
            v = v.contiguous().view(-1, bsz * self.num_heads,
                                    self.head_dim).transpose(0, 1)

        if saved_state is not None:
            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
            if 'prev_key' in saved_state:
                prev_key = saved_state['prev_key'].view(
                    bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    k = prev_key
                else:
                    k = torch.cat((prev_key, k), dim=1)
            if 'prev_value' in saved_state:
                prev_value = saved_state['prev_value'].view(
                    bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    v = prev_value
                else:
                    v = torch.cat((prev_value, v), dim=1)
            saved_state['prev_key'] = k.view(bsz, self.num_heads, -1,
                                             self.head_dim)
            saved_state['prev_value'] = v.view(bsz, self.num_heads, -1,
                                               self.head_dim)

            self._set_input_buffer(incremental_state, saved_state)

        src_len = k.size(1)

        # This is part of a workaround to get around fork/join parallelism
        # not supporting Optional types.
        if key_padding_mask is not None and key_padding_mask.shape == torch.Size(
            []):
            key_padding_mask = None

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len

        if self.add_zero_attn:
            src_len += 1
            k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])],
                          dim=1)
            v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])],
                          dim=1)
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat([
                    key_padding_mask,
                    torch.zeros(key_padding_mask.size(0),
                                1).type_as(key_padding_mask)
                ],
                                             dim=1)

        attn_weights = torch.bmm(q, k.transpose(1, 2))
        attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len,
                                              bsz)

        if self.self_attention:
            # focus = -2 * (sigma**-2) * ((abs_pos-mu)**2)  # -\frac{(p-u)^2}{sigma^2/2}
            # print(focus.size())
            # print(attn_weights.size())
            attn_weights = attn_weights + focus

        assert list(
            attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]

        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(0)
            if self.onnx_trace:
                attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
            attn_weights += attn_mask

        if key_padding_mask is not None:
            # don't attend to padding symbols
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
                                             src_len)
            if self.onnx_trace:
                attn_weights = torch.where(
                    key_padding_mask.unsqueeze(1).unsqueeze(2),
                    torch.Tensor([float("-Inf")]),
                    attn_weights.float()).type_as(attn_weights)
            else:
                attn_weights = attn_weights.masked_fill(
                    key_padding_mask.unsqueeze(1).unsqueeze(2),
                    float('-inf'),
                )
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len,
                                             src_len)

        attn_weights = utils.softmax(
            attn_weights,
            dim=-1,
            onnx_trace=self.onnx_trace,
        ).type_as(attn_weights)

        if self.encoder_decoder_attention:
            # print(attn_weights.size())
            attn_weights = attn_weights * salient_g

        attn_weights = F.dropout(attn_weights,
                                 p=self.dropout,
                                 training=self.training)

        attn = torch.bmm(attn_weights, v)

        assert list(
            attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
        if (self.onnx_trace and attn.size(1) == 1):
            # when ONNX tracing a single decoder step (sequence length == 1)
            # the transpose is a no-op copy before view, thus unnecessary
            attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
        else:
            attn = attn.transpose(0,
                                  1).contiguous().view(tgt_len, bsz, embed_dim)
        attn = self.out_proj(attn)

        if need_weights:
            # average attention weights over heads
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
                                             src_len)
            attn_weights = attn_weights.sum(dim=1) / self.num_heads
        else:
            attn_weights = None

        return attn, attn_weights