Beispiel #1
0
def test_unet2d_decode(num_patches: int,
                       image_shape: TupleInt3) -> None:
    """
    Test if the Decode block of a UNet3D creates tensors of the expected size when the kernels only operate in
    X and Y.
    """
    set_random_seed(1234)
    num_input_channels = image_shape[0]
    num_output_channels = num_input_channels // 2
    upsample_layer = UNet2D.UNetDecodeBlock((num_input_channels, num_output_channels),
                                            upsample_kernel_size=(1, 4, 4),
                                            upsampling_stride=(1, 2, 2))
    encode_layer = UNet2D.UNetEncodeBlockSynthesis(channels=(num_output_channels, num_output_channels),
                                                   kernel_size=(1, 3, 3))

    dim_z = 1
    input_shape = (num_patches, num_input_channels, dim_z, image_shape[1], image_shape[2])
    input_tensor = torch.rand(*input_shape).float()
    skip_connection = torch.zeros((num_patches, num_output_channels, dim_z, image_shape[1] * 2, image_shape[2] * 2))
    output = encode_layer(upsample_layer(input_tensor), skip_connection)

    def output_image_size(i: int) -> int:
        return image_shape[i] * 2

    # Expected output shape:
    # The first dimension (patches) should be retained unchanged.
    # We should get as many output channels as requested
    # Unet is defined as running over degenerate 3D images with Z=1, this should be preserved.
    # The two trailing dimensions are the adjusted image dimensions
    expected_output_shape = (num_patches, num_output_channels, dim_z, output_image_size(1), output_image_size(2))
    assert output.shape == expected_output_shape
Beispiel #2
0
def build_net(args: SegmentationModelBase) -> BaseSegmentationModel:
    """
    Build network architectures

    :param args: Network configuration arguments
    """
    full_channels_list = [
        args.number_of_image_channels, *args.feature_channels,
        args.number_of_classes
    ]
    initial_fcn = [BasicLayer] * 2
    residual_blocks = [[BasicLayer, BasicLayer]] * 3
    basic_network_definition = initial_fcn + residual_blocks  # type: ignore
    # no dilation for the initial FCN and then a constant 1 neighbourhood dilation for the rest residual blocks
    basic_dilations = [1] * len(initial_fcn) + [2, 2] * len(
        basic_network_definition)
    # Crop size must be at least 29 because all architectures (apart from UNets) shrink the input image by 28
    crop_size_constraints = CropSizeConstraints(
        minimum_size=basic_size_shrinkage + 1)
    run_weight_initialization = True

    network: BaseSegmentationModel
    if args.architecture == ModelArchitectureConfig.Basic:
        network_definition = basic_network_definition
        network = ComplexModel(args, full_channels_list, basic_dilations,
                               network_definition,
                               crop_size_constraints)  # type: ignore

    elif args.architecture == ModelArchitectureConfig.UNet3D:
        network = UNet3D(input_image_channels=args.number_of_image_channels,
                         initial_feature_channels=args.feature_channels[0],
                         num_classes=args.number_of_classes,
                         kernel_size=args.kernel_size,
                         num_downsampling_paths=args.num_downsampling_paths)
        run_weight_initialization = False

    elif args.architecture == ModelArchitectureConfig.UNet2D:
        network = UNet2D(input_image_channels=args.number_of_image_channels,
                         initial_feature_channels=args.feature_channels[0],
                         num_classes=args.number_of_classes,
                         padding_mode=PaddingMode.Edge,
                         num_downsampling_paths=args.num_downsampling_paths)
        run_weight_initialization = False

    else:
        raise ValueError(f"Unknown model architecture {args.architecture}")
    network.validate_crop_size(args.crop_size, "Training crop size")
    network.validate_crop_size(args.test_crop_size,
                               "Test crop size")  # type: ignore
    # Initialize network weights
    if run_weight_initialization:
        network.apply(init_weights)  # type: ignore
    return network