Ejemplo n.º 1
0
def test_unet3_num_downsampling_paths() -> None:
    for num_downsampling_paths in range(1, 5):
        j = int(2**num_downsampling_paths)

        # Test that num_downsampling_paths for built UNet3D
        # is set via model configuration
        crop_size = (j, j, j)
        config = SegmentationModelBase(
            architecture=ModelArchitectureConfig.UNet3D,
            image_channels=["ct"],
            feature_channels=[1],
            crop_size=crop_size,
            num_downsampling_paths=num_downsampling_paths,
            should_validate=False)
        network = build_net(config)
        assert network.num_downsampling_paths == num_downsampling_paths

        # Test that exception is raised if crop size is smaller than is allowed
        # by num_downsampling_paths
        too_small_crop_size = (j // 2, j // 2, j // 2)
        ex_msg = f"Crop size is not valid. The required minimum is {crop_size}"
        config = SegmentationModelBase(
            architecture=ModelArchitectureConfig.UNet3D,
            image_channels=["ct"],
            feature_channels=[1],
            crop_size=too_small_crop_size,
            num_downsampling_paths=num_downsampling_paths,
            should_validate=False)
        with pytest.raises(ValueError) as ex:
            network = build_net(config)
        assert ex_msg in str(ex)
Ejemplo n.º 2
0
 def create_model(self) -> Any:
     """
     Creates a PyTorch model from the settings stored in the present object.
     :return: The network model as a torch.nn.Module object
     """
     # Use a local import here to avoid reliance on pytorch too early.
     # Return type should be BaseModel, but that would also introduce reliance on pytorch.
     from InnerEye.ML.utils.model_util import build_net
     return build_net(self)
Ejemplo n.º 3
0
def test_crop_size_multiple_in_build_net() -> None:
    """
    Tests if the the crop_size validation is really called in the model creation code
    """
    config = SegmentationModelBase(architecture=ModelArchitectureConfig.UNet3D,
                                   image_channels=["ct"],
                                   feature_channels=[1],
                                   kernel_size=3,
                                   crop_size=(17, 16, 16),
                                   should_validate=False)
    # Crop size of 17 in dimension 0 is invalid for a UNet3D, should be multiple of 16.
    # This should raise a ValueError.
    with pytest.raises(ValueError) as ex:
        build_net(config)
    assert "Training crop size: Crop size is not valid" in str(ex)
    config.crop_size = (16, 16, 16)
    config.test_crop_size = (17, 18, 19)
    with pytest.raises(ValueError) as ex:
        build_net(config)
    assert "Test crop size: Crop size is not valid" in str(ex)