def test_ill_arg(self): with self.assertRaises(AssertionError): ResBlock(spatial_dims=3, in_channels=8, kernel_size=2, num_groups=8) with self.assertRaises(ValueError): ResBlock(spatial_dims=3, in_channels=8, norm_name="norm", num_groups=8) with self.assertRaises(AssertionError): ResBlock(spatial_dims=3, in_channels=8, num_groups=3)
def test_ill_arg(self): with self.assertRaises(AssertionError): ResBlock(spatial_dims=3, in_channels=8, norm="group", kernel_size=2) with self.assertRaises(ValueError): ResBlock(spatial_dims=3, in_channels=8, norm="norm")
def _make_up_layers(self): up_layers, up_samples = nn.ModuleList(), nn.ModuleList() upsample_mode, blocks_up, spatial_dims, filters, norm = ( self.upsample_mode, self.blocks_up, self.spatial_dims, self.init_filters, self.norm, ) n_up = len(blocks_up) for i in range(n_up): sample_in_channels = filters * 2**(n_up - i) up_layers.append( nn.Sequential(*[ ResBlock(spatial_dims, sample_in_channels // 2, norm=norm) for _ in range(blocks_up[i]) ])) up_samples.append( nn.Sequential(*[ get_conv_layer(spatial_dims, sample_in_channels, sample_in_channels // 2, kernel_size=1), get_upsample_layer(spatial_dims, sample_in_channels // 2, upsample_mode=upsample_mode), ])) return up_layers, up_samples
def _make_down_layers(self): down_layers = nn.ModuleList() blocks_down, spatial_dims, filters, norm_name, num_groups = ( self.blocks_down, self.spatial_dims, self.init_filters, self.norm_name, self.num_groups, ) for i in range(len(blocks_down)): layer_in_channels = filters * 2**i pre_conv = (get_conv_layer(spatial_dims, layer_in_channels // 2, layer_in_channels, stride=2) if i > 0 else nn.Identity()) down_layer = nn.Sequential( pre_conv, *[ ResBlock(spatial_dims, layer_in_channels, norm_name=norm_name, num_groups=num_groups) for _ in range(blocks_down[i]) ], ) down_layers.append(down_layer) return down_layers
def test_shape(self, input_param, input_shape, expected_shape): net = ResBlock(**input_param) with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape)
def test_shape(self, input_param, input_shape, expected_shape): net = ResBlock(**input_param) net.eval() with torch.no_grad(): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape)