def create_transforms_from_config(
        config: CfgNode,
        apply_augmentations: bool,
        expand_channels: bool = True) -> ImageTransformationPipeline:
    """
    Defines the image transformations pipeline from a config file. It has been designed for Chest X-Ray
    images but it can be used for other types of images data, type of augmentations to use and strength are
    expected to be defined in the config. The channel expansion is needed for gray images.
    :param config: config yaml file fixing strength and type of augmentation to apply
    :param apply_augmentations: if True return transformation pipeline with augmentations. Else,
    disable augmentations i.e. only resize and center crop the image.
    :param expand_channels: if True the expand channel transformation from InnerEye.ML.augmentations.image_transforms
    will be added to the transformation passed through the config. This is needed for single channel images as CXR.
    """
    transforms: List[Any] = []
    if expand_channels:
        transforms.append(ExpandChannels())
    if apply_augmentations:
        if config.augmentation.use_random_affine:
            transforms.append(
                RandomAffine(
                    degrees=config.augmentation.random_affine.max_angle,
                    translate=(
                        config.augmentation.random_affine.max_horizontal_shift,
                        config.augmentation.random_affine.max_vertical_shift),
                    shear=config.augmentation.random_affine.max_shear))
        if config.augmentation.use_random_crop:
            transforms.append(
                RandomResizedCrop(scale=config.augmentation.random_crop.scale,
                                  size=config.preprocess.resize))
        else:
            transforms.append(Resize(size=config.preprocess.resize))
        if config.augmentation.use_random_horizontal_flip:
            transforms.append(
                RandomHorizontalFlip(
                    p=config.augmentation.random_horizontal_flip.prob))
        if config.augmentation.use_gamma_transform:
            transforms.append(
                RandomGamma(scale=config.augmentation.gamma.scale))
        if config.augmentation.use_random_color:
            transforms.append(
                ColorJitter(
                    brightness=config.augmentation.random_color.brightness,
                    contrast=config.augmentation.random_color.contrast,
                    saturation=config.augmentation.random_color.saturation))
        if config.augmentation.use_elastic_transform:
            transforms.append(
                ElasticTransform(
                    alpha=config.augmentation.elastic_transform.alpha,
                    sigma=config.augmentation.elastic_transform.sigma,
                    p_apply=config.augmentation.elastic_transform.p_apply))
        transforms.append(CenterCrop(config.preprocess.center_crop_size))
        if config.augmentation.use_random_erasing:
            transforms.append(
                RandomErasing(scale=config.augmentation.random_erasing.scale,
                              ratio=config.augmentation.random_erasing.ratio))
        if config.augmentation.add_gaussian_noise:
            transforms.append(
                AddGaussianNoise(
                    p_apply=config.augmentation.gaussian_noise.p_apply,
                    std=config.augmentation.gaussian_noise.std))
    else:
        transforms += [
            Resize(size=config.preprocess.resize),
            CenterCrop(config.preprocess.center_crop_size)
        ]
    pipeline = ImageTransformationPipeline(transforms)
    return pipeline
Exemplo n.º 2
0
def test_create_transform_pipeline_from_config() -> None:
    """
    Tests that the pipeline returned by create_transform_pipeline_from_config returns the expected transformation.
    """
    transformation_pipeline = create_cxr_transforms_from_config(
        cxr_augmentation_config, apply_augmentations=True)
    fake_cxr_as_array = np.ones([256, 256]) * 255.
    fake_cxr_as_array[100:150, 100:200] = 1
    fake_cxr_image = PIL.Image.fromarray(fake_cxr_as_array).convert("L")

    all_transforms = [
        ExpandChannels(),
        RandomAffine(degrees=180, translate=(0, 0), shear=40),
        RandomResizedCrop(scale=(0.4, 1.0), size=256),
        RandomHorizontalFlip(p=0.5),
        RandomGamma(scale=(0.5, 1.5)),
        ColorJitter(saturation=0, brightness=0.2, contrast=0.2),
        ElasticTransform(sigma=4, alpha=34, p_apply=0.4),
        CenterCrop(size=224),
        RandomErasing(scale=(0.15, 0.4), ratio=(0.33, 3)),
        AddGaussianNoise(std=0.05, p_apply=0.5)
    ]

    np.random.seed(3)
    torch.manual_seed(3)
    random.seed(3)

    transformed_image = transformation_pipeline(fake_cxr_image)
    assert isinstance(transformed_image, torch.Tensor)
    # Expected pipeline
    image = np.ones([256, 256]) * 255.
    image[100:150, 100:200] = 1
    image = PIL.Image.fromarray(image).convert("L")
    # In the pipeline the image is converted to tensor before applying the transformations. Do the same here.
    image = ToTensor()(image).reshape([1, 1, 256, 256])

    np.random.seed(3)
    torch.manual_seed(3)
    random.seed(3)

    expected_transformed = image
    for t in all_transforms:
        expected_transformed = t(expected_transformed)
    # The pipeline takes as input [C, Z, H, W] and returns [C, Z, H, W]
    # But the transforms list expect [Z, C, H, W] and returns [Z, C, H, W] so need to permute dimension to compare
    expected_transformed = torch.transpose(expected_transformed, 1,
                                           0).squeeze(1)
    assert torch.isclose(expected_transformed, transformed_image).all()

    # Test the evaluation pipeline
    transformation_pipeline = create_cxr_transforms_from_config(
        cxr_augmentation_config, apply_augmentations=False)
    transformed_image = transformation_pipeline(image)
    assert isinstance(transformed_image, torch.Tensor)
    all_transforms = [ExpandChannels(), Resize(size=256), CenterCrop(size=224)]
    expected_transformed = image
    for t in all_transforms:
        expected_transformed = t(expected_transformed)
    expected_transformed = torch.transpose(expected_transformed, 1,
                                           0).squeeze(1)
    assert torch.isclose(expected_transformed, transformed_image).all()
Exemplo n.º 3
0
def test_invalid_tensors(invalid_test_tensor: torch.Tensor) -> None:
    # This is invalid input (expects 4 dimensions)
    with pytest.raises(ValueError):
        ExpandChannels()(invalid_test_tensor)
    with pytest.raises(ValueError):
        RandomGamma(scale=(0.3, 3))(invalid_test_tensor)
Exemplo n.º 4
0
def test_expand_channels() -> None:
    with pytest.raises(ValueError):
        ExpandChannels()(invalid_test_tensor)

    tensor_img = ExpandChannels()(test_tensor_1channel_1slice.clone())
    assert tensor_img.shape == torch.Size([1, 3, *image_size])
Exemplo n.º 5
0
def test_expand_channels(tensor_1channel_1slice: torch.Tensor) -> None:
    tensor_img = ExpandChannels()(tensor_1channel_1slice)
    assert tensor_img.shape == torch.Size([1, 3, *image_size])