예제 #1
0
 def test_2d(self):
     sample_t1 = self.sample_subject.t1
     sample_2d = sample_t1.data[..., :1]
     assert sample_2d.shape == (1, 10, 20, 1)
     transform = tio.EnsureShapeMultiple(4, method='crop')
     transformed = transform(sample_2d)
     assert transformed.shape == (1, 8, 20, 1)
예제 #2
0
def get_dataset(
    input_path,
    tta_iterations=0,
    interpolation='bspline',
    tolerance=0.1,
    mni_transform_path=None,
):
    if mni_transform_path is None:
        image = tio.ScalarImage(input_path)
    else:
        affine = tio.io.read_matrix(mni_transform_path)
        image = tio.ScalarImage(input_path, **{TO_MNI: affine})
    subject = tio.Subject({IMAGE_NAME: image})
    landmarks = np.array([
        0., 0.31331614, 0.61505419, 0.76732501, 0.98887953, 1.71169384,
        3.21741126, 13.06931455, 32.70817796, 40.87807389, 47.83508873,
        63.4408591, 100.
    ])
    hist_std = tio.HistogramStandardization({IMAGE_NAME: landmarks})
    preprocess_transforms = [
        tio.ToCanonical(),
        hist_std,
        tio.ZNormalization(masking_method=tio.ZNormalization.mean),
    ]
    zooms = nib.load(input_path).header.get_zooms()
    pixdim = np.array(zooms)
    diff_to_1_iso = np.abs(pixdim - 1)
    if np.any(diff_to_1_iso > tolerance) or mni_transform_path is not None:
        kwargs = {'image_interpolation': interpolation}
        if mni_transform_path is not None:
            kwargs['pre_affine_name'] = TO_MNI
            kwargs['target'] = tio.datasets.Colin27().t1.path
        resample_transform = tio.Resample(**kwargs)
        preprocess_transforms.append(resample_transform)
    preprocess_transforms.append(tio.EnsureShapeMultiple(8, method='crop'))
    preprocess_transform = tio.Compose(preprocess_transforms)
    no_aug_dataset = tio.SubjectsDataset([subject],
                                         transform=preprocess_transform)

    aug_subjects = tta_iterations * [subject]
    if not aug_subjects:
        return no_aug_dataset
    augment_transform = tio.Compose((
        preprocess_transform,
        tio.RandomFlip(),
        tio.RandomAffine(image_interpolation=interpolation),
    ))
    aug_dataset = tio.SubjectsDataset(aug_subjects,
                                      transform=augment_transform)
    dataset = torch.utils.data.ConcatDataset((no_aug_dataset, aug_dataset))
    return dataset
예제 #3
0
 def get_transform(self, channels, is_3d=True, labels=True):
     landmarks_dict = {
         channel: np.linspace(0, 100, 13)
         for channel in channels
     }
     disp = 1 if is_3d else (1, 1, 0.01)
     elastic = tio.RandomElasticDeformation(max_displacement=disp)
     cp_args = (9, 21, 30) if is_3d else (21, 30, 1)
     resize_args = (10, 20, 30) if is_3d else (10, 20, 1)
     flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1)
     swap_patch = (2, 3, 4) if is_3d else (3, 4, 1)
     pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6)
     crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4)
     remapping = {1: 2, 2: 1, 3: 20, 4: 25}
     transforms = [
         tio.CropOrPad(cp_args),
         tio.EnsureShapeMultiple(2, method='crop'),
         tio.Resize(resize_args),
         tio.ToCanonical(),
         tio.RandomAnisotropy(downsampling=(1.75, 2), axes=axes_downsample),
         tio.CopyAffine(channels[0]),
         tio.Resample((1, 1.1, 1.25)),
         tio.RandomFlip(axes=flip_axes, flip_probability=1),
         tio.RandomMotion(),
         tio.RandomGhosting(axes=(0, 1, 2)),
         tio.RandomSpike(),
         tio.RandomNoise(),
         tio.RandomBlur(),
         tio.RandomSwap(patch_size=swap_patch, num_iterations=5),
         tio.Lambda(lambda x: 2 * x, types_to_apply=tio.INTENSITY),
         tio.RandomBiasField(),
         tio.RescaleIntensity(out_min_max=(0, 1)),
         tio.ZNormalization(),
         tio.HistogramStandardization(landmarks_dict),
         elastic,
         tio.RandomAffine(),
         tio.OneOf({
             tio.RandomAffine(): 3,
             elastic: 1,
         }),
         tio.RemapLabels(remapping=remapping, masking_method='Left'),
         tio.RemoveLabels([1, 3]),
         tio.SequentialLabels(),
         tio.Pad(pad_args, padding_mode=3),
         tio.Crop(crop_args),
     ]
     if labels:
         transforms.append(tio.RandomLabelsToImage(label_key='label'))
     return tio.Compose(transforms)
예제 #4
0
 def test_pad(self):
     sample_t1 = self.sample_subject.t1
     assert sample_t1.shape == (1, 10, 20, 30)
     transform = tio.EnsureShapeMultiple(4, method='pad')
     transformed = transform(sample_t1)
     assert transformed.shape == (1, 12, 20, 32)
예제 #5
0
 def test_bad_method(self):
     with self.assertRaises(ValueError):
         tio.EnsureShapeMultiple(1, method='bad')
예제 #6
0
    def predict_whole_image(self,
                            image_channels: np.ndarray,
                            voxel_spacing_mm: TupleFloat3,
                            mask: Optional[np.ndarray] = None,
                            patient_id: int = 0) -> InferencePipeline.Result:
        """
        Performs a single inference pass through the pipeline for the provided image
        :param image_channels: The input image channels to perform inference on in format: Channels x Z x Y x X.
        :param voxel_spacing_mm: Voxel spacing to use for each dimension in (Z x Y x X) order
        :param mask: A binary image used to ignore results outside it in format: Z x Y x X.
        :param patient_id: The identifier of the patient this image belongs to (defaults to 0 if None provided).
        :return InferenceResult: that contains Segmentation for each of the classes and their posterior probabilities.
        """
        if image_channels is None:
            raise Exception("image_channels cannot be None")
        if image_channels.ndim != 4:
            raise NotImplementedError(
                "image_channels must be in shape: Channels x Z x Y x X"
                "found image_channels shape: {}".format(image_channels.shape))
        if mask is not None:
            ml_util.check_size_matches(image_channels, mask, 4, 3,
                                       [-1, -2, -3])
        self.model.eval()

        image = tio.ScalarImage(tensor=image_channels)
        INPUT = 'input_image'
        MASK = 'mask'

        subject_dict: Dict[str, tio.Image] = {INPUT: image}
        if mask is not None:
            subject_dict[MASK] = tio.LabelMap(tensor=mask[np.newaxis])
        subject = tio.Subject(subject_dict)

        constraints = self.model.model.crop_size_constraints

        # Make sure the image size is compatible with the model
        multiple_constraints = constraints.multiple_of  # type: ignore
        if multiple_constraints is not None:
            ensure_shape_multiple = tio.EnsureShapeMultiple(
                constraints.multiple_of)  # type: ignore
            subject = ensure_shape_multiple(subject)  # type: ignore

        # There may be cases where the test image is smaller than the test_crop_size. Adjust crop_size
        # to always fit into image. If test_crop_size is smaller than the image, crop will remain unchanged.
        restrict_patch_size = constraints.restrict_crop_size_to_image  # type: ignore
        effective_patch_size, effective_stride = restrict_patch_size(
            subject.spatial_shape,  # type: ignore
            self.model_config.test_crop_size,
            self.model_config.inference_stride_size)

        patch_overlap = np.array(effective_patch_size) - np.array(
            effective_stride)
        grid_sampler = tio.inference.GridSampler(
            subject,
            effective_patch_size,
            patch_overlap,
            padding_mode=self.model_config.padding_mode.value,
        )
        batch_size = self.model_config.inference_batch_size
        patch_loader = torch.utils.data.DataLoader(
            grid_sampler, batch_size=batch_size)  # type: ignore
        aggregator = tio.inference.GridAggregator(grid_sampler)

        logging.debug(
            f"Inference on image size {subject.spatial_shape} will run "
            f"with crop size {effective_patch_size} and stride {effective_stride}"
        )
        for patches_batch in patch_loader:
            input_tensor = patches_batch[INPUT][tio.DATA].float()
            if self.model_config.use_gpu:
                input_tensor = input_tensor.cuda()
            locations = patches_batch[tio.LOCATION]
            # perform the forward pass
            patches_posteriors = self.model(input_tensor).detach()
            # pad posteriors if they are smaller than the input
            input_shape = input_tensor.shape[-3:]
            patches_posteriors_shape = patches_posteriors.shape[-3:]
            if input_shape != patches_posteriors_shape:
                difference = np.array(input_shape) - np.array(
                    patches_posteriors_shape)
                assert not np.any(
                    difference %
                    2)  # the differences in shape are expected to be even
                padding = tuple(np.repeat(difference // 2, 2))
                patches_posteriors = torch.nn.functional.pad(
                    patches_posteriors, padding)
            # collect the predictions over each of the batches
            aggregator.add_batch(patches_posteriors, locations)
        posteriors = aggregator.get_output_tensor().numpy()
        posteriors_mask = None if mask is None else subject[MASK].numpy()[0]
        posteriors, segmentation = self.post_process_posteriors(
            posteriors, mask=posteriors_mask)

        image_util.check_array_range(posteriors,
                                     error_prefix="Whole image posteriors")

        # Make sure the final shape matches the input shape by undoing the padding in EnsureShapeMultiple (if any)
        posteriors_image = tio.ScalarImage(tensor=posteriors,
                                           affine=image.affine)
        segmentation_image = tio.LabelMap(tensor=segmentation[np.newaxis],
                                          affine=image.affine)
        subject.add_image(posteriors_image, 'posteriors')
        subject.add_image(segmentation_image, 'segmentation')
        # Remove some images to avoid unnecessary computations
        subject.remove_image(INPUT)
        if mask is not None:
            subject.remove_image(MASK)
        subject_original_space = subject.apply_inverse_transform(
        ) if subject.applied_transforms else subject
        posteriors = subject_original_space.posteriors.numpy()  # type: ignore
        segmentation = subject_original_space.segmentation.numpy()[
            0]  # type: ignore

        # prepare pipeline results from the processed batch
        return InferencePipeline.Result(patient_id=patient_id,
                                        segmentation=segmentation,
                                        posteriors=posteriors,
                                        voxel_spacing_mm=voxel_spacing_mm)