def __init__(self, inplanes, planes, stride=1, downsample=None, conv_cfg=None, norm_cfg=None): spconv.SparseModule.__init__(self) BasicBlock.__init__(self, inplanes, planes, stride=stride, downsample=downsample, conv_cfg=conv_cfg, norm_cfg=norm_cfg)
def test_resnet_basic_block(): with pytest.raises(AssertionError): # Not implemented yet. dcn = dict(type='DCN', deformable_groups=1, fallback_on_stride=False) BasicBlock(64, 64, dcn=dcn) with pytest.raises(AssertionError): # Not implemented yet. plugins = [ dict( cfg=dict(type='ContextBlock', ratio=1. / 16), position='after_conv3') ] BasicBlock(64, 64, plugins=plugins) with pytest.raises(AssertionError): # Not implemented yet plugins = [ dict( cfg=dict( type='GeneralizedAttention', spatial_range=-1, num_heads=8, attention_type='0010', kv_stride=2), position='after_conv2') ] BasicBlock(64, 64, plugins=plugins) # test BasicBlock structure and forward block = BasicBlock(64, 64) assert block.conv1.in_channels == 64 assert block.conv1.out_channels == 64 assert block.conv1.kernel_size == (3, 3) assert block.conv2.in_channels == 64 assert block.conv2.out_channels == 64 assert block.conv2.kernel_size == (3, 3) x = torch.randn(1, 64, 56, 56) x_out = block(x) assert x_out.shape == torch.Size([1, 64, 56, 56]) # Test BasicBlock with checkpoint forward block = BasicBlock(64, 64, with_cp=True) assert block.with_cp x = torch.randn(1, 64, 56, 56) x_out = block(x) assert x_out.shape == torch.Size([1, 64, 56, 56])