Exemple #1
0
class SelfMultiheadAttn(nn.Module):
    """Multi-headed attention.

    See "Attention Is All You Need" for more details.
    """
    def __init__(
        self,
        embed_dim,
        num_heads,
        dropout=0.0,
        bias=False,
        include_norm_add=False,
        impl="fast",
        separate_qkv_params=False,
        mask_additive=False,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        self.bias = bias
        self.include_norm_add = include_norm_add
        self.impl = impl
        self.scaling = self.head_dim**-0.5
        self.separate_qkv_params = separate_qkv_params
        self.mask_additive = mask_additive
        if mask_additive:
            assert self.include_norm_add == False, "additive mask not supported with layer norm"
            assert impl == "default" or (
                impl == "fast" and
                bias), "additive mask not supported for fast mode without bias"
        if separate_qkv_params:
            self.q_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
            self.k_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
            self.v_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
        else:
            self.in_proj_weight = Parameter(
                torch.Tensor(3 * embed_dim, embed_dim))
        self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
        if self.bias:
            if separate_qkv_params:
                self.q_bias = Parameter(torch.Tensor(embed_dim))
                self.k_bias = Parameter(torch.Tensor(embed_dim))
                self.v_bias = Parameter(torch.Tensor(embed_dim))
            else:
                self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
            self.out_proj_bias = Parameter(torch.Tensor(embed_dim))
        else:
            if separate_qkv_params:
                self.register_parameter("q_bias", None)
                self.register_parameter("k_bias", None)
                self.register_parameter("v_bias", None)
                self.q_bias = None
                self.k_bias = None
                self.v_bias = None
            else:
                self.register_parameter("in_proj_bias", None)
                self.in_proj_bias = None
            self.register_parameter("out_proj_bias", None)
            self.out_proj_bias = None
        if self.include_norm_add:
            if impl == "fast":
                self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))
                self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim))
                self.lyr_nrm = None
            else:
                self.register_parameter("lyr_norm_gamma_weights", None)
                self.register_parameter("lyr_norm_beta_weights", None)
                self.lyr_nrm_gamma_weights = None
                self.lyr_nrm_beta_weights = None
                self.lyr_nrm = FusedLayerNorm(embed_dim)
        self.reset_parameters()

        if self.include_norm_add:
            if impl == "fast":
                self.attn_func = fast_self_attn_norm_add_func
            elif impl == "default":
                self.attn_func = self_attn_func
            else:
                assert False, "Unsupported impl: {} !".format(impl)
        else:
            if impl == "fast":
                self.attn_func = fast_self_attn_func
            elif impl == "default":
                self.attn_func = self_attn_func
            else:
                assert False, "Unsupported impl: {} !".format(impl)

    def reset_parameters(self):
        if self.separate_qkv_params:
            nn.init.xavier_uniform_(self.q_weight)
            nn.init.xavier_uniform_(self.k_weight)
            nn.init.xavier_uniform_(self.v_weight)
        else:
            # in_proj_weight has shape [3 * hidden, hidden] but it should be
            # initialized like a [hidden, hidden] matrix.
            # sqrt(6 / (hidden + hidden)) / sqrt(6 / (3 * hidden + hidden)) = sqrt(2)
            # therefore xavier_uniform gain should be set to sqrt(2).
            nn.init.xavier_uniform_(self.in_proj_weight, gain=math.sqrt(2))
        nn.init.xavier_uniform_(self.out_proj_weight)
        if self.bias:
            if self.separate_qkv_params:
                nn.init.constant_(self.q_bias, 0.0)
                nn.init.constant_(self.k_bias, 0.0)
                nn.init.constant_(self.v_bias, 0.0)
            else:
                nn.init.constant_(self.in_proj_bias, 0.0)
            nn.init.constant_(self.out_proj_bias, 0.0)
        if self.include_norm_add:
            if self.impl == "fast":
                nn.init.ones_(self.lyr_nrm_gamma_weights)
                nn.init.zeros_(self.lyr_nrm_beta_weights)
            else:
                self.lyr_nrm.reset_parameters()

    def forward(self,
                query,
                key,
                value,
                key_padding_mask=None,
                need_weights=False,
                attn_mask=None,
                is_training=True):
        """Input shape: Time x Batch x Channel

        Self-attention can be implemented by passing in the same arguments for
        query, key and value. Future timesteps can be masked with the
        `mask_future_timesteps` argument. Padding elements can be excluded from
        the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
        batch x src_len, where padding elements are indicated by 1s.
        """
        if self.separate_qkv_params:
            input_weights = (torch.cat(
                [
                    self.q_weight.view(self.num_heads, 1, self.head_dim,
                                       self.embed_dim),
                    self.k_weight.view(self.num_heads, 1, self.head_dim,
                                       self.embed_dim),
                    self.v_weight.view(self.num_heads, 1, self.head_dim,
                                       self.embed_dim),
                ],
                dim=1,
            ).reshape(3 * self.embed_dim, self.embed_dim).contiguous())
        else:
            input_weights = self.in_proj_weight
        if self.bias:
            if self.separate_qkv_params:
                input_bias = (torch.cat(
                    [
                        self.q_bias.view(self.num_heads, 1, self.head_dim),
                        self.k_bias.view(self.num_heads, 1, self.head_dim),
                        self.v_bias.view(self.num_heads, 1, self.head_dim),
                    ],
                    dim=1,
                ).reshape(3 * self.embed_dim).contiguous())
            else:
                input_bias = self.in_proj_bias
        else:
            input_bias = None
        if key_padding_mask is not None:
            assert attn_mask is None, "ERROR attn_mask and key_padding_mask should not be both defined!"
            mask = key_padding_mask
        elif attn_mask is not None:
            assert self.mask_additive == False, "additive mask not supported for time mask"
            mask = attn_mask
        else:
            mask = None

        if self.include_norm_add:
            if self.impl == "fast":
                outputs = self.attn_func(
                    attn_mask is not None,
                    is_training,
                    self.num_heads,
                    query,
                    self.lyr_nrm_gamma_weights,
                    self.lyr_nrm_beta_weights,
                    input_weights,
                    self.out_proj_weight,
                    mask,
                    self.dropout,
                )
            else:
                lyr_nrm_results = self.lyr_nrm(query)
                outputs = self.attn_func(
                    attn_mask is not None,
                    is_training,
                    self.num_heads,
                    self.scaling,
                    lyr_nrm_results,
                    input_weights,
                    self.out_proj_weight,
                    input_bias,
                    self.out_proj_bias,
                    mask,
                    self.mask_additive,
                    self.dropout,
                )
                if is_training:
                    outputs = jit_dropout_add(outputs, query, self.dropout,
                                              is_training)
                else:
                    outputs = outputs + query
        else:
            if self.impl == "fast":
                outputs = self.attn_func(
                    attn_mask is not None,
                    is_training,
                    self.num_heads,
                    query,
                    input_weights,
                    self.out_proj_weight,
                    input_bias,
                    self.out_proj_bias,
                    mask,
                    self.mask_additive,
                    self.dropout,
                )
            else:
                outputs = self.attn_func(
                    attn_mask is not None,
                    is_training,
                    self.num_heads,
                    self.scaling,
                    query,
                    input_weights,
                    self.out_proj_weight,
                    input_bias,
                    self.out_proj_bias,
                    mask,
                    self.mask_additive,
                    self.dropout,
                )

        return outputs, None
Exemple #2
0
class EncdecMultiheadAttn(nn.Module):
    """Multi-headed attention.

    See "Attention Is All You Need" for more details.
    """
    def __init__(self,
                 embed_dim,
                 num_heads,
                 dropout=0.,
                 bias=False,
                 include_norm_add=False,
                 impl='fast'):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        self.bias = bias
        self.include_norm_add = include_norm_add
        self.impl = impl
        self.scaling = self.head_dim**-0.5

        self.in_proj_weight_q = Parameter(torch.Tensor(embed_dim, embed_dim))
        self.in_proj_weight_kv = Parameter(
            torch.Tensor(2 * embed_dim, embed_dim))
        self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
        if self.bias:
            assert impl != 'fast', "ERROR! The Fast implementation does not support biases!"
            self.in_proj_bias_q = Parameter(torch.Tensor(embed_dim))
            self.in_proj_bias_kv = Parameter(torch.Tensor(2 * embed_dim))
            self.out_proj_bias = Parameter(torch.Tensor(embed_dim))
        else:
            self.register_parameter('in_proj_bias_q', None)
            self.register_parameter('in_proj_bias_kv', None)
            self.in_proj_bias_q = None
            self.in_proj_bias_kv = None
            self.out_proj_bias = None
        if self.include_norm_add:
            if impl == 'fast':
                self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))
                self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim))
                self.lyr_nrm = None
            else:
                self.register_parameter('lyr_norm_gamma_weights', None)
                self.register_parameter('lyr_norm_beta_weights', None)
                self.lyr_nrm_gamma_weights = None
                self.lyr_nrm_beta_weights = None
                self.lyr_nrm = FusedLayerNorm(embed_dim)
        self.reset_parameters()

        if self.include_norm_add:
            if impl == 'fast': self.attn_func = fast_encdec_attn_norm_add_func
            elif impl == 'default': self.attn_func = encdec_attn_func
            else: assert False, "Unsupported impl: {} !".format(impl)
        else:
            if impl == 'fast': self.attn_func = fast_encdec_attn_func
            elif impl == 'default': self.attn_func = encdec_attn_func
            else: assert False, "Unsupported impl: {} !".format(impl)

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.in_proj_weight_q)
        nn.init.xavier_uniform_(self.in_proj_weight_kv)
        nn.init.xavier_uniform_(self.out_proj_weight)
        if self.bias:
            nn.init.constant_(self.in_proj_bias_q, 0.)
            nn.init.constant_(self.in_proj_bias_kv, 0.)
            nn.init.constant_(self.out_proj_bias, 0.)
        if self.include_norm_add:
            if self.impl == 'fast':
                nn.init.ones_(self.lyr_nrm_gamma_weights)
                nn.init.zeros_(self.lyr_nrm_beta_weights)
            else:
                self.lyr_nrm.reset_parameters()

    def forward(self,
                query,
                key,
                value,
                key_padding_mask=None,
                need_weights=False,
                attn_mask=None,
                is_training=True):
        """Input shape: Time x Batch x Channel

        Self-attention can be implemented by passing in the same arguments for
        query, key and value. Future timesteps can be masked with the
        `mask_future_timesteps` argument. Padding elements can be excluded from
        the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
        batch x src_len, where padding elements are indicated by 1s.
        """

        if key_padding_mask is not None:
            assert (
                attn_mask is None
            ), "ERROR attn_mask and key_padding_mask should not be both defined!"
            mask = key_padding_mask
        elif attn_mask is not None:
            mask = attn_mask
        else:
            mask = None

        if self.include_norm_add:
            if self.impl == 'fast':
                outputs = self.attn_func(
                    attn_mask is not None, is_training, self.num_heads, query,
                    key, self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights,
                    self.in_proj_weight_q, self.in_proj_weight_kv,
                    self.out_proj_weight, mask, self.dropout)
            else:
                lyr_nrm_results = self.lyr_nrm(query)
                outputs = self.attn_func(
                    attn_mask is not None, is_training, self.num_heads,
                    self.scaling, lyr_nrm_results, key, self.in_proj_weight_q,
                    self.in_proj_weight_kv, self.out_proj_weight,
                    self.in_proj_bias_q, self.in_proj_bias_kv,
                    self.out_proj_bias, mask, self.dropout)
                if is_training:
                    outputs = jit_dropout_add(outputs, query, self.dropout,
                                              is_training)
                else:
                    outputs = outputs + query
        else:
            if self.impl == 'fast':
                outputs = self.attn_func(attn_mask is not None, is_training,
                                         self.num_heads, query, key,
                                         self.in_proj_weight_q,
                                         self.in_proj_weight_kv,
                                         self.out_proj_weight, mask,
                                         self.dropout)
            else:
                outputs = self.attn_func(
                    attn_mask is not None, is_training, self.num_heads,
                    self.scaling, query, key, self.in_proj_weight_q,
                    self.in_proj_weight_kv, self.out_proj_weight,
                    self.in_proj_bias_q, self.in_proj_bias_kv,
                    self.out_proj_bias, mask, self.dropout)

        return outputs, None