def test_vit_hybrid_backbone(): # Test VGG11+ViT-B/16 hybrid model backbone = VGG(11, norm_eval=True) backbone.init_weights() model = VisionTransformer(hybrid_backbone=backbone) model.init_weights() model.train() assert check_norm_state(model.modules(), True) imgs = torch.randn(1, 3, 224, 224) feat = model(imgs) assert feat.shape == torch.Size((1, 768))
def test_vit_backbone(): with pytest.raises(TypeError): # pretrained must be a string path model = VisionTransformer() model.init_weights(pretrained=0) # Test ViT base model with input size of 224 # and patch size of 16 model = VisionTransformer() model.init_weights() model.train() assert check_norm_state(model.modules(), True) imgs = torch.randn(1, 3, 224, 224) feat = model(imgs) assert feat.shape == torch.Size((1, 768))
def test_vit_backbone(): cfg_ori = dict(arch='b', img_size=224, patch_size=16, drop_rate=0.1, init_cfg=[ dict(type='Kaiming', layer='Conv2d', mode='fan_in', nonlinearity='linear') ]) with pytest.raises(AssertionError): # test invalid arch cfg = deepcopy(cfg_ori) cfg['arch'] = 'unknown' VisionTransformer(**cfg) with pytest.raises(AssertionError): # test arch without essential keys cfg = deepcopy(cfg_ori) cfg['arch'] = { 'num_layers': 24, 'num_heads': 16, 'feedforward_channels': 4096 } VisionTransformer(**cfg) # Test ViT base model with input size of 224 # and patch size of 16 model = VisionTransformer(**cfg_ori) model.init_weights() model.train() assert check_norm_state(model.modules(), True) imgs = torch.randn(3, 3, 224, 224) patch_token, cls_token = model(imgs)[-1] assert cls_token.shape == (3, 768) assert patch_token.shape == (3, 768, 14, 14) # Test custom arch ViT without output cls token cfg = deepcopy(cfg_ori) cfg['arch'] = { 'embed_dims': 128, 'num_layers': 24, 'num_heads': 16, 'feedforward_channels': 1024 } cfg['output_cls_token'] = False model = VisionTransformer(**cfg) patch_token = model(imgs)[-1] assert patch_token.shape == (3, 128, 14, 14) # Test ViT with multi out indices cfg = deepcopy(cfg_ori) cfg['out_indices'] = [-3, -2, -1] model = VisionTransformer(**cfg) for out in model(imgs): assert out[0].shape == (3, 768, 14, 14) assert out[1].shape == (3, 768)