def test_arch(self):
        # Test invalid default arch
        with self.assertRaisesRegex(AssertionError, 'not in default archs'):
            cfg = deepcopy(self.cfg)
            cfg['arch'] = 'unknown'
            SwinTransformer(**cfg)

        # Test invalid custom arch
        with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
            cfg = deepcopy(self.cfg)
            cfg['arch'] = {
                'embed_dims': 96,
                'num_heads': [3, 6, 12, 16],
            }
            SwinTransformer(**cfg)

        # Test custom arch
        cfg = deepcopy(self.cfg)
        depths = [2, 2, 4, 2]
        num_heads = [6, 12, 6, 12]
        cfg['arch'] = {
            'embed_dims': 256,
            'depths': depths,
            'num_heads': num_heads
        }
        model = SwinTransformer(**cfg)
        for i, stage in enumerate(model.stages):
            self.assertEqual(stage.embed_dims, 256 * (2**i))
            self.assertEqual(len(stage.blocks), depths[i])
            self.assertEqual(stage.blocks[0].attn.w_msa.num_heads,
                             num_heads[i])
def test_assertion():
    """Test Swin Transformer backbone."""
    with pytest.raises(AssertionError):
        # Swin Transformer arch string should be in
        SwinTransformer(arch='unknown')

    with pytest.raises(AssertionError):
        # Swin Transformer arch dict should include 'embed_dims',
        # 'depths' and 'num_head' keys.
        SwinTransformer(arch=dict(embed_dims=96, depths=[2, 2, 18, 2]))
def test_load_checkpoint():
    model = SwinTransformer(arch='tiny')
    ckpt_path = os.path.join(tempfile.gettempdir(), 'ckpt.pth')

    assert model._version == 2

    # test load v2 checkpoint
    save_checkpoint(model, ckpt_path)
    load_checkpoint(model, ckpt_path, strict=True)

    # test load v1 checkpoint
    setattr(model, 'norm', model.norm3)
    model._version = 1
    del model.norm3
    save_checkpoint(model, ckpt_path)
    model = SwinTransformer(arch='tiny')
    load_checkpoint(model, ckpt_path, strict=True)
def test_swin_transformer():
    """Test Swin Transformer backbone."""
    with pytest.raises(AssertionError):
        # Swin Transformer arch string should be in
        SwinTransformer(arch='unknown')

    with pytest.raises(AssertionError):
        # Swin Transformer arch dict should include 'embed_dims',
        # 'depths' and 'num_head' keys.
        SwinTransformer(arch=dict(embed_dims=96, depths=[2, 2, 18, 2]))

    # Test tiny arch forward
    model = SwinTransformer(arch='Tiny')
    model.init_weights()
    model.train()

    imgs = torch.randn(1, 3, 224, 224)
    output = model(imgs)
    assert output.shape == (1, 768, 49)

    # Test small arch forward
    model = SwinTransformer(arch='small')
    model.init_weights()
    model.train()

    imgs = torch.randn(1, 3, 224, 224)
    output = model(imgs)
    assert output.shape == (1, 768, 49)

    # Test base arch forward
    model = SwinTransformer(arch='B')
    model.init_weights()
    model.train()

    imgs = torch.randn(1, 3, 224, 224)
    output = model(imgs)
    assert output.shape == (1, 1024, 49)

    # Test large arch forward
    model = SwinTransformer(arch='l')
    model.init_weights()
    model.train()

    imgs = torch.randn(1, 3, 224, 224)
    output = model(imgs)
    assert output.shape == (1, 1536, 49)

    # Test base arch with window_size=12, image_size=384
    model = SwinTransformer(arch='base',
                            img_size=384,
                            stage_cfgs=dict(block_cfgs=dict(window_size=12)))
    model.init_weights()
    model.train()

    imgs = torch.randn(1, 3, 384, 384)
    output = model(imgs)
    assert output.shape == (1, 1024, 144)

    # Test small with use_abs_pos_embed = True
    model = SwinTransformer(arch='small', use_abs_pos_embed=True)
    model.init_weights()
    model.train()

    assert model.absolute_pos_embed.shape == (1, 3136, 96)

    # Test small with use_abs_pos_embed = False
    with pytest.raises(AttributeError):
        model = SwinTransformer(arch='small', use_abs_pos_embed=False)
        model.absolute_pos_embed

    # Test small with auto_pad = True
    model = SwinTransformer(arch='small',
                            auto_pad=True,
                            stage_cfgs=dict(block_cfgs={'window_size': 7},
                                            downsample_cfg={
                                                'kernel_size': (3, 2),
                                            }))
    model.init_weights()
    model.train()

    imgs = torch.randn(1, 3, 224, 224)

    # stage 1
    input_h = int(224 / 4 / 3)
    expect_h = ceil(input_h / 7) * 7
    input_w = int(224 / 4 / 2)
    expect_w = ceil(input_w / 7) * 7
    assert model.stages[1].blocks[0].attn.pad_b == expect_h - input_h
    assert model.stages[1].blocks[0].attn.pad_r == expect_w - input_w

    # stage 2
    input_h = int(224 / 4 / 3 / 3)
    # input_h is smaller than window_size, shrink the window_size to input_h.
    expect_h = input_h
    input_w = int(224 / 4 / 2 / 2)
    expect_w = ceil(input_w / input_h) * input_h
    assert model.stages[2].blocks[0].attn.pad_b == expect_h - input_h
    assert model.stages[2].blocks[0].attn.pad_r == expect_w - input_w

    # stage 3
    input_h = int(224 / 4 / 3 / 3 / 3)
    expect_h = input_h
    input_w = int(224 / 4 / 2 / 2 / 2)
    expect_w = ceil(input_w / input_h) * input_h
    assert model.stages[3].blocks[0].attn.pad_b == expect_h - input_h
    assert model.stages[3].blocks[0].attn.pad_r == expect_w - input_w

    # Test small with auto_pad = False
    with pytest.raises(AssertionError):
        model = SwinTransformer(arch='small',
                                auto_pad=False,
                                stage_cfgs=dict(block_cfgs={'window_size': 7},
                                                downsample_cfg={
                                                    'kernel_size': (3, 2),
                                                }))

    # Test drop_path_rate decay
    model = SwinTransformer(
        arch='small',
        drop_path_rate=0.2,
    )
    depths = model.arch_settings['depths']
    pos = 0
    for i, depth in enumerate(depths):
        for j in range(depth):
            block = model.stages[i].blocks[j]
            expect_prob = 0.2 / (sum(depths) - 1) * pos
            assert np.isclose(block.ffn.dropout_layer.drop_prob, expect_prob)
            assert np.isclose(block.attn.drop.drop_prob, expect_prob)
            pos += 1
def test_structure():
    # Test small with use_abs_pos_embed = True
    model = SwinTransformer(arch='small', use_abs_pos_embed=True)
    assert model.absolute_pos_embed.shape == (1, 3136, 96)

    # Test small with use_abs_pos_embed = False
    model = SwinTransformer(arch='small', use_abs_pos_embed=False)
    assert not hasattr(model, 'absolute_pos_embed')

    # Test small with auto_pad = True
    model = SwinTransformer(
        arch='small',
        auto_pad=True,
        stage_cfgs=dict(
            block_cfgs={'window_size': 7},
            downsample_cfg={
                'kernel_size': (3, 2),
            }))

    # stage 1
    input_h = int(224 / 4 / 3)
    expect_h = ceil(input_h / 7) * 7
    input_w = int(224 / 4 / 2)
    expect_w = ceil(input_w / 7) * 7
    assert model.stages[1].blocks[0].attn.pad_b == expect_h - input_h
    assert model.stages[1].blocks[0].attn.pad_r == expect_w - input_w

    # stage 2
    input_h = int(224 / 4 / 3 / 3)
    # input_h is smaller than window_size, shrink the window_size to input_h.
    expect_h = input_h
    input_w = int(224 / 4 / 2 / 2)
    expect_w = ceil(input_w / input_h) * input_h
    assert model.stages[2].blocks[0].attn.pad_b == expect_h - input_h
    assert model.stages[2].blocks[0].attn.pad_r == expect_w - input_w

    # stage 3
    input_h = int(224 / 4 / 3 / 3 / 3)
    expect_h = input_h
    input_w = int(224 / 4 / 2 / 2 / 2)
    expect_w = ceil(input_w / input_h) * input_h
    assert model.stages[3].blocks[0].attn.pad_b == expect_h - input_h
    assert model.stages[3].blocks[0].attn.pad_r == expect_w - input_w

    # Test small with auto_pad = False
    with pytest.raises(AssertionError):
        model = SwinTransformer(
            arch='small',
            auto_pad=False,
            stage_cfgs=dict(
                block_cfgs={'window_size': 7},
                downsample_cfg={
                    'kernel_size': (3, 2),
                }))

    # Test drop_path_rate decay
    model = SwinTransformer(
        arch='small',
        drop_path_rate=0.2,
    )
    depths = model.arch_settings['depths']
    pos = 0
    for i, depth in enumerate(depths):
        for j in range(depth):
            block = model.stages[i].blocks[j]
            expect_prob = 0.2 / (sum(depths) - 1) * pos
            assert np.isclose(block.ffn.dropout_layer.drop_prob, expect_prob)
            assert np.isclose(block.attn.drop.drop_prob, expect_prob)
            pos += 1
def test_forward():
    # Test tiny arch forward
    model = SwinTransformer(arch='Tiny')
    model.init_weights()
    model.train()

    imgs = torch.randn(1, 3, 224, 224)
    output = model(imgs)
    assert len(output) == 1
    assert output[0].shape == (1, 768, 7, 7)

    # Test small arch forward
    model = SwinTransformer(arch='small')
    model.init_weights()
    model.train()

    imgs = torch.randn(1, 3, 224, 224)
    output = model(imgs)
    assert len(output) == 1
    assert output[0].shape == (1, 768, 7, 7)

    # Test base arch forward
    model = SwinTransformer(arch='B')
    model.init_weights()
    model.train()

    imgs = torch.randn(1, 3, 224, 224)
    output = model(imgs)
    assert len(output) == 1
    assert output[0].shape == (1, 1024, 7, 7)

    # Test large arch forward
    model = SwinTransformer(arch='l')
    model.init_weights()
    model.train()

    imgs = torch.randn(1, 3, 224, 224)
    output = model(imgs)
    assert len(output) == 1
    assert output[0].shape == (1, 1536, 7, 7)

    # Test base arch with window_size=12, image_size=384
    model = SwinTransformer(
        arch='base',
        img_size=384,
        stage_cfgs=dict(block_cfgs=dict(window_size=12)))
    model.init_weights()
    model.train()

    imgs = torch.randn(1, 3, 384, 384)
    output = model(imgs)
    assert len(output) == 1
    assert output[0].shape == (1, 1024, 12, 12)
    def test_init_weights(self):
        # test weight init cfg
        cfg = deepcopy(self.cfg)
        cfg['use_abs_pos_embed'] = True
        cfg['init_cfg'] = [
            dict(
                type='Kaiming',
                layer='Conv2d',
                mode='fan_in',
                nonlinearity='linear')
        ]
        model = SwinTransformer(**cfg)
        ori_weight = model.patch_embed.projection.weight.clone().detach()
        # The pos_embed is all zero before initialize
        self.assertTrue(
            torch.allclose(model.absolute_pos_embed, torch.tensor(0.)))

        model.init_weights()
        initialized_weight = model.patch_embed.projection.weight
        self.assertFalse(torch.allclose(ori_weight, initialized_weight))
        self.assertFalse(
            torch.allclose(model.absolute_pos_embed, torch.tensor(0.)))

        pretrain_pos_embed = model.absolute_pos_embed.clone().detach()

        tmpdir = tempfile.gettempdir()
        # Save v3 checkpoints
        checkpoint_v2 = os.path.join(tmpdir, 'v3.pth')
        save_checkpoint(model, checkpoint_v2)
        # Save v1 checkpoints
        setattr(model, 'norm', model.norm3)
        setattr(model.stages[0].blocks[1].attn, 'attn_mask',
                torch.zeros(64, 49, 49))
        model._version = 1
        del model.norm3
        checkpoint_v1 = os.path.join(tmpdir, 'v1.pth')
        save_checkpoint(model, checkpoint_v1)

        # test load v1 checkpoint
        cfg = deepcopy(self.cfg)
        cfg['use_abs_pos_embed'] = True
        model = SwinTransformer(**cfg)
        load_checkpoint(model, checkpoint_v1, strict=True)

        # test load v3 checkpoint
        cfg = deepcopy(self.cfg)
        cfg['use_abs_pos_embed'] = True
        model = SwinTransformer(**cfg)
        load_checkpoint(model, checkpoint_v2, strict=True)

        # test load v3 checkpoint with different img_size
        cfg = deepcopy(self.cfg)
        cfg['img_size'] = 384
        cfg['use_abs_pos_embed'] = True
        model = SwinTransformer(**cfg)
        load_checkpoint(model, checkpoint_v2, strict=True)
        resized_pos_embed = timm_resize_pos_embed(
            pretrain_pos_embed, model.absolute_pos_embed, num_tokens=0)
        self.assertTrue(
            torch.allclose(model.absolute_pos_embed, resized_pos_embed))

        os.remove(checkpoint_v1)
        os.remove(checkpoint_v2)
    def test_structure(self):
        # test drop_path_rate decay
        cfg = deepcopy(self.cfg)
        cfg['drop_path_rate'] = 0.2
        model = SwinTransformer(**cfg)
        depths = model.arch_settings['depths']
        blocks = chain(*[stage.blocks for stage in model.stages])
        for i, block in enumerate(blocks):
            expect_prob = 0.2 / (sum(depths) - 1) * i
            self.assertAlmostEqual(block.ffn.dropout_layer.drop_prob,
                                   expect_prob)
            self.assertAlmostEqual(block.attn.drop.drop_prob, expect_prob)

        # test Swin-Transformer with norm_eval=True
        cfg = deepcopy(self.cfg)
        cfg['norm_eval'] = True
        cfg['norm_cfg'] = dict(type='BN')
        cfg['stage_cfgs'] = dict(block_cfgs=dict(norm_cfg=dict(type='BN')))
        model = SwinTransformer(**cfg)
        model.init_weights()
        model.train()
        self.assertTrue(check_norm_state(model.modules(), False))

        # test Swin-Transformer with first stage frozen.
        cfg = deepcopy(self.cfg)
        frozen_stages = 0
        cfg['frozen_stages'] = frozen_stages
        cfg['out_indices'] = (0, 1, 2, 3)
        model = SwinTransformer(**cfg)
        model.init_weights()
        model.train()

        # the patch_embed and first stage should not require grad.
        self.assertFalse(model.patch_embed.training)
        for param in model.patch_embed.parameters():
            self.assertFalse(param.requires_grad)
        for i in range(frozen_stages + 1):
            stage = model.stages[i]
            for param in stage.parameters():
                self.assertFalse(param.requires_grad)
        for param in model.norm0.parameters():
            self.assertFalse(param.requires_grad)

        # the second stage should require grad.
        for i in range(frozen_stages + 1, 4):
            stage = model.stages[i]
            for param in stage.parameters():
                self.assertTrue(param.requires_grad)
            norm = getattr(model, f'norm{i}')
            for param in norm.parameters():
                self.assertTrue(param.requires_grad)
    def test_forward(self):
        imgs = torch.randn(3, 3, 224, 224)

        cfg = deepcopy(self.cfg)
        model = SwinTransformer(**cfg)
        outs = model(imgs)
        self.assertIsInstance(outs, tuple)
        self.assertEqual(len(outs), 1)
        feat = outs[-1]
        self.assertEqual(feat.shape, (3, 1024, 7, 7))

        # test with window_size=12
        cfg = deepcopy(self.cfg)
        cfg['window_size'] = 12
        model = SwinTransformer(**cfg)
        outs = model(torch.randn(3, 3, 384, 384))
        self.assertIsInstance(outs, tuple)
        self.assertEqual(len(outs), 1)
        feat = outs[-1]
        self.assertEqual(feat.shape, (3, 1024, 12, 12))
        with self.assertRaisesRegex(AssertionError, r'the window size \(12\)'):
            model(torch.randn(3, 3, 224, 224))

        # test with pad_small_map=True
        cfg = deepcopy(self.cfg)
        cfg['window_size'] = 12
        cfg['pad_small_map'] = True
        model = SwinTransformer(**cfg)
        outs = model(torch.randn(3, 3, 224, 224))
        self.assertIsInstance(outs, tuple)
        self.assertEqual(len(outs), 1)
        feat = outs[-1]
        self.assertEqual(feat.shape, (3, 1024, 7, 7))

        # test multiple output indices
        cfg = deepcopy(self.cfg)
        cfg['out_indices'] = (0, 1, 2, 3)
        model = SwinTransformer(**cfg)
        outs = model(imgs)
        self.assertIsInstance(outs, tuple)
        self.assertEqual(len(outs), 4)
        for stride, out in zip([2, 4, 8, 8], outs):
            self.assertEqual(out.shape,
                             (3, 128 * stride, 56 // stride, 56 // stride))

        # test with checkpoint forward
        cfg = deepcopy(self.cfg)
        cfg['with_cp'] = True
        model = SwinTransformer(**cfg)
        for m in model.modules():
            if isinstance(m, SwinBlock):
                self.assertTrue(m.with_cp)
        model.init_weights()
        model.train()

        outs = model(imgs)
        self.assertIsInstance(outs, tuple)
        self.assertEqual(len(outs), 1)
        feat = outs[-1]
        self.assertEqual(feat.shape, (3, 1024, 7, 7))

        # test with dynamic input shape
        imgs1 = torch.randn(3, 3, 224, 224)
        imgs2 = torch.randn(3, 3, 256, 256)
        imgs3 = torch.randn(3, 3, 256, 309)
        cfg = deepcopy(self.cfg)
        model = SwinTransformer(**cfg)
        for imgs in [imgs1, imgs2, imgs3]:
            outs = model(imgs)
            self.assertIsInstance(outs, tuple)
            self.assertEqual(len(outs), 1)
            feat = outs[-1]
            expect_feat_shape = (math.ceil(imgs.shape[2] / 32),
                                 math.ceil(imgs.shape[3] / 32))
            self.assertEqual(feat.shape, (3, 1024, *expect_feat_shape))