def test_full_attention_forward(self):
     d_model = 128
     n_heads = 4
     transformer = TransformerEncoder([
         TransformerEncoderLayer(
             AttentionLayer(ImprovedClusteredAttention(clusters=10, topk=5),
                            d_model, n_heads), d_model, n_heads)
         for i in range(6)
     ])
     transformer = transformer.to("cuda")
     x = torch.rand(100, 20, d_model).to("cuda")
     y = transformer(x)
     self.assertEqual(y.shape, (100, 20, d_model))
 def test_clustered_attention_forward(self):
     d_model = 128
     n_heads = 4
     transformer = TransformerEncoder([
         TransformerEncoderLayer(
             AttentionLayer(ClusteredAttention(clusters=10), d_model,
                            n_heads), d_model, n_heads) for i in range(6)
     ])
     x = transformer(torch.rand(100, 20, d_model))
     self.assertEqual(x.shape, (100, 20, d_model))
 def test_full_attention_forward(self):
     d_model = 128
     n_heads = 4
     transformer = TransformerEncoder([
         TransformerEncoderLayer(
             AttentionLayer(FullAttention(), d_model, n_heads), d_model,
             n_heads) for i in range(6)
     ])
     x = transformer(torch.rand(10, 7, d_model))
     self.assertEqual(x.shape, (10, 7, d_model))
 def test_improved_clustered_attention_forward(self):
     d_model = 128
     n_heads = 4
     transformer = TransformerEncoder([
         TransformerEncoderLayer(
             AttentionLayer(
                 ReformerAttention(
                     chunk_size=32,
                     rounds=4,
                     bits=8,
                     masked=False,
                 ), d_model, n_heads), d_model, n_heads) for i in range(6)
     ])
     x = torch.rand(12, 128, d_model)
     y = transformer(x)
     self.assertEqual(y.shape, (12, 128, d_model))
 def test_topk_equals_length_attention_masked(self):
     d_model = 32
     n_heads = 4
     improved_transformer = TransformerEncoder([
         TransformerEncoderLayer(
             AttentionLayer(
                 ImprovedClusteredAttention(clusters=10, topk=20), d_model,
                 n_heads), d_model, n_heads) for i in range(6)
     ])
     full_transformer = TransformerEncoder([
         TransformerEncoderLayer(
             AttentionLayer(FullAttention(), d_model, n_heads), d_model,
             n_heads) for i in range(6)
     ])
     full_transformer = full_transformer.to("cuda")
     improved_transformer = improved_transformer.to("cuda")
     improved_transformer.load_state_dict(full_transformer.state_dict())
     improved_transformer.eval()
     full_transformer.eval()
     x = torch.rand(100, 20, d_model).to("cuda")
     lengths = x.new_full((100, ), 20, dtype=torch.int64)
     lengths[1] = 5
     lengths[10] = 10
     length_mask = LengthMask(lengths=lengths, max_len=20)
     y_full = improved_transformer(x, length_mask=length_mask)
     y_improved = full_transformer(x, length_mask=length_mask)
     self.assertLess(
         torch.max(torch.abs(y_improved[1, :5] - y_full[1, :5])), 1e-4)
     self.assertLess(
         torch.max(torch.abs(y_improved[10, :10] - y_full[10, :10])), 1e-4)
 def test_topk_equals_length_attention(self):
     d_model = 32
     n_heads = 4
     improved_transformer = TransformerEncoder([
         TransformerEncoderLayer(
             AttentionLayer(
                 ImprovedClusteredAttention(clusters=10, topk=20), d_model,
                 n_heads), d_model, n_heads) for i in range(6)
     ])
     full_transformer = TransformerEncoder([
         TransformerEncoderLayer(
             AttentionLayer(FullAttention(), d_model, n_heads), d_model,
             n_heads) for i in range(6)
     ])
     full_transformer = full_transformer.to("cuda")
     improved_transformer = improved_transformer.to("cuda")
     improved_transformer.load_state_dict(full_transformer.state_dict())
     improved_transformer.eval()
     full_transformer.eval()
     x = torch.rand(100, 20, d_model).to("cuda")
     y_full = improved_transformer(x)
     y_improved = full_transformer(x)
     self.assertLess(torch.max(torch.abs(y_improved - y_full)), 1e-4)