def test_shape(self, input_param, input_shape, expected_shape): for net in [ UnetResBlock(**input_param), UnetBasicBlock(**input_param) ]: with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape)
def test_shape(self, input_param, input_shape, expected_shape): for net in [ UnetResBlock(**input_param), UnetBasicBlock(**input_param) ]: net.eval() with torch.no_grad(): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape)
def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int], upsample_kernel_size: Union[Sequence[int], int], norm_name: Union[Tuple, str], res_block: bool = False, ) -> None: """ Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. kernel_size: convolution kernel size. upsample_kernel_size: convolution kernel size for transposed convolution layers. norm_name: feature normalization type and arguments. res_block: bool argument to determine if residual block is used. """ super().__init__() upsample_stride = upsample_kernel_size self.transp_conv = get_conv_layer( spatial_dims, in_channels, out_channels, kernel_size=upsample_kernel_size, stride=upsample_stride, conv_only=True, is_transposed=True, ) if res_block: self.conv_block = UnetResBlock( spatial_dims, out_channels + out_channels, out_channels, kernel_size=kernel_size, stride=1, norm_name=norm_name, ) else: self.conv_block = UnetBasicBlock( # type: ignore spatial_dims, out_channels + out_channels, out_channels, kernel_size=kernel_size, stride=1, norm_name=norm_name, )
def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], norm_name: Union[Tuple, str], res_block: bool = False, ) -> None: """ Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. kernel_size: convolution kernel size. stride: convolution stride. norm_name: feature normalization type and arguments. res_block: bool argument to determine if residual block is used. """ super().__init__() if res_block: self.layer = UnetResBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, norm_name=norm_name, ) else: self.layer = UnetBasicBlock( # type: ignore spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, norm_name=norm_name, )
def test_ill_arg(self): with self.assertRaises(ValueError): UnetBasicBlock(3, 4, 2, kernel_size=3, stride=1, norm_name="norm") with self.assertRaises(AssertionError): UnetResBlock(3, 4, 2, kernel_size=1, stride=4, norm_name="batch")
def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, num_layer: int, kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], upsample_kernel_size: Union[Sequence[int], int], norm_name: Union[Tuple, str], conv_block: bool = False, res_block: bool = False, ) -> None: """ Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. num_layer: number of upsampling blocks. kernel_size: convolution kernel size. stride: convolution stride. upsample_kernel_size: convolution kernel size for transposed convolution layers. norm_name: feature normalization type and arguments. conv_block: bool argument to determine if convolutional block is used. res_block: bool argument to determine if residual block is used. """ super().__init__() upsample_stride = upsample_kernel_size self.transp_conv_init = get_conv_layer( spatial_dims, in_channels, out_channels, kernel_size=upsample_kernel_size, stride=upsample_stride, conv_only=True, is_transposed=True, ) if conv_block: if res_block: self.blocks = nn.ModuleList([ nn.Sequential( get_conv_layer( spatial_dims, out_channels, out_channels, kernel_size=upsample_kernel_size, stride=upsample_stride, conv_only=True, is_transposed=True, ), UnetResBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, norm_name=norm_name, ), ) for i in range(num_layer) ]) else: self.blocks = nn.ModuleList([ nn.Sequential( get_conv_layer( spatial_dims, out_channels, out_channels, kernel_size=upsample_kernel_size, stride=upsample_stride, conv_only=True, is_transposed=True, ), UnetBasicBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, norm_name=norm_name, ), ) for i in range(num_layer) ]) else: self.blocks = nn.ModuleList([ get_conv_layer( spatial_dims, out_channels, out_channels, kernel_size=upsample_kernel_size, stride=upsample_stride, conv_only=True, is_transposed=True, ) for i in range(num_layer) ])