def test_vit_backbone():
    with pytest.raises(TypeError):
        # pretrained must be a string path
        model = VisionTransformer()
        model.init_weights(pretrained=0)

    # Test ViT base model with input size of 224
    # and patch size of 16
    model = VisionTransformer()
    model.init_weights()
    model.train()

    assert check_norm_state(model.modules(), True)

    imgs = torch.randn(1, 3, 224, 224)
    feat = model(imgs)
    assert feat.shape == torch.Size((1, 768))
Пример #2
0
def test_vit_weight_init():
    # test weight init cfg
    pretrain_cfg = dict(
        arch='b',
        img_size=224,
        patch_size=16,
        init_cfg=[dict(type='Constant', val=1., layer='Conv2d')])
    pretrain_model = VisionTransformer(**pretrain_cfg)
    pretrain_model.init_weights()
    assert torch.allclose(pretrain_model.patch_embed.projection.weight,
                          torch.tensor(1.))
    assert pretrain_model.pos_embed.abs().sum() > 0

    pos_embed_weight = pretrain_model.pos_embed.detach()
    tmpdir = tempfile.gettempdir()
    checkpoint = os.path.join(tmpdir, 'test.pth')
    torch.save(pretrain_model.state_dict(), checkpoint)

    # test load checkpoint
    finetune_cfg = dict(arch='b',
                        img_size=224,
                        patch_size=16,
                        init_cfg=dict(type='Pretrained',
                                      checkpoint=checkpoint))
    finetune_model = VisionTransformer(**finetune_cfg)
    finetune_model.init_weights()
    assert torch.allclose(finetune_model.pos_embed, pos_embed_weight)

    # test load checkpoint with different img_size
    finetune_cfg = dict(arch='b',
                        img_size=384,
                        patch_size=16,
                        init_cfg=dict(type='Pretrained',
                                      checkpoint=checkpoint))
    finetune_model = VisionTransformer(**finetune_cfg)
    finetune_model.init_weights()
    resized_pos_embed = timm_resize_pos_embed(pos_embed_weight,
                                              finetune_model.pos_embed)
    assert torch.allclose(finetune_model.pos_embed, resized_pos_embed)

    os.remove(checkpoint)
Пример #3
0
    def test_init_weights(self):
        # test weight init cfg
        cfg = deepcopy(self.cfg)
        cfg['init_cfg'] = [
            dict(type='Kaiming',
                 layer='Conv2d',
                 mode='fan_in',
                 nonlinearity='linear')
        ]
        model = VisionTransformer(**cfg)
        ori_weight = model.patch_embed.projection.weight.clone().detach()
        # The pos_embed is all zero before initialize
        self.assertTrue(torch.allclose(model.pos_embed, torch.tensor(0.)))

        model.init_weights()
        initialized_weight = model.patch_embed.projection.weight
        self.assertFalse(torch.allclose(ori_weight, initialized_weight))
        self.assertFalse(torch.allclose(model.pos_embed, torch.tensor(0.)))

        # test load checkpoint
        pretrain_pos_embed = model.pos_embed.clone().detach()
        tmpdir = tempfile.gettempdir()
        checkpoint = os.path.join(tmpdir, 'test.pth')
        save_checkpoint(model, checkpoint)
        cfg = deepcopy(self.cfg)
        model = VisionTransformer(**cfg)
        load_checkpoint(model, checkpoint, strict=True)
        self.assertTrue(torch.allclose(model.pos_embed, pretrain_pos_embed))

        # test load checkpoint with different img_size
        cfg = deepcopy(self.cfg)
        cfg['img_size'] = 384
        model = VisionTransformer(**cfg)
        load_checkpoint(model, checkpoint, strict=True)
        resized_pos_embed = timm_resize_pos_embed(pretrain_pos_embed,
                                                  model.pos_embed)
        self.assertTrue(torch.allclose(model.pos_embed, resized_pos_embed))

        os.remove(checkpoint)
def test_vit_hybrid_backbone():

    # Test VGG11+ViT-B/16 hybrid model
    backbone = VGG(11, norm_eval=True)
    backbone.init_weights()
    model = VisionTransformer(hybrid_backbone=backbone)
    model.init_weights()
    model.train()

    assert check_norm_state(model.modules(), True)

    imgs = torch.randn(1, 3, 224, 224)
    feat = model(imgs)
    assert feat.shape == torch.Size((1, 768))
Пример #5
0
    def test_structure(self):
        # Test invalid default arch
        with self.assertRaisesRegex(AssertionError, 'not in default archs'):
            cfg = deepcopy(self.cfg)
            cfg['arch'] = 'unknown'
            VisionTransformer(**cfg)

        # Test invalid custom arch
        with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
            cfg = deepcopy(self.cfg)
            cfg['arch'] = {
                'num_layers': 24,
                'num_heads': 16,
                'feedforward_channels': 4096
            }
            VisionTransformer(**cfg)

        # Test custom arch
        cfg = deepcopy(self.cfg)
        cfg['arch'] = {
            'embed_dims': 128,
            'num_layers': 24,
            'num_heads': 16,
            'feedforward_channels': 1024
        }
        model = VisionTransformer(**cfg)
        self.assertEqual(model.embed_dims, 128)
        self.assertEqual(model.num_layers, 24)
        for layer in model.layers:
            self.assertEqual(layer.attn.num_heads, 16)
            self.assertEqual(layer.ffn.feedforward_channels, 1024)

        # Test out_indices
        cfg = deepcopy(self.cfg)
        cfg['out_indices'] = {1: 1}
        with self.assertRaisesRegex(AssertionError, "get <class 'dict'>"):
            VisionTransformer(**cfg)
        cfg['out_indices'] = [0, 13]
        with self.assertRaisesRegex(AssertionError, 'Invalid out_indices 13'):
            VisionTransformer(**cfg)

        # Test model structure
        cfg = deepcopy(self.cfg)
        model = VisionTransformer(**cfg)
        self.assertEqual(len(model.layers), 12)
        dpr_inc = 0.1 / (12 - 1)
        dpr = 0
        for layer in model.layers:
            self.assertEqual(layer.attn.embed_dims, 768)
            self.assertEqual(layer.attn.num_heads, 12)
            self.assertEqual(layer.ffn.feedforward_channels, 3072)
            self.assertAlmostEqual(layer.attn.out_drop.drop_prob, dpr)
            self.assertAlmostEqual(layer.ffn.dropout_layer.drop_prob, dpr)
            dpr += dpr_inc
Пример #6
0
def test_vit_backbone():

    cfg_ori = dict(arch='b',
                   img_size=224,
                   patch_size=16,
                   drop_rate=0.1,
                   init_cfg=[
                       dict(type='Kaiming',
                            layer='Conv2d',
                            mode='fan_in',
                            nonlinearity='linear')
                   ])

    with pytest.raises(AssertionError):
        # test invalid arch
        cfg = deepcopy(cfg_ori)
        cfg['arch'] = 'unknown'
        VisionTransformer(**cfg)

    with pytest.raises(AssertionError):
        # test arch without essential keys
        cfg = deepcopy(cfg_ori)
        cfg['arch'] = {
            'num_layers': 24,
            'num_heads': 16,
            'feedforward_channels': 4096
        }
        VisionTransformer(**cfg)

    # Test ViT base model with input size of 224
    # and patch size of 16
    model = VisionTransformer(**cfg_ori)
    model.init_weights()
    model.train()

    assert check_norm_state(model.modules(), True)

    imgs = torch.randn(3, 3, 224, 224)
    patch_token, cls_token = model(imgs)[-1]
    assert cls_token.shape == (3, 768)
    assert patch_token.shape == (3, 768, 14, 14)

    # Test custom arch ViT without output cls token
    cfg = deepcopy(cfg_ori)
    cfg['arch'] = {
        'embed_dims': 128,
        'num_layers': 24,
        'num_heads': 16,
        'feedforward_channels': 1024
    }
    cfg['output_cls_token'] = False
    model = VisionTransformer(**cfg)
    patch_token = model(imgs)[-1]
    assert patch_token.shape == (3, 128, 14, 14)

    # Test ViT with multi out indices
    cfg = deepcopy(cfg_ori)
    cfg['out_indices'] = [-3, -2, -1]
    model = VisionTransformer(**cfg)
    for out in model(imgs):
        assert out[0].shape == (3, 768, 14, 14)
        assert out[1].shape == (3, 768)
Пример #7
0
    def test_forward(self):
        imgs = torch.randn(3, 3, 224, 224)

        # test with_cls_token=False
        cfg = deepcopy(self.cfg)
        cfg['with_cls_token'] = False
        cfg['output_cls_token'] = True
        with self.assertRaisesRegex(AssertionError, 'but got False'):
            VisionTransformer(**cfg)

        cfg = deepcopy(self.cfg)
        cfg['with_cls_token'] = False
        cfg['output_cls_token'] = False
        model = VisionTransformer(**cfg)
        outs = model(imgs)
        self.assertIsInstance(outs, tuple)
        self.assertEqual(len(outs), 1)
        patch_token = outs[-1]
        self.assertEqual(patch_token.shape, (3, 768, 14, 14))

        # test with output_cls_token
        cfg = deepcopy(self.cfg)
        model = VisionTransformer(**cfg)
        outs = model(imgs)
        self.assertIsInstance(outs, tuple)
        self.assertEqual(len(outs), 1)
        patch_token, cls_token = outs[-1]
        self.assertEqual(patch_token.shape, (3, 768, 14, 14))
        self.assertEqual(cls_token.shape, (3, 768))

        # test without output_cls_token
        cfg = deepcopy(self.cfg)
        cfg['output_cls_token'] = False
        model = VisionTransformer(**cfg)
        outs = model(imgs)
        self.assertIsInstance(outs, tuple)
        self.assertEqual(len(outs), 1)
        patch_token = outs[-1]
        self.assertEqual(patch_token.shape, (3, 768, 14, 14))

        # Test forward with multi out indices
        cfg = deepcopy(self.cfg)
        cfg['out_indices'] = [-3, -2, -1]
        model = VisionTransformer(**cfg)
        outs = model(imgs)
        self.assertIsInstance(outs, tuple)
        self.assertEqual(len(outs), 3)
        for out in outs:
            patch_token, cls_token = out
            self.assertEqual(patch_token.shape, (3, 768, 14, 14))
            self.assertEqual(cls_token.shape, (3, 768))

        # Test forward with dynamic input size
        imgs1 = torch.randn(3, 3, 224, 224)
        imgs2 = torch.randn(3, 3, 256, 256)
        imgs3 = torch.randn(3, 3, 256, 309)
        cfg = deepcopy(self.cfg)
        model = VisionTransformer(**cfg)
        for imgs in [imgs1, imgs2, imgs3]:
            outs = model(imgs)
            self.assertIsInstance(outs, tuple)
            self.assertEqual(len(outs), 1)
            patch_token, cls_token = outs[-1]
            expect_feat_shape = (math.ceil(imgs.shape[2] / 16),
                                 math.ceil(imgs.shape[3] / 16))
            self.assertEqual(patch_token.shape, (3, 768, *expect_feat_shape))
            self.assertEqual(cls_token.shape, (3, 768))