Example #1
0
 def __init__(self,
              spatial_dims: int,
              in_channels: int,
              out_channels: int,
              kernel_size: Union[Sequence[int], int],
              stride: Union[Sequence[int], int],
              num_groups: int,
              norm_name: str,
              is_prunable: bool = False):
     super(UnetResBlockEx, self).__init__(
         spatial_dims=spatial_dims,
         in_channels=in_channels,
         out_channels=out_channels,
         kernel_size=kernel_size,
         stride=stride,
         norm_name=norm_name,
     )
     self.conv1 = get_conv_layer(
         spatial_dims,
         in_channels,
         out_channels,
         kernel_size=kernel_size,
         stride=stride,
         conv_only=True,
         num_groups=num_groups,
         is_prunable=is_prunable,
     )
     self.conv2 = get_conv_layer(
         spatial_dims,
         out_channels,
         out_channels,
         kernel_size=kernel_size,
         stride=1,
         conv_only=True,
         num_groups=num_groups,
         is_prunable=is_prunable,
     )
     self.conv3 = get_conv_layer(
         spatial_dims,
         in_channels,
         out_channels,
         kernel_size=1,
         stride=stride,
         conv_only=True,
         num_groups=num_groups,
         is_prunable=is_prunable,
     )
     self.lrelu = get_act_layer(("leakyrelu", {
         "inplace": True,
         "negative_slope": 0.01
     }))
     self.norm1 = get_norm_layer(spatial_dims, out_channels, norm_name)
     self.norm2 = get_norm_layer(spatial_dims, out_channels, norm_name)
     self.norm3 = get_norm_layer(spatial_dims, out_channels, norm_name)
     self.downsample = in_channels != out_channels
     stride_np = np.atleast_1d(stride)
     if not np.all(stride_np == 1):
         self.downsample = True
Example #2
0
 def test_suggested(self):
     with self.assertRaisesRegex(ValueError, "did you mean 'GROUP'?"):
         get_norm_layer(name="grop", spatial_dims=2)
Example #3
0
 def test_norm_layer(self, input_param, expected):
     layer = get_norm_layer(**input_param)
     self.assertEqual(f"{layer}", expected)