Beispiel #1
0
def create_data_gen_pipeline(cf, patient_data, do_aug=True, **kwargs):
    """
    create mutli-threaded train/val/test batch generation and augmentation pipeline.
    :param patient_data: dictionary containing one dictionary per patient in the train/test subset.
    :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing)
    :return: multithreaded_generator
    """

    # create instance of batch generator as first element in pipeline.
    data_gen = BatchGenerator(cf, patient_data, **kwargs)

    my_transforms = []
    if do_aug:
        if cf.da_kwargs["mirror"]:
            mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes'])
            my_transforms.append(mirror_transform)

        spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim],
                                             patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
                                             do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
                                             alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
                                             do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
                                             angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
                                             do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
                                             random_crop=cf.da_kwargs['random_crop'])

        my_transforms.append(spatial_transform)
    else:
        my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim]))

    my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg))
    all_transforms = Compose(my_transforms)
    # multithreaded_generator = SingleThreadedAugmenter(data_gen, all_transforms)
    multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers))
    return multithreaded_generator
Beispiel #2
0
    def generate_train_batch(self, pid=None):

        if pid is None:
            pid = self.dataset_pids[self.patient_ix]
        patient = self._data[pid]

        # already swapped dimensions in pp from (c,)z,y,x to c,y,x,z or h,w,d to ease 2D/3D-case handling
        all_data = np.load(patient['data'], mmap_mode='r')
        data = all_data[0].astype('float16')[np.newaxis]
        seg = all_data[1].astype('uint8')[np.newaxis]

        data_shp_raw = data.shape
        data = data[self.chans]
        spatial_shp = data[0].shape  # spatial dims need to be in order x,y,z
        assert spatial_shp == seg[0].shape, "spatial shape incongruence betw. data and seg"

        out_data = data[None]
        out_seg = seg[None]

        batch_2D = {'data': out_data, 'seg': out_seg}
        for o in self.cf.roi_items:
            batch_2D[o] = np.repeat(np.array([patient[o]]), len(out_data), axis=0)
        converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg)
        batch_2D = converter(**batch_2D)

        batch_2D.update({'patient_bb_target': batch_2D['bb_target'],
                         'original_img_shape': out_data.shape})
        for o in self.cf.roi_items:
            batch_2D["patient_" + o] = batch_2D[o]

        out_batch = batch_2D
        out_batch.update({'pid': np.array([patient['pid']] * len(out_data))})

        self.patient_ix += 1
        if self.patient_ix == len(self.dataset_pids):
            self.patient_ix = 0

        return out_batch
Beispiel #3
0
    def generate_train_batch(self, pid=None):

        if pid is None:
            pid = self.dataset_pids[self.patient_ix]
        patient = self._data[pid]

        # already swapped dimensions in pp from (c,)z,y,x to c,y,x,z or h,w,d to ease 2D/3D-case handling
        data = np.load(patient['data'], mmap_mode='r').astype('float16')[np.newaxis]
        seg =  np.load(patient[self.gt_prefix+'seg']).astype('uint8')[np.newaxis]

        data_shp_raw = data.shape
        plot_bg = data[self.cf.plot_bg_chan] if self.cf.plot_bg_chan not in self.chans else None
        data = data[self.chans]
        discarded_chans = len(
            [c for c in np.setdiff1d(np.arange(data_shp_raw[0]), self.chans) if c < self.cf.plot_bg_chan])
        spatial_shp = data[0].shape  # spatial dims need to be in order x,y,z
        assert spatial_shp == seg[0].shape, "spatial shape incongruence betw. data and seg"

        if np.any([spatial_shp[i] < ps for i, ps in enumerate(self.patch_size)]):
            new_shape = [np.max([spatial_shp[i], self.patch_size[i]]) for i in range(len(self.patch_size))]
            data = dutils.pad_nd_image(data, new_shape)  # use 'return_slicer' to crop image back to original shape.
            seg = dutils.pad_nd_image(seg, new_shape)
            if plot_bg is not None:
                plot_bg = dutils.pad_nd_image(plot_bg, new_shape)

        if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds:
            # adds the batch dim here bc won't go through MTaugmenter
            out_data = data[np.newaxis]
            out_seg = seg[np.newaxis]
            if plot_bg is not None:
               out_plot_bg = plot_bg[np.newaxis]
            # data and seg shape: (1,c,x,y,z), where c=1 for seg

            batch_3D = {'data': out_data, 'seg': out_seg}
            for o in self.cf.roi_items:
                batch_3D[o] = np.array([patient[self.gt_prefix+o]])
            converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg)
            batch_3D = converter(**batch_3D)
            batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape})
            for o in self.cf.roi_items:
                batch_3D["patient_" + o] = batch_3D[o]

        if self.cf.dim == 2:
            out_data = np.transpose(data, axes=(3, 0, 1, 2)).astype('float32')  # (c,y,x,z) to (b=z,c,x,y), use z=b as batchdim
            out_seg = np.transpose(seg, axes=(3, 0, 1, 2)).astype('uint8')  # (c,y,x,z) to (b=z,c,x,y)

            batch_2D = {'data': out_data, 'seg': out_seg}
            for o in self.cf.roi_items:
                batch_2D[o] = np.repeat(np.array([patient[self.gt_prefix+o]]), len(out_data), axis=0)
            converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg)
            batch_2D = converter(**batch_2D)

            if plot_bg is not None:
                out_plot_bg = np.transpose(plot_bg, axes=(2, 0, 1)).astype('float32')

            if self.cf.merge_2D_to_3D_preds:
                batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'],
                                 'original_img_shape': out_data.shape})
                for o in self.cf.roi_items:
                    batch_2D["patient_" + o] = batch_3D[o]
            else:
                batch_2D.update({'patient_bb_target': batch_2D['bb_target'],
                                 'original_img_shape': out_data.shape})
                for o in self.cf.roi_items:
                    batch_2D["patient_" + o] = batch_2D[o]

        out_batch = batch_3D if self.cf.dim == 3 else batch_2D
        out_batch.update({'pid': np.array([patient['pid']] * len(out_data))})

        if self.cf.plot_bg_chan in self.chans and discarded_chans > 0:  # len(self.chans[:self.cf.plot_bg_chan])<data_shp_raw[0]:
            assert plot_bg is None
            plot_bg = int(self.cf.plot_bg_chan - discarded_chans)
            out_plot_bg = plot_bg
        if plot_bg is not None:
            out_batch['plot_bg'] = out_plot_bg

        # eventual tiling into patches
        spatial_shp = out_batch["data"].shape[2:]
        if np.any([spatial_shp[ix] > self.patch_size[ix] for ix in range(len(spatial_shp))]):
            patient_batch = out_batch
            print("patientiterator produced patched batch!")
            patch_crop_coords_list = dutils.get_patch_crop_coords(data[0], self.patch_size)
            new_img_batch, new_seg_batch = [], []

            for c in patch_crop_coords_list:
                new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3], c[4]:c[5]])
                seg_patch = seg[:, c[0]:c[1], c[2]: c[3], c[4]:c[5]]
                new_seg_batch.append(seg_patch)
            shps = []
            for arr in new_img_batch:
                shps.append(arr.shape)

            data = np.array(new_img_batch)  # (patches, c, x, y, z)
            seg = np.array(new_seg_batch)
            if self.cf.dim == 2:
                # all patches have z dimension 1 (slices). discard dimension
                data = data[..., 0]
                seg = seg[..., 0]
            patch_batch = {'data': data.astype('float32'), 'seg': seg.astype('uint8'),
                           'pid': np.array([patient['pid']] * data.shape[0])}
            for o in self.cf.roi_items:
                patch_batch[o] = np.repeat(np.array([patient[self.gt_prefix+o]]), len(patch_crop_coords_list), axis=0)
            #patient-wise (orig) batch info for putting the patches back together after prediction
            for o in self.cf.roi_items:
                patch_batch["patient_"+o] = patient_batch["patient_"+o]
                if self.cf.dim == 2:
                    # this could also be named "unpatched_2d_roi_items"
                    patch_batch["patient_" + o + "_2d"] = patient_batch[o]
            patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list)
            patch_batch['patient_bb_target'] = patient_batch['patient_bb_target']
            if self.cf.dim == 2:
                patch_batch['patient_bb_target_2d'] = patient_batch['bb_target']
            patch_batch['patient_data'] = patient_batch['data']
            patch_batch['patient_seg'] = patient_batch['seg']
            patch_batch['original_img_shape'] = patient_batch['original_img_shape']
            if plot_bg is not None:
                patch_batch['patient_plot_bg'] = patient_batch['plot_bg']

            converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, get_rois_from_seg=False,
                                                           class_specific_seg=self.cf.class_specific_seg)

            patch_batch = converter(**patch_batch)
            out_batch = patch_batch

        self.patient_ix += 1
        if self.patient_ix == len(self.dataset_pids):
            self.patient_ix = 0

        return out_batch