def test_2d(self): sample_t1 = self.sample_subject.t1 sample_2d = sample_t1.data[..., :1] assert sample_2d.shape == (1, 10, 20, 1) transform = tio.EnsureShapeMultiple(4, method='crop') transformed = transform(sample_2d) assert transformed.shape == (1, 8, 20, 1)
def get_dataset( input_path, tta_iterations=0, interpolation='bspline', tolerance=0.1, mni_transform_path=None, ): if mni_transform_path is None: image = tio.ScalarImage(input_path) else: affine = tio.io.read_matrix(mni_transform_path) image = tio.ScalarImage(input_path, **{TO_MNI: affine}) subject = tio.Subject({IMAGE_NAME: image}) landmarks = np.array([ 0., 0.31331614, 0.61505419, 0.76732501, 0.98887953, 1.71169384, 3.21741126, 13.06931455, 32.70817796, 40.87807389, 47.83508873, 63.4408591, 100. ]) hist_std = tio.HistogramStandardization({IMAGE_NAME: landmarks}) preprocess_transforms = [ tio.ToCanonical(), hist_std, tio.ZNormalization(masking_method=tio.ZNormalization.mean), ] zooms = nib.load(input_path).header.get_zooms() pixdim = np.array(zooms) diff_to_1_iso = np.abs(pixdim - 1) if np.any(diff_to_1_iso > tolerance) or mni_transform_path is not None: kwargs = {'image_interpolation': interpolation} if mni_transform_path is not None: kwargs['pre_affine_name'] = TO_MNI kwargs['target'] = tio.datasets.Colin27().t1.path resample_transform = tio.Resample(**kwargs) preprocess_transforms.append(resample_transform) preprocess_transforms.append(tio.EnsureShapeMultiple(8, method='crop')) preprocess_transform = tio.Compose(preprocess_transforms) no_aug_dataset = tio.SubjectsDataset([subject], transform=preprocess_transform) aug_subjects = tta_iterations * [subject] if not aug_subjects: return no_aug_dataset augment_transform = tio.Compose(( preprocess_transform, tio.RandomFlip(), tio.RandomAffine(image_interpolation=interpolation), )) aug_dataset = tio.SubjectsDataset(aug_subjects, transform=augment_transform) dataset = torch.utils.data.ConcatDataset((no_aug_dataset, aug_dataset)) return dataset
def get_transform(self, channels, is_3d=True, labels=True): landmarks_dict = { channel: np.linspace(0, 100, 13) for channel in channels } disp = 1 if is_3d else (1, 1, 0.01) elastic = tio.RandomElasticDeformation(max_displacement=disp) cp_args = (9, 21, 30) if is_3d else (21, 30, 1) resize_args = (10, 20, 30) if is_3d else (10, 20, 1) flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1) swap_patch = (2, 3, 4) if is_3d else (3, 4, 1) pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6) crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4) remapping = {1: 2, 2: 1, 3: 20, 4: 25} transforms = [ tio.CropOrPad(cp_args), tio.EnsureShapeMultiple(2, method='crop'), tio.Resize(resize_args), tio.ToCanonical(), tio.RandomAnisotropy(downsampling=(1.75, 2), axes=axes_downsample), tio.CopyAffine(channels[0]), tio.Resample((1, 1.1, 1.25)), tio.RandomFlip(axes=flip_axes, flip_probability=1), tio.RandomMotion(), tio.RandomGhosting(axes=(0, 1, 2)), tio.RandomSpike(), tio.RandomNoise(), tio.RandomBlur(), tio.RandomSwap(patch_size=swap_patch, num_iterations=5), tio.Lambda(lambda x: 2 * x, types_to_apply=tio.INTENSITY), tio.RandomBiasField(), tio.RescaleIntensity(out_min_max=(0, 1)), tio.ZNormalization(), tio.HistogramStandardization(landmarks_dict), elastic, tio.RandomAffine(), tio.OneOf({ tio.RandomAffine(): 3, elastic: 1, }), tio.RemapLabels(remapping=remapping, masking_method='Left'), tio.RemoveLabels([1, 3]), tio.SequentialLabels(), tio.Pad(pad_args, padding_mode=3), tio.Crop(crop_args), ] if labels: transforms.append(tio.RandomLabelsToImage(label_key='label')) return tio.Compose(transforms)
def test_pad(self): sample_t1 = self.sample_subject.t1 assert sample_t1.shape == (1, 10, 20, 30) transform = tio.EnsureShapeMultiple(4, method='pad') transformed = transform(sample_t1) assert transformed.shape == (1, 12, 20, 32)
def test_bad_method(self): with self.assertRaises(ValueError): tio.EnsureShapeMultiple(1, method='bad')
def predict_whole_image(self, image_channels: np.ndarray, voxel_spacing_mm: TupleFloat3, mask: Optional[np.ndarray] = None, patient_id: int = 0) -> InferencePipeline.Result: """ Performs a single inference pass through the pipeline for the provided image :param image_channels: The input image channels to perform inference on in format: Channels x Z x Y x X. :param voxel_spacing_mm: Voxel spacing to use for each dimension in (Z x Y x X) order :param mask: A binary image used to ignore results outside it in format: Z x Y x X. :param patient_id: The identifier of the patient this image belongs to (defaults to 0 if None provided). :return InferenceResult: that contains Segmentation for each of the classes and their posterior probabilities. """ if image_channels is None: raise Exception("image_channels cannot be None") if image_channels.ndim != 4: raise NotImplementedError( "image_channels must be in shape: Channels x Z x Y x X" "found image_channels shape: {}".format(image_channels.shape)) if mask is not None: ml_util.check_size_matches(image_channels, mask, 4, 3, [-1, -2, -3]) self.model.eval() image = tio.ScalarImage(tensor=image_channels) INPUT = 'input_image' MASK = 'mask' subject_dict: Dict[str, tio.Image] = {INPUT: image} if mask is not None: subject_dict[MASK] = tio.LabelMap(tensor=mask[np.newaxis]) subject = tio.Subject(subject_dict) constraints = self.model.model.crop_size_constraints # Make sure the image size is compatible with the model multiple_constraints = constraints.multiple_of # type: ignore if multiple_constraints is not None: ensure_shape_multiple = tio.EnsureShapeMultiple( constraints.multiple_of) # type: ignore subject = ensure_shape_multiple(subject) # type: ignore # There may be cases where the test image is smaller than the test_crop_size. Adjust crop_size # to always fit into image. If test_crop_size is smaller than the image, crop will remain unchanged. restrict_patch_size = constraints.restrict_crop_size_to_image # type: ignore effective_patch_size, effective_stride = restrict_patch_size( subject.spatial_shape, # type: ignore self.model_config.test_crop_size, self.model_config.inference_stride_size) patch_overlap = np.array(effective_patch_size) - np.array( effective_stride) grid_sampler = tio.inference.GridSampler( subject, effective_patch_size, patch_overlap, padding_mode=self.model_config.padding_mode.value, ) batch_size = self.model_config.inference_batch_size patch_loader = torch.utils.data.DataLoader( grid_sampler, batch_size=batch_size) # type: ignore aggregator = tio.inference.GridAggregator(grid_sampler) logging.debug( f"Inference on image size {subject.spatial_shape} will run " f"with crop size {effective_patch_size} and stride {effective_stride}" ) for patches_batch in patch_loader: input_tensor = patches_batch[INPUT][tio.DATA].float() if self.model_config.use_gpu: input_tensor = input_tensor.cuda() locations = patches_batch[tio.LOCATION] # perform the forward pass patches_posteriors = self.model(input_tensor).detach() # pad posteriors if they are smaller than the input input_shape = input_tensor.shape[-3:] patches_posteriors_shape = patches_posteriors.shape[-3:] if input_shape != patches_posteriors_shape: difference = np.array(input_shape) - np.array( patches_posteriors_shape) assert not np.any( difference % 2) # the differences in shape are expected to be even padding = tuple(np.repeat(difference // 2, 2)) patches_posteriors = torch.nn.functional.pad( patches_posteriors, padding) # collect the predictions over each of the batches aggregator.add_batch(patches_posteriors, locations) posteriors = aggregator.get_output_tensor().numpy() posteriors_mask = None if mask is None else subject[MASK].numpy()[0] posteriors, segmentation = self.post_process_posteriors( posteriors, mask=posteriors_mask) image_util.check_array_range(posteriors, error_prefix="Whole image posteriors") # Make sure the final shape matches the input shape by undoing the padding in EnsureShapeMultiple (if any) posteriors_image = tio.ScalarImage(tensor=posteriors, affine=image.affine) segmentation_image = tio.LabelMap(tensor=segmentation[np.newaxis], affine=image.affine) subject.add_image(posteriors_image, 'posteriors') subject.add_image(segmentation_image, 'segmentation') # Remove some images to avoid unnecessary computations subject.remove_image(INPUT) if mask is not None: subject.remove_image(MASK) subject_original_space = subject.apply_inverse_transform( ) if subject.applied_transforms else subject posteriors = subject_original_space.posteriors.numpy() # type: ignore segmentation = subject_original_space.segmentation.numpy()[ 0] # type: ignore # prepare pipeline results from the processed batch return InferencePipeline.Result(patient_id=patient_id, segmentation=segmentation, posteriors=posteriors, voxel_spacing_mm=voxel_spacing_mm)