def check(crop_size: TupleInt3, multiple_of: Optional[IntOrTuple3] = None, minimum: Optional[IntOrTuple3] = None) -> CropSizeConstraints: constraints = CropSizeConstraints(minimum_size=minimum, multiple_of=multiple_of) constraints.validate(crop_size) return constraints
def test_restrict_crop_size_too_small() -> None: """Test the modification of crop sizes when the image size is below the minimum.""" shape = (10, 30, 40) crop_size = (20, 40, 20) stride = (10, 20, 20) constraint = CropSizeConstraints(multiple_of=16) with pytest.raises(ValueError) as e: constraint.restrict_crop_size_to_image(shape, crop_size, stride) # Error message should contain the actual image size assert str(shape) in e.value.args[0] assert "16" in e.value.args[0]
def test_crop_size_constructor() -> None: """ Test error handling in the constructor of CropSizeConstraints """ # Minimum size is given as 3-tuple, but working with 2 dimensions: with pytest.raises(ValueError) as err: CropSizeConstraints(minimum_size=(1, 2, 3), num_dimensions=2) assert "must have length 2" in str(err) # Minimum size and multiple_of are not compatible: with pytest.raises(ValueError) as err: CropSizeConstraints(minimum_size=(1, 2, 3), multiple_of=16) assert "The minimum size must be at least as large" in str(err)
def test_restrict_crop_size_large_image() -> None: """Test the modification of crop sizes when the image size is larger than the crop: The crop and stride should be returned unchanged.""" shape = (30, 50, 50) crop_size = (20, 40, 40) stride = crop_size constraint = CropSizeConstraints(multiple_of=1) expected_crop = crop_size expected_stride = stride check_restrict_crop(constraint, shape, crop_size, stride, expected_crop, expected_stride) constraint2 = CropSizeConstraints(multiple_of=1, minimum_size=999) with pytest.raises(ValueError) as err: check_restrict_crop(constraint2, shape, crop_size, stride, expected_crop, expected_stride) assert "at least a size of" in str(err) assert "999" in str(err)
def create_model(crop_size: TupleInt3, multiple_of: IntOrTuple3) -> BaseModel: model = SimpleModel( 1, [1], 2, 2, crop_size_constraints=CropSizeConstraints(multiple_of=multiple_of)) model.validate_crop_size(crop_size) return model
def test_model_summary_on_minimum_crop_size() -> None: """ Test that a model summary is generated when no specific crop size is specified. """ model = MyFavModel() min_crop_size = (5, 6, 7) model.crop_size_constraints = CropSizeConstraints(minimum_size=min_crop_size) model.generate_model_summary() assert model.summary_crop_size == min_crop_size assert model.summary is not None
def check_restrict_crop(constraint: CropSizeConstraints, shape: TupleInt3, crop_size: TupleInt3, stride: TupleInt3, expected_crop: TupleInt3, expected_stride: TupleInt3) -> None: (crop_new, stride_new) = constraint.restrict_crop_size_to_image( shape, crop_size, stride) assert crop_new == expected_crop assert stride_new == expected_stride # Stride and crop must be integer tuples assert isinstance(crop_new[0], int) assert isinstance(stride_new[0], int)
def test_restrict_crop_size_uneven() -> None: """ Test a case when the image is larger than the crop in Z, but not in X and Y. """ shape = (20, 30, 30) crop_size = (10, 60, 60) stride = crop_size constraint = CropSizeConstraints(multiple_of=1) expected_crop = (10, 30, 30) expected_stride = (10, 30, 30) check_restrict_crop(constraint, shape, crop_size, stride, expected_crop, expected_stride)
def test_restrict_crop_size_tuple() -> None: """Test the modification of crop sizes when the image size is smaller than the crop, and the crop_multiple is a tuple with element-wise multiples.""" shape = (20, 35, 40) crop_size = (25, 40, 20) stride = (10, 20, 20) constraint = CropSizeConstraints(multiple_of=(1, 16, 16)) # Expected new crop size is the elementwise minimum of crop_size and shape, # rounded down to the nearest multiple of 16, apart from dimension 0 expected_crop = (20, 48, 32) expected_stride = (8, 24, 32) check_restrict_crop(constraint, shape, crop_size, stride, expected_crop, expected_stride)
def test_restrict_crop_size() -> None: """Test the modification of crop sizes when the image size is smaller than the crop.""" shape = (20, 35, 40) crop_size = (25, 40, 20) stride = (10, 20, 20) constraint = CropSizeConstraints(multiple_of=16) # Expected new crop size is the elementwise minimum of crop_size and shape, # rounded up to the nearest multiple of 16 expected_crop = (32, 48, 32) # Stride should maintain (elementwise) the same ratio to crop as before expected_stride = (12, 24, 32) check_restrict_crop(constraint, shape, crop_size, stride, expected_crop, expected_stride)
def build_net(args: SegmentationModelBase) -> BaseSegmentationModel: """ Build network architectures :param args: Network configuration arguments """ full_channels_list = [ args.number_of_image_channels, *args.feature_channels, args.number_of_classes ] initial_fcn = [BasicLayer] * 2 residual_blocks = [[BasicLayer, BasicLayer]] * 3 basic_network_definition = initial_fcn + residual_blocks # type: ignore # no dilation for the initial FCN and then a constant 1 neighbourhood dilation for the rest residual blocks basic_dilations = [1] * len(initial_fcn) + [2, 2] * len( basic_network_definition) # Crop size must be at least 29 because all architectures (apart from UNets) shrink the input image by 28 crop_size_constraints = CropSizeConstraints( minimum_size=basic_size_shrinkage + 1) run_weight_initialization = True network: BaseSegmentationModel if args.architecture == ModelArchitectureConfig.Basic: network_definition = basic_network_definition network = ComplexModel(args, full_channels_list, basic_dilations, network_definition, crop_size_constraints) # type: ignore elif args.architecture == ModelArchitectureConfig.UNet3D: network = UNet3D(input_image_channels=args.number_of_image_channels, initial_feature_channels=args.feature_channels[0], num_classes=args.number_of_classes, kernel_size=args.kernel_size, num_downsampling_paths=args.num_downsampling_paths) run_weight_initialization = False elif args.architecture == ModelArchitectureConfig.UNet2D: network = UNet2D(input_image_channels=args.number_of_image_channels, initial_feature_channels=args.feature_channels[0], num_classes=args.number_of_classes, padding_mode=PaddingMode.Edge, num_downsampling_paths=args.num_downsampling_paths) run_weight_initialization = False else: raise ValueError(f"Unknown model architecture {args.architecture}") network.validate_crop_size(args.crop_size, "Training crop size") network.validate_crop_size(args.test_crop_size, "Test crop size") # type: ignore # Initialize network weights if run_weight_initialization: network.apply(init_weights) # type: ignore return network
def __init__(self, input_channels: Any, channels: Any, n_classes: int, kernel_size: int): # minimum crop size: Network first reduces size by 4, then halves, then multiplies by 2 and adds 1 # 64 -> 62 -> 30 -> 61 -> 61 super().__init__(name='SimpleModel', input_channels=input_channels, crop_size_constraints=CropSizeConstraints(minimum_size=6)) self.channels = channels self.n_classes = n_classes self.kernel_size = kernel_size self._model = torch.nn.Sequential( torch.nn.Conv3d(input_channels, channels[0], kernel_size=self.kernel_size), torch.nn.Conv3d(channels[0], channels[1], kernel_size=self.kernel_size, stride=2), torch.nn.ConvTranspose3d(channels[1], channels[0], kernel_size=self.kernel_size, stride=2), torch.nn.ConvTranspose3d(channels[0], n_classes, kernel_size=1) )
def test_crop_size_constraints() -> None: """ Test the basic logic to validate a crop size inside of a CropSizeConstraints instance. """ def check(crop_size: TupleInt3, multiple_of: Optional[IntOrTuple3] = None, minimum: Optional[IntOrTuple3] = None) -> CropSizeConstraints: constraints = CropSizeConstraints(minimum_size=minimum, multiple_of=multiple_of) constraints.validate(crop_size) return constraints # crop_size_multiple == 1: Any crop size is allowed. c = check((9, 10, 11), multiple_of=1) # If a scalar multiple_of is used, it should be stored expanded along dimensions assert c.multiple_of == (1, 1, 1) # Using a tuple multiple_of: Crop is twice as large as multiple_of in each dimension, this is hence valid. c = check((10, 12, 14), multiple_of=(5, 6, 7)) assert c.multiple_of == (5, 6, 7) # Minimum size has not been provided, should default to multiple_of. assert c.minimum_size == (5, 6, 7) # Try with a couple more common crop sizes check((32, 64, 64), multiple_of=16) check((1, 64, 64), multiple_of=(1, 16, 16)) # Crop size is at the allowed minimum: This is valid check((3, 4, 5), multiple_of=(3, 4, 5)) # Provide a scalar minimum: Should be stored expanded into 3 dimensions c = check((9, 10, 11), minimum=2) assert c.minimum_size == (2, 2, 2) assert c.multiple_of is None # Provide a tuple minimum c = check((9, 10, 11), minimum=(5, 6, 7)) assert c.minimum_size == (5, 6, 7) # A crop size at exactly the minimum is valid check((5, 6, 7), minimum=(5, 6, 7)) # Checking for minimum and multiple at the same time check((9, 10, 11), minimum=1, multiple_of=1) check((10, 12, 14), minimum=(5, 6, 7), multiple_of=2) def assert_fails(crop_size: TupleInt3, multiple_of: Optional[IntOrTuple3] = None, minimum: Optional[IntOrTuple3] = None) -> None: with pytest.raises(ValueError) as err: check(crop_size, multiple_of, minimum) assert str(crop_size) in str(err) assert "Crop size is not valid" in str(err) # Crop size is not a multiple of 2 in dimensions 0 and 1 assert_fails((3, 4, 5), 2) # Crop size is not a multiple of 6 in dimension 2 assert_fails((3, 4, 5), (3, 4, 6)) assert_fails((16, 16, 200), 16) # Crop size is too small assert_fails((1, 2, 3), (10, 10, 10)) assert_fails((10, 20, 30), multiple_of=10, minimum=20) # Minimum size must be at least as large as multiple_of: with pytest.raises(ValueError) as err: CropSizeConstraints(minimum_size=10, multiple_of=20, num_dimensions=2) assert "minimum size must be at least as large as the multiple_of" in str( err) assert "(10, 10)" in str(err) assert "(20, 20)" in str(err) # Check that num_dimensions is working as expected (though not presently used in the codebase) c = CropSizeConstraints(minimum_size=30, multiple_of=10, num_dimensions=2) assert c.multiple_of == (10, 10) assert c.minimum_size == (30, 30)
def __init__(self, input_image_channels: int, initial_feature_channels: int, num_classes: int, kernel_size: IntOrTuple3, name: str = "UNet3D", num_downsampling_paths: int = 4, downsampling_factor: IntOrTuple3 = 2, downsampling_dilation: IntOrTuple3 = (1, 1, 1), padding_mode: PaddingMode = PaddingMode.Zero): if isinstance(downsampling_factor, int): downsampling_factor = (downsampling_factor,) * 3 crop_size_multiple = tuple(factor ** num_downsampling_paths for factor in downsampling_factor) crop_size_constraints = CropSizeConstraints(multiple_of=crop_size_multiple) super().__init__(name=name, input_channels=input_image_channels, crop_size_constraints=crop_size_constraints) """ Modified 3D-Unet Class :param input_image_channels: Number of image channels (scans) that are fed into the model. :param initial_feature_channels: Number of feature-maps used in the model - Subsequent layers will contain number of featuremaps that is multiples of `initial_feature_channels` (e.g. 2^(image_level) * initial_feature_channels) :param num_classes: Number of output classes :param kernel_size: Spatial support of conv kernels in each spatial axis. :param num_downsampling_paths: Number of image levels used in Unet (in encoding and decoding paths) :param downsampling_factor: Spatial downsampling factor for each tensor axis (depth, width, height). This will be used as the stride for the first convolution layer in each encoder block. :param downsampling_dilation: An additional dilation that is used in the second convolution layer in each of the encoding blocks of the UNet. This can be used to increase the receptive field of the network. A good choice is (1, 2, 2), to increase the receptive field only in X and Y. :param crop_size: The size of the crop that should be used for training. """ self.num_dimensions = 3 self._layers = torch.nn.ModuleList() self.upsampling_kernel_size = get_upsampling_kernel_size(downsampling_factor, self.num_dimensions) # Create forward blocks for the encoding side, including central part self._layers.append(UNet3D.UNetEncodeBlock((self.input_channels, self.initial_feature_channels), kernel_size=self.kernel_size, downsampling_stride=1, padding_mode=self.padding_mode, depth=0)) current_channels = self.initial_feature_channels for depth in range(1, self.num_downsampling_paths + 1): # type: ignore self._layers.append(UNet3D.UNetEncodeBlock((current_channels, current_channels * 2), # type: ignore kernel_size=self.kernel_size, downsampling_stride=self.downsampling_factor, dilation=self.downsampling_dilation, padding_mode=self.padding_mode, depth=depth)) current_channels *= 2 # type: ignore # Create forward blocks and upsampling layers for the decoding side for depth in range(self.num_downsampling_paths + 1, 1, -1): # type: ignore channels = (current_channels, current_channels // 2) # type: ignore self._layers.append(UNet3D.UNetDecodeBlock(channels, upsample_kernel_size=self.upsampling_kernel_size, upsampling_stride=self.downsampling_factor)) # Use negative depth to distinguish the encode blocks in the decoding pathway. self._layers.append(UNet3D.UNetEncodeBlockSynthesis(channels=(channels[1], channels[1]), kernel_size=self.kernel_size, padding_mode=self.padding_mode, depth=-depth)) current_channels //= 2 # type: ignore # Add final fc layer self.output_layer = Conv3d(current_channels, self.num_classes, kernel_size=1) # type: ignore