def test_qkv_bias(self): # test qkv_bias=True attn = WindowMSA( embed_dims=96, window_size=(7, 7), num_heads=4, qkv_bias=True) self.assertEqual(attn.qkv.bias.shape, (96 * 3, )) # test qkv_bias=False attn = WindowMSA( embed_dims=96, window_size=(7, 7), num_heads=4, qkv_bias=False) self.assertIsNone(attn.qkv.bias)
def tets_qk_scale(self): # test default qk_scale attn = WindowMSA( embed_dims=96, window_size=(7, 7), num_heads=4, qk_scale=None) head_dims = 96 // 4 self.assertAlmostEqual(attn.scale, head_dims**-0.5) # test specified qk_scale attn = WindowMSA( embed_dims=96, window_size=(7, 7), num_heads=4, qk_scale=0.3) self.assertEqual(attn.scale, 0.3)
def test_forward(self): attn = WindowMSA(embed_dims=96, window_size=(7, 7), num_heads=4) inputs = torch.rand((16, 7 * 7, 96)) output = attn(inputs) self.assertEqual(output.shape, inputs.shape) # test non-square window_size attn = WindowMSA(embed_dims=96, window_size=(6, 7), num_heads=4) inputs = torch.rand((16, 6 * 7, 96)) output = attn(inputs) self.assertEqual(output.shape, inputs.shape)
def test_relative_pos_embed(self): attn = WindowMSA(embed_dims=96, window_size=(7, 8), num_heads=4) self.assertEqual(attn.relative_position_bias_table.shape, ((2 * 7 - 1) * (2 * 8 - 1), 4)) # test relative_position_index expected_rel_pos_index = get_relative_position_index((7, 8)) self.assertTrue( torch.allclose(attn.relative_position_index, expected_rel_pos_index)) # test default init self.assertTrue( torch.allclose(attn.relative_position_bias_table, torch.tensor(0.))) attn.init_weights() self.assertFalse( torch.allclose(attn.relative_position_bias_table, torch.tensor(0.)))
def test_mask(self): inputs = torch.rand(16, 7 * 7, 96) attn = WindowMSA(embed_dims=96, window_size=(7, 7), num_heads=4) mask = torch.zeros((4, 49, 49)) # Mask the first column mask[:, 0, :] = -100 mask[:, :, 0] = -100 outs = attn(inputs, mask=mask) inputs[:, 0, :].normal_() outs_with_mask = attn(inputs, mask=mask) torch.testing.assert_allclose(outs[:, 1:, :], outs_with_mask[:, 1:, :])
def test_window_msa(): batch_size = 1 num_windows = (4, 4) embed_dims = 96 window_size = (7, 7) num_heads = 4 attn = WindowMSA(embed_dims=embed_dims, window_size=window_size, num_heads=num_heads) inputs = torch.rand((batch_size * num_windows[0] * num_windows[1], window_size[0] * window_size[1], embed_dims)) # test forward output = attn(inputs) assert output.shape == inputs.shape assert attn.relative_position_bias_table.shape == ( (2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) # test relative_position_bias_table init attn.init_weights() assert abs(attn.relative_position_bias_table).sum() > 0 # test non-square window_size window_size = (6, 7) attn = WindowMSA(embed_dims=embed_dims, window_size=window_size, num_heads=num_heads) inputs = torch.rand((batch_size * num_windows[0] * num_windows[1], window_size[0] * window_size[1], embed_dims)) output = attn(inputs) assert output.shape == inputs.shape # test relative_position_index expected_rel_pos_index = get_relative_position_index(window_size) assert (attn.relative_position_index == expected_rel_pos_index).all() # test qkv_bias=True attn = WindowMSA(embed_dims=embed_dims, window_size=window_size, num_heads=num_heads, qkv_bias=True) assert attn.qkv.bias.shape == (embed_dims * 3, ) # test qkv_bias=False attn = WindowMSA(embed_dims=embed_dims, window_size=window_size, num_heads=num_heads, qkv_bias=False) assert attn.qkv.bias is None # test default qk_scale attn = WindowMSA(embed_dims=embed_dims, window_size=window_size, num_heads=num_heads, qk_scale=None) head_dims = embed_dims // num_heads assert np.isclose(attn.scale, head_dims**-0.5) # test specified qk_scale attn = WindowMSA(embed_dims=embed_dims, window_size=window_size, num_heads=num_heads, qk_scale=0.3) assert attn.scale == 0.3 # test attn_drop attn = WindowMSA(embed_dims=embed_dims, window_size=window_size, num_heads=num_heads, attn_drop=1.0) inputs = torch.rand((batch_size * num_windows[0] * num_windows[1], window_size[0] * window_size[1], embed_dims)) # drop all attn output, output shuold be equal to proj.bias assert torch.allclose(attn(inputs), attn.proj.bias) # test prob_drop attn = WindowMSA(embed_dims=embed_dims, window_size=window_size, num_heads=num_heads, proj_drop=1.0) assert (attn(inputs) == 0).all()
def test_prob_drop(self): inputs = torch.rand(16, 7 * 7, 96) attn = WindowMSA( embed_dims=96, window_size=(7, 7), num_heads=4, proj_drop=1.0) self.assertTrue(torch.allclose(attn(inputs), torch.tensor(0.)))
def test_attn_drop(self): inputs = torch.rand(16, 7 * 7, 96) attn = WindowMSA( embed_dims=96, window_size=(7, 7), num_heads=4, attn_drop=1.0) # drop all attn output, output shuold be equal to proj.bias self.assertTrue(torch.allclose(attn(inputs), attn.proj.bias))