Beispiel #1
0
    def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None:
        """
        Save data into a Nifti file.
        The meta_data could optionally have the following keys:

            - ``'filename_or_obj'`` -- for output file name creation, corresponding to filename or object.
            - ``'original_affine'`` -- for data orientation handling, defaulting to an identity matrix.
            - ``'affine'`` -- for data output affine, defaulting to an identity matrix.
            - ``'spatial_shape'`` -- for data output shape.
            - ``'patch_index'`` -- if the data is a patch of big image, append the patch index to filename.

        When meta_data is specified, the saver will try to resample batch data from the space
        defined by "affine" to the space defined by "original_affine".

        If meta_data is None, use the default index (starting from 0) as the filename.

        Args:
            data: target data content that to be saved as a NIfTI format file.
                Assuming the data shape starts with a channel dimension and followed by spatial dimensions.
            meta_data: the meta data information corresponding to the data.

        See Also
            :py:meth:`monai.data.nifti_writer.write_nifti`
        """
        filename = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index)
        self._data_index += 1
        original_affine = meta_data.get("original_affine", None) if meta_data else None
        affine = meta_data.get("affine", None) if meta_data else None
        spatial_shape = meta_data.get("spatial_shape", None) if meta_data else None
        patch_index = meta_data.get(Key.PATCH_INDEX, None) if meta_data else None

        if isinstance(data, torch.Tensor):
            data = data.detach().cpu().numpy()

        path = create_file_basename(self.output_postfix, filename, self.output_dir, self.data_root_dir, patch_index)
        path = f"{path}{self.output_ext}"
        # change data shape to be (channel, h, w, d)
        while len(data.shape) < 4:
            data = np.expand_dims(data, -1)
        # change data to "channel last" format and write to nifti format file
        data = np.moveaxis(np.asarray(data), 0, -1)

        # if desired, remove trailing singleton dimensions
        if self.squeeze_end_dims:
            while data.shape[-1] == 1:
                data = np.squeeze(data, -1)

        write_nifti(
            data,
            file_name=path,
            affine=affine,
            target_affine=original_affine,
            resample=self.resample,
            output_spatial_shape=spatial_shape,
            mode=self.mode,
            padding_mode=self.padding_mode,
            align_corners=self.align_corners,
            dtype=self.dtype,
            output_dtype=self.output_dtype,
        )
Beispiel #2
0
 def __call__(self, engine):
     """
     This method assumes:
         - 3rd output of engine.state.batch is a meta data dict, and have the keys:
         'filename_or_obj' -- for output file name creation
         and optionally 'original_affine', 'affine' for data orientation handling.
         - output file datatype from `engine.state.output.dtype`.
     """
     meta_data = engine.state.batch[
         2]  # assuming 3rd output of input dataset is a meta data dict
     filenames = meta_data['filename_or_obj']
     original_affine = meta_data.get('original_affine', None)
     affine = meta_data.get('affine', None)
     engine_output = self.output_transform(engine.state.output)
     for batch_id, filename in enumerate(
             filenames):  # save a batch of files
         seg_output = engine_output[batch_id]
         _affine = affine[batch_id]
         _original_affine = original_affine[batch_id]
         if isinstance(seg_output, torch.Tensor):
             seg_output = seg_output.detach().cpu().numpy()
         output_filename = self._create_file_basename(
             self.output_postfix, filename, self.output_path)
         output_filename = '{}{}'.format(output_filename, self.output_ext)
         write_nifti(seg_output,
                     _affine,
                     output_filename,
                     _original_affine,
                     dtype=seg_output.dtype)
         print('saved: {}'.format(output_filename))
Beispiel #3
0
    def __call__(self, engine):
        """
        This method assumes self.batch_transform will extract Metadata from the input batch.
        Metadata should have the following keys:

            - ``'filename_or_obj'`` -- for output file name creation
            - ``'original_affine'`` (optional) for data orientation handling
            - ``'affine'`` (optional) for data output affine.

        output file datatype is determined from ``engine.state.output.dtype``.
        """
        meta_data = self.batch_transform(engine.state.batch)
        filenames = meta_data['filename_or_obj']
        original_affine = meta_data.get('original_affine', None)
        affine = meta_data.get('affine', None)

        engine_output = self.output_transform(engine.state.output)
        for batch_id, filename in enumerate(
                filenames):  # save a batch of files
            seg_output = engine_output[batch_id]
            affine_ = affine[batch_id]
            original_affine_ = original_affine[batch_id]
            if isinstance(seg_output, torch.Tensor):
                seg_output = seg_output.detach().cpu().numpy()
            output_filename = self._create_file_basename(
                self.output_postfix, filename, self.output_path)
            output_filename = '{}{}'.format(output_filename, self.output_ext)
            # change output to "channel last" format and write to nifti format file
            to_save = np.moveaxis(seg_output, 0, -1)
            write_nifti(to_save,
                        affine_,
                        output_filename,
                        original_affine_,
                        dtype=seg_output.dtype)
            self.logger.info('saved: {}'.format(output_filename))
Beispiel #4
0
    def save(self,
             data: Union[torch.Tensor, np.ndarray],
             meta_data: Optional[Dict] = None) -> None:
        """
        Save data into a Nifti file.
        The meta_data could optionally have the following keys:

            - ``'filename_or_obj'`` -- for output file name creation, corresponding to filename or object.
            - ``'original_affine'`` -- for data orientation handling, defaulting to an identity matrix.
            - ``'affine'`` -- for data output affine, defaulting to an identity matrix.
            - ``'spatial_shape'`` -- for data output shape.

        When meta_data is specified, the saver will try to resample batch data from the space
        defined by "affine" to the space defined by "original_affine".

        If meta_data is None, use the default index (starting from 0) as the filename.

        Args:
            data: target data content that to be saved as a NIfTI format file.
                Assuming the data shape starts with a channel dimension and followed by spatial dimensions.
            meta_data: the meta data information corresponding to the data.

        See Also
            :py:meth:`monai.data.nifti_writer.write_nifti`
        """
        filename = meta_data["filename_or_obj"] if meta_data else str(
            self._data_index)
        for _ in range(self.output_name_uplevel):
            filename = os.path.dirname(filename)
        self._data_index += 1
        original_affine = meta_data.get("original_affine",
                                        None) if meta_data else None
        affine = meta_data.get("affine", None) if meta_data else None
        spatial_shape = meta_data.get("spatial_shape",
                                      None) if meta_data else None

        if torch.is_tensor(data):
            data = data.detach().cpu().numpy()

        filename = create_file_basename(self.output_postfix, filename,
                                        self.output_dir)
        filename = f"{filename}{self.output_ext}"
        # change data shape to be (channel, h, w, d)
        while len(data.shape) < 4:
            data = np.expand_dims(data, -1)
        # change data to "channel last" format and write to nifti format file
        data = np.moveaxis(data, 0, -1)
        write_nifti(
            data,
            file_name=filename,
            affine=affine,
            target_affine=original_affine,
            resample=self.resample,
            output_spatial_shape=spatial_shape,
            mode=self.mode,
            padding_mode=self.padding_mode,
            align_corners=self.align_corners,
            dtype=self.dtype,
        )
Beispiel #5
0
    def save(self, data: Union[torch.Tensor, np.ndarray], meta_data=None):
        """
        Save data into a Nifti file.
        The metadata could optionally have the following keys:

            - ``'filename_or_obj'`` -- for output file name creation, corresponding to filename or object.
            - ``'original_affine'`` -- for data orientation handling, defaulting to an identity matrix.
            - ``'affine'`` -- for data output affine, defaulting to an identity matrix.
            - ``'spatial_shape'`` -- for data output shape.

        If meta_data is None, use the default index from 0 to save data instead.

        args:
            data (Tensor or ndarray): target data content that to be saved as a NIfTI format file.
                Assuming the data shape starts with a channel dimension and followed by spatial dimensions.
            meta_data (dict): the meta data information corresponding to the data.

        See Also
            :py:meth:`monai.data.nifti_writer.write_nifti`
        """
        filename = meta_data["filename_or_obj"] if meta_data else str(
            self._data_index)
        self._data_index += 1
        original_affine = meta_data.get("original_affine",
                                        None) if meta_data else None
        affine = meta_data.get("affine", None) if meta_data else None
        spatial_shape = meta_data.get("spatial_shape",
                                      None) if meta_data else None

        if torch.is_tensor(data):
            data = data.detach().cpu().numpy()
        filename = create_file_basename(self.output_postfix, filename,
                                        self.output_dir)
        filename = f"{filename}{self.output_ext}"
        # change data to "channel last" format and write to nifti format file
        data = np.moveaxis(data, 0, -1)
        write_nifti(
            data,
            file_name=filename,
            affine=affine,
            target_affine=original_affine,
            resample=self.resample,
            output_shape=spatial_shape,
            interp_order=self.interp_order,
            mode=self.mode,
            cval=self.cval,
            dtype=self.dtype or data.dtype,
        )
Beispiel #6
0
    def _iteration(self, engine: Engine,
                   batchdata: Dict[str, Any]) -> Dict[str, torch.Tensor]:
        """
        callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
        Return below item in a dictionary:
            - PRED: prediction result of model.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.

        Raises:
            ValueError: When ``batchdata`` is None.

        """
        if batchdata is None:
            raise ValueError("Must provide batch data for current iteration.")
        batch = self.prepare_batch(batchdata, engine.state.device,
                                   engine.non_blocking)
        if len(batch) == 2:
            inputs, _ = batch
            args: Tuple = ()
            kwargs: Dict = {}
        else:
            inputs, _, args, kwargs = batch

        def _compute_pred():
            ct = 1.0
            pred = self.inferer(inputs, self.network, *args, **kwargs).cpu()
            pred = nn.functional.softmax(pred, dim=1)
            if not self.tta_val:
                return pred
            else:
                for dims in [[2], [3], [4], (2, 3), (2, 4), (3, 4), (2, 3, 4)]:
                    flip_inputs = torch.flip(inputs, dims=dims)
                    flip_pred = torch.flip(self.inferer(
                        flip_inputs, self.network).cpu(),
                                           dims=dims)
                    flip_pred = nn.functional.softmax(flip_pred, dim=1)
                    del flip_inputs
                    pred += flip_pred
                    del flip_pred
                    ct += 1
                return pred / ct

        # execute forward computation
        with eval_mode(self.network):
            if self.amp:
                with torch.cuda.amp.autocast():
                    predictions = _compute_pred()
            else:
                predictions = _compute_pred()

        inputs = inputs.cpu()
        predictions = self.post_pred(predictions)

        affine = batchdata["image_meta_dict"]["affine"].numpy()[0]
        resample_flag = batchdata["resample_flag"]
        anisotrophy_flag = batchdata["anisotrophy_flag"]
        crop_shape = batchdata["crop_shape"][0].tolist()
        original_shape = batchdata["original_shape"][0].tolist()

        if resample_flag:
            # convert the prediction back to the original (after cropped) shape
            predictions = recovery_prediction(predictions.numpy()[0],
                                              [self.n_classes, *crop_shape],
                                              anisotrophy_flag)
        else:
            predictions = predictions.numpy()

        predictions = predictions[0]
        predictions = np.argmax(predictions, axis=0)

        # pad the prediction back to the original shape
        predictions_org = np.zeros([*original_shape])
        box_start, box_end = batchdata["bbox"][0]
        h_start, w_start, d_start = box_start
        h_end, w_end, d_end = box_end
        predictions_org[h_start:h_end, w_start:w_end,
                        d_start:d_end] = predictions
        del predictions

        filename = batchdata["image_meta_dict"]["filename_or_obj"][0].split(
            "/")[-1]

        print("save {} with shape: {}, mean values: {}".format(
            filename, predictions_org.shape, predictions_org.mean()))
        write_nifti(
            data=predictions_org,
            file_name=os.path.join(self.output_dir, filename),
            affine=affine,
            resample=False,
            output_dtype=np.uint8,
        )
        engine.fire_event(IterationEvents.FORWARD_COMPLETED)
        return {"pred": predictions_org}
Beispiel #7
0
def transform_and_copy(data, cahce_dir):
    copy_dir = os.path.join(cahce_dir, 'copied_images')
    if not os.path.exists(copy_dir):
        os.mkdir(copy_dir)
    copy_list_path = os.path.join(copy_dir, 'copied_images.npy')
    if not os.path.exists(copy_list_path):
        print("transforming and copying images...")
        imageLoader = LoadImage()
        to_copy_list = [x for x in data if int(x['_label']) == 1]
        mul = 1  #int(len(data)/len(to_copy_list) - 1)

        rand_x_flip = RandFlip(spatial_axis=0, prob=0.50)
        rand_y_flip = RandFlip(spatial_axis=1, prob=0.50)
        rand_z_flip = RandFlip(spatial_axis=2, prob=0.50)
        rand_affine = RandAffine(prob=1.0,
                                 rotate_range=(0, 0, np.pi / 10),
                                 shear_range=(0.12, 0.12, 0.0),
                                 translate_range=(0, 0, 0),
                                 scale_range=(0.12, 0.12, 0.0),
                                 padding_mode="zeros")
        rand_gaussian_noise = RandGaussianNoise(prob=0.5, mean=0.0, std=0.05)
        transform = Compose([
            AddChannel(),
            rand_x_flip,
            rand_y_flip,
            rand_z_flip,
            rand_affine,
            SqueezeDim(),
        ])
        copy_list = []
        n = len(to_copy_list)
        for i in range(len(to_copy_list)):
            print(f'Copying image {i+1}/{n}', end="\r")
            to_copy = to_copy_list[i]
            image_file = to_copy['image']
            _image_file = replace_suffix(image_file, '.nii.gz', '')
            label = to_copy['label']
            _label = to_copy['_label']
            image_data, _ = imageLoader(image_file)
            seg_file = to_copy['seg']
            seg_data, _ = nrrd.read(seg_file)

            for i in range(mul):
                rand_seed = np.random.randint(1e8)
                transform.set_random_state(seed=rand_seed)
                new_image_data = rand_gaussian_noise(
                    np.array(transform(image_data)))
                transform.set_random_state(seed=rand_seed)
                new_seg_data = np.array(transform(seg_data))
                #multi_slice_viewer(image_data, image_file)
                #multi_slice_viewer(seg_data, seg_file)
                #seg_image = MaskIntensity(seg_data)(image_data)
                #multi_slice_viewer(seg_image, seg_file)
                image_basename = os.path.basename(_image_file)
                seg_basename = image_basename + f'_seg_{i}.nrrd'
                image_basename = image_basename + f'_{i}.nii.gz'

                new_image_file = os.path.join(copy_dir, image_basename)
                write_nifti(new_image_data, new_image_file, resample=False)
                new_seg_file = os.path.join(copy_dir, seg_basename)
                nrrd.write(new_seg_file, new_seg_data)
                copy_list.append({
                    'image': new_image_file,
                    'seg': new_seg_file,
                    'label': label,
                    '_label': _label
                })

        np.save(copy_list_path, copy_list)
        print("done transforming and copying!")

    copy_list = np.load(copy_list_path, allow_pickle=True)
    return copy_list
Beispiel #8
0
def large_image_splitter(data, cache_dir, num_splits, only_label_one=False):
    print("Splitting large images...")
    len_old = len(data)
    print("original data len:", len_old)
    split_images_dir = os.path.join(cache_dir, 'split_images')
    split_images = os.path.join(split_images_dir, 'split_images.npy')

    def _replace_in_data(split_images, num_splits):
        new_images = []
        for image in data:
            new_images.append(image)
            for s in split_images:
                source_image = s['source']
                if image['_label'] == 0 and only_label_one is True:
                    break
                if image['image'] == source_image:
                    #new_images.pop()
                    for i in range(min(num_splits, len(s["splits"]))):
                        new_images.append(s["splits"][i])
                    break
        return new_images

    if os.path.exists(split_images):
        new_images = np.load(split_images, allow_pickle=True)
        """for s in new_images:
            print("split image:", s["source"], end='\r')"""
        out_data = _replace_in_data(new_images, num_splits)
    else:
        if not os.path.exists(split_images_dir):
            os.mkdir(split_images_dir)
        new_images = []
        imageLoader = LoadImage()
        for image in data:
            image_data, _ = imageLoader(image["image"])
            seg_data, _ = nrrd.read(image['seg'])
            label = image['_label']
            z_len = image_data.shape[2]
            if z_len > 200:
                count = z_len // 80
                print("splitting image:",
                      image["image"],
                      f"into {count} parts",
                      "shape:",
                      image_data.shape,
                      end='\r')
                split_image_list = [
                    image_data[:, :, idz::count] for idz in range(count)
                ]
                split_seg_list = [
                    seg_data[:, :, idz::count] for idz in range(count)
                ]
                new_image = {'source': image["image"], 'splits': []}
                for i in range(count):
                    image_file = os.path.basename(
                        replace_suffix(image["image"], '.nii.gz', ''))
                    image_file = os.path.join(split_images_dir,
                                              image_file + f'_{i}.nii.gz')
                    seg_file = os.path.basename(
                        replace_suffix(image["seg"], '.nrrd', ''))
                    seg_file = os.path.join(split_images_dir,
                                            seg_file + f'_seg_{i}.nrrd')
                    split_image = np.array(split_image_list[i])
                    split_seg = np.array(split_seg_list[i], dtype=np.uint8)

                    rand_affine = RandAffine(prob=1.0,
                                             rotate_range=(0, 0, np.pi / 16),
                                             shear_range=(0.07, 0.07, 0.0),
                                             translate_range=(0, 0, 0),
                                             scale_range=(0.07, 0.07, 0.0),
                                             padding_mode="zeros")
                    transform = Compose([
                        AddChannel(),
                        rand_affine,
                        SqueezeDim(),
                    ])
                    rand_seed = np.random.randint(1e8)
                    transform.set_random_state(seed=rand_seed)
                    split_image = transform(split_image).detach().cpu().numpy()
                    transform.set_random_state(seed=rand_seed)
                    split_seg = transform(split_seg).detach().cpu().numpy()

                    write_nifti(split_image, image_file, resample=False)
                    nrrd.write(seg_file, split_seg)
                    new_image['splits'].append({
                        'image': image_file,
                        'label': image['label'],
                        '_label': image['_label'],
                        'seg': seg_file,
                        'w': False
                    })
                new_images.append(new_image)
        np.save(split_images, new_images)
        out_data = _replace_in_data(new_images, num_splits)

    print("new data len:", len(out_data))
    return out_data
Beispiel #9
0
def image_mixing(data, seed=None):
    #random.seed(seed)

    file_list = [x for x in data if int(x['_label']) == 1]
    random.shuffle(file_list)

    crop_foreground = CropForegroundd(keys=["image"],
                                      source_key="image",
                                      margin=(0, 0, 0),
                                      select_fn=lambda x: x != 0)
    WW, WL = 1500, -600
    ct_window = CTWindowd(keys=["image"], width=WW, level=WL)
    resize2 = Resized(keys=["image"],
                      spatial_size=(int(512 * 0.75), int(512 * 0.75), -1),
                      mode="area")
    resize1 = Resized(keys=["image"],
                      spatial_size=(-1, -1, 40),
                      mode="nearest")
    gauss = GaussianSmooth(sigma=(1., 1., 0))
    gauss2 = GaussianSmooth(sigma=(2.0, 2.0, 0))
    affine = Affined(keys=["image"],
                     scale_params=(1.0, 2.0, 1.0),
                     padding_mode='zeros')

    common_transform = Compose([
        LoadImaged(keys=["image"]),
        ct_window,
        CTSegmentation(keys=["image"]),
        AddChanneld(keys=["image"]),
        affine,
        crop_foreground,
        resize1,
        resize2,
        SqueezeDimd(keys=["image"]),
    ])

    dirs = setup_directories()
    data_dir = dirs['data']
    mixed_images_dir = os.path.join(data_dir, 'mixed_images')
    if not os.path.exists(mixed_images_dir):
        os.mkdir(mixed_images_dir)

    for img1, img2 in itertools.combinations(file_list, 2):

        img1 = {'image': img1["image"], 'seg': img1['seg']}
        img2 = {'image': img2["image"], 'seg': img2['seg']}

        img1_data = common_transform(img1)["image"]
        img2_data = common_transform(img2)["image"]
        img1_mask, img2_mask = (img1_data > 0), (img2_data > 0)
        img_presek = np.logical_and(img1_mask, img2_mask)
        img = np.maximum(img_presek * img1_data, img_presek * img2_data)

        multi_slice_viewer(img, "img1")

        loop = True
        while loop:
            save = input("Save image [y/n/e]: ")
            if save.lower() == 'y':
                loop = False
                k = str(time.time()).encode('utf-8')
                h = blake2b(key=k, digest_size=16)
                name = h.hexdigest() + '.nii.gz'
                out_path = os.path.join(mixed_images_dir, name)
                write_nifti(img, out_path, resample=False)
            elif save.lower() == 'n':
                loop = False
                break
            elif save.lower() == 'e':
                print("exeting")
                exit()
            else:
                print("wrong input!")