コード例 #1
0
    def _get_transforms(self, augmentation_config: Optional[CfgNode],
                        dataset_name: str,
                        is_ssl_encoder_module: bool) -> Tuple[Any, Any]:

        # is_ssl_encoder_module will be True for ssl training, False for linear head training
        train_transforms = ImageTransformationPipeline([Lambda(lambda x: x)])  # do nothing
        val_transforms = ImageTransformationPipeline([Lambda(lambda x: x + 1)])  # add 1

        if is_ssl_encoder_module:
            train_transforms = DualViewTransformWrapper(train_transforms)  # type: ignore
            val_transforms = DualViewTransformWrapper(val_transforms)  # type: ignore
        return train_transforms, val_transforms
コード例 #2
0
def test_custom_tf_on_various_input(
        use_different_transformation_per_channel: bool) -> None:
    """
    This tests that we can run transformation pipeline with our custom transforms on various types
    of input: PIL image, 3D tensor, 4D tensors. Tests that use_different_transformation_per_channel has the correct
    behavior. The transforms are test individually in test_image_transforms.py
    """
    pipeline = ImageTransformationPipeline([
        ElasticTransform(sigma=4, alpha=34, p_apply=1),
        AddGaussianNoise(p_apply=1, std=0.05),
        RandomGamma(scale=(0.3, 3))
    ], use_different_transformation_per_channel)

    # Test PIL image input
    transformed = pipeline(test_image_as_pil)
    assert transformed.shape == test_2d_image_as_CHW_tensor.shape

    # Test image as [C, H, W] tensor
    pipeline(test_2d_image_as_CHW_tensor)
    assert transformed.shape == test_2d_image_as_CHW_tensor.shape

    # Test image as [1, 1, H, W]
    transformed = pipeline(test_2d_image_as_ZCHW_tensor)
    assert isinstance(transformed, torch.Tensor)
    assert transformed.shape == torch.Size([1, 1, *image_size])

    # Test with a fake scan [C, Z, H, W] -> [25, 34, 32, 32]
    transformed = pipeline(test_4d_scan_as_tensor)
    assert isinstance(transformed, torch.Tensor)
    assert transformed.shape == test_4d_scan_as_tensor.shape

    # Same transformation should be applied to all slices and channels.
    assert torch.isclose(
        transformed[0, 0],
        transformed[1, 1]).all() != use_different_transformation_per_channel
コード例 #3
0
 def get_segmentation_transform(self) -> ModelTransformsPerExecutionMode:
     if self.imaging_feature_type in [ImagingFeatureType.Segmentation, ImagingFeatureType.ImageAndSegmentation]:
         return ModelTransformsPerExecutionMode(
             train=ImageTransformationPipeline(
                 transforms=[RandomAffine(10), ColorJitter(0.2)],
                 use_different_transformation_per_channel=True))
     return ModelTransformsPerExecutionMode()
コード例 #4
0
 def get_image_transform(self) -> ModelTransformsPerExecutionMode:
     if self.use_combined_model:
         return ModelTransformsPerExecutionMode(
             train=ImageTransformationPipeline(
                 transforms=[RandomAffine(degrees=30, translate=(0.1, 0.1), shear=15),
                             ColorJitter(brightness=0.2)]))
     else:
         return ModelTransformsPerExecutionMode()
コード例 #5
0
 def get_image_transform(self) -> ModelTransformsPerExecutionMode:
     """
     Get transforms to perform on image samples for each model execution mode.
     """
     if self.imaging_feature_type in [ImagingFeatureType.Image, ImagingFeatureType.ImageAndSegmentation]:
         return ModelTransformsPerExecutionMode(
             train=ImageTransformationPipeline(
                 transforms=[RandomAffine(10), ColorJitter(0.2)],
                 use_different_transformation_per_channel=True))
     return ModelTransformsPerExecutionMode()
コード例 #6
0
def test_torchvision_on_various_input(
        use_different_transformation_per_channel: bool,
) -> None:
    """
    This tests that we can run transformation pipeline with out of the box torchvision transforms on various types
    of input: PIL image, 3D tensor, 4D tensors. Tests that use_different_transformation_per_channel has the correct
    behavior.
    """
    image_as_pil, image_2d_as_CHW_tensor, image_2d_as_ZCHW_tensor, scan_4d_as_tensor = create_test_images()
    transform = ImageTransformationPipeline(
        [
            CenterCrop(crop_size),
            RandomErasing(),
            RandomAffine(degrees=(10, 12), shear=15, translate=(0.1, 0.3)),
        ],
        use_different_transformation_per_channel,
    )

    # Test PIL image input
    transformed = transform(image_as_pil)
    assert isinstance(transformed, torch.Tensor)
    assert transformed.shape == torch.Size([1, crop_size, crop_size])

    # Test image as [C, H. W] tensor
    transformed = transform(image_2d_as_CHW_tensor.clone())
    assert isinstance(transformed, torch.Tensor)
    assert transformed.shape == torch.Size([1, crop_size, crop_size])

    # Test image as [1, 1, H, W]
    transformed = transform(image_2d_as_ZCHW_tensor)
    assert isinstance(transformed, torch.Tensor)
    assert transformed.shape == torch.Size([1, 1, crop_size, crop_size])

    # Test with a fake 4D scan [C, Z, H, W] -> [25, 34, 32, 32]
    transformed = transform(scan_4d_as_tensor)
    assert isinstance(transformed, torch.Tensor)
    assert transformed.shape == torch.Size([5, 4, crop_size, crop_size])

    # Same transformation should be applied to all slices and channels.
    assert (
            torch.isclose(transformed[0, 0], transformed[1, 1]).all()
            != use_different_transformation_per_channel
    )
コード例 #7
0
def test_image_encoder(
        test_output_dirs: OutputFolderForTests, encode_channels_jointly: bool,
        use_non_imaging_features: bool,
        kernel_size_per_encoding_block: Optional[Union[TupleInt3,
                                                       List[TupleInt3]]],
        stride_size_per_encoding_block: Optional[Union[TupleInt3,
                                                       List[TupleInt3]]],
        reduction_factor: float, expected_num_reduced_features: int,
        aggregation_type: AggregationType) -> None:
    """
    Test if the image encoder networks can be trained without errors (including GradCam computation and data
    augmentation).
    """
    logging_to_stdout()
    set_random_seed(0)
    dataset_folder = Path(test_output_dirs.make_sub_dir("dataset"))
    scan_size = (6, 64, 60)
    scan_files: List[str] = []
    for s in range(4):
        random_scan = np.random.uniform(0, 1, scan_size)
        scan_file_name = f"scan{s + 1}{NumpyFile.NUMPY.value}"
        np.save(str(dataset_folder / scan_file_name), random_scan)
        scan_files.append(scan_file_name)

    dataset_contents = """subject,channel,path,label,numerical1,numerical2,categorical1,categorical2
S1,week0,scan1.npy,,1,10,Male,Val1
S1,week1,scan2.npy,True,2,20,Female,Val2
S2,week0,scan3.npy,,3,30,Female,Val3
S2,week1,scan4.npy,False,4,40,Female,Val1
S3,week0,scan1.npy,,5,50,Male,Val2
S3,week1,scan3.npy,True,6,60,Male,Val2
"""
    (dataset_folder / "dataset.csv").write_text(dataset_contents)
    numerical_columns = ["numerical1", "numerical2"
                         ] if use_non_imaging_features else []
    categorical_columns = ["categorical1", "categorical2"
                           ] if use_non_imaging_features else []
    non_image_feature_channels = get_non_image_features_dict(default_channels=["week1", "week0"],
                                                             specific_channels={"categorical2": ["week1"]}) \
        if use_non_imaging_features else {}
    config_for_dataset = ScalarModelBase(
        local_dataset=dataset_folder,
        image_channels=["week0", "week1"],
        image_file_column="path",
        label_channels=["week1"],
        label_value_column="label",
        non_image_feature_channels=non_image_feature_channels,
        numerical_columns=numerical_columns,
        categorical_columns=categorical_columns,
        should_validate=False)
    config_for_dataset.read_dataset_into_dataframe_and_pre_process()

    dataset = ScalarDataset(
        config_for_dataset,
        sample_transform=ScalarItemAugmentation(
            ImageTransformationPipeline(
                [RandomAffine(10), ColorJitter(0.2)],
                use_different_transformation_per_channel=True)))
    assert len(dataset) == 3

    config = ImageEncoder(
        encode_channels_jointly=encode_channels_jointly,
        should_validate=False,
        numerical_columns=numerical_columns,
        categorical_columns=categorical_columns,
        non_image_feature_channels=non_image_feature_channels,
        categorical_feature_encoder=config_for_dataset.
        categorical_feature_encoder,
        encoder_dimensionality_reduction_factor=reduction_factor,
        aggregation_type=aggregation_type,
        scan_size=(6, 64, 60))

    if kernel_size_per_encoding_block:
        config.kernel_size_per_encoding_block = kernel_size_per_encoding_block
    if stride_size_per_encoding_block:
        config.stride_size_per_encoding_block = stride_size_per_encoding_block

    config.set_output_to(test_output_dirs.root_dir)
    config.max_batch_grad_cam = 1
    model = create_model_with_temperature_scaling(config)
    input_size: List[Tuple] = [(len(config.image_channels), *scan_size)]
    if use_non_imaging_features:
        input_size.append(
            (config.get_total_number_of_non_imaging_features(), ))

        # Original number output channels (unreduced) is
        # num initial channel * (num encoder block - 1) = 4 * (3-1) = 8
        if encode_channels_jointly:
            # reduced_num_channels + num_non_img_features
            assert model.final_num_feature_channels == expected_num_reduced_features + \
                   config.get_total_number_of_non_imaging_features()
        else:
            # num_img_channels * reduced_num_channels + num_non_img_features
            assert model.final_num_feature_channels == len(config.image_channels) * expected_num_reduced_features + \
                   config.get_total_number_of_non_imaging_features()

    summarizer = ModelSummary(model)
    summarizer.generate_summary(input_sizes=input_size)
    config.local_dataset = dataset_folder
    config.validate()
    model_train_unittest(config, dirs=test_output_dirs)