Example #1
0
 def check(crop_size: TupleInt3,
           multiple_of: Optional[IntOrTuple3] = None,
           minimum: Optional[IntOrTuple3] = None) -> CropSizeConstraints:
     constraints = CropSizeConstraints(minimum_size=minimum,
                                       multiple_of=multiple_of)
     constraints.validate(crop_size)
     return constraints
Example #2
0
def test_restrict_crop_size_too_small() -> None:
    """Test the modification of crop sizes when the image size is below the minimum."""
    shape = (10, 30, 40)
    crop_size = (20, 40, 20)
    stride = (10, 20, 20)
    constraint = CropSizeConstraints(multiple_of=16)
    with pytest.raises(ValueError) as e:
        constraint.restrict_crop_size_to_image(shape, crop_size, stride)
    # Error message should contain the actual image size
    assert str(shape) in e.value.args[0]
    assert "16" in e.value.args[0]
Example #3
0
def test_crop_size_constructor() -> None:
    """
    Test error handling in the constructor of CropSizeConstraints
    """
    # Minimum size is given as 3-tuple, but working with 2 dimensions:
    with pytest.raises(ValueError) as err:
        CropSizeConstraints(minimum_size=(1, 2, 3), num_dimensions=2)
    assert "must have length 2" in str(err)
    # Minimum size and multiple_of are not compatible:
    with pytest.raises(ValueError) as err:
        CropSizeConstraints(minimum_size=(1, 2, 3), multiple_of=16)
    assert "The minimum size must be at least as large" in str(err)
def test_restrict_crop_size_large_image() -> None:
    """Test the modification of crop sizes when the image size is larger than the crop:
    The crop and stride should be returned unchanged."""
    shape = (30, 50, 50)
    crop_size = (20, 40, 40)
    stride = crop_size
    constraint = CropSizeConstraints(multiple_of=1)
    expected_crop = crop_size
    expected_stride = stride
    check_restrict_crop(constraint, shape, crop_size, stride, expected_crop, expected_stride)
    constraint2 = CropSizeConstraints(multiple_of=1, minimum_size=999)
    with pytest.raises(ValueError) as err:
        check_restrict_crop(constraint2, shape, crop_size, stride, expected_crop, expected_stride)
    assert "at least a size of" in str(err)
    assert "999" in str(err)
Example #5
0
 def create_model(crop_size: TupleInt3,
                  multiple_of: IntOrTuple3) -> BaseModel:
     model = SimpleModel(
         1, [1],
         2,
         2,
         crop_size_constraints=CropSizeConstraints(multiple_of=multiple_of))
     model.validate_crop_size(crop_size)
     return model
def test_model_summary_on_minimum_crop_size() -> None:
    """
    Test that a model summary is generated when no specific crop size is specified.
    """
    model = MyFavModel()
    min_crop_size = (5, 6, 7)
    model.crop_size_constraints = CropSizeConstraints(minimum_size=min_crop_size)
    model.generate_model_summary()
    assert model.summary_crop_size == min_crop_size
    assert model.summary is not None
Example #7
0
def check_restrict_crop(constraint: CropSizeConstraints, shape: TupleInt3,
                        crop_size: TupleInt3, stride: TupleInt3,
                        expected_crop: TupleInt3,
                        expected_stride: TupleInt3) -> None:
    (crop_new, stride_new) = constraint.restrict_crop_size_to_image(
        shape, crop_size, stride)
    assert crop_new == expected_crop
    assert stride_new == expected_stride
    # Stride and crop must be integer tuples
    assert isinstance(crop_new[0], int)
    assert isinstance(stride_new[0], int)
def test_restrict_crop_size_uneven() -> None:
    """
    Test a case when the image is larger than the crop in Z, but not in X and Y.
    """
    shape = (20, 30, 30)
    crop_size = (10, 60, 60)
    stride = crop_size
    constraint = CropSizeConstraints(multiple_of=1)
    expected_crop = (10, 30, 30)
    expected_stride = (10, 30, 30)
    check_restrict_crop(constraint, shape, crop_size, stride, expected_crop, expected_stride)
def test_restrict_crop_size_tuple() -> None:
    """Test the modification of crop sizes when the image size is smaller than the crop,
    and the crop_multiple is a tuple with element-wise multiples."""
    shape = (20, 35, 40)
    crop_size = (25, 40, 20)
    stride = (10, 20, 20)
    constraint = CropSizeConstraints(multiple_of=(1, 16, 16))
    # Expected new crop size is the elementwise minimum of crop_size and shape,
    # rounded down to the nearest multiple of 16, apart from dimension 0
    expected_crop = (20, 48, 32)
    expected_stride = (8, 24, 32)
    check_restrict_crop(constraint, shape, crop_size, stride, expected_crop, expected_stride)
def test_restrict_crop_size() -> None:
    """Test the modification of crop sizes when the image size is smaller than the crop."""
    shape = (20, 35, 40)
    crop_size = (25, 40, 20)
    stride = (10, 20, 20)
    constraint = CropSizeConstraints(multiple_of=16)
    # Expected new crop size is the elementwise minimum of crop_size and shape,
    # rounded up to the nearest multiple of 16
    expected_crop = (32, 48, 32)
    # Stride should maintain (elementwise) the same ratio to crop as before
    expected_stride = (12, 24, 32)
    check_restrict_crop(constraint, shape, crop_size, stride, expected_crop, expected_stride)
Example #11
0
def build_net(args: SegmentationModelBase) -> BaseSegmentationModel:
    """
    Build network architectures

    :param args: Network configuration arguments
    """
    full_channels_list = [
        args.number_of_image_channels, *args.feature_channels,
        args.number_of_classes
    ]
    initial_fcn = [BasicLayer] * 2
    residual_blocks = [[BasicLayer, BasicLayer]] * 3
    basic_network_definition = initial_fcn + residual_blocks  # type: ignore
    # no dilation for the initial FCN and then a constant 1 neighbourhood dilation for the rest residual blocks
    basic_dilations = [1] * len(initial_fcn) + [2, 2] * len(
        basic_network_definition)
    # Crop size must be at least 29 because all architectures (apart from UNets) shrink the input image by 28
    crop_size_constraints = CropSizeConstraints(
        minimum_size=basic_size_shrinkage + 1)
    run_weight_initialization = True

    network: BaseSegmentationModel
    if args.architecture == ModelArchitectureConfig.Basic:
        network_definition = basic_network_definition
        network = ComplexModel(args, full_channels_list, basic_dilations,
                               network_definition,
                               crop_size_constraints)  # type: ignore

    elif args.architecture == ModelArchitectureConfig.UNet3D:
        network = UNet3D(input_image_channels=args.number_of_image_channels,
                         initial_feature_channels=args.feature_channels[0],
                         num_classes=args.number_of_classes,
                         kernel_size=args.kernel_size,
                         num_downsampling_paths=args.num_downsampling_paths)
        run_weight_initialization = False

    elif args.architecture == ModelArchitectureConfig.UNet2D:
        network = UNet2D(input_image_channels=args.number_of_image_channels,
                         initial_feature_channels=args.feature_channels[0],
                         num_classes=args.number_of_classes,
                         padding_mode=PaddingMode.Edge,
                         num_downsampling_paths=args.num_downsampling_paths)
        run_weight_initialization = False

    else:
        raise ValueError(f"Unknown model architecture {args.architecture}")
    network.validate_crop_size(args.crop_size, "Training crop size")
    network.validate_crop_size(args.test_crop_size,
                               "Test crop size")  # type: ignore
    # Initialize network weights
    if run_weight_initialization:
        network.apply(init_weights)  # type: ignore
    return network
Example #12
0
 def __init__(self, input_channels: Any, channels: Any, n_classes: int, kernel_size: int):
     # minimum crop size: Network first reduces size by 4, then halves, then multiplies by 2 and adds 1
     # 64 -> 62 -> 30 -> 61 -> 61
     super().__init__(name='SimpleModel',
                      input_channels=input_channels,
                      crop_size_constraints=CropSizeConstraints(minimum_size=6))
     self.channels = channels
     self.n_classes = n_classes
     self.kernel_size = kernel_size
     self._model = torch.nn.Sequential(
         torch.nn.Conv3d(input_channels, channels[0], kernel_size=self.kernel_size),
         torch.nn.Conv3d(channels[0], channels[1], kernel_size=self.kernel_size, stride=2),
         torch.nn.ConvTranspose3d(channels[1], channels[0], kernel_size=self.kernel_size, stride=2),
         torch.nn.ConvTranspose3d(channels[0], n_classes, kernel_size=1)
     )
Example #13
0
def test_crop_size_constraints() -> None:
    """
    Test the basic logic to validate a crop size inside of a CropSizeConstraints instance.
    """
    def check(crop_size: TupleInt3,
              multiple_of: Optional[IntOrTuple3] = None,
              minimum: Optional[IntOrTuple3] = None) -> CropSizeConstraints:
        constraints = CropSizeConstraints(minimum_size=minimum,
                                          multiple_of=multiple_of)
        constraints.validate(crop_size)
        return constraints

    # crop_size_multiple == 1: Any crop size is allowed.
    c = check((9, 10, 11), multiple_of=1)
    # If a scalar multiple_of is used, it should be stored expanded along dimensions
    assert c.multiple_of == (1, 1, 1)
    # Using a tuple multiple_of: Crop is twice as large as multiple_of in each dimension, this is hence valid.
    c = check((10, 12, 14), multiple_of=(5, 6, 7))
    assert c.multiple_of == (5, 6, 7)
    # Minimum size has not been provided, should default to multiple_of.
    assert c.minimum_size == (5, 6, 7)
    # Try with a couple more common crop sizes
    check((32, 64, 64), multiple_of=16)
    check((1, 64, 64), multiple_of=(1, 16, 16))
    # Crop size is at the allowed minimum: This is valid
    check((3, 4, 5), multiple_of=(3, 4, 5))
    # Provide a scalar minimum: Should be stored expanded into 3 dimensions
    c = check((9, 10, 11), minimum=2)
    assert c.minimum_size == (2, 2, 2)
    assert c.multiple_of is None
    # Provide a tuple minimum
    c = check((9, 10, 11), minimum=(5, 6, 7))
    assert c.minimum_size == (5, 6, 7)
    # A crop size at exactly the minimum is valid
    check((5, 6, 7), minimum=(5, 6, 7))
    # Checking for minimum and multiple at the same time
    check((9, 10, 11), minimum=1, multiple_of=1)
    check((10, 12, 14), minimum=(5, 6, 7), multiple_of=2)

    def assert_fails(crop_size: TupleInt3,
                     multiple_of: Optional[IntOrTuple3] = None,
                     minimum: Optional[IntOrTuple3] = None) -> None:
        with pytest.raises(ValueError) as err:
            check(crop_size, multiple_of, minimum)
        assert str(crop_size) in str(err)
        assert "Crop size is not valid" in str(err)

    # Crop size is not a multiple of 2 in dimensions 0 and 1
    assert_fails((3, 4, 5), 2)
    # Crop size is not a multiple of 6 in dimension 2
    assert_fails((3, 4, 5), (3, 4, 6))
    assert_fails((16, 16, 200), 16)
    # Crop size is too small
    assert_fails((1, 2, 3), (10, 10, 10))
    assert_fails((10, 20, 30), multiple_of=10, minimum=20)

    # Minimum size must be at least as large as multiple_of:
    with pytest.raises(ValueError) as err:
        CropSizeConstraints(minimum_size=10, multiple_of=20, num_dimensions=2)
    assert "minimum size must be at least as large as the multiple_of" in str(
        err)
    assert "(10, 10)" in str(err)
    assert "(20, 20)" in str(err)

    # Check that num_dimensions is working as expected (though not presently used in the codebase)
    c = CropSizeConstraints(minimum_size=30, multiple_of=10, num_dimensions=2)
    assert c.multiple_of == (10, 10)
    assert c.minimum_size == (30, 30)
Example #14
0
    def __init__(self,
                 input_image_channels: int,
                 initial_feature_channels: int,
                 num_classes: int,
                 kernel_size: IntOrTuple3,
                 name: str = "UNet3D",
                 num_downsampling_paths: int = 4,
                 downsampling_factor: IntOrTuple3 = 2,
                 downsampling_dilation: IntOrTuple3 = (1, 1, 1),
                 padding_mode: PaddingMode = PaddingMode.Zero):
        if isinstance(downsampling_factor, int):
            downsampling_factor = (downsampling_factor,) * 3
        crop_size_multiple = tuple(factor ** num_downsampling_paths
                                   for factor in downsampling_factor)
        crop_size_constraints = CropSizeConstraints(multiple_of=crop_size_multiple)
        super().__init__(name=name,
                         input_channels=input_image_channels,
                         crop_size_constraints=crop_size_constraints)
        """
        Modified 3D-Unet Class
        :param input_image_channels: Number of image channels (scans) that are fed into the model.
        :param initial_feature_channels: Number of feature-maps used in the model - Subsequent layers will contain
        number
        of featuremaps that is multiples of `initial_feature_channels` (e.g. 2^(image_level) * initial_feature_channels)
        :param num_classes: Number of output classes
        :param kernel_size: Spatial support of conv kernels in each spatial axis.
        :param num_downsampling_paths: Number of image levels used in Unet (in encoding and decoding paths)
        :param downsampling_factor: Spatial downsampling factor for each tensor axis (depth, width, height). This will
        be used as the stride for the first convolution layer in each encoder block.
        :param downsampling_dilation: An additional dilation that is used in the second convolution layer in each
        of the encoding blocks of the UNet. This can be used to increase the receptive field of the network. A good
        choice is (1, 2, 2), to increase the receptive field only in X and Y.
        :param crop_size: The size of the crop that should be used for training.
        """

        self.num_dimensions = 3
        self._layers = torch.nn.ModuleList()
        self.upsampling_kernel_size = get_upsampling_kernel_size(downsampling_factor, self.num_dimensions)

        # Create forward blocks for the encoding side, including central part
        self._layers.append(UNet3D.UNetEncodeBlock((self.input_channels, self.initial_feature_channels),
                                                   kernel_size=self.kernel_size,
                                                   downsampling_stride=1,
                                                   padding_mode=self.padding_mode,
                                                   depth=0))

        current_channels = self.initial_feature_channels
        for depth in range(1, self.num_downsampling_paths + 1):  # type: ignore
            self._layers.append(UNet3D.UNetEncodeBlock((current_channels, current_channels * 2),  # type: ignore
                                                       kernel_size=self.kernel_size,
                                                       downsampling_stride=self.downsampling_factor,
                                                       dilation=self.downsampling_dilation,
                                                       padding_mode=self.padding_mode,
                                                       depth=depth))
            current_channels *= 2  # type: ignore

        # Create forward blocks and upsampling layers for the decoding side
        for depth in range(self.num_downsampling_paths + 1, 1, -1):  # type: ignore
            channels = (current_channels, current_channels // 2)  # type: ignore
            self._layers.append(UNet3D.UNetDecodeBlock(channels,
                                                       upsample_kernel_size=self.upsampling_kernel_size,
                                                       upsampling_stride=self.downsampling_factor))

            # Use negative depth to distinguish the encode blocks in the decoding pathway.
            self._layers.append(UNet3D.UNetEncodeBlockSynthesis(channels=(channels[1], channels[1]),
                                                                kernel_size=self.kernel_size,
                                                                padding_mode=self.padding_mode,
                                                                depth=-depth))

            current_channels //= 2  # type: ignore

        # Add final fc layer
        self.output_layer = Conv3d(current_channels, self.num_classes, kernel_size=1)  # type: ignore