def _get_inputs(self, N=10, L=5, S=8, H=4, E=32, D=64, device="cpu"): return (torch.rand(N, L, H, E).to(device), torch.rand(N, S, H, E).to(device), torch.rand(N, S, H, D).to(device), FullMask(L, S, device=device), FullMask(N, L, device=device), FullMask(N, S, device=device))
def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, ): query_layer, key_layer, value_layer = self.project_QKV(hidden_states) query_layer = query_layer.transpose(1, 2) key_layer = key_layer.transpose(1, 2) value_layer = value_layer.transpose(1, 2) # Change mask behavior to be compatible with https://github.com/idiap/fast-transformers attention_mask = (attention_mask.squeeze(1).squeeze(1) + 10000) / 10000 attention_mask = FullMask(mask=attention_mask.bool()) attention_length = LengthMask(attention_mask.lengths, max_len=t) context_layer = self.reformer(queries=query_layer, keys=key_layer, values=value_layer, attn_mask=attention_mask, query_lengths=attention_length, key_lengths=attention_length) context_layer = self.reshape_output(context_layer.transpose(1, 2)) return (context_layer, )
def test_lower_triangular(self): m = TriangularCausalMask(3) self.assertTrue(m.lower_triangular) self.assertTrue(torch.all(m.bool_matrix == (torch.tensor([ [1, 0, 0], [1, 1, 0], [1, 1, 1] ]) > 0))) m = FullMask(torch.tensor([ [1, 0, 0], [1, 1, 0], [1, 1, 1] ]) > 0) self.assertTrue(m.lower_triangular) m = FullMask(torch.tensor([ [1, 0, 1], [1, 1, 0], [1, 1, 1] ]) > 0) self.assertFalse(m.lower_triangular) m = LengthMask(torch.tensor([1, 1, 3])) self.assertFalse(m.lower_triangular) m = LengthMask(torch.tensor([1, 2, 3])) self.assertTrue(m.lower_triangular) m = LengthMask(torch.tensor([1, 2, 3]), max_len=4) self.assertTrue(m.lower_triangular)
def test_feature_map_sharing(self): x = torch.rand(3, 100, 4 * 32) f = Favor.factory(n_dims=64) att = AttentionLayer(LinearAttention(32, f), 4 * 32, 4) attn_mask = FullMask(100) lengths = FullMask(3, 100) y = att(x, x, x, attn_mask, lengths, lengths) y = att(y, y, y, attn_mask, lengths, lengths) y.sum().backward()
def test_full_mask_constructor_arguments(self): m = FullMask(torch.rand(10, 10) > 0.5) self.assertEqual(m.shape, (10, 10)) self.assertFalse(m.all_ones) m = FullMask(10) self.assertEqual(m.shape, (10, 10)) self.assertTrue(m.all_ones) m = FullMask(10, 5) self.assertEqual(m.shape, (10, 5)) self.assertTrue(m.all_ones)
def test_full_mask(self): m = FullMask(N=10) self.assertEqual(m.shape, (10, 10)) self.assertTrue(torch.all(m.bool_matrix)) self.assertTrue(torch.all(m.float_matrix == 1)) self.assertTrue(torch.all(m.additive_matrix == 0)) with self.assertRaises(ValueError): m = FullMask(torch.rand(10)) m = FullMask(torch.rand(10, 5) > 0.5) self.assertEqual(m.shape, (10, 5))
def test_masking(self): q, k, v, m1, m2, m3 = self._get_inputs() m1 = FullMask(torch.rand(5, 8) > 0.5) att = AFTFullAttention() v = att(q, k, v, m1, m2, m3) att = AFTSimpleAttention() with self.assertRaises(ValueError): v = att(q, k, v, m1, m2, m3) q, k, v, m1, m2, m3 = self._get_inputs(L=8, S=8) m1 = FullMask(torch.tril(torch.ones(8, 8, dtype=torch.bool))) v = att(q, k, v, m1, m2, m3)
def test_casting_to_lengths(self): m = FullMask(torch.tensor([ [1, 0, 0], [1, 1, 0], [1, 1, 1] ]) > 0) self.assertEqual(m.shape, (3, 3)) self.assertTrue(torch.all(m.lengths == torch.tensor([1, 2, 3]))) m = FullMask(torch.tensor([ [1, 0, 1], [1, 1, 0], [1, 1, 1] ]) > 0) with self.assertRaises(ValueError): m.lengths
def forward_pre(self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): tgt2 = self.norm1(tgt) q = k = self.with_pos_embed(tgt2, query_pos) N, L, E = q.shape attn_mask = FullMask(L, device=q.device) length_mask = LengthMask(src.new_full((N,), L, dtype=torch.int64)) tgt2 = self.self_attn(q, k, tgt2, attn_mask, length_mask, length_mask) tgt = tgt + self.dropout1(tgt2) tgt2 = self.norm2(tgt) tgt2 = self.multihead_attn(self.with_pos_embed(tgt2, query_pos), self.with_pos_embed(memory, pos), meomory, attn_mask, length_mask, length_mask) tgt = tgt + self.dropout2(tgt2) tgt2 = self.norm3(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout3(tgt2) return tgt
def forward_post(self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): N, L, E = tgt.shape _, S, E = memory.shape q = k = self.with_pos_embed(tgt, query_pos) attn_mask = FullMask(L, device=q.device) memory_length = LengthMask(tgt.new_full((N,), S, dtype=torch.int64)) query_length = LengthMask(tgt.new_full((N,), L, dtype=torch.int64)) tgt2 = self.self_attn(q, k, tgt, attn_mask, query_length, query_length) tgt = tgt + self.dropout1(tgt2) tgt = self.norm1(tgt) tgt2 = self.multihead_attn(self.with_pos_embed(tgt, query_pos), self.with_pos_embed(memory, pos), memory, attn_mask, query_length, memory_length) tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout3(tgt2) tgt = self.norm3(tgt) return tgt
def test_correctness(self): # Prepare the inputs N = 10 H = 4 E = 25 M = 64 L = 42 S = 100 q = torch.rand(N, L, H, E) k = torch.rand(N, S, H, E) v = torch.rand(N, S, H, M) m1 = FullMask(L, S) m2 = LengthMask(torch.full((N, ), L, dtype=torch.int64)) m3 = LengthMask(torch.full((N, ), S, dtype=torch.int64)) # Get the outputs from the attention in batch mode att = LinearAttention(E) att.eval() v_out1 = att(q, k, v, m1, m2, m3) # Get the output from the attention in recurrent mode att = RecurrentCrossLinearAttention(E) att.eval() v_out2_unstacked = [] state = None for i in range(L): vi, state = att(q[:, i], k, v, m3, state=state) v_out2_unstacked.append(vi) v_out2 = torch.stack(v_out2_unstacked, dim=1) # Check that they match self.assertLess(torch.abs(v_out1 - v_out2).max(), 1e-6)
def forward( self, x, attn_mask = None, length_mask = None, ): """ LayerNorm is applied either before or after the self-attention/ffn modules similar to the original Transformer imlementation. """ N = x.shape[0] L = x.shape[1] attn_mask = attn_mask or FullMask(L, device=x.device) length_mask = length_mask or \ LengthMask(x.new_full((N,), L, dtype=torch.int64)) residual = x if self.layer_norm_first: x = self.self_attn_layer_norm(x) x = self.self_attn( x, x, x, attn_mask=attn_mask, query_lengths=length_mask, key_lengths=length_mask, ) x = self.dropout1(x) x = residual + x residual = x x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = self.dropout2(x) x = self.fc2(x) x = self.dropout3(x) x = residual + x else: x = self.self_attn( x, x, x, attn_mask=attn_mask, query_lengths=length_mask, key_lengths=length_mask, ) x = self.dropout1(x) x = residual + x x = self.self_attn_layer_norm(x) residual = x x = self.activation_fn(self.fc1(x)) x = self.dropout2(x) x = self.fc2(x) x = self.dropout3(x) x = residual + x x = self.final_layer_norm(x) return x, None
def test_compare_with_batch(self): N = 10 L = 42 S = 100 D = 1024 E = D // 4 x = torch.rand(N, L, D) m = torch.rand(N, S, D) tests = [("full", FullAttention, FullAttention, RecurrentFullAttention, RecurrentCrossFullAttention), ("linear", partial(CausalLinearAttention, E), partial(LinearAttention, E), partial(RecurrentLinearAttention, E), partial(RecurrentCrossLinearAttention, E))] for name, a1, a2, a3, a4 in tests: dec = TransformerDecoder([ TransformerDecoderLayer(AttentionLayer(a1(), D, 4), AttentionLayer(a2(), D, 4), D) for i in range(4) ]) rdec = RecurrentTransformerDecoder([ RecurrentTransformerDecoderLayer( RecurrentAttentionLayer(a3(), D, 4), RecurrentCrossAttentionLayer(a4(), D, 4), D) for i in range(4) ]) dec.eval() rdec.eval() rdec.load_state_dict(dec.state_dict()) x_mask = TriangularCausalMask(L) x_length = LengthMask(torch.full((N, ), L, dtype=torch.int64)) m_mask = FullMask(L, S) m_length = LengthMask(torch.full((N, ), S, dtype=torch.int64)) y1 = dec(x, m, x_mask=x_mask, x_length_mask=x_length, memory_mask=m_mask, memory_length_mask=m_length) state = None y2 = [] for i in range(L): y2i, state = rdec(x[:, i], m, memory_length_mask=m_length, state=state) y2.append(y2i) y2 = torch.stack(y2, dim=1) self.assertLess(torch.abs(y1 - y2).max(), 1e-5)
def test_compare_with_full(self): local_att = LocalAttention(17, softmax_temp=1).eval() full_att = FullAttention(softmax_temp=1).eval() q, k, v, m1, m2, m3 = self._get_inputs(N=10, L=128, S=128, D=32) m = FullMask( torch.abs(torch.arange(128)[:, None] - torch.arange(128)[None]) < 9) v_full = full_att(q, k, v, m, m2, m3) v_local = local_att(q, k, v, m1, m2, m3) self.assertTrue(torch.allclose(v_full, v_local, atol=1e-5, rtol=1e-5))
def test_masking(self): att = LinearAttention(32) q, k, v, m1, m2, m3 = self._get_inputs() # Make sure that we raise an error if m1 is not all ones with self.assertRaises(RuntimeError): att(q, k, v, FullMask(torch.rand(*m1.shape) > 0.5), m2, m3) # Make sure that the key lengths is paid attention to q, k, v, m1, m2, m3 = self._get_inputs(S=10, D=1) m3 = LengthMask(torch.tensor(list(range(10))) + 1) for i in range(9): v[i, i + 1:] = 1e9 v_new = att(q, k, v, m1, m2, m3) self.assertLess(v_new.max().item(), 1)
def forward_pre(self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None): src2 = self.norm1(src) q = k = self.with_pos_embed(src2, pos) N, L, E = q.shape attn_mask = FullMask(L, device=q.device) length_mask = LengthMask(src.new_full((N,), L, dtype=torch.int64)) src2 = self.self_attn(q, k, v, attn_mask, length_mask, length_mask) src = src + self.dropout1(src2) src2 = self.norm2(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) src = src + self.dropout2(src2) return src
def forward(self, src, src_len_mask=None): x = self.initial_ff(src) # x = self.pos_encoding(x) self.lstm.flatten_parameters() x, _ = self.lstm(x) for i in range(self.tf_depth[0]): x = self.encoder_pre(x) # in case the input sequence length has changed set_size = src.shape[1] batch_size = src.shape[0] if self.mask.shape[1] != set_size: self.mask = FullMask(N=self.k, M=set_size, device=src.device) self.kl_mask = LengthMask(torch.ones(batch_size) * set_size, device=src.device) # in case the batch size has changed if (self.ql_mask.shape[0] != batch_size) or (self.kl_mask.shape[0] != batch_size): self.ql_mask = LengthMask(torch.ones(batch_size) * self.k, device=src.device) self.kl_mask = LengthMask(torch.ones(batch_size) * set_size, device=src.device) # extend seeds to size of batch S = self.seeds.unsqueeze(0).repeat(batch_size, 1, 1) # perform pooling by multihead attention, reducing dimensionality of set x = self.attn_pooling(S, x, x, self.mask, self.ql_mask, self.kl_mask) for i in range(self.tf_depth[1]): x = self.encoder_post(x) x = self.final_ff(x) return x
def __init__(self, num_feats, num_output_points, lstm_layers, n_layers, n_heads, hidden_dim, ff_dim, tf_depth=3, dropout=0.15): super(SetTransformer, self).__init__() self.num_feats = num_feats self.k = num_output_points def dup(x): return (x, x) if type(x) == int else x self.lstm_layers = lstm_layers self.n_layers = dup(n_layers) self.n_heads = dup(n_heads) self.hidden_dim = dup(hidden_dim) self.ff_dim = dup(ff_dim) self.tf_depth = dup(tf_depth) self.d_model = [self.hidden_dim[i] * self.n_heads[i] for i in [0, 1]] encoder_builder_pre = TransformerEncoderBuilder.from_kwargs( n_layers=self.n_layers[0], n_heads=self.n_heads[0], query_dimensions=self.hidden_dim[0], value_dimensions=self.hidden_dim[0], feed_forward_dimensions=self.ff_dim[0], attention_type='linear', dropout=dropout) encoder_builder_post = TransformerEncoderBuilder.from_kwargs( n_layers=self.n_layers[1], n_heads=self.n_heads[1], query_dimensions=self.hidden_dim[1], value_dimensions=self.hidden_dim[1], feed_forward_dimensions=self.ff_dim[1], attention_type='linear', dropout=dropout) self.seeds = nn.Parameter(torch.normal(0, 1, (self.k, self.d_model[0]))) self.encoder_pre = encoder_builder_pre.get() self.encoder_post = encoder_builder_post.get() self.initial_ff = nn.Linear(self.num_feats, self.d_model[0]) # self.pos_encoding = PositionalEncoding(self.d_model[0], dropout=dropout) self.lstm = nn.LSTM(self.d_model[0], self.d_model[0], 2, batch_first=True, bidirectional=False) self.attn_pooling = AttentionLayer(LinearAttention(self.d_model[0]), self.d_model[0], self.n_heads[0]) self.final_ff = nn.Linear(self.d_model[1], self.num_feats) # init masks to meaningless values, doesn't matter what. these are all empty anyway. self.mask = FullMask(N=self.k, M=5) self.kl_mask = LengthMask(torch.ones(5) * 5) self.ql_mask = LengthMask(torch.ones(5) * self.k)