Ejemplo n.º 1
0
    def generate_train_batch(self):
        pid = self.dataset_pids[self.patient_ix]
        patient = self._data[pid]
        # data shape: from (c, z, y, x) to (c, y, x, z).
        data = np.transpose(np.load(patient['data'], mmap_mode='r'),
                            axes=(3, 1, 2, 0)).copy()
        seg = np.transpose(np.load(patient['seg'], mmap_mode='r'),
                           axes=(3, 1, 2, 0))[0].copy()
        batch_class_targets = np.array([patient['class_target']])

        # pad data if smaller than patch_size seen during training.
        if np.any([
                data.shape[dim + 1] < ps
                for dim, ps in enumerate(self.patch_size)
        ]):
            new_shape = [data.shape[0]] + [
                np.max([data.shape[dim + 1], self.patch_size[dim]])
                for dim, ps in enumerate(self.patch_size)
            ]
            data = dutils.pad_nd_image(
                data, new_shape
            )  # use 'return_slicer' to crop image back to original shape.
            seg = dutils.pad_nd_image(seg, new_shape)

        # get 3D targets for evaluation, even if network operates in 2D. 2D predictions will be merged to 3D in predictor.
        if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds:
            out_data = data[np.newaxis]
            out_seg = seg[np.newaxis, np.newaxis]
            out_targets = batch_class_targets

            batch_3D = {
                'data': out_data,
                'seg': out_seg,
                'class_target': out_targets,
                'pid': pid
            }
            converter = ConvertSegToBoundingBoxCoordinates(
                dim=3,
                get_rois_from_seg_flag=False,
                class_specific_seg_flag=self.cf.class_specific_seg_flag)
            batch_3D = converter(**batch_3D)
            batch_3D.update({
                'patient_bb_target': batch_3D['bb_target'],
                'patient_roi_labels': batch_3D['class_target'],
                'original_img_shape': out_data.shape
            })

        if self.cf.dim == 2:
            out_data = np.transpose(data, axes=(3, 0, 1, 2))  # (z, c, y, x )
            out_seg = np.transpose(seg, axes=(2, 0, 1))[:, np.newaxis]
            out_targets = np.array(
                np.repeat(batch_class_targets, out_data.shape[0], axis=0))

            # if set to not None, add neighbouring slices to each selected slice in channel dimension.
            if self.cf.n_3D_context is not None:
                slice_range = range(self.cf.n_3D_context,
                                    out_data.shape[0] + self.cf.n_3D_context)
                out_data = np.pad(
                    out_data, ((self.cf.n_3D_context, self.cf.n_3D_context),
                               (0, 0), (0, 0), (0, 0)),
                    'constant',
                    constant_values=0)
                out_data = np.array([
                    np.concatenate([
                        out_data[ii]
                        for ii in range(slice_id -
                                        self.cf.n_3D_context, slice_id +
                                        self.cf.n_3D_context + 1)
                    ],
                                   axis=0) for slice_id in slice_range
                ])

            batch_2D = {
                'data': out_data,
                'seg': out_seg,
                'class_target': out_targets,
                'pid': pid
            }
            converter = ConvertSegToBoundingBoxCoordinates(
                dim=2,
                get_rois_from_seg_flag=False,
                class_specific_seg_flag=self.cf.class_specific_seg_flag)
            batch_2D = converter(**batch_2D)

            if self.cf.merge_2D_to_3D_preds:
                batch_2D.update({
                    'patient_bb_target':
                    batch_3D['patient_bb_target'],
                    'patient_roi_labels':
                    batch_3D['patient_roi_labels'],
                    'original_img_shape':
                    out_data.shape
                })
            else:
                batch_2D.update({
                    'patient_bb_target': batch_2D['bb_target'],
                    'patient_roi_labels': batch_2D['class_target'],
                    'original_img_shape': out_data.shape
                })

        out_batch = batch_3D if self.cf.dim == 3 else batch_2D
        patient_batch = out_batch

        # crop patient-volume to patches of patch_size used during training. stack patches up in batch dimension.
        # in this case, 2D is treated as a special case of 3D with patch_size[z] = 1.
        if np.any(
            [data.shape[dim + 1] > self.patch_size[dim] for dim in range(3)]):
            patch_crop_coords_list = dutils.get_patch_crop_coords(
                data[0], self.patch_size)
            new_img_batch, new_seg_batch, new_class_targets_batch = [], [], []

            for cix, c in enumerate(patch_crop_coords_list):

                seg_patch = seg[c[0]:c[1], c[2]:c[3], c[4]:c[5]]
                new_seg_batch.append(seg_patch)

                # if set to not None, add neighbouring slices to each selected slice in channel dimension.
                # correct patch_crop coordinates by added slices of 3D context.
                if self.cf.dim == 2 and self.cf.n_3D_context is not None:
                    tmp_c_5 = c[5] + (self.cf.n_3D_context * 2)
                    if cix == 0:
                        data = np.pad(
                            data,
                            ((0, 0), (0, 0), (0, 0),
                             (self.cf.n_3D_context, self.cf.n_3D_context)),
                            'constant',
                            constant_values=0)
                else:
                    tmp_c_5 = c[5]

                new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3],
                                          c[4]:tmp_c_5])

            data = np.array(new_img_batch)  # (n_patches, c, x, y, z)
            seg = np.array(
                new_seg_batch)[:, np.newaxis]  # (n_patches, 1, x, y, z)
            batch_class_targets = np.repeat(batch_class_targets,
                                            len(patch_crop_coords_list),
                                            axis=0)

            if self.cf.dim == 2:
                if self.cf.n_3D_context is not None:
                    data = np.transpose(data[:, 0], axes=(0, 3, 1, 2))
                else:
                    # all patches have z dimension 1 (slices). discard dimension
                    data = data[..., 0]
                seg = seg[..., 0]

            patch_batch = {
                'data': data,
                'seg': seg,
                'class_target': batch_class_targets,
                'pid': pid
            }
            patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list)
            patch_batch['patient_bb_target'] = patient_batch[
                'patient_bb_target']
            patch_batch['patient_roi_labels'] = patient_batch[
                'patient_roi_labels']
            patch_batch['original_img_shape'] = patient_batch[
                'original_img_shape']

            converter = ConvertSegToBoundingBoxCoordinates(
                self.cf.dim,
                get_rois_from_seg_flag=False,
                class_specific_seg_flag=self.cf.class_specific_seg_flag)
            patch_batch = converter(**patch_batch)
            out_batch = patch_batch

        self.patient_ix += 1
        if self.patient_ix == len(self.dataset_pids):
            self.patient_ix = 0

        out_batch['data'][:, self.cf.drop_channels_test, ] = 0.
        return out_batch
Ejemplo n.º 2
0
    def generate_train_batch(self, pid=None):

        if pid is None:
            pid = self.dataset_pids[self.patient_ix]
        patient = self._data[pid]

        # already swapped dimensions in pp from (c,)z,y,x to c,y,x,z or h,w,d to ease 2D/3D-case handling
        data = np.load(patient['data'], mmap_mode='r').astype('float16')[np.newaxis]
        seg =  np.load(patient[self.gt_prefix+'seg']).astype('uint8')[np.newaxis]

        data_shp_raw = data.shape
        plot_bg = data[self.cf.plot_bg_chan] if self.cf.plot_bg_chan not in self.chans else None
        data = data[self.chans]
        discarded_chans = len(
            [c for c in np.setdiff1d(np.arange(data_shp_raw[0]), self.chans) if c < self.cf.plot_bg_chan])
        spatial_shp = data[0].shape  # spatial dims need to be in order x,y,z
        assert spatial_shp == seg[0].shape, "spatial shape incongruence betw. data and seg"

        if np.any([spatial_shp[i] < ps for i, ps in enumerate(self.patch_size)]):
            new_shape = [np.max([spatial_shp[i], self.patch_size[i]]) for i in range(len(self.patch_size))]
            data = dutils.pad_nd_image(data, new_shape)  # use 'return_slicer' to crop image back to original shape.
            seg = dutils.pad_nd_image(seg, new_shape)
            if plot_bg is not None:
                plot_bg = dutils.pad_nd_image(plot_bg, new_shape)

        if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds:
            # adds the batch dim here bc won't go through MTaugmenter
            out_data = data[np.newaxis]
            out_seg = seg[np.newaxis]
            if plot_bg is not None:
               out_plot_bg = plot_bg[np.newaxis]
            # data and seg shape: (1,c,x,y,z), where c=1 for seg

            batch_3D = {'data': out_data, 'seg': out_seg}
            for o in self.cf.roi_items:
                batch_3D[o] = np.array([patient[self.gt_prefix+o]])
            converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg)
            batch_3D = converter(**batch_3D)
            batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape})
            for o in self.cf.roi_items:
                batch_3D["patient_" + o] = batch_3D[o]

        if self.cf.dim == 2:
            out_data = np.transpose(data, axes=(3, 0, 1, 2)).astype('float32')  # (c,y,x,z) to (b=z,c,x,y), use z=b as batchdim
            out_seg = np.transpose(seg, axes=(3, 0, 1, 2)).astype('uint8')  # (c,y,x,z) to (b=z,c,x,y)

            batch_2D = {'data': out_data, 'seg': out_seg}
            for o in self.cf.roi_items:
                batch_2D[o] = np.repeat(np.array([patient[self.gt_prefix+o]]), len(out_data), axis=0)
            converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg)
            batch_2D = converter(**batch_2D)

            if plot_bg is not None:
                out_plot_bg = np.transpose(plot_bg, axes=(2, 0, 1)).astype('float32')

            if self.cf.merge_2D_to_3D_preds:
                batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'],
                                 'original_img_shape': out_data.shape})
                for o in self.cf.roi_items:
                    batch_2D["patient_" + o] = batch_3D[o]
            else:
                batch_2D.update({'patient_bb_target': batch_2D['bb_target'],
                                 'original_img_shape': out_data.shape})
                for o in self.cf.roi_items:
                    batch_2D["patient_" + o] = batch_2D[o]

        out_batch = batch_3D if self.cf.dim == 3 else batch_2D
        out_batch.update({'pid': np.array([patient['pid']] * len(out_data))})

        if self.cf.plot_bg_chan in self.chans and discarded_chans > 0:  # len(self.chans[:self.cf.plot_bg_chan])<data_shp_raw[0]:
            assert plot_bg is None
            plot_bg = int(self.cf.plot_bg_chan - discarded_chans)
            out_plot_bg = plot_bg
        if plot_bg is not None:
            out_batch['plot_bg'] = out_plot_bg

        # eventual tiling into patches
        spatial_shp = out_batch["data"].shape[2:]
        if np.any([spatial_shp[ix] > self.patch_size[ix] for ix in range(len(spatial_shp))]):
            patient_batch = out_batch
            print("patientiterator produced patched batch!")
            patch_crop_coords_list = dutils.get_patch_crop_coords(data[0], self.patch_size)
            new_img_batch, new_seg_batch = [], []

            for c in patch_crop_coords_list:
                new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3], c[4]:c[5]])
                seg_patch = seg[:, c[0]:c[1], c[2]: c[3], c[4]:c[5]]
                new_seg_batch.append(seg_patch)
            shps = []
            for arr in new_img_batch:
                shps.append(arr.shape)

            data = np.array(new_img_batch)  # (patches, c, x, y, z)
            seg = np.array(new_seg_batch)
            if self.cf.dim == 2:
                # all patches have z dimension 1 (slices). discard dimension
                data = data[..., 0]
                seg = seg[..., 0]
            patch_batch = {'data': data.astype('float32'), 'seg': seg.astype('uint8'),
                           'pid': np.array([patient['pid']] * data.shape[0])}
            for o in self.cf.roi_items:
                patch_batch[o] = np.repeat(np.array([patient[self.gt_prefix+o]]), len(patch_crop_coords_list), axis=0)
            #patient-wise (orig) batch info for putting the patches back together after prediction
            for o in self.cf.roi_items:
                patch_batch["patient_"+o] = patient_batch["patient_"+o]
                if self.cf.dim == 2:
                    # this could also be named "unpatched_2d_roi_items"
                    patch_batch["patient_" + o + "_2d"] = patient_batch[o]
            patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list)
            patch_batch['patient_bb_target'] = patient_batch['patient_bb_target']
            if self.cf.dim == 2:
                patch_batch['patient_bb_target_2d'] = patient_batch['bb_target']
            patch_batch['patient_data'] = patient_batch['data']
            patch_batch['patient_seg'] = patient_batch['seg']
            patch_batch['original_img_shape'] = patient_batch['original_img_shape']
            if plot_bg is not None:
                patch_batch['patient_plot_bg'] = patient_batch['plot_bg']

            converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, get_rois_from_seg=False,
                                                           class_specific_seg=self.cf.class_specific_seg)

            patch_batch = converter(**patch_batch)
            out_batch = patch_batch

        self.patient_ix += 1
        if self.patient_ix == len(self.dataset_pids):
            self.patient_ix = 0

        return out_batch