def test_topk_equals_length_attention_masked(self):
     d_model = 32
     n_heads = 4
     improved_transformer = TransformerEncoder([
         TransformerEncoderLayer(
             AttentionLayer(
                 ImprovedClusteredAttention(clusters=10, topk=20), d_model,
                 n_heads), d_model, n_heads) for i in range(6)
     ])
     full_transformer = TransformerEncoder([
         TransformerEncoderLayer(
             AttentionLayer(FullAttention(), d_model, n_heads), d_model,
             n_heads) for i in range(6)
     ])
     full_transformer = full_transformer.to("cuda")
     improved_transformer = improved_transformer.to("cuda")
     improved_transformer.load_state_dict(full_transformer.state_dict())
     improved_transformer.eval()
     full_transformer.eval()
     x = torch.rand(100, 20, d_model).to("cuda")
     lengths = x.new_full((100, ), 20, dtype=torch.int64)
     lengths[1] = 5
     lengths[10] = 10
     length_mask = LengthMask(lengths=lengths, max_len=20)
     y_full = improved_transformer(x, length_mask=length_mask)
     y_improved = full_transformer(x, length_mask=length_mask)
     self.assertLess(
         torch.max(torch.abs(y_improved[1, :5] - y_full[1, :5])), 1e-4)
     self.assertLess(
         torch.max(torch.abs(y_improved[10, :10] - y_full[10, :10])), 1e-4)
 def test_full_attention_forward(self):
     d_model = 128
     n_heads = 4
     transformer = TransformerEncoder([
         TransformerEncoderLayer(
             AttentionLayer(ImprovedClusteredAttention(clusters=10, topk=5),
                            d_model, n_heads), d_model, n_heads)
         for i in range(6)
     ])
     transformer = transformer.to("cuda")
     x = torch.rand(100, 20, d_model).to("cuda")
     y = transformer(x)
     self.assertEqual(y.shape, (100, 20, d_model))
 def test_topk_equals_length_attention(self):
     d_model = 32
     n_heads = 4
     improved_transformer = TransformerEncoder([
         TransformerEncoderLayer(
             AttentionLayer(
                 ImprovedClusteredAttention(clusters=10, topk=20), d_model,
                 n_heads), d_model, n_heads) for i in range(6)
     ])
     full_transformer = TransformerEncoder([
         TransformerEncoderLayer(
             AttentionLayer(FullAttention(), d_model, n_heads), d_model,
             n_heads) for i in range(6)
     ])
     full_transformer = full_transformer.to("cuda")
     improved_transformer = improved_transformer.to("cuda")
     improved_transformer.load_state_dict(full_transformer.state_dict())
     improved_transformer.eval()
     full_transformer.eval()
     x = torch.rand(100, 20, d_model).to("cuda")
     y_full = improved_transformer(x)
     y_improved = full_transformer(x)
     self.assertLess(torch.max(torch.abs(y_improved - y_full)), 1e-4)