def test_vit_hybrid_backbone():

    # Test VGG11+ViT-B/16 hybrid model
    backbone = VGG(11, norm_eval=True)
    backbone.init_weights()
    model = VisionTransformer(hybrid_backbone=backbone)
    model.init_weights()
    model.train()

    assert check_norm_state(model.modules(), True)

    imgs = torch.randn(1, 3, 224, 224)
    feat = model(imgs)
    assert feat.shape == torch.Size((1, 768))
def test_vit_backbone():
    with pytest.raises(TypeError):
        # pretrained must be a string path
        model = VisionTransformer()
        model.init_weights(pretrained=0)

    # Test ViT base model with input size of 224
    # and patch size of 16
    model = VisionTransformer()
    model.init_weights()
    model.train()

    assert check_norm_state(model.modules(), True)

    imgs = torch.randn(1, 3, 224, 224)
    feat = model(imgs)
    assert feat.shape == torch.Size((1, 768))
Esempio n. 3
0
def test_vit_backbone():

    cfg_ori = dict(arch='b',
                   img_size=224,
                   patch_size=16,
                   drop_rate=0.1,
                   init_cfg=[
                       dict(type='Kaiming',
                            layer='Conv2d',
                            mode='fan_in',
                            nonlinearity='linear')
                   ])

    with pytest.raises(AssertionError):
        # test invalid arch
        cfg = deepcopy(cfg_ori)
        cfg['arch'] = 'unknown'
        VisionTransformer(**cfg)

    with pytest.raises(AssertionError):
        # test arch without essential keys
        cfg = deepcopy(cfg_ori)
        cfg['arch'] = {
            'num_layers': 24,
            'num_heads': 16,
            'feedforward_channels': 4096
        }
        VisionTransformer(**cfg)

    # Test ViT base model with input size of 224
    # and patch size of 16
    model = VisionTransformer(**cfg_ori)
    model.init_weights()
    model.train()

    assert check_norm_state(model.modules(), True)

    imgs = torch.randn(3, 3, 224, 224)
    patch_token, cls_token = model(imgs)[-1]
    assert cls_token.shape == (3, 768)
    assert patch_token.shape == (3, 768, 14, 14)

    # Test custom arch ViT without output cls token
    cfg = deepcopy(cfg_ori)
    cfg['arch'] = {
        'embed_dims': 128,
        'num_layers': 24,
        'num_heads': 16,
        'feedforward_channels': 1024
    }
    cfg['output_cls_token'] = False
    model = VisionTransformer(**cfg)
    patch_token = model(imgs)[-1]
    assert patch_token.shape == (3, 128, 14, 14)

    # Test ViT with multi out indices
    cfg = deepcopy(cfg_ori)
    cfg['out_indices'] = [-3, -2, -1]
    model = VisionTransformer(**cfg)
    for out in model(imgs):
        assert out[0].shape == (3, 768, 14, 14)
        assert out[1].shape == (3, 768)