def test_vit_hybrid_backbone():

    # Test VGG11+ViT-B/16 hybrid model
    backbone = VGG(11, norm_eval=True)
    backbone.init_weights()
    model = VisionTransformer(hybrid_backbone=backbone)
    model.init_weights()
    model.train()

    assert check_norm_state(model.modules(), True)

    imgs = torch.randn(1, 3, 224, 224)
    feat = model(imgs)
    assert feat.shape == torch.Size((1, 768))
def test_vit_backbone():
    with pytest.raises(TypeError):
        # pretrained must be a string path
        model = VisionTransformer()
        model.init_weights(pretrained=0)

    # Test ViT base model with input size of 224
    # and patch size of 16
    model = VisionTransformer()
    model.init_weights()
    model.train()

    assert check_norm_state(model.modules(), True)

    imgs = torch.randn(1, 3, 224, 224)
    feat = model(imgs)
    assert feat.shape == torch.Size((1, 768))
Example #3
0
    def test_init_weights(self):
        # test weight init cfg
        cfg = deepcopy(self.cfg)
        cfg['init_cfg'] = [
            dict(type='Kaiming',
                 layer='Conv2d',
                 mode='fan_in',
                 nonlinearity='linear')
        ]
        model = VisionTransformer(**cfg)
        ori_weight = model.patch_embed.projection.weight.clone().detach()
        # The pos_embed is all zero before initialize
        self.assertTrue(torch.allclose(model.pos_embed, torch.tensor(0.)))

        model.init_weights()
        initialized_weight = model.patch_embed.projection.weight
        self.assertFalse(torch.allclose(ori_weight, initialized_weight))
        self.assertFalse(torch.allclose(model.pos_embed, torch.tensor(0.)))

        # test load checkpoint
        pretrain_pos_embed = model.pos_embed.clone().detach()
        tmpdir = tempfile.gettempdir()
        checkpoint = os.path.join(tmpdir, 'test.pth')
        save_checkpoint(model, checkpoint)
        cfg = deepcopy(self.cfg)
        model = VisionTransformer(**cfg)
        load_checkpoint(model, checkpoint, strict=True)
        self.assertTrue(torch.allclose(model.pos_embed, pretrain_pos_embed))

        # test load checkpoint with different img_size
        cfg = deepcopy(self.cfg)
        cfg['img_size'] = 384
        model = VisionTransformer(**cfg)
        load_checkpoint(model, checkpoint, strict=True)
        resized_pos_embed = timm_resize_pos_embed(pretrain_pos_embed,
                                                  model.pos_embed)
        self.assertTrue(torch.allclose(model.pos_embed, resized_pos_embed))

        os.remove(checkpoint)
Example #4
0
def test_vit_weight_init():
    # test weight init cfg
    pretrain_cfg = dict(
        arch='b',
        img_size=224,
        patch_size=16,
        init_cfg=[dict(type='Constant', val=1., layer='Conv2d')])
    pretrain_model = VisionTransformer(**pretrain_cfg)
    pretrain_model.init_weights()
    assert torch.allclose(pretrain_model.patch_embed.projection.weight,
                          torch.tensor(1.))
    assert pretrain_model.pos_embed.abs().sum() > 0

    pos_embed_weight = pretrain_model.pos_embed.detach()
    tmpdir = tempfile.gettempdir()
    checkpoint = os.path.join(tmpdir, 'test.pth')
    torch.save(pretrain_model.state_dict(), checkpoint)

    # test load checkpoint
    finetune_cfg = dict(arch='b',
                        img_size=224,
                        patch_size=16,
                        init_cfg=dict(type='Pretrained',
                                      checkpoint=checkpoint))
    finetune_model = VisionTransformer(**finetune_cfg)
    finetune_model.init_weights()
    assert torch.allclose(finetune_model.pos_embed, pos_embed_weight)

    # test load checkpoint with different img_size
    finetune_cfg = dict(arch='b',
                        img_size=384,
                        patch_size=16,
                        init_cfg=dict(type='Pretrained',
                                      checkpoint=checkpoint))
    finetune_model = VisionTransformer(**finetune_cfg)
    finetune_model.init_weights()
    resized_pos_embed = timm_resize_pos_embed(pos_embed_weight,
                                              finetune_model.pos_embed)
    assert torch.allclose(finetune_model.pos_embed, resized_pos_embed)

    os.remove(checkpoint)
Example #5
0
def test_vit_backbone():

    cfg_ori = dict(arch='b',
                   img_size=224,
                   patch_size=16,
                   drop_rate=0.1,
                   init_cfg=[
                       dict(type='Kaiming',
                            layer='Conv2d',
                            mode='fan_in',
                            nonlinearity='linear')
                   ])

    with pytest.raises(AssertionError):
        # test invalid arch
        cfg = deepcopy(cfg_ori)
        cfg['arch'] = 'unknown'
        VisionTransformer(**cfg)

    with pytest.raises(AssertionError):
        # test arch without essential keys
        cfg = deepcopy(cfg_ori)
        cfg['arch'] = {
            'num_layers': 24,
            'num_heads': 16,
            'feedforward_channels': 4096
        }
        VisionTransformer(**cfg)

    # Test ViT base model with input size of 224
    # and patch size of 16
    model = VisionTransformer(**cfg_ori)
    model.init_weights()
    model.train()

    assert check_norm_state(model.modules(), True)

    imgs = torch.randn(3, 3, 224, 224)
    patch_token, cls_token = model(imgs)[-1]
    assert cls_token.shape == (3, 768)
    assert patch_token.shape == (3, 768, 14, 14)

    # Test custom arch ViT without output cls token
    cfg = deepcopy(cfg_ori)
    cfg['arch'] = {
        'embed_dims': 128,
        'num_layers': 24,
        'num_heads': 16,
        'feedforward_channels': 1024
    }
    cfg['output_cls_token'] = False
    model = VisionTransformer(**cfg)
    patch_token = model(imgs)[-1]
    assert patch_token.shape == (3, 128, 14, 14)

    # Test ViT with multi out indices
    cfg = deepcopy(cfg_ori)
    cfg['out_indices'] = [-3, -2, -1]
    model = VisionTransformer(**cfg)
    for out in model(imgs):
        assert out[0].shape == (3, 768, 14, 14)
        assert out[1].shape == (3, 768)