def create_data_gen_pipeline(patient_data, cf, is_training=True):
    """
    create mutli-threaded train/val/test batch generation and augmentation pipeline.
    :param patient_data: dictionary containing one dictionary per patient in the train/test subset.
    :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing)
    :return: multithreaded_generator
    """

    # create instance of batch generator as first element in pipeline.
    data_gen = BatchGenerator(patient_data, batch_size=cf.batch_size, cf=cf)

    # add transformations to pipeline.
    my_transforms = []
    if is_training:
        mirror_transform = Mirror(axes=np.arange(cf.dim))
        my_transforms.append(mirror_transform)
        spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim],
                                             patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
                                             do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
                                             alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
                                             do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
                                             angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
                                             do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
                                             random_crop=cf.da_kwargs['random_crop'])

        my_transforms.append(spatial_transform)
    else:
        my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim]))

    my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, get_rois_from_seg_flag=False, class_specific_seg_flag=cf.class_specific_seg_flag))
    all_transforms = Compose(my_transforms)
    # multithreaded_generator = SingleThreadedAugmenter(data_gen, all_transforms)
    multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers))
    return multithreaded_generator
Exemplo n.º 2
0
    def generate_train_batch(self):

        pid = self.dataset_pids[self.patient_ix]
        patient = self._data[pid]
        all_data = np.load(patient['data'], mmap_mode='r')
        data = all_data[0]
        seg = all_data[1].astype('uint8')
        batch_class_targets = np.array([patient['class_target']])

        out_data = data[None, None]
        out_seg = seg[None, None]

        print('check patient data loader', out_data.shape, out_seg.shape)
        batch_2D = {'data': out_data, 'seg': out_seg, 'class_target': batch_class_targets, 'pid': pid}
        converter = ConvertSegToBoundingBoxCoordinates(dim=2, get_rois_from_seg_flag=False, class_specific_seg_flag=self.cf.class_specific_seg_flag)
        batch_2D = converter(**batch_2D)

        batch_2D.update({'patient_bb_target': batch_2D['bb_target'],
                         'patient_roi_labels': batch_2D['roi_labels'],
                         'original_img_shape': out_data.shape})

        self.patient_ix += 1
        if self.patient_ix == len(self.dataset_pids):
            self.patient_ix = 0

        return batch_2D
Exemplo n.º 3
0
    def generate_train_batch(self):
        pid = self.dataset_pids[self.patient_ix]
        patient = self._data[pid]
        # data shape: from (c, z, y, x) to (c, y, x, z).
        data = np.transpose(np.load(patient['data'], mmap_mode='r'),
                            axes=(3, 1, 2, 0)).copy()
        seg = np.transpose(np.load(patient['seg'], mmap_mode='r'),
                           axes=(3, 1, 2, 0))[0].copy()
        batch_class_targets = np.array([patient['class_target']])

        # pad data if smaller than patch_size seen during training.
        if np.any([
                data.shape[dim + 1] < ps
                for dim, ps in enumerate(self.patch_size)
        ]):
            new_shape = [data.shape[0]] + [
                np.max([data.shape[dim + 1], self.patch_size[dim]])
                for dim, ps in enumerate(self.patch_size)
            ]
            data = dutils.pad_nd_image(
                data, new_shape
            )  # use 'return_slicer' to crop image back to original shape.
            seg = dutils.pad_nd_image(seg, new_shape)

        # get 3D targets for evaluation, even if network operates in 2D. 2D predictions will be merged to 3D in predictor.
        if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds:
            out_data = data[np.newaxis]
            out_seg = seg[np.newaxis, np.newaxis]
            out_targets = batch_class_targets

            batch_3D = {
                'data': out_data,
                'seg': out_seg,
                'class_target': out_targets,
                'pid': pid
            }
            converter = ConvertSegToBoundingBoxCoordinates(
                dim=3,
                get_rois_from_seg_flag=False,
                class_specific_seg_flag=self.cf.class_specific_seg_flag)
            batch_3D = converter(**batch_3D)
            batch_3D.update({
                'patient_bb_target': batch_3D['bb_target'],
                'patient_roi_labels': batch_3D['class_target'],
                'original_img_shape': out_data.shape
            })

        if self.cf.dim == 2:
            out_data = np.transpose(data, axes=(3, 0, 1, 2))  # (z, c, y, x )
            out_seg = np.transpose(seg, axes=(2, 0, 1))[:, np.newaxis]
            out_targets = np.array(
                np.repeat(batch_class_targets, out_data.shape[0], axis=0))

            # if set to not None, add neighbouring slices to each selected slice in channel dimension.
            if self.cf.n_3D_context is not None:
                slice_range = range(self.cf.n_3D_context,
                                    out_data.shape[0] + self.cf.n_3D_context)
                out_data = np.pad(
                    out_data, ((self.cf.n_3D_context, self.cf.n_3D_context),
                               (0, 0), (0, 0), (0, 0)),
                    'constant',
                    constant_values=0)
                out_data = np.array([
                    np.concatenate([
                        out_data[ii]
                        for ii in range(slice_id -
                                        self.cf.n_3D_context, slice_id +
                                        self.cf.n_3D_context + 1)
                    ],
                                   axis=0) for slice_id in slice_range
                ])

            batch_2D = {
                'data': out_data,
                'seg': out_seg,
                'class_target': out_targets,
                'pid': pid
            }
            converter = ConvertSegToBoundingBoxCoordinates(
                dim=2,
                get_rois_from_seg_flag=False,
                class_specific_seg_flag=self.cf.class_specific_seg_flag)
            batch_2D = converter(**batch_2D)

            if self.cf.merge_2D_to_3D_preds:
                batch_2D.update({
                    'patient_bb_target':
                    batch_3D['patient_bb_target'],
                    'patient_roi_labels':
                    batch_3D['patient_roi_labels'],
                    'original_img_shape':
                    out_data.shape
                })
            else:
                batch_2D.update({
                    'patient_bb_target': batch_2D['bb_target'],
                    'patient_roi_labels': batch_2D['class_target'],
                    'original_img_shape': out_data.shape
                })

        out_batch = batch_3D if self.cf.dim == 3 else batch_2D
        patient_batch = out_batch

        # crop patient-volume to patches of patch_size used during training. stack patches up in batch dimension.
        # in this case, 2D is treated as a special case of 3D with patch_size[z] = 1.
        if np.any(
            [data.shape[dim + 1] > self.patch_size[dim] for dim in range(3)]):
            patch_crop_coords_list = dutils.get_patch_crop_coords(
                data[0], self.patch_size)
            new_img_batch, new_seg_batch, new_class_targets_batch = [], [], []

            for cix, c in enumerate(patch_crop_coords_list):

                seg_patch = seg[c[0]:c[1], c[2]:c[3], c[4]:c[5]]
                new_seg_batch.append(seg_patch)

                # if set to not None, add neighbouring slices to each selected slice in channel dimension.
                # correct patch_crop coordinates by added slices of 3D context.
                if self.cf.dim == 2 and self.cf.n_3D_context is not None:
                    tmp_c_5 = c[5] + (self.cf.n_3D_context * 2)
                    if cix == 0:
                        data = np.pad(
                            data,
                            ((0, 0), (0, 0), (0, 0),
                             (self.cf.n_3D_context, self.cf.n_3D_context)),
                            'constant',
                            constant_values=0)
                else:
                    tmp_c_5 = c[5]

                new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3],
                                          c[4]:tmp_c_5])

            data = np.array(new_img_batch)  # (n_patches, c, x, y, z)
            seg = np.array(
                new_seg_batch)[:, np.newaxis]  # (n_patches, 1, x, y, z)
            batch_class_targets = np.repeat(batch_class_targets,
                                            len(patch_crop_coords_list),
                                            axis=0)

            if self.cf.dim == 2:
                if self.cf.n_3D_context is not None:
                    data = np.transpose(data[:, 0], axes=(0, 3, 1, 2))
                else:
                    # all patches have z dimension 1 (slices). discard dimension
                    data = data[..., 0]
                seg = seg[..., 0]

            patch_batch = {
                'data': data,
                'seg': seg,
                'class_target': batch_class_targets,
                'pid': pid
            }
            patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list)
            patch_batch['patient_bb_target'] = patient_batch[
                'patient_bb_target']
            patch_batch['patient_roi_labels'] = patient_batch[
                'patient_roi_labels']
            patch_batch['original_img_shape'] = patient_batch[
                'original_img_shape']

            converter = ConvertSegToBoundingBoxCoordinates(
                self.cf.dim,
                get_rois_from_seg_flag=False,
                class_specific_seg_flag=self.cf.class_specific_seg_flag)
            patch_batch = converter(**patch_batch)
            out_batch = patch_batch

        self.patient_ix += 1
        if self.patient_ix == len(self.dataset_pids):
            self.patient_ix = 0

        out_batch['data'][:, self.cf.drop_channels_test, ] = 0.
        return out_batch
Exemplo n.º 4
0
    def generate_train_batch(self):

        pid = self.dataset_pids[self.patient_ix]
        patient = self._data[pid]
        data = np.transpose(np.load(patient['data'], mmap_mode='r'),
                            axes=(1, 2, 0))[np.newaxis]  # (c, y, x, z)
        seg = np.transpose(np.load(patient['seg'], mmap_mode='r'),
                           axes=(1, 2, 0))
        print('patient', patient)
        print('data', data.shape)
        batch_class_targets = np.array([patient['class_target']])

        # pad data if smaller than patch_size seen during training.
        if np.any([
                data.shape[dim + 1] < ps
                for dim, ps in enumerate(self.patch_size)
        ]):
            new_shape = [data.shape[0]] + [
                np.max([data.shape[dim + 1], self.patch_size[dim]])
                for dim, ps in enumerate(self.patch_size)
            ]
            data = dutils.pad_nd_image(
                data, new_shape
            )  # use 'return_slicer' to crop image back to original shape.
            if len(new_shape) == 4:
                new_shape = new_shape[1:]
            seg = dutils.pad_nd_image(seg, new_shape)
        # get 3D targets for evaluation, even if network operates in 2D. 2D predictions will be merged to 3D in predictor.
        if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds:  #default True
            out_data = data[np.newaxis]
            out_seg = seg[np.newaxis, np.newaxis]
            out_targets = batch_class_targets

            batch_3D = {
                'data': out_data,
                'seg': out_seg,
                'class_target': out_targets,
                'pid': pid
            }
            converter = ConvertSegToBoundingBoxCoordinates(
                dim=3,
                get_rois_from_seg_flag=False,
                class_specific_seg_flag=False)  #default false
            batch_3D = converter(**batch_3D)
            batch_3D.update({
                'patient_bb_target': batch_3D['bb_target'],
                'patient_roi_labels': batch_3D['roi_labels'],
                'original_img_shape': out_data.shape
            })

        out_batch = batch_3D if self.cf.dim == 3 else batch_2D
        patient_batch = out_batch

        # crop patient-volume to patches of patch_size used during training. stack patches up in batch dimension.
        # in this case, 2D is treated as a special case of 3D with patch_size[z] = 1.
        if np.any(
            [data.shape[dim + 1] > self.patch_size[dim] for dim in range(3)]):
            patch_crop_coords_list = dutils.get_patch_crop_coords_stride(
                data[0], self.patch_size, self.testing_patch_stride)
            new_img_batch, new_seg_batch, new_class_targets_batch = [], [], []

            for cix, c in enumerate(patch_crop_coords_list):

                seg_patch = seg[c[0]:c[1], c[2]:c[3], c[4]:c[5]]
                new_seg_batch.append(seg_patch)

                # if set to not None, add neighbouring slices to each selected slice in channel dimension.
                # correct patch_crop coordinates by added slices of 3D context.
                if self.cf.dim == 2 and self.cf.n_3D_context is not None:
                    tmp_c_5 = c[5] + (self.cf.n_3D_context * 2)
                    if cix == 0:
                        data = np.pad(
                            data,
                            ((0, 0), (0, 0), (0, 0),
                             (self.cf.n_3D_context, self.cf.n_3D_context)),
                            'constant',
                            constant_values=0)
                else:
                    tmp_c_5 = c[5]

                new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3],
                                          c[4]:tmp_c_5])

            data = np.array(new_img_batch)  # (n_patches, c, x, y, z)
            seg = np.array(
                new_seg_batch)[:, np.newaxis]  # (n_patches, 1, x, y, z)
            batch_class_targets = np.repeat(batch_class_targets,
                                            len(patch_crop_coords_list),
                                            axis=0)

            patch_batch = {
                'data': data,
                'seg': seg,
                'class_target': batch_class_targets,
                'pid': pid
            }  #classtarget is len == cropsize
            patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list)
            patch_batch['patient_bb_target'] = patient_batch[
                'patient_bb_target']  #gt box
            patch_batch['patient_roi_labels'] = patient_batch[
                'patient_roi_labels']
            patch_batch['original_img_shape'] = patient_batch[
                'original_img_shape']

            converter = ConvertSegToBoundingBoxCoordinates(
                self.cf.dim,
                get_rois_from_seg_flag=False,
                class_specific_seg_flag=self.cf.class_specific_seg_flag)
            patch_batch = converter(**patch_batch)
            out_batch = patch_batch

        self.patient_ix += 1
        if self.patient_ix == len(self.dataset_pids):
            self.patient_ix = 0
        if out_batch['patient_roi_labels'][0][0] > 0:
            out_batch['patient_roi_labels'][0] = [1]

        return out_batch