def test_load_checkpoint():
    model = SwinTransformer(arch='tiny')
    ckpt_path = os.path.join(tempfile.gettempdir(), 'ckpt.pth')

    assert model._version == 2

    # test load v2 checkpoint
    save_checkpoint(model, ckpt_path)
    load_checkpoint(model, ckpt_path, strict=True)

    # test load v1 checkpoint
    setattr(model, 'norm', model.norm3)
    model._version = 1
    del model.norm3
    save_checkpoint(model, ckpt_path)
    model = SwinTransformer(arch='tiny')
    load_checkpoint(model, ckpt_path, strict=True)
    def test_init_weights(self):
        # test weight init cfg
        cfg = deepcopy(self.cfg)
        cfg['use_abs_pos_embed'] = True
        cfg['init_cfg'] = [
            dict(
                type='Kaiming',
                layer='Conv2d',
                mode='fan_in',
                nonlinearity='linear')
        ]
        model = SwinTransformer(**cfg)
        ori_weight = model.patch_embed.projection.weight.clone().detach()
        # The pos_embed is all zero before initialize
        self.assertTrue(
            torch.allclose(model.absolute_pos_embed, torch.tensor(0.)))

        model.init_weights()
        initialized_weight = model.patch_embed.projection.weight
        self.assertFalse(torch.allclose(ori_weight, initialized_weight))
        self.assertFalse(
            torch.allclose(model.absolute_pos_embed, torch.tensor(0.)))

        pretrain_pos_embed = model.absolute_pos_embed.clone().detach()

        tmpdir = tempfile.gettempdir()
        # Save v3 checkpoints
        checkpoint_v2 = os.path.join(tmpdir, 'v3.pth')
        save_checkpoint(model, checkpoint_v2)
        # Save v1 checkpoints
        setattr(model, 'norm', model.norm3)
        setattr(model.stages[0].blocks[1].attn, 'attn_mask',
                torch.zeros(64, 49, 49))
        model._version = 1
        del model.norm3
        checkpoint_v1 = os.path.join(tmpdir, 'v1.pth')
        save_checkpoint(model, checkpoint_v1)

        # test load v1 checkpoint
        cfg = deepcopy(self.cfg)
        cfg['use_abs_pos_embed'] = True
        model = SwinTransformer(**cfg)
        load_checkpoint(model, checkpoint_v1, strict=True)

        # test load v3 checkpoint
        cfg = deepcopy(self.cfg)
        cfg['use_abs_pos_embed'] = True
        model = SwinTransformer(**cfg)
        load_checkpoint(model, checkpoint_v2, strict=True)

        # test load v3 checkpoint with different img_size
        cfg = deepcopy(self.cfg)
        cfg['img_size'] = 384
        cfg['use_abs_pos_embed'] = True
        model = SwinTransformer(**cfg)
        load_checkpoint(model, checkpoint_v2, strict=True)
        resized_pos_embed = timm_resize_pos_embed(
            pretrain_pos_embed, model.absolute_pos_embed, num_tokens=0)
        self.assertTrue(
            torch.allclose(model.absolute_pos_embed, resized_pos_embed))

        os.remove(checkpoint_v1)
        os.remove(checkpoint_v2)