def test_custom_arch(arch_name, arch, out_channels): # output the last feature map model = RegNet(arch) model.init_weights() imgs = torch.randn(1, 3, 224, 224) feat = model(imgs) assert isinstance(feat, torch.Tensor) assert feat.shape == (1, out_channels[-1], 7, 7) # output feature map of all stages model = RegNet(arch, out_indices=(0, 1, 2, 3)) model.init_weights() imgs = torch.randn(1, 3, 224, 224) feat = model(imgs) assert len(feat) == 4 assert feat[0].shape == (1, out_channels[0], 56, 56) assert feat[1].shape == (1, out_channels[1], 28, 28) assert feat[2].shape == (1, out_channels[2], 14, 14) assert feat[3].shape == (1, out_channels[3], 7, 7)
def test_exception(): # arch must be a str or dict with pytest.raises(TypeError): _ = RegNet(50)
def test_regnet_backbone(arch_name, arch, out_channels): with pytest.raises(AssertionError): # ResNeXt depth should be in [50, 101, 152] RegNet(arch_name + '233') # output the last feature map model = RegNet(arch_name) model.init_weights() model.train() imgs = torch.randn(1, 3, 224, 224) feat = model(imgs) assert isinstance(feat, torch.Tensor) assert feat.shape == (1, out_channels[-1], 7, 7) # output feature map of all stages model = RegNet(arch_name, out_indices=(0, 1, 2, 3)) model.init_weights() model.train() imgs = torch.randn(1, 3, 224, 224) feat = model(imgs) assert len(feat) == 4 assert feat[0].shape == (1, out_channels[0], 56, 56) assert feat[1].shape == (1, out_channels[1], 28, 28) assert feat[2].shape == (1, out_channels[2], 14, 14) assert feat[3].shape == (1, out_channels[3], 7, 7)