def __init__(
        self,
        embedding_dim: int = 768,
        ffn_embedding_dim: int = 3072,
        num_attention_heads: int = 8,
        dropout: float = 0.1,
        attention_dropout: float = 0.1,
        activation_dropout: float = 0.1,
        activation_fn: str = 'relu',
        export: bool = False,
        is_bidirectional: bool = True,
        stride: int = 32,
        expressivity: int = 8,
    ) -> None:

        super().__init__(embedding_dim, ffn_embedding_dim, num_attention_heads,
                         dropout, attention_dropout, activation_dropout,
                         activation_fn, export)

        self.self_attn = SparseMultiheadAttention(
            self.embedding_dim,
            num_attention_heads,
            dropout=attention_dropout,
            add_bias_kv=False,
            add_zero_attn=False,
            self_attention=True,
            is_bidirectional=is_bidirectional,
            stride=stride,
            expressivity=expressivity,
        )
Пример #2
0
    def test_sparse_multihead_attention(self):
        attn_weights = torch.randn(1, 8, 8)
        bidirectional_sparse_mask = torch.tensor([
            [0, 0, 0, 0, 0, float("-inf"),
             float("-inf"), 0],
            [0, 0, 0, 0, 0, float("-inf"),
             float("-inf"), 0],
            [0, 0, 0, 0, 0, float("-inf"),
             float("-inf"), 0],
            [0, 0, 0, 0, 0, float("-inf"),
             float("-inf"), 0],
            [float("-inf"),
             float("-inf"),
             float("-inf"), 0, 0, 0, 0, 0],
            [float("-inf"),
             float("-inf"),
             float("-inf"), 0, 0, 0, 0, 0],
            [float("-inf"),
             float("-inf"),
             float("-inf"), 0, 0, 0, 0, 0],
            [float("-inf"),
             float("-inf"),
             float("-inf"), 0, 0, 0, 0, 0],
        ])

        bidirectional_attention = SparseMultiheadAttention(
            16, 1, stride=4, expressivity=1, is_bidirectional=True)
        bidirectional_attention_sparse_mask = (
            bidirectional_attention.buffered_sparse_mask(attn_weights, 8, 8))
        torch.all(
            torch.eq(bidirectional_attention_sparse_mask,
                     bidirectional_sparse_mask))

        sparse_mask = torch.tensor([
            [
                0,
                float("-inf"),
                float("-inf"),
                float("-inf"),
                float("-inf"),
                float("-inf"),
                float("-inf"),
                float("-inf"),
            ],
            [
                0,
                0,
                float("-inf"),
                float("-inf"),
                float("-inf"),
                float("-inf"),
                float("-inf"),
                float("-inf"),
            ],
            [
                0,
                0,
                0,
                float("-inf"),
                float("-inf"),
                float("-inf"),
                float("-inf"),
                float("-inf"),
            ],
            [
                0,
                0,
                0,
                0,
                float("-inf"),
                float("-inf"),
                float("-inf"),
                float("-inf"),
            ],
            [0, 0, 0, 0, 0,
             float("-inf"),
             float("-inf"),
             float("-inf")],
            [
                float("-inf"),
                float("-inf"),
                float("-inf"),
                0,
                0,
                0,
                float("-inf"),
                float("-inf"),
            ],
            [
                float("-inf"),
                float("-inf"),
                float("-inf"),
                0,
                0,
                0,
                0,
                float("-inf"),
            ],
            [float("-inf"),
             float("-inf"),
             float("-inf"), 0, 0, 0, 0, 0],
        ])

        attention = SparseMultiheadAttention(16,
                                             1,
                                             stride=4,
                                             expressivity=1,
                                             is_bidirectional=False)
        attention_sparse_mask = attention.buffered_sparse_mask(
            attn_weights, 8, 8)

        torch.all(torch.eq(attention_sparse_mask, sparse_mask))