def create_and_set_torch_datasets(self, for_training: bool = True, for_inference: bool = True) -> None: """ Creates torch datasets for all model execution modes, and stores them in the object. """ from InnerEye.ML.dataset.cropping_dataset import CroppingDataset from InnerEye.ML.dataset.full_image_dataset import FullImageDataset dataset_splits = self.get_dataset_splits() crop_transforms = self.get_cropped_image_sample_transforms() full_image_transforms = self.get_full_image_sample_transforms() if for_training: self._datasets_for_training = { ModelExecutionMode.TRAIN: CroppingDataset( self, dataset_splits.train, cropped_sample_transforms=crop_transforms.train, # type: ignore full_image_sample_transforms=full_image_transforms.train), # type: ignore ModelExecutionMode.VAL: CroppingDataset( self, dataset_splits.val, cropped_sample_transforms=crop_transforms.val, # type: ignore full_image_sample_transforms=full_image_transforms.val), # type: ignore } if for_inference: self._datasets_for_inference = { mode: FullImageDataset( self, dataset_splits[mode], full_image_sample_transforms=full_image_transforms.test) # type: ignore for mode in ModelExecutionMode if len(dataset_splits[mode]) > 0 }
def test_cropping_dataset_has_reproducible_randomness(cropping_dataset: CroppingDataset, num_dataload_workers: int) -> None: cropping_dataset.dataset_indices = [1, 2] * 2 expected_center_indices = None for k in range(3): ml_util.set_random_seed(1) loader = cropping_dataset.as_data_loader(shuffle=True, batch_size=4, num_dataload_workers=num_dataload_workers) for i, item in enumerate(loader): item = CroppedSample.from_dict(sample=item) if expected_center_indices is None: expected_center_indices = item.center_indices else: assert np.array_equal(expected_center_indices, item.center_indices)
def test_cropping_dataset_as_data_loader(cropping_dataset: CroppingDataset, num_dataload_workers: int) -> None: batch_size = 2 loader = cropping_dataset.as_data_loader( shuffle=True, batch_size=batch_size, num_dataload_workers=num_dataload_workers) for i, item in enumerate(loader): item = CroppedSample.from_dict(sample=item) assert item is not None assert item.image.shape == \ (batch_size, cropping_dataset.args.number_of_image_channels) + cropping_dataset.args.crop_size # type: ignore assert item.mask.shape == ( batch_size, ) + cropping_dataset.args.crop_size # type: ignore assert item.labels.shape == \ (batch_size, cropping_dataset.args.number_of_classes) + cropping_dataset.args.crop_size # type: ignore # check the mask center crops are as expected assert item.mask_center_crop.shape == ( batch_size, ) + cropping_dataset.args.center_size # type: ignore assert item.labels_center_crop.shape == \ (batch_size, cropping_dataset.args.number_of_classes) + cropping_dataset.args.center_size # type: ignore # check the contents of the center crops for b in range(batch_size): expected = image_util.get_center_crop( image=item.mask[b], crop_shape=cropping_dataset.args.center_size) assert np.array_equal(item.mask_center_crop[b], expected) for c in range(len(item.labels_center_crop[b])): expected = image_util.get_center_crop( image=item.labels[b][c], crop_shape=cropping_dataset.args.center_size) assert np.array_equal(item.labels_center_crop[b][c], expected)
def test_cropped_sample(use_mask: bool) -> None: ml_util.set_random_seed(1) image_size = [4] * 3 crop_size = (2, 2, 2) center_size = (1, 1, 1) # create small image sample for random cropping image = np.random.uniform(size=[1] + image_size) labels = np.zeros(shape=[2] + image_size) # Two foreground points in the corners at (0, 0, 0) and (3, 3, 3) labels[0] = 1 labels[0, 0, 0, 0] = 0 labels[0, 3, 3, 3] = 0 labels[1, 0, 0, 0] = 1 labels[1, 3, 3, 3] = 1 crop_slicer: Optional[slice] if use_mask: # If mask is used, the cropping center point should be inside the mask. # Create a mask that has exactly 1 point of overlap with the labels, # that point must then be the center mask = np.zeros(shape=image_size, dtype=ImageDataType.MASK.value) mask[3, 3, 3] = 1 expected_center: Optional[List[int]] = [3, 3, 3] crop_slicer = slice(2, 4) else: mask = np.ones(shape=image_size, dtype=ImageDataType.MASK.value) expected_center = None crop_slicer = None sample = Sample(image=image, labels=labels, mask=mask, metadata=DummyPatientMetadata) for _ in range(0, 100): cropped_sample = CroppingDataset.create_random_cropped_sample( sample=sample, crop_size=crop_size, center_size=center_size, class_weights=[0, 1]) if expected_center is not None: assert list(cropped_sample.center_indices ) == expected_center # type: ignore assert np.array_equal( cropped_sample.image, sample.image[:, crop_slicer, crop_slicer, crop_slicer]) assert np.array_equal( cropped_sample.labels, sample.labels[:, crop_slicer, crop_slicer, crop_slicer]) assert np.array_equal( cropped_sample.mask, sample.mask[crop_slicer, crop_slicer, crop_slicer]) else: # The crop center point must be any point that has a positive foreground label center = cropped_sample.center_indices print("Center point chosen: {}".format(center)) assert labels[1, center[0], center[1], center[2]] != 0
def test_cropping_dataset_padding(cropping_dataset: CroppingDataset, num_dataload_workers: int) -> None: """ Tests the data type of torch tensors (e.g. image, labels, and mask) created by the dataset generator, which are provided as input into the computational graph :return: """ cropping_dataset.args.crop_size = (300, 300, 300) cropping_dataset.args.padding_mode = PaddingMode.Zero loader = cropping_dataset.as_data_loader(shuffle=True, batch_size=2, num_dataload_workers=1) for i, item in enumerate(loader): sample = CroppedSample.from_dict(item) assert sample.image.shape[-3:] == cropping_dataset.args.crop_size
def test_create_possibly_padded_sample_for_cropping(crop_size: Any) -> None: image_size = [4] * 3 image = np.random.uniform(size=[1] + image_size) labels = np.zeros(shape=[2] + image_size) mask = np.zeros(shape=image_size, dtype=ImageDataType.MASK.value) cropped_sample = CroppingDataset.create_possibly_padded_sample_for_cropping( sample=Sample(image=image, labels=labels, mask=mask, metadata=DummyPatientMetadata), crop_size=crop_size, padding_mode=PaddingMode.Zero ) assert cropped_sample.image.shape[-3:] == crop_size assert cropped_sample.labels.shape[-3:] == crop_size assert cropped_sample.mask.shape[-3:] == crop_size
def test_cropping_dataset_sample_dtype(cropping_dataset: CroppingDataset, num_dataload_workers: int) -> None: """ Tests the data type of torch tensors (e.g. image, labels, and mask) created by the dataset generator, which are provided as input into the computational graph :return: """ loader = cropping_dataset.as_data_loader(shuffle=True, batch_size=2, num_dataload_workers=num_dataload_workers) for i, item in enumerate(loader): item = CroppedSample.from_dict(item) assert item.image.numpy().dtype == ImageDataType.IMAGE.value assert item.labels.numpy().dtype == ImageDataType.SEGMENTATION.value assert item.mask.numpy().dtype == ImageDataType.MASK.value assert item.mask_center_crop.numpy().dtype == ImageDataType.MASK.value assert item.labels_center_crop.numpy().dtype == ImageDataType.SEGMENTATION.value
def cropping_dataset(default_config: SegmentationModelBase, normalize_fn: Callable) -> CroppingDataset: df = default_config.get_dataset_splits() return CroppingDataset(args=default_config, data_frame=df.train)
def visualize_random_crops(sample: Sample, config: SegmentationModelBase, output_folder: Path) -> np.ndarray: """ Simulate the effect of sampling random crops (as is done for trainig segmentation models), and store the results as a Nifti heatmap and as 3 axial/sagittal/coronal slices. The heatmap and the slices are stored in the given output folder, with filenames that contain the patient ID as the prefix. :param sample: The patient information from the dataset, with scans and ground truth labels. :param config: The model configuration. :param output_folder: The folder into which the heatmap and thumbnails should be written. :return: A numpy array that has the same size as the image, containing how often each voxel was contained in """ output_folder.mkdir(exist_ok=True, parents=True) sample = CroppingDataset.create_possibly_padded_sample_for_cropping( sample=sample, crop_size=config.crop_size, padding_mode=config.padding_mode) logging.info(f"Processing sample: {sample.patient_id}") # Exhaustively sample with random crop function image_channel0 = sample.image[0] heatmap = np.zeros(image_channel0.shape, dtype=np.uint16) # Number of repeats should fit into the range of UInt16, because we will later save the heatmap as an integer # Nifti file of that datatype. repeats = 200 for _ in range(repeats): slicers, _ = augmentation.slicers_for_random_crop( sample=sample, crop_size=config.crop_size, class_weights=config.class_weights) heatmap[slicers[0], slicers[1], slicers[2]] += 1 is_3dim = heatmap.shape[0] > 1 header = sample.metadata.image_header if not header: logging.warning( f"No image header found for patient {sample.patient_id}. Using default header." ) header = get_unit_image_header() if is_3dim: ct_output_name = str(output_folder / f"{sample.patient_id}_ct.nii.gz") heatmap_output_name = str( output_folder / f"{sample.patient_id}_sampled_patches.nii.gz") io_util.store_as_nifti(image=heatmap, header=header, file_name=heatmap_output_name, image_type=heatmap.dtype, scale=False) io_util.store_as_nifti(image=image_channel0, header=header, file_name=ct_output_name, image_type=sample.image.dtype, scale=False) heatmap_scaled = heatmap.astype(dtype=np.float) / heatmap.max() # If the incoming image is effectively a 2D image with degenerate Z dimension, then only plot a single # axial thumbnail. Otherwise, plot thumbnails for all 3 dimensions. dimensions = list(range(3)) if is_3dim else [0] # Center the 3 thumbnails at one of the points where the heatmap attains a maximum. This should ensure that # the thumbnails are in an area where many of the organs of interest are located. max_heatmap_index = np.unravel_index( heatmap.argmax(), heatmap.shape) if is_3dim else (0, 0, 0) for dimension in dimensions: plt.clf() scan_with_transparent_overlay( scan=image_channel0, overlay=heatmap_scaled, dimension=dimension, position=max_heatmap_index[dimension] if is_3dim else 0, spacing=header.spacing) # Construct a filename that has a dimension suffix if we are generating 3 of them. For 2dim images, skip # the suffix. thumbnail = f"{sample.patient_id}_sampled_patches" if is_3dim: thumbnail += f"_dim{dimension}" thumbnail += ".png" resize_and_save(width_inch=5, height_inch=5, filename=output_folder / thumbnail) return heatmap
def main(args: CheckPatchSamplingConfig) -> None: # Identify paths to inputs and outputs commandline_args = { "train_batch_size": 1, "local_dataset": Path(args.local_dataset) } output_folder = Path(args.output_folder) output_folder.mkdir(parents=True, exist_ok=True) # Create a config file config = ModelConfigLoader[SegmentationModelBase]( ).create_model_config_from_name(args.model_name, overrides=commandline_args) # Set a random seed ml_util.set_random_seed(config.random_seed) # Get a dataloader object that checks csv dataset_splits = config.get_dataset_splits() # Load a sample using the full image data loader full_image_dataset = FullImageDataset(config, dataset_splits.train) for sample_index in range(args.number_samples): sample = CroppingDataset.create_possibly_padded_sample_for_cropping( sample=full_image_dataset.get_samples_at_index( index=sample_index)[0], crop_size=config.crop_size, padding_mode=config.padding_mode) print("Processing sample: ", sample.patient_id) # Exhaustively sample with random crop function heatmap = np.zeros(sample.mask.shape, dtype=np.uint16) for _ in range(args.number_crop_iterations): cropped_sample, center_point = augmentation.random_crop( sample=sample, crop_size=config.crop_size, class_weights=config.class_weights) patch_mask = create_mask_for_patch(output_shape=heatmap.shape, output_dtype=heatmap.dtype, center=center_point, crop_size=config.crop_size) heatmap += patch_mask ct_output_name = str(output_folder / "{}_ct.nii.gz".format(int(sample.patient_id))) heatmap_output_name = str( output_folder / "{}_sampled_patches.nii.gz".format(int(sample.patient_id))) if not sample.metadata.image_header: raise ValueError("None header expected some header") io_util.store_as_nifti(image=heatmap, header=sample.metadata.image_header, file_name=heatmap_output_name, image_type=heatmap.dtype, scale=False) io_util.store_as_nifti(image=sample.image[0], header=sample.metadata.image_header, file_name=ct_output_name, image_type=sample.image.dtype, scale=False)