def test_arch(self): # Test invalid default arch with self.assertRaisesRegex(AssertionError, 'not in default archs'): cfg = deepcopy(self.cfg) cfg['arch'] = 'unknown' SwinTransformer(**cfg) # Test invalid custom arch with self.assertRaisesRegex(AssertionError, 'Custom arch needs'): cfg = deepcopy(self.cfg) cfg['arch'] = { 'embed_dims': 96, 'num_heads': [3, 6, 12, 16], } SwinTransformer(**cfg) # Test custom arch cfg = deepcopy(self.cfg) depths = [2, 2, 4, 2] num_heads = [6, 12, 6, 12] cfg['arch'] = { 'embed_dims': 256, 'depths': depths, 'num_heads': num_heads } model = SwinTransformer(**cfg) for i, stage in enumerate(model.stages): self.assertEqual(stage.embed_dims, 256 * (2**i)) self.assertEqual(len(stage.blocks), depths[i]) self.assertEqual(stage.blocks[0].attn.w_msa.num_heads, num_heads[i])
def test_assertion(): """Test Swin Transformer backbone.""" with pytest.raises(AssertionError): # Swin Transformer arch string should be in SwinTransformer(arch='unknown') with pytest.raises(AssertionError): # Swin Transformer arch dict should include 'embed_dims', # 'depths' and 'num_head' keys. SwinTransformer(arch=dict(embed_dims=96, depths=[2, 2, 18, 2]))
def test_load_checkpoint(): model = SwinTransformer(arch='tiny') ckpt_path = os.path.join(tempfile.gettempdir(), 'ckpt.pth') assert model._version == 2 # test load v2 checkpoint save_checkpoint(model, ckpt_path) load_checkpoint(model, ckpt_path, strict=True) # test load v1 checkpoint setattr(model, 'norm', model.norm3) model._version = 1 del model.norm3 save_checkpoint(model, ckpt_path) model = SwinTransformer(arch='tiny') load_checkpoint(model, ckpt_path, strict=True)
def test_swin_transformer(): """Test Swin Transformer backbone.""" with pytest.raises(AssertionError): # Swin Transformer arch string should be in SwinTransformer(arch='unknown') with pytest.raises(AssertionError): # Swin Transformer arch dict should include 'embed_dims', # 'depths' and 'num_head' keys. SwinTransformer(arch=dict(embed_dims=96, depths=[2, 2, 18, 2])) # Test tiny arch forward model = SwinTransformer(arch='Tiny') model.init_weights() model.train() imgs = torch.randn(1, 3, 224, 224) output = model(imgs) assert output.shape == (1, 768, 49) # Test small arch forward model = SwinTransformer(arch='small') model.init_weights() model.train() imgs = torch.randn(1, 3, 224, 224) output = model(imgs) assert output.shape == (1, 768, 49) # Test base arch forward model = SwinTransformer(arch='B') model.init_weights() model.train() imgs = torch.randn(1, 3, 224, 224) output = model(imgs) assert output.shape == (1, 1024, 49) # Test large arch forward model = SwinTransformer(arch='l') model.init_weights() model.train() imgs = torch.randn(1, 3, 224, 224) output = model(imgs) assert output.shape == (1, 1536, 49) # Test base arch with window_size=12, image_size=384 model = SwinTransformer(arch='base', img_size=384, stage_cfgs=dict(block_cfgs=dict(window_size=12))) model.init_weights() model.train() imgs = torch.randn(1, 3, 384, 384) output = model(imgs) assert output.shape == (1, 1024, 144) # Test small with use_abs_pos_embed = True model = SwinTransformer(arch='small', use_abs_pos_embed=True) model.init_weights() model.train() assert model.absolute_pos_embed.shape == (1, 3136, 96) # Test small with use_abs_pos_embed = False with pytest.raises(AttributeError): model = SwinTransformer(arch='small', use_abs_pos_embed=False) model.absolute_pos_embed # Test small with auto_pad = True model = SwinTransformer(arch='small', auto_pad=True, stage_cfgs=dict(block_cfgs={'window_size': 7}, downsample_cfg={ 'kernel_size': (3, 2), })) model.init_weights() model.train() imgs = torch.randn(1, 3, 224, 224) # stage 1 input_h = int(224 / 4 / 3) expect_h = ceil(input_h / 7) * 7 input_w = int(224 / 4 / 2) expect_w = ceil(input_w / 7) * 7 assert model.stages[1].blocks[0].attn.pad_b == expect_h - input_h assert model.stages[1].blocks[0].attn.pad_r == expect_w - input_w # stage 2 input_h = int(224 / 4 / 3 / 3) # input_h is smaller than window_size, shrink the window_size to input_h. expect_h = input_h input_w = int(224 / 4 / 2 / 2) expect_w = ceil(input_w / input_h) * input_h assert model.stages[2].blocks[0].attn.pad_b == expect_h - input_h assert model.stages[2].blocks[0].attn.pad_r == expect_w - input_w # stage 3 input_h = int(224 / 4 / 3 / 3 / 3) expect_h = input_h input_w = int(224 / 4 / 2 / 2 / 2) expect_w = ceil(input_w / input_h) * input_h assert model.stages[3].blocks[0].attn.pad_b == expect_h - input_h assert model.stages[3].blocks[0].attn.pad_r == expect_w - input_w # Test small with auto_pad = False with pytest.raises(AssertionError): model = SwinTransformer(arch='small', auto_pad=False, stage_cfgs=dict(block_cfgs={'window_size': 7}, downsample_cfg={ 'kernel_size': (3, 2), })) # Test drop_path_rate decay model = SwinTransformer( arch='small', drop_path_rate=0.2, ) depths = model.arch_settings['depths'] pos = 0 for i, depth in enumerate(depths): for j in range(depth): block = model.stages[i].blocks[j] expect_prob = 0.2 / (sum(depths) - 1) * pos assert np.isclose(block.ffn.dropout_layer.drop_prob, expect_prob) assert np.isclose(block.attn.drop.drop_prob, expect_prob) pos += 1
def test_structure(): # Test small with use_abs_pos_embed = True model = SwinTransformer(arch='small', use_abs_pos_embed=True) assert model.absolute_pos_embed.shape == (1, 3136, 96) # Test small with use_abs_pos_embed = False model = SwinTransformer(arch='small', use_abs_pos_embed=False) assert not hasattr(model, 'absolute_pos_embed') # Test small with auto_pad = True model = SwinTransformer( arch='small', auto_pad=True, stage_cfgs=dict( block_cfgs={'window_size': 7}, downsample_cfg={ 'kernel_size': (3, 2), })) # stage 1 input_h = int(224 / 4 / 3) expect_h = ceil(input_h / 7) * 7 input_w = int(224 / 4 / 2) expect_w = ceil(input_w / 7) * 7 assert model.stages[1].blocks[0].attn.pad_b == expect_h - input_h assert model.stages[1].blocks[0].attn.pad_r == expect_w - input_w # stage 2 input_h = int(224 / 4 / 3 / 3) # input_h is smaller than window_size, shrink the window_size to input_h. expect_h = input_h input_w = int(224 / 4 / 2 / 2) expect_w = ceil(input_w / input_h) * input_h assert model.stages[2].blocks[0].attn.pad_b == expect_h - input_h assert model.stages[2].blocks[0].attn.pad_r == expect_w - input_w # stage 3 input_h = int(224 / 4 / 3 / 3 / 3) expect_h = input_h input_w = int(224 / 4 / 2 / 2 / 2) expect_w = ceil(input_w / input_h) * input_h assert model.stages[3].blocks[0].attn.pad_b == expect_h - input_h assert model.stages[3].blocks[0].attn.pad_r == expect_w - input_w # Test small with auto_pad = False with pytest.raises(AssertionError): model = SwinTransformer( arch='small', auto_pad=False, stage_cfgs=dict( block_cfgs={'window_size': 7}, downsample_cfg={ 'kernel_size': (3, 2), })) # Test drop_path_rate decay model = SwinTransformer( arch='small', drop_path_rate=0.2, ) depths = model.arch_settings['depths'] pos = 0 for i, depth in enumerate(depths): for j in range(depth): block = model.stages[i].blocks[j] expect_prob = 0.2 / (sum(depths) - 1) * pos assert np.isclose(block.ffn.dropout_layer.drop_prob, expect_prob) assert np.isclose(block.attn.drop.drop_prob, expect_prob) pos += 1
def test_forward(): # Test tiny arch forward model = SwinTransformer(arch='Tiny') model.init_weights() model.train() imgs = torch.randn(1, 3, 224, 224) output = model(imgs) assert len(output) == 1 assert output[0].shape == (1, 768, 7, 7) # Test small arch forward model = SwinTransformer(arch='small') model.init_weights() model.train() imgs = torch.randn(1, 3, 224, 224) output = model(imgs) assert len(output) == 1 assert output[0].shape == (1, 768, 7, 7) # Test base arch forward model = SwinTransformer(arch='B') model.init_weights() model.train() imgs = torch.randn(1, 3, 224, 224) output = model(imgs) assert len(output) == 1 assert output[0].shape == (1, 1024, 7, 7) # Test large arch forward model = SwinTransformer(arch='l') model.init_weights() model.train() imgs = torch.randn(1, 3, 224, 224) output = model(imgs) assert len(output) == 1 assert output[0].shape == (1, 1536, 7, 7) # Test base arch with window_size=12, image_size=384 model = SwinTransformer( arch='base', img_size=384, stage_cfgs=dict(block_cfgs=dict(window_size=12))) model.init_weights() model.train() imgs = torch.randn(1, 3, 384, 384) output = model(imgs) assert len(output) == 1 assert output[0].shape == (1, 1024, 12, 12)
def test_init_weights(self): # test weight init cfg cfg = deepcopy(self.cfg) cfg['use_abs_pos_embed'] = True cfg['init_cfg'] = [ dict( type='Kaiming', layer='Conv2d', mode='fan_in', nonlinearity='linear') ] model = SwinTransformer(**cfg) ori_weight = model.patch_embed.projection.weight.clone().detach() # The pos_embed is all zero before initialize self.assertTrue( torch.allclose(model.absolute_pos_embed, torch.tensor(0.))) model.init_weights() initialized_weight = model.patch_embed.projection.weight self.assertFalse(torch.allclose(ori_weight, initialized_weight)) self.assertFalse( torch.allclose(model.absolute_pos_embed, torch.tensor(0.))) pretrain_pos_embed = model.absolute_pos_embed.clone().detach() tmpdir = tempfile.gettempdir() # Save v3 checkpoints checkpoint_v2 = os.path.join(tmpdir, 'v3.pth') save_checkpoint(model, checkpoint_v2) # Save v1 checkpoints setattr(model, 'norm', model.norm3) setattr(model.stages[0].blocks[1].attn, 'attn_mask', torch.zeros(64, 49, 49)) model._version = 1 del model.norm3 checkpoint_v1 = os.path.join(tmpdir, 'v1.pth') save_checkpoint(model, checkpoint_v1) # test load v1 checkpoint cfg = deepcopy(self.cfg) cfg['use_abs_pos_embed'] = True model = SwinTransformer(**cfg) load_checkpoint(model, checkpoint_v1, strict=True) # test load v3 checkpoint cfg = deepcopy(self.cfg) cfg['use_abs_pos_embed'] = True model = SwinTransformer(**cfg) load_checkpoint(model, checkpoint_v2, strict=True) # test load v3 checkpoint with different img_size cfg = deepcopy(self.cfg) cfg['img_size'] = 384 cfg['use_abs_pos_embed'] = True model = SwinTransformer(**cfg) load_checkpoint(model, checkpoint_v2, strict=True) resized_pos_embed = timm_resize_pos_embed( pretrain_pos_embed, model.absolute_pos_embed, num_tokens=0) self.assertTrue( torch.allclose(model.absolute_pos_embed, resized_pos_embed)) os.remove(checkpoint_v1) os.remove(checkpoint_v2)
def test_structure(self): # test drop_path_rate decay cfg = deepcopy(self.cfg) cfg['drop_path_rate'] = 0.2 model = SwinTransformer(**cfg) depths = model.arch_settings['depths'] blocks = chain(*[stage.blocks for stage in model.stages]) for i, block in enumerate(blocks): expect_prob = 0.2 / (sum(depths) - 1) * i self.assertAlmostEqual(block.ffn.dropout_layer.drop_prob, expect_prob) self.assertAlmostEqual(block.attn.drop.drop_prob, expect_prob) # test Swin-Transformer with norm_eval=True cfg = deepcopy(self.cfg) cfg['norm_eval'] = True cfg['norm_cfg'] = dict(type='BN') cfg['stage_cfgs'] = dict(block_cfgs=dict(norm_cfg=dict(type='BN'))) model = SwinTransformer(**cfg) model.init_weights() model.train() self.assertTrue(check_norm_state(model.modules(), False)) # test Swin-Transformer with first stage frozen. cfg = deepcopy(self.cfg) frozen_stages = 0 cfg['frozen_stages'] = frozen_stages cfg['out_indices'] = (0, 1, 2, 3) model = SwinTransformer(**cfg) model.init_weights() model.train() # the patch_embed and first stage should not require grad. self.assertFalse(model.patch_embed.training) for param in model.patch_embed.parameters(): self.assertFalse(param.requires_grad) for i in range(frozen_stages + 1): stage = model.stages[i] for param in stage.parameters(): self.assertFalse(param.requires_grad) for param in model.norm0.parameters(): self.assertFalse(param.requires_grad) # the second stage should require grad. for i in range(frozen_stages + 1, 4): stage = model.stages[i] for param in stage.parameters(): self.assertTrue(param.requires_grad) norm = getattr(model, f'norm{i}') for param in norm.parameters(): self.assertTrue(param.requires_grad)
def test_forward(self): imgs = torch.randn(3, 3, 224, 224) cfg = deepcopy(self.cfg) model = SwinTransformer(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) feat = outs[-1] self.assertEqual(feat.shape, (3, 1024, 7, 7)) # test with window_size=12 cfg = deepcopy(self.cfg) cfg['window_size'] = 12 model = SwinTransformer(**cfg) outs = model(torch.randn(3, 3, 384, 384)) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) feat = outs[-1] self.assertEqual(feat.shape, (3, 1024, 12, 12)) with self.assertRaisesRegex(AssertionError, r'the window size \(12\)'): model(torch.randn(3, 3, 224, 224)) # test with pad_small_map=True cfg = deepcopy(self.cfg) cfg['window_size'] = 12 cfg['pad_small_map'] = True model = SwinTransformer(**cfg) outs = model(torch.randn(3, 3, 224, 224)) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) feat = outs[-1] self.assertEqual(feat.shape, (3, 1024, 7, 7)) # test multiple output indices cfg = deepcopy(self.cfg) cfg['out_indices'] = (0, 1, 2, 3) model = SwinTransformer(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 4) for stride, out in zip([2, 4, 8, 8], outs): self.assertEqual(out.shape, (3, 128 * stride, 56 // stride, 56 // stride)) # test with checkpoint forward cfg = deepcopy(self.cfg) cfg['with_cp'] = True model = SwinTransformer(**cfg) for m in model.modules(): if isinstance(m, SwinBlock): self.assertTrue(m.with_cp) model.init_weights() model.train() outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) feat = outs[-1] self.assertEqual(feat.shape, (3, 1024, 7, 7)) # test with dynamic input shape imgs1 = torch.randn(3, 3, 224, 224) imgs2 = torch.randn(3, 3, 256, 256) imgs3 = torch.randn(3, 3, 256, 309) cfg = deepcopy(self.cfg) model = SwinTransformer(**cfg) for imgs in [imgs1, imgs2, imgs3]: outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) feat = outs[-1] expect_feat_shape = (math.ceil(imgs.shape[2] / 32), math.ceil(imgs.shape[3] / 32)) self.assertEqual(feat.shape, (3, 1024, *expect_feat_shape))