def test_vit_hybrid_backbone(): # Test VGG11+ViT-B/16 hybrid model backbone = VGG(11, norm_eval=True) backbone.init_weights() model = VisionTransformer(hybrid_backbone=backbone) model.init_weights() model.train() assert check_norm_state(model.modules(), True) imgs = torch.randn(1, 3, 224, 224) feat = model(imgs) assert feat.shape == torch.Size((1, 768))
def test_structure(self): # Test invalid default arch with self.assertRaisesRegex(AssertionError, 'not in default archs'): cfg = deepcopy(self.cfg) cfg['arch'] = 'unknown' VisionTransformer(**cfg) # Test invalid custom arch with self.assertRaisesRegex(AssertionError, 'Custom arch needs'): cfg = deepcopy(self.cfg) cfg['arch'] = { 'num_layers': 24, 'num_heads': 16, 'feedforward_channels': 4096 } VisionTransformer(**cfg) # Test custom arch cfg = deepcopy(self.cfg) cfg['arch'] = { 'embed_dims': 128, 'num_layers': 24, 'num_heads': 16, 'feedforward_channels': 1024 } model = VisionTransformer(**cfg) self.assertEqual(model.embed_dims, 128) self.assertEqual(model.num_layers, 24) for layer in model.layers: self.assertEqual(layer.attn.num_heads, 16) self.assertEqual(layer.ffn.feedforward_channels, 1024) # Test out_indices cfg = deepcopy(self.cfg) cfg['out_indices'] = {1: 1} with self.assertRaisesRegex(AssertionError, "get <class 'dict'>"): VisionTransformer(**cfg) cfg['out_indices'] = [0, 13] with self.assertRaisesRegex(AssertionError, 'Invalid out_indices 13'): VisionTransformer(**cfg) # Test model structure cfg = deepcopy(self.cfg) model = VisionTransformer(**cfg) self.assertEqual(len(model.layers), 12) dpr_inc = 0.1 / (12 - 1) dpr = 0 for layer in model.layers: self.assertEqual(layer.attn.embed_dims, 768) self.assertEqual(layer.attn.num_heads, 12) self.assertEqual(layer.ffn.feedforward_channels, 3072) self.assertAlmostEqual(layer.attn.out_drop.drop_prob, dpr) self.assertAlmostEqual(layer.ffn.dropout_layer.drop_prob, dpr) dpr += dpr_inc
def test_init_weights(self): # test weight init cfg cfg = deepcopy(self.cfg) cfg['init_cfg'] = [ dict(type='Kaiming', layer='Conv2d', mode='fan_in', nonlinearity='linear') ] model = VisionTransformer(**cfg) ori_weight = model.patch_embed.projection.weight.clone().detach() # The pos_embed is all zero before initialize self.assertTrue(torch.allclose(model.pos_embed, torch.tensor(0.))) model.init_weights() initialized_weight = model.patch_embed.projection.weight self.assertFalse(torch.allclose(ori_weight, initialized_weight)) self.assertFalse(torch.allclose(model.pos_embed, torch.tensor(0.))) # test load checkpoint pretrain_pos_embed = model.pos_embed.clone().detach() tmpdir = tempfile.gettempdir() checkpoint = os.path.join(tmpdir, 'test.pth') save_checkpoint(model, checkpoint) cfg = deepcopy(self.cfg) model = VisionTransformer(**cfg) load_checkpoint(model, checkpoint, strict=True) self.assertTrue(torch.allclose(model.pos_embed, pretrain_pos_embed)) # test load checkpoint with different img_size cfg = deepcopy(self.cfg) cfg['img_size'] = 384 model = VisionTransformer(**cfg) load_checkpoint(model, checkpoint, strict=True) resized_pos_embed = timm_resize_pos_embed(pretrain_pos_embed, model.pos_embed) self.assertTrue(torch.allclose(model.pos_embed, resized_pos_embed)) os.remove(checkpoint)
def test_vit_backbone(): with pytest.raises(TypeError): # pretrained must be a string path model = VisionTransformer() model.init_weights(pretrained=0) # Test ViT base model with input size of 224 # and patch size of 16 model = VisionTransformer() model.init_weights() model.train() assert check_norm_state(model.modules(), True) imgs = torch.randn(1, 3, 224, 224) feat = model(imgs) assert feat.shape == torch.Size((1, 768))
def test_vit_backbone(): cfg_ori = dict(arch='b', img_size=224, patch_size=16, drop_rate=0.1, init_cfg=[ dict(type='Kaiming', layer='Conv2d', mode='fan_in', nonlinearity='linear') ]) with pytest.raises(AssertionError): # test invalid arch cfg = deepcopy(cfg_ori) cfg['arch'] = 'unknown' VisionTransformer(**cfg) with pytest.raises(AssertionError): # test arch without essential keys cfg = deepcopy(cfg_ori) cfg['arch'] = { 'num_layers': 24, 'num_heads': 16, 'feedforward_channels': 4096 } VisionTransformer(**cfg) # Test ViT base model with input size of 224 # and patch size of 16 model = VisionTransformer(**cfg_ori) model.init_weights() model.train() assert check_norm_state(model.modules(), True) imgs = torch.randn(3, 3, 224, 224) patch_token, cls_token = model(imgs)[-1] assert cls_token.shape == (3, 768) assert patch_token.shape == (3, 768, 14, 14) # Test custom arch ViT without output cls token cfg = deepcopy(cfg_ori) cfg['arch'] = { 'embed_dims': 128, 'num_layers': 24, 'num_heads': 16, 'feedforward_channels': 1024 } cfg['output_cls_token'] = False model = VisionTransformer(**cfg) patch_token = model(imgs)[-1] assert patch_token.shape == (3, 128, 14, 14) # Test ViT with multi out indices cfg = deepcopy(cfg_ori) cfg['out_indices'] = [-3, -2, -1] model = VisionTransformer(**cfg) for out in model(imgs): assert out[0].shape == (3, 768, 14, 14) assert out[1].shape == (3, 768)
def test_vit_weight_init(): # test weight init cfg pretrain_cfg = dict( arch='b', img_size=224, patch_size=16, init_cfg=[dict(type='Constant', val=1., layer='Conv2d')]) pretrain_model = VisionTransformer(**pretrain_cfg) pretrain_model.init_weights() assert torch.allclose(pretrain_model.patch_embed.projection.weight, torch.tensor(1.)) assert pretrain_model.pos_embed.abs().sum() > 0 pos_embed_weight = pretrain_model.pos_embed.detach() tmpdir = tempfile.gettempdir() checkpoint = os.path.join(tmpdir, 'test.pth') torch.save(pretrain_model.state_dict(), checkpoint) # test load checkpoint finetune_cfg = dict(arch='b', img_size=224, patch_size=16, init_cfg=dict(type='Pretrained', checkpoint=checkpoint)) finetune_model = VisionTransformer(**finetune_cfg) finetune_model.init_weights() assert torch.allclose(finetune_model.pos_embed, pos_embed_weight) # test load checkpoint with different img_size finetune_cfg = dict(arch='b', img_size=384, patch_size=16, init_cfg=dict(type='Pretrained', checkpoint=checkpoint)) finetune_model = VisionTransformer(**finetune_cfg) finetune_model.init_weights() resized_pos_embed = timm_resize_pos_embed(pos_embed_weight, finetune_model.pos_embed) assert torch.allclose(finetune_model.pos_embed, resized_pos_embed) os.remove(checkpoint)
def test_forward(self): imgs = torch.randn(3, 3, 224, 224) # test with_cls_token=False cfg = deepcopy(self.cfg) cfg['with_cls_token'] = False cfg['output_cls_token'] = True with self.assertRaisesRegex(AssertionError, 'but got False'): VisionTransformer(**cfg) cfg = deepcopy(self.cfg) cfg['with_cls_token'] = False cfg['output_cls_token'] = False model = VisionTransformer(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) patch_token = outs[-1] self.assertEqual(patch_token.shape, (3, 768, 14, 14)) # test with output_cls_token cfg = deepcopy(self.cfg) model = VisionTransformer(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) patch_token, cls_token = outs[-1] self.assertEqual(patch_token.shape, (3, 768, 14, 14)) self.assertEqual(cls_token.shape, (3, 768)) # test without output_cls_token cfg = deepcopy(self.cfg) cfg['output_cls_token'] = False model = VisionTransformer(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) patch_token = outs[-1] self.assertEqual(patch_token.shape, (3, 768, 14, 14)) # Test forward with multi out indices cfg = deepcopy(self.cfg) cfg['out_indices'] = [-3, -2, -1] model = VisionTransformer(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 3) for out in outs: patch_token, cls_token = out self.assertEqual(patch_token.shape, (3, 768, 14, 14)) self.assertEqual(cls_token.shape, (3, 768)) # Test forward with dynamic input size imgs1 = torch.randn(3, 3, 224, 224) imgs2 = torch.randn(3, 3, 256, 256) imgs3 = torch.randn(3, 3, 256, 309) cfg = deepcopy(self.cfg) model = VisionTransformer(**cfg) for imgs in [imgs1, imgs2, imgs3]: outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) patch_token, cls_token = outs[-1] expect_feat_shape = (math.ceil(imgs.shape[2] / 16), math.ceil(imgs.shape[3] / 16)) self.assertEqual(patch_token.shape, (3, 768, *expect_feat_shape)) self.assertEqual(cls_token.shape, (3, 768))