def test_unet2d_decode(num_patches: int, image_shape: TupleInt3) -> None: """ Test if the Decode block of a UNet3D creates tensors of the expected size when the kernels only operate in X and Y. """ set_random_seed(1234) num_input_channels = image_shape[0] num_output_channels = num_input_channels // 2 upsample_layer = UNet2D.UNetDecodeBlock((num_input_channels, num_output_channels), upsample_kernel_size=(1, 4, 4), upsampling_stride=(1, 2, 2)) encode_layer = UNet2D.UNetEncodeBlockSynthesis(channels=(num_output_channels, num_output_channels), kernel_size=(1, 3, 3)) dim_z = 1 input_shape = (num_patches, num_input_channels, dim_z, image_shape[1], image_shape[2]) input_tensor = torch.rand(*input_shape).float() skip_connection = torch.zeros((num_patches, num_output_channels, dim_z, image_shape[1] * 2, image_shape[2] * 2)) output = encode_layer(upsample_layer(input_tensor), skip_connection) def output_image_size(i: int) -> int: return image_shape[i] * 2 # Expected output shape: # The first dimension (patches) should be retained unchanged. # We should get as many output channels as requested # Unet is defined as running over degenerate 3D images with Z=1, this should be preserved. # The two trailing dimensions are the adjusted image dimensions expected_output_shape = (num_patches, num_output_channels, dim_z, output_image_size(1), output_image_size(2)) assert output.shape == expected_output_shape
def build_net(args: SegmentationModelBase) -> BaseSegmentationModel: """ Build network architectures :param args: Network configuration arguments """ full_channels_list = [ args.number_of_image_channels, *args.feature_channels, args.number_of_classes ] initial_fcn = [BasicLayer] * 2 residual_blocks = [[BasicLayer, BasicLayer]] * 3 basic_network_definition = initial_fcn + residual_blocks # type: ignore # no dilation for the initial FCN and then a constant 1 neighbourhood dilation for the rest residual blocks basic_dilations = [1] * len(initial_fcn) + [2, 2] * len( basic_network_definition) # Crop size must be at least 29 because all architectures (apart from UNets) shrink the input image by 28 crop_size_constraints = CropSizeConstraints( minimum_size=basic_size_shrinkage + 1) run_weight_initialization = True network: BaseSegmentationModel if args.architecture == ModelArchitectureConfig.Basic: network_definition = basic_network_definition network = ComplexModel(args, full_channels_list, basic_dilations, network_definition, crop_size_constraints) # type: ignore elif args.architecture == ModelArchitectureConfig.UNet3D: network = UNet3D(input_image_channels=args.number_of_image_channels, initial_feature_channels=args.feature_channels[0], num_classes=args.number_of_classes, kernel_size=args.kernel_size, num_downsampling_paths=args.num_downsampling_paths) run_weight_initialization = False elif args.architecture == ModelArchitectureConfig.UNet2D: network = UNet2D(input_image_channels=args.number_of_image_channels, initial_feature_channels=args.feature_channels[0], num_classes=args.number_of_classes, padding_mode=PaddingMode.Edge, num_downsampling_paths=args.num_downsampling_paths) run_weight_initialization = False else: raise ValueError(f"Unknown model architecture {args.architecture}") network.validate_crop_size(args.crop_size, "Training crop size") network.validate_crop_size(args.test_crop_size, "Test crop size") # type: ignore # Initialize network weights if run_weight_initialization: network.apply(init_weights) # type: ignore return network