def _get_inputs(self, N=10, L=5, S=8, H=4, E=32, D=64, device="cpu"):
     return (torch.rand(N, L, H, E).to(device), torch.rand(N, S, H,
                                                           E).to(device),
             torch.rand(N, S, H, D).to(device), FullMask(L,
                                                         S,
                                                         device=device),
             FullMask(N, L, device=device), FullMask(N, S, device=device))
Ejemplo n.º 2
0
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):

        query_layer, key_layer, value_layer = self.project_QKV(hidden_states)

        query_layer = query_layer.transpose(1, 2)
        key_layer = key_layer.transpose(1, 2)
        value_layer = value_layer.transpose(1, 2)

        # Change mask behavior to be compatible with https://github.com/idiap/fast-transformers
        attention_mask = (attention_mask.squeeze(1).squeeze(1) + 10000) / 10000
        attention_mask = FullMask(mask=attention_mask.bool())
        attention_length = LengthMask(attention_mask.lengths, max_len=t)

        context_layer = self.reformer(queries=query_layer,
                                      keys=key_layer,
                                      values=value_layer,
                                      attn_mask=attention_mask,
                                      query_lengths=attention_length,
                                      key_lengths=attention_length)

        context_layer = self.reshape_output(context_layer.transpose(1, 2))

        return (context_layer, )
Ejemplo n.º 3
0
    def test_lower_triangular(self):
        m = TriangularCausalMask(3)
        self.assertTrue(m.lower_triangular)
        self.assertTrue(torch.all(m.bool_matrix == (torch.tensor([
            [1, 0, 0],
            [1, 1, 0],
            [1, 1, 1]
        ]) > 0)))

        m = FullMask(torch.tensor([
            [1, 0, 0],
            [1, 1, 0],
            [1, 1, 1]
        ]) > 0)
        self.assertTrue(m.lower_triangular)

        m = FullMask(torch.tensor([
            [1, 0, 1],
            [1, 1, 0],
            [1, 1, 1]
        ]) > 0)
        self.assertFalse(m.lower_triangular)

        m = LengthMask(torch.tensor([1, 1, 3]))
        self.assertFalse(m.lower_triangular)
        m = LengthMask(torch.tensor([1, 2, 3]))
        self.assertTrue(m.lower_triangular)
        m = LengthMask(torch.tensor([1, 2, 3]), max_len=4)
        self.assertTrue(m.lower_triangular)
Ejemplo n.º 4
0
    def test_feature_map_sharing(self):
        x = torch.rand(3, 100, 4 * 32)
        f = Favor.factory(n_dims=64)
        att = AttentionLayer(LinearAttention(32, f), 4 * 32, 4)

        attn_mask = FullMask(100)
        lengths = FullMask(3, 100)
        y = att(x, x, x, attn_mask, lengths, lengths)
        y = att(y, y, y, attn_mask, lengths, lengths)
        y.sum().backward()
Ejemplo n.º 5
0
    def test_full_mask_constructor_arguments(self):
        m = FullMask(torch.rand(10, 10) > 0.5)
        self.assertEqual(m.shape, (10, 10))
        self.assertFalse(m.all_ones)

        m = FullMask(10)
        self.assertEqual(m.shape, (10, 10))
        self.assertTrue(m.all_ones)

        m = FullMask(10, 5)
        self.assertEqual(m.shape, (10, 5))
        self.assertTrue(m.all_ones)
Ejemplo n.º 6
0
    def test_full_mask(self):
        m = FullMask(N=10)
        self.assertEqual(m.shape, (10, 10))
        self.assertTrue(torch.all(m.bool_matrix))
        self.assertTrue(torch.all(m.float_matrix == 1))
        self.assertTrue(torch.all(m.additive_matrix == 0))

        with self.assertRaises(ValueError):
            m = FullMask(torch.rand(10))

        m = FullMask(torch.rand(10, 5) > 0.5)
        self.assertEqual(m.shape, (10, 5))
Ejemplo n.º 7
0
    def test_masking(self):
        q, k, v, m1, m2, m3 = self._get_inputs()
        m1 = FullMask(torch.rand(5, 8) > 0.5)

        att = AFTFullAttention()
        v = att(q, k, v, m1, m2, m3)

        att = AFTSimpleAttention()
        with self.assertRaises(ValueError):
            v = att(q, k, v, m1, m2, m3)

        q, k, v, m1, m2, m3 = self._get_inputs(L=8, S=8)
        m1 = FullMask(torch.tril(torch.ones(8, 8, dtype=torch.bool)))
        v = att(q, k, v, m1, m2, m3)
Ejemplo n.º 8
0
    def test_casting_to_lengths(self):
        m = FullMask(torch.tensor([
            [1, 0, 0],
            [1, 1, 0],
            [1, 1, 1]
        ]) > 0)
        self.assertEqual(m.shape, (3, 3))
        self.assertTrue(torch.all(m.lengths == torch.tensor([1, 2, 3])))

        m = FullMask(torch.tensor([
            [1, 0, 1],
            [1, 1, 0],
            [1, 1, 1]
        ]) > 0)
        with self.assertRaises(ValueError):
            m.lengths
Ejemplo n.º 9
0
 def forward_pre(self, tgt, memory,
                 tgt_mask: Optional[Tensor] = None,
                 memory_mask: Optional[Tensor] = None,
                 tgt_key_padding_mask: Optional[Tensor] = None,
                 memory_key_padding_mask: Optional[Tensor] = None,
                 pos: Optional[Tensor] = None,
                 query_pos: Optional[Tensor] = None):
     tgt2 = self.norm1(tgt)
     q = k = self.with_pos_embed(tgt2, query_pos)
     N, L, E = q.shape
     attn_mask = FullMask(L, device=q.device)
     length_mask = LengthMask(src.new_full((N,), L, dtype=torch.int64))
     tgt2 = self.self_attn(q, k, tgt2, attn_mask,
                           length_mask, length_mask)
     tgt = tgt + self.dropout1(tgt2)
     tgt2 = self.norm2(tgt)
     tgt2 = self.multihead_attn(self.with_pos_embed(tgt2, query_pos),
                                self.with_pos_embed(memory, pos),
                                meomory, attn_mask,
                                length_mask, length_mask)
     tgt = tgt + self.dropout2(tgt2)
     tgt2 = self.norm3(tgt)
     tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
     tgt = tgt + self.dropout3(tgt2)
     return tgt
Ejemplo n.º 10
0
 def forward_post(self, tgt, memory,
                  tgt_mask: Optional[Tensor] = None,
                  memory_mask: Optional[Tensor] = None,
                  tgt_key_padding_mask: Optional[Tensor] = None,
                  memory_key_padding_mask: Optional[Tensor] = None,
                  pos: Optional[Tensor] = None,
                  query_pos: Optional[Tensor] = None):
     N, L, E = tgt.shape
     _, S, E = memory.shape
     q = k = self.with_pos_embed(tgt, query_pos)
     attn_mask = FullMask(L, device=q.device)
     memory_length = LengthMask(tgt.new_full((N,), S, dtype=torch.int64))
     query_length = LengthMask(tgt.new_full((N,), L, dtype=torch.int64))
     tgt2 = self.self_attn(q, k, tgt, attn_mask, query_length, query_length)
     
     tgt = tgt + self.dropout1(tgt2)
     tgt = self.norm1(tgt)
     tgt2 = self.multihead_attn(self.with_pos_embed(tgt, query_pos),
                                self.with_pos_embed(memory, pos),
                                memory, attn_mask,
                                query_length, memory_length)
     tgt = tgt + self.dropout2(tgt2)
     tgt = self.norm2(tgt)
     tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
     tgt = tgt + self.dropout3(tgt2)
     tgt = self.norm3(tgt)
     return tgt
Ejemplo n.º 11
0
    def test_correctness(self):
        # Prepare the inputs
        N = 10
        H = 4
        E = 25
        M = 64
        L = 42
        S = 100
        q = torch.rand(N, L, H, E)
        k = torch.rand(N, S, H, E)
        v = torch.rand(N, S, H, M)
        m1 = FullMask(L, S)
        m2 = LengthMask(torch.full((N, ), L, dtype=torch.int64))
        m3 = LengthMask(torch.full((N, ), S, dtype=torch.int64))

        # Get the outputs from the attention in batch mode
        att = LinearAttention(E)
        att.eval()
        v_out1 = att(q, k, v, m1, m2, m3)

        # Get the output from the attention in recurrent mode
        att = RecurrentCrossLinearAttention(E)
        att.eval()
        v_out2_unstacked = []
        state = None
        for i in range(L):
            vi, state = att(q[:, i], k, v, m3, state=state)
            v_out2_unstacked.append(vi)
        v_out2 = torch.stack(v_out2_unstacked, dim=1)

        # Check that they match
        self.assertLess(torch.abs(v_out1 - v_out2).max(), 1e-6)
Ejemplo n.º 12
0
    def forward(
        self,
        x,
        attn_mask = None,
        length_mask = None,
    ):
        """
        LayerNorm is applied either before or after the self-attention/ffn
        modules similar to the original Transformer imlementation.
        """
        N = x.shape[0]
        L = x.shape[1]
        attn_mask = attn_mask or FullMask(L, device=x.device)
        length_mask = length_mask or \
            LengthMask(x.new_full((N,), L, dtype=torch.int64))

        residual = x

        if self.layer_norm_first:
            x = self.self_attn_layer_norm(x)
            x = self.self_attn(
                x, x, x,
                attn_mask=attn_mask,
                query_lengths=length_mask,
                key_lengths=length_mask,
            )
            x = self.dropout1(x)
            x = residual + x

            residual = x
            x = self.final_layer_norm(x)
            x = self.activation_fn(self.fc1(x))
            x = self.dropout2(x)
            x = self.fc2(x)
            x = self.dropout3(x)
            x = residual + x
        else:
            x = self.self_attn(
                x, x, x,
                attn_mask=attn_mask,
                query_lengths=length_mask,
                key_lengths=length_mask,
            )
            x = self.dropout1(x)
            x = residual + x

            x = self.self_attn_layer_norm(x)

            residual = x
            x = self.activation_fn(self.fc1(x))
            x = self.dropout2(x)
            x = self.fc2(x)
            x = self.dropout3(x)
            x = residual + x
            x = self.final_layer_norm(x)

        return x, None
Ejemplo n.º 13
0
    def test_compare_with_batch(self):
        N = 10
        L = 42
        S = 100
        D = 1024
        E = D // 4
        x = torch.rand(N, L, D)
        m = torch.rand(N, S, D)

        tests = [("full", FullAttention, FullAttention, RecurrentFullAttention,
                  RecurrentCrossFullAttention),
                 ("linear", partial(CausalLinearAttention,
                                    E), partial(LinearAttention, E),
                  partial(RecurrentLinearAttention,
                          E), partial(RecurrentCrossLinearAttention, E))]

        for name, a1, a2, a3, a4 in tests:
            dec = TransformerDecoder([
                TransformerDecoderLayer(AttentionLayer(a1(), D, 4),
                                        AttentionLayer(a2(), D, 4), D)
                for i in range(4)
            ])
            rdec = RecurrentTransformerDecoder([
                RecurrentTransformerDecoderLayer(
                    RecurrentAttentionLayer(a3(), D, 4),
                    RecurrentCrossAttentionLayer(a4(), D, 4), D)
                for i in range(4)
            ])
            dec.eval()
            rdec.eval()
            rdec.load_state_dict(dec.state_dict())

            x_mask = TriangularCausalMask(L)
            x_length = LengthMask(torch.full((N, ), L, dtype=torch.int64))
            m_mask = FullMask(L, S)
            m_length = LengthMask(torch.full((N, ), S, dtype=torch.int64))

            y1 = dec(x,
                     m,
                     x_mask=x_mask,
                     x_length_mask=x_length,
                     memory_mask=m_mask,
                     memory_length_mask=m_length)
            state = None
            y2 = []
            for i in range(L):
                y2i, state = rdec(x[:, i],
                                  m,
                                  memory_length_mask=m_length,
                                  state=state)
                y2.append(y2i)
            y2 = torch.stack(y2, dim=1)

            self.assertLess(torch.abs(y1 - y2).max(), 1e-5)
    def test_compare_with_full(self):
        local_att = LocalAttention(17, softmax_temp=1).eval()
        full_att = FullAttention(softmax_temp=1).eval()

        q, k, v, m1, m2, m3 = self._get_inputs(N=10, L=128, S=128, D=32)
        m = FullMask(
            torch.abs(torch.arange(128)[:, None] -
                      torch.arange(128)[None]) < 9)
        v_full = full_att(q, k, v, m, m2, m3)
        v_local = local_att(q, k, v, m1, m2, m3)

        self.assertTrue(torch.allclose(v_full, v_local, atol=1e-5, rtol=1e-5))
    def test_masking(self):
        att = LinearAttention(32)
        q, k, v, m1, m2, m3 = self._get_inputs()

        # Make sure that we raise an error if m1 is not all ones
        with self.assertRaises(RuntimeError):
            att(q, k, v, FullMask(torch.rand(*m1.shape) > 0.5), m2, m3)

        # Make sure that the key lengths is paid attention to
        q, k, v, m1, m2, m3 = self._get_inputs(S=10, D=1)
        m3 = LengthMask(torch.tensor(list(range(10))) + 1)
        for i in range(9):
            v[i, i + 1:] = 1e9
        v_new = att(q, k, v, m1, m2, m3)
        self.assertLess(v_new.max().item(), 1)
Ejemplo n.º 16
0
 def forward_pre(self, src,
                 src_mask: Optional[Tensor] = None,
                 src_key_padding_mask: Optional[Tensor] = None,
                 pos: Optional[Tensor] = None):
     src2 = self.norm1(src)
     q = k = self.with_pos_embed(src2, pos)
     N, L, E = q.shape
     attn_mask = FullMask(L, device=q.device)
     length_mask = LengthMask(src.new_full((N,), L, dtype=torch.int64))
     src2 = self.self_attn(q, k, v, attn_mask, length_mask, length_mask)
     src = src + self.dropout1(src2)
     src2 = self.norm2(src)
     src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
     src = src + self.dropout2(src2)
     return src
Ejemplo n.º 17
0
    def forward(self, src, src_len_mask=None):

        x = self.initial_ff(src)

        # x = self.pos_encoding(x)

        self.lstm.flatten_parameters()
        x, _ = self.lstm(x)

        for i in range(self.tf_depth[0]):
            x = self.encoder_pre(x)

        # in case the input sequence length has changed
        set_size = src.shape[1]
        batch_size = src.shape[0]
        if self.mask.shape[1] != set_size:
            self.mask = FullMask(N=self.k, M=set_size, device=src.device)
            self.kl_mask = LengthMask(torch.ones(batch_size) * set_size,
                                      device=src.device)
        # in case the batch size has changed
        if (self.ql_mask.shape[0] != batch_size) or (self.kl_mask.shape[0] !=
                                                     batch_size):
            self.ql_mask = LengthMask(torch.ones(batch_size) * self.k,
                                      device=src.device)
            self.kl_mask = LengthMask(torch.ones(batch_size) * set_size,
                                      device=src.device)

        # extend seeds to size of batch
        S = self.seeds.unsqueeze(0).repeat(batch_size, 1, 1)

        # perform pooling by multihead attention, reducing dimensionality of set
        x = self.attn_pooling(S, x, x, self.mask, self.ql_mask, self.kl_mask)

        for i in range(self.tf_depth[1]):
            x = self.encoder_post(x)

        x = self.final_ff(x)

        return x
Ejemplo n.º 18
0
    def __init__(self,
                 num_feats,
                 num_output_points,
                 lstm_layers,
                 n_layers,
                 n_heads,
                 hidden_dim,
                 ff_dim,
                 tf_depth=3,
                 dropout=0.15):
        super(SetTransformer, self).__init__()

        self.num_feats = num_feats
        self.k = num_output_points

        def dup(x):
            return (x, x) if type(x) == int else x

        self.lstm_layers = lstm_layers
        self.n_layers = dup(n_layers)
        self.n_heads = dup(n_heads)
        self.hidden_dim = dup(hidden_dim)
        self.ff_dim = dup(ff_dim)
        self.tf_depth = dup(tf_depth)

        self.d_model = [self.hidden_dim[i] * self.n_heads[i] for i in [0, 1]]

        encoder_builder_pre = TransformerEncoderBuilder.from_kwargs(
            n_layers=self.n_layers[0],
            n_heads=self.n_heads[0],
            query_dimensions=self.hidden_dim[0],
            value_dimensions=self.hidden_dim[0],
            feed_forward_dimensions=self.ff_dim[0],
            attention_type='linear',
            dropout=dropout)

        encoder_builder_post = TransformerEncoderBuilder.from_kwargs(
            n_layers=self.n_layers[1],
            n_heads=self.n_heads[1],
            query_dimensions=self.hidden_dim[1],
            value_dimensions=self.hidden_dim[1],
            feed_forward_dimensions=self.ff_dim[1],
            attention_type='linear',
            dropout=dropout)

        self.seeds = nn.Parameter(torch.normal(0, 1,
                                               (self.k, self.d_model[0])))
        self.encoder_pre = encoder_builder_pre.get()
        self.encoder_post = encoder_builder_post.get()

        self.initial_ff = nn.Linear(self.num_feats, self.d_model[0])
        # self.pos_encoding = PositionalEncoding(self.d_model[0], dropout=dropout)
        self.lstm = nn.LSTM(self.d_model[0],
                            self.d_model[0],
                            2,
                            batch_first=True,
                            bidirectional=False)
        self.attn_pooling = AttentionLayer(LinearAttention(self.d_model[0]),
                                           self.d_model[0], self.n_heads[0])
        self.final_ff = nn.Linear(self.d_model[1], self.num_feats)

        # init masks to meaningless values, doesn't matter what. these are all empty anyway.
        self.mask = FullMask(N=self.k, M=5)
        self.kl_mask = LengthMask(torch.ones(5) * 5)
        self.ql_mask = LengthMask(torch.ones(5) * self.k)