def test_invalid_crop_size(crop_size: Any) -> None: with pytest.raises(Exception): augmentation.random_crop( Sample(metadata=DummyPatientMetadata, image=valid_image_4d, labels=valid_labels, mask=valid_mask), crop_size, valid_class_weights)
def test_valid_class_weights(class_weights: List[float]) -> None: """ Produce a large number of crops and make sure the crop center class proportions respect class weights """ ml_util.set_random_seed(1) num_classes = len(valid_labels) image = np.zeros_like(valid_image_4d) labels = np.zeros_like(valid_labels) class0, class1, class2 = non_empty_classes = [0, 2, 4] labels[class0] = 1 labels[class0][3, 3, 3] = 0 labels[class0][3, 2, 3] = 0 labels[class1][3, 3, 3] = 1 labels[class2][3, 2, 3] = 1 mask = np.ones_like(valid_mask) sample = Sample(image=image, labels=labels, mask=mask, metadata=DummyPatientMetadata) crop_size = (1, 1, 1) total_crops = 200 sampled_label_center_distribution = np.zeros(num_classes) # If there is no class that has a non-zero weight and is present in the sample, there is no valid # way to select a class, so we expect an exception to be thrown. if class_weights is not None and sum(class_weights[c] for c in non_empty_classes) == 0: with pytest.raises(ValueError): augmentation.random_crop(sample, crop_size, class_weights) return for _ in range(0, total_crops): crop_sample, center, _ = augmentation.random_crop( sample, crop_size, class_weights) sampled_class = list(labels[:, center[0], center[1], center[2]]).index(1) sampled_label_center_distribution[sampled_class] += 1 sampled_label_center_distribution /= total_crops if class_weights is None: weight = 1.0 / len(non_empty_classes) expected_label_center_distribution = [ weight if c in non_empty_classes else 0.0 for c in range(number_of_classes) ] else: total = sum(class_weights[c] for c in non_empty_classes) expected_label_center_distribution = [ class_weights[c] / total if c in non_empty_classes else 0.0 for c in range(number_of_classes) ] assert np.allclose(sampled_label_center_distribution, expected_label_center_distribution, atol=0.1)
def test_random_crop_no_fg() -> None: with pytest.raises(Exception): augmentation.random_crop( Sample(metadata=DummyPatientMetadata, image=valid_image_4d, labels=valid_labels, mask=np.zeros_like(valid_mask)), valid_crop_size, valid_class_weights) with pytest.raises(Exception): augmentation.random_crop( Sample(metadata=DummyPatientMetadata, image=valid_image_4d, labels=np.zeros_like(valid_labels), mask=valid_mask), valid_crop_size, valid_class_weights)
def test_invalid_arrays(image: Any, labels: Any, mask: Any, class_weights: Any) -> None: """ Tests failure cases of the random_crop function for invalid image, labels, mask or class weights arguments. """ # Skip the final combination, because it is valid if not (np.array_equal(image, valid_image_4d) and np.array_equal( labels, valid_labels) and np.array_equal(mask, valid_mask) and class_weights == valid_class_weights): with pytest.raises(Exception): augmentation.random_crop( Sample(metadata=DummyPatientMetadata, image=image, labels=labels, mask=mask), valid_crop_size, class_weights)
def test_valid_full_crop() -> None: metadata = DummyPatientMetadata sample, _, _ = augmentation.random_crop(sample=Sample(image=valid_image_4d, labels=valid_labels, mask=valid_mask, metadata=metadata), crop_size=valid_full_crop_size, class_weights=valid_class_weights) assert np.array_equal(sample.image, valid_image_4d) assert np.array_equal(sample.labels, valid_labels) assert np.array_equal(sample.mask, valid_mask) assert sample.metadata == metadata
def test_random_crop(crop_size: Any) -> None: labels = valid_labels # create labels such that there are no foreground voxels in a particular class # this should ne handled gracefully (class being ignored from sampling) labels[0] = 1 labels[1] = 0 sample, _, _ = augmentation.random_crop( Sample(image=valid_image_4d, labels=valid_labels, mask=valid_mask, metadata=DummyPatientMetadata), crop_size, valid_class_weights) expected_img_crop_size = (valid_image_4d.shape[0], *crop_size) expected_labels_crop_size = (valid_labels.shape[0], *crop_size) assert sample.image.shape == expected_img_crop_size assert sample.labels.shape == expected_labels_crop_size assert sample.mask.shape == tuple(crop_size)
def create_random_cropped_sample( sample: Sample, crop_size: TupleInt3, center_size: TupleInt3, class_weights: Optional[List[float]] = None) -> CroppedSample: """ Creates an instance of a cropped sample extracted from full 3D images. :param sample: the full size 3D sample to use for extracting a cropped sample. :param crop_size: the size of the crop to extract. :param center_size: the size of the center of the crop (this should be the same as the spatial dimensions of the posteriors that the model produces) :param class_weights: the distribution to use for the crop center class. :return: CroppedSample """ # crop the original raw sample sample, center_point = augmentation.random_crop( sample=sample, crop_size=crop_size, class_weights=class_weights) # crop the mask and label centers if required if center_size == crop_size: mask_center_crop = sample.mask labels_center_crop = sample.labels else: mask_center_crop = image_util.get_center_crop( image=sample.mask, crop_shape=center_size) labels_center_crop = np.zeros( shape=[len(sample.labels)] + list(center_size), # type: ignore dtype=ImageDataType.SEGMENTATION.value) for c in range(len(sample.labels)): # type: ignore labels_center_crop[c] = image_util.get_center_crop( image=sample.labels[c], crop_shape=center_size) return CroppedSample(image=sample.image, mask=sample.mask, labels=sample.labels, mask_center_crop=mask_center_crop, labels_center_crop=labels_center_crop, center_indices=center_point, metadata=sample.metadata)
def main(args: CheckPatchSamplingConfig) -> None: # Identify paths to inputs and outputs commandline_args = { "train_batch_size": 1, "local_dataset": Path(args.local_dataset) } output_folder = Path(args.output_folder) output_folder.mkdir(parents=True, exist_ok=True) # Create a config file config = ModelConfigLoader[SegmentationModelBase]( ).create_model_config_from_name(args.model_name, overrides=commandline_args) # Set a random seed ml_util.set_random_seed(config.random_seed) # Get a dataloader object that checks csv dataset_splits = config.get_dataset_splits() # Load a sample using the full image data loader full_image_dataset = FullImageDataset(config, dataset_splits.train) for sample_index in range(args.number_samples): sample = CroppingDataset.create_possibly_padded_sample_for_cropping( sample=full_image_dataset.get_samples_at_index( index=sample_index)[0], crop_size=config.crop_size, padding_mode=config.padding_mode) print("Processing sample: ", sample.patient_id) # Exhaustively sample with random crop function heatmap = np.zeros(sample.mask.shape, dtype=np.uint16) for _ in range(args.number_crop_iterations): cropped_sample, center_point = augmentation.random_crop( sample=sample, crop_size=config.crop_size, class_weights=config.class_weights) patch_mask = create_mask_for_patch(output_shape=heatmap.shape, output_dtype=heatmap.dtype, center=center_point, crop_size=config.crop_size) heatmap += patch_mask ct_output_name = str(output_folder / "{}_ct.nii.gz".format(int(sample.patient_id))) heatmap_output_name = str( output_folder / "{}_sampled_patches.nii.gz".format(int(sample.patient_id))) if not sample.metadata.image_header: raise ValueError("None header expected some header") io_util.store_as_nifti(image=heatmap, header=sample.metadata.image_header, file_name=heatmap_output_name, image_type=heatmap.dtype, scale=False) io_util.store_as_nifti(image=sample.image[0], header=sample.metadata.image_header, file_name=ct_output_name, image_type=sample.image.dtype, scale=False)
def visualize_random_crops(sample: Sample, config: SegmentationModelBase, output_folder: Path) -> np.ndarray: """ Simulate the effect of sampling random crops (as is done for trainig segmentation models), and store the results as a Nifti heatmap and as 3 axial/sagittal/coronal slices. The heatmap and the slices are stored in the given output folder, with filenames that contain the patient ID as the prefix. :param sample: The patient information from the dataset, with scans and ground truth labels. :param config: The model configuration. :param output_folder: The folder into which the heatmap and thumbnails should be written. :return: A numpy array that has the same size as the image, containing how often each voxel was contained in """ output_folder.mkdir(exist_ok=True, parents=True) sample = CroppingDataset.create_possibly_padded_sample_for_cropping( sample=sample, crop_size=config.crop_size, padding_mode=config.padding_mode) print(f"Processing sample: {sample.patient_id}") # Exhaustively sample with random crop function image_channel0 = sample.image[0] heatmap = np.zeros(image_channel0.shape, dtype=np.uint16) # Number of repeats should fit into the range of UInt16, because we will later save the heatmap as an integer # Nifti file of that datatype. repeats = 1000 for _ in range(repeats): _, _, slicers = augmentation.random_crop( sample=sample, crop_size=config.crop_size, class_weights=config.class_weights) heatmap[slicers[0], slicers[1], slicers[2]] += 1 is_3dim = heatmap.shape[0] > 1 header = sample.metadata.image_header if not header: logging.warning( f"No image header found for patient {sample.patient_id}. Using default header." ) header = get_unit_image_header() if is_3dim: ct_output_name = str(output_folder / f"{sample.patient_id}_ct.nii.gz") heatmap_output_name = str( output_folder / f"{sample.patient_id}_sampled_patches.nii.gz") io_util.store_as_nifti(image=heatmap, header=header, file_name=heatmap_output_name, image_type=heatmap.dtype, scale=False) io_util.store_as_nifti(image=image_channel0, header=header, file_name=ct_output_name, image_type=sample.image.dtype, scale=False) heatmap_scaled = heatmap.astype(dtype=np.float) / heatmap.max() # If the incoming image is effectively a 2D image with degenerate Z dimension, then only plot a single # axial thumbnail. Otherwise, plot thumbnails for all 3 dimensions. dimensions = list(range(3)) if is_3dim else [0] # Center the 3 thumbnails at one of the points where the heatmap attains a maximum. This should ensure that # the thumbnails are in an area where many of the organs of interest are located. max_heatmap_index = np.unravel_index( heatmap.argmax(), heatmap.shape) if is_3dim else (0, 0, 0) for dimension in dimensions: plt.clf() scan_with_transparent_overlay( scan=image_channel0, overlay=heatmap_scaled, dimension=dimension, position=max_heatmap_index[dimension] if is_3dim else 0, spacing=header.spacing) # Construct a filename that has a dimension suffix if we are generating 3 of them. For 2dim images, skip # the suffix. thumbnail = f"{sample.patient_id}_sampled_patches" if is_3dim: thumbnail += f"_dim{dimension}" thumbnail += ".png" resize_and_save(width_inch=5, height_inch=5, filename=output_folder / thumbnail) return heatmap