def test_qkv_bias(self):
        # test qkv_bias=True
        attn = WindowMSA(
            embed_dims=96, window_size=(7, 7), num_heads=4, qkv_bias=True)
        self.assertEqual(attn.qkv.bias.shape, (96 * 3, ))

        # test qkv_bias=False
        attn = WindowMSA(
            embed_dims=96, window_size=(7, 7), num_heads=4, qkv_bias=False)
        self.assertIsNone(attn.qkv.bias)
    def tets_qk_scale(self):
        # test default qk_scale
        attn = WindowMSA(
            embed_dims=96, window_size=(7, 7), num_heads=4, qk_scale=None)
        head_dims = 96 // 4
        self.assertAlmostEqual(attn.scale, head_dims**-0.5)

        # test specified qk_scale
        attn = WindowMSA(
            embed_dims=96, window_size=(7, 7), num_heads=4, qk_scale=0.3)
        self.assertEqual(attn.scale, 0.3)
    def test_forward(self):
        attn = WindowMSA(embed_dims=96, window_size=(7, 7), num_heads=4)
        inputs = torch.rand((16, 7 * 7, 96))
        output = attn(inputs)
        self.assertEqual(output.shape, inputs.shape)

        # test non-square window_size
        attn = WindowMSA(embed_dims=96, window_size=(6, 7), num_heads=4)
        inputs = torch.rand((16, 6 * 7, 96))
        output = attn(inputs)
        self.assertEqual(output.shape, inputs.shape)
    def test_relative_pos_embed(self):
        attn = WindowMSA(embed_dims=96, window_size=(7, 8), num_heads=4)
        self.assertEqual(attn.relative_position_bias_table.shape,
                         ((2 * 7 - 1) * (2 * 8 - 1), 4))
        # test relative_position_index
        expected_rel_pos_index = get_relative_position_index((7, 8))
        self.assertTrue(
            torch.allclose(attn.relative_position_index,
                           expected_rel_pos_index))

        # test default init
        self.assertTrue(
            torch.allclose(attn.relative_position_bias_table,
                           torch.tensor(0.)))
        attn.init_weights()
        self.assertFalse(
            torch.allclose(attn.relative_position_bias_table,
                           torch.tensor(0.)))
 def test_mask(self):
     inputs = torch.rand(16, 7 * 7, 96)
     attn = WindowMSA(embed_dims=96, window_size=(7, 7), num_heads=4)
     mask = torch.zeros((4, 49, 49))
     # Mask the first column
     mask[:, 0, :] = -100
     mask[:, :, 0] = -100
     outs = attn(inputs, mask=mask)
     inputs[:, 0, :].normal_()
     outs_with_mask = attn(inputs, mask=mask)
     torch.testing.assert_allclose(outs[:, 1:, :], outs_with_mask[:, 1:, :])
def test_window_msa():
    batch_size = 1
    num_windows = (4, 4)
    embed_dims = 96
    window_size = (7, 7)
    num_heads = 4
    attn = WindowMSA(embed_dims=embed_dims,
                     window_size=window_size,
                     num_heads=num_heads)
    inputs = torch.rand((batch_size * num_windows[0] * num_windows[1],
                         window_size[0] * window_size[1], embed_dims))

    # test forward
    output = attn(inputs)
    assert output.shape == inputs.shape
    assert attn.relative_position_bias_table.shape == (
        (2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)

    # test relative_position_bias_table init
    attn.init_weights()
    assert abs(attn.relative_position_bias_table).sum() > 0

    # test non-square window_size
    window_size = (6, 7)
    attn = WindowMSA(embed_dims=embed_dims,
                     window_size=window_size,
                     num_heads=num_heads)
    inputs = torch.rand((batch_size * num_windows[0] * num_windows[1],
                         window_size[0] * window_size[1], embed_dims))
    output = attn(inputs)
    assert output.shape == inputs.shape

    # test relative_position_index
    expected_rel_pos_index = get_relative_position_index(window_size)
    assert (attn.relative_position_index == expected_rel_pos_index).all()

    # test qkv_bias=True
    attn = WindowMSA(embed_dims=embed_dims,
                     window_size=window_size,
                     num_heads=num_heads,
                     qkv_bias=True)
    assert attn.qkv.bias.shape == (embed_dims * 3, )

    # test qkv_bias=False
    attn = WindowMSA(embed_dims=embed_dims,
                     window_size=window_size,
                     num_heads=num_heads,
                     qkv_bias=False)
    assert attn.qkv.bias is None

    # test default qk_scale
    attn = WindowMSA(embed_dims=embed_dims,
                     window_size=window_size,
                     num_heads=num_heads,
                     qk_scale=None)
    head_dims = embed_dims // num_heads
    assert np.isclose(attn.scale, head_dims**-0.5)

    # test specified qk_scale
    attn = WindowMSA(embed_dims=embed_dims,
                     window_size=window_size,
                     num_heads=num_heads,
                     qk_scale=0.3)
    assert attn.scale == 0.3

    # test attn_drop
    attn = WindowMSA(embed_dims=embed_dims,
                     window_size=window_size,
                     num_heads=num_heads,
                     attn_drop=1.0)
    inputs = torch.rand((batch_size * num_windows[0] * num_windows[1],
                         window_size[0] * window_size[1], embed_dims))
    # drop all attn output, output shuold be equal to proj.bias
    assert torch.allclose(attn(inputs), attn.proj.bias)

    # test prob_drop
    attn = WindowMSA(embed_dims=embed_dims,
                     window_size=window_size,
                     num_heads=num_heads,
                     proj_drop=1.0)
    assert (attn(inputs) == 0).all()
 def test_prob_drop(self):
     inputs = torch.rand(16, 7 * 7, 96)
     attn = WindowMSA(
         embed_dims=96, window_size=(7, 7), num_heads=4, proj_drop=1.0)
     self.assertTrue(torch.allclose(attn(inputs), torch.tensor(0.)))
 def test_attn_drop(self):
     inputs = torch.rand(16, 7 * 7, 96)
     attn = WindowMSA(
         embed_dims=96, window_size=(7, 7), num_heads=4, attn_drop=1.0)
     # drop all attn output, output shuold be equal to proj.bias
     self.assertTrue(torch.allclose(attn(inputs), attn.proj.bias))