Beispiel #1
0
    def generate_train_batch(self):
        # everything done in here is per batch
        # print statements in here get confusing due to multithreading

        batch_pids = self.get_batch_pids()

        batch_data, batch_segs, batch_patient_targets = [], [], []
        batch_roi_items = {name: [] for name in self.cf.roi_items}
        # record roi count and empty count of classes in batch
        # empty count for no presence of resp. class in whole sample (empty slices in 2D/patients in 3D)
        batch_roi_counts = np.zeros((len(self.unique_ts),), dtype='uint32')
        batch_empty_counts = np.zeros((len(self.unique_ts),), dtype='uint32')

        for b in range(len(batch_pids)):
            patient = self._data[batch_pids[b]]

            all_data = np.load(patient['data'], mmap_mode='r')
            data = all_data[0].astype('float16')[np.newaxis]
            seg = all_data[1].astype('uint8')

            spatial_shp = data[0].shape
            assert spatial_shp == seg.shape, "spatial shape incongruence betw. data and seg"
            if np.any([spatial_shp[ix] < self.cf.pre_crop_size[ix] for ix in range(len(spatial_shp))]):
                new_shape = [np.max([spatial_shp[ix], self.cf.pre_crop_size[ix]]) for ix in range(len(spatial_shp))]
                data = dutils.pad_nd_image(data, (len(data), *new_shape))
                seg = dutils.pad_nd_image(seg, new_shape)

            batch_data.append(data)
            batch_segs.append(seg[np.newaxis])

            for o in batch_roi_items: #after loop, holds every entry of every batchpatient per observable
                    batch_roi_items[o].append(patient[o])

            for tix in range(len(self.unique_ts)):
                non_zero = np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix])
                batch_roi_counts[tix] += non_zero
                batch_empty_counts[tix] += int(non_zero == 0)
                # todo remove assert when checked
                if not np.any(seg):
                    assert non_zero==0

        batch = {'data': np.array(batch_data), 'seg': np.array(batch_segs).astype('uint8'),
                 'pid': batch_pids,
                 'roi_counts': batch_roi_counts, 'empty_counts': batch_empty_counts}
        for key,val in batch_roi_items.items(): #extend batch dic by entries of observables dic
            batch[key] = np.array(val)

        return batch
Beispiel #2
0
    def generate_train_batch(self):
        pid = self.dataset_pids[self.patient_ix]
        patient = self._data[pid]
        # data shape: from (c, z, y, x) to (c, y, x, z).
        data = np.transpose(np.load(patient['data'], mmap_mode='r'),
                            axes=(3, 1, 2, 0)).copy()
        seg = np.transpose(np.load(patient['seg'], mmap_mode='r'),
                           axes=(3, 1, 2, 0))[0].copy()
        batch_class_targets = np.array([patient['class_target']])

        # pad data if smaller than patch_size seen during training.
        if np.any([
                data.shape[dim + 1] < ps
                for dim, ps in enumerate(self.patch_size)
        ]):
            new_shape = [data.shape[0]] + [
                np.max([data.shape[dim + 1], self.patch_size[dim]])
                for dim, ps in enumerate(self.patch_size)
            ]
            data = dutils.pad_nd_image(
                data, new_shape
            )  # use 'return_slicer' to crop image back to original shape.
            seg = dutils.pad_nd_image(seg, new_shape)

        # get 3D targets for evaluation, even if network operates in 2D. 2D predictions will be merged to 3D in predictor.
        if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds:
            out_data = data[np.newaxis]
            out_seg = seg[np.newaxis, np.newaxis]
            out_targets = batch_class_targets

            batch_3D = {
                'data': out_data,
                'seg': out_seg,
                'class_target': out_targets,
                'pid': pid
            }
            converter = ConvertSegToBoundingBoxCoordinates(
                dim=3,
                get_rois_from_seg_flag=False,
                class_specific_seg_flag=self.cf.class_specific_seg_flag)
            batch_3D = converter(**batch_3D)
            batch_3D.update({
                'patient_bb_target': batch_3D['bb_target'],
                'patient_roi_labels': batch_3D['class_target'],
                'original_img_shape': out_data.shape
            })

        if self.cf.dim == 2:
            out_data = np.transpose(data, axes=(3, 0, 1, 2))  # (z, c, y, x )
            out_seg = np.transpose(seg, axes=(2, 0, 1))[:, np.newaxis]
            out_targets = np.array(
                np.repeat(batch_class_targets, out_data.shape[0], axis=0))

            # if set to not None, add neighbouring slices to each selected slice in channel dimension.
            if self.cf.n_3D_context is not None:
                slice_range = range(self.cf.n_3D_context,
                                    out_data.shape[0] + self.cf.n_3D_context)
                out_data = np.pad(
                    out_data, ((self.cf.n_3D_context, self.cf.n_3D_context),
                               (0, 0), (0, 0), (0, 0)),
                    'constant',
                    constant_values=0)
                out_data = np.array([
                    np.concatenate([
                        out_data[ii]
                        for ii in range(slice_id -
                                        self.cf.n_3D_context, slice_id +
                                        self.cf.n_3D_context + 1)
                    ],
                                   axis=0) for slice_id in slice_range
                ])

            batch_2D = {
                'data': out_data,
                'seg': out_seg,
                'class_target': out_targets,
                'pid': pid
            }
            converter = ConvertSegToBoundingBoxCoordinates(
                dim=2,
                get_rois_from_seg_flag=False,
                class_specific_seg_flag=self.cf.class_specific_seg_flag)
            batch_2D = converter(**batch_2D)

            if self.cf.merge_2D_to_3D_preds:
                batch_2D.update({
                    'patient_bb_target':
                    batch_3D['patient_bb_target'],
                    'patient_roi_labels':
                    batch_3D['patient_roi_labels'],
                    'original_img_shape':
                    out_data.shape
                })
            else:
                batch_2D.update({
                    'patient_bb_target': batch_2D['bb_target'],
                    'patient_roi_labels': batch_2D['class_target'],
                    'original_img_shape': out_data.shape
                })

        out_batch = batch_3D if self.cf.dim == 3 else batch_2D
        patient_batch = out_batch

        # crop patient-volume to patches of patch_size used during training. stack patches up in batch dimension.
        # in this case, 2D is treated as a special case of 3D with patch_size[z] = 1.
        if np.any(
            [data.shape[dim + 1] > self.patch_size[dim] for dim in range(3)]):
            patch_crop_coords_list = dutils.get_patch_crop_coords(
                data[0], self.patch_size)
            new_img_batch, new_seg_batch, new_class_targets_batch = [], [], []

            for cix, c in enumerate(patch_crop_coords_list):

                seg_patch = seg[c[0]:c[1], c[2]:c[3], c[4]:c[5]]
                new_seg_batch.append(seg_patch)

                # if set to not None, add neighbouring slices to each selected slice in channel dimension.
                # correct patch_crop coordinates by added slices of 3D context.
                if self.cf.dim == 2 and self.cf.n_3D_context is not None:
                    tmp_c_5 = c[5] + (self.cf.n_3D_context * 2)
                    if cix == 0:
                        data = np.pad(
                            data,
                            ((0, 0), (0, 0), (0, 0),
                             (self.cf.n_3D_context, self.cf.n_3D_context)),
                            'constant',
                            constant_values=0)
                else:
                    tmp_c_5 = c[5]

                new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3],
                                          c[4]:tmp_c_5])

            data = np.array(new_img_batch)  # (n_patches, c, x, y, z)
            seg = np.array(
                new_seg_batch)[:, np.newaxis]  # (n_patches, 1, x, y, z)
            batch_class_targets = np.repeat(batch_class_targets,
                                            len(patch_crop_coords_list),
                                            axis=0)

            if self.cf.dim == 2:
                if self.cf.n_3D_context is not None:
                    data = np.transpose(data[:, 0], axes=(0, 3, 1, 2))
                else:
                    # all patches have z dimension 1 (slices). discard dimension
                    data = data[..., 0]
                seg = seg[..., 0]

            patch_batch = {
                'data': data,
                'seg': seg,
                'class_target': batch_class_targets,
                'pid': pid
            }
            patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list)
            patch_batch['patient_bb_target'] = patient_batch[
                'patient_bb_target']
            patch_batch['patient_roi_labels'] = patient_batch[
                'patient_roi_labels']
            patch_batch['original_img_shape'] = patient_batch[
                'original_img_shape']

            converter = ConvertSegToBoundingBoxCoordinates(
                self.cf.dim,
                get_rois_from_seg_flag=False,
                class_specific_seg_flag=self.cf.class_specific_seg_flag)
            patch_batch = converter(**patch_batch)
            out_batch = patch_batch

        self.patient_ix += 1
        if self.patient_ix == len(self.dataset_pids):
            self.patient_ix = 0

        out_batch['data'][:, self.cf.drop_channels_test, ] = 0.
        return out_batch
Beispiel #3
0
    def generate_train_batch(self):

        batch_data, batch_segs, batch_pids, batch_targets, batch_patient_labels = [], [], [], [], []
        class_targets_list = [
            v['class_target'] for (k, v) in self._data.items()
        ]

        #I am turning this off, because it is problematic with my class 20
        if False:  #self.cf.head_classes > 2:
            # samples patients towards equilibrium of foreground classes on a roi-level (after randomly sampling the ratio "batch_sample_slack).
            batch_ixs = dutils.get_class_balanced_patients(
                class_targets_list,
                self.batch_size,
                self.cf.head_classes - 1,
                slack_factor=self.cf.batch_sample_slack)
        else:
            batch_ixs = np.random.choice(len(class_targets_list),
                                         self.batch_size)

        patients = list(self._data.items())

        for b in batch_ixs:
            patient = patients[b][1]

            # data shape: from (c, z, y, x) to (c, y, x, z).
            data = np.transpose(np.load(patient['data'], mmap_mode='r'),
                                axes=(3, 1, 2, 0))
            seg = np.transpose(np.load(patient['seg'], mmap_mode='r'),
                               axes=(3, 1, 2, 0))
            batch_pids.append(patient['pid'])
            batch_targets.append(patient['class_target'])

            if self.cf.dim == 2:
                # draw random slice from patient while oversampling slices containing foreground objects with p_fg.
                if len(patient['fg_slices']) > 0:
                    fg_prob = self.p_fg / len(patient['fg_slices'])
                    bg_prob = (1 - self.p_fg) / (data.shape[3] -
                                                 len(patient['fg_slices']))
                    slices_prob = [
                        fg_prob if ix in patient['fg_slices'] else bg_prob
                        for ix in range(data.shape[3])
                    ]
                    slice_id = np.random.choice(data.shape[3], p=slices_prob)
                else:
                    slice_id = np.random.choice(data.shape[3])

                # if set to not None, add neighbouring slices to each selected slice in channel dimension.
                if self.cf.n_3D_context is not None:
                    padded_data = dutils.pad_nd_image(
                        data[0],
                        [(data.shape[-1] + (self.cf.n_3D_context * 2))],
                        mode='constant')
                    padded_slice_id = slice_id + self.cf.n_3D_context
                    data = (np.concatenate([
                        padded_data[..., ii][np.newaxis] for ii in range(
                            padded_slice_id -
                            self.cf.n_3D_context, padded_slice_id +
                            self.cf.n_3D_context + 1)
                    ],
                                           axis=0))
                else:
                    data = data[..., slice_id]
                seg = seg[..., slice_id]

            # pad data if smaller than pre_crop_size.
            if np.any([
                    data.shape[dim + 1] < ps
                    for dim, ps in enumerate(self.cf.pre_crop_size)
            ]):
                new_shape = [
                    np.max([data.shape[dim + 1], ps])
                    for dim, ps in enumerate(self.cf.pre_crop_size)
                ]
                data = dutils.pad_nd_image(data, new_shape, mode='constant')
                seg = dutils.pad_nd_image(seg, new_shape, mode='constant')

            # crop patches of size pre_crop_size, while sampling patches containing foreground with p_fg.
            crop_dims = [
                dim for dim, ps in enumerate(self.cf.pre_crop_size)
                if data.shape[dim + 1] > ps
            ]
            if len(crop_dims) > 0:
                fg_prob_sample = np.random.rand(1)
                # with p_fg: sample random pixel from random ROI and shift center by random value.
                if fg_prob_sample < self.p_fg and np.sum(seg) > 0:
                    seg_ixs = np.argwhere(
                        seg == np.random.choice(np.unique(seg)[1:], 1))
                    roi_anchor_pixel = seg_ixs[np.random.choice(
                        seg_ixs.shape[0], 1)][0]
                    assert seg[tuple(roi_anchor_pixel)] > 0
                    # sample the patch center coords. constrained by edges of images - pre_crop_size /2. And by
                    # distance to the desired ROI < patch_size /2.
                    # (here final patch size to account for center_crop after data augmentation).
                    sample_seg_center = {}
                    for ii in crop_dims:
                        low = np.max((self.cf.pre_crop_size[ii] // 2,
                                      roi_anchor_pixel[ii] -
                                      (self.cf.patch_size[ii] // 2 -
                                       self.crop_margin[ii])))
                        high = np.min((data.shape[ii + 1] -
                                       self.cf.pre_crop_size[ii] // 2,
                                       roi_anchor_pixel[ii] +
                                       (self.cf.patch_size[ii] // 2 -
                                        self.crop_margin[ii])))
                        # happens if lesion on the edge of the image. dont care about roi anymore,
                        # just make sure pre-crop is inside image.
                        if low >= high:
                            low = data.shape[ii + 1] // 2 - (
                                data.shape[ii + 1] // 2 -
                                self.cf.pre_crop_size[ii] // 2)
                            high = data.shape[ii + 1] // 2 + (
                                data.shape[ii + 1] // 2 -
                                self.cf.pre_crop_size[ii] // 2)
                        sample_seg_center[ii] = np.random.randint(low=low,
                                                                  high=high)

                else:
                    # not guaranteed to be empty. probability of emptiness depends on the data.
                    sample_seg_center = {
                        ii:
                        np.random.randint(low=self.cf.pre_crop_size[ii] // 2,
                                          high=data.shape[ii + 1] -
                                          self.cf.pre_crop_size[ii] // 2)
                        for ii in crop_dims
                    }

                for ii in crop_dims:
                    min_crop = int(sample_seg_center[ii] -
                                   self.cf.pre_crop_size[ii] // 2)
                    max_crop = int(sample_seg_center[ii] +
                                   self.cf.pre_crop_size[ii] // 2)
                    data = np.take(data,
                                   indices=range(min_crop, max_crop),
                                   axis=ii + 1)
                    seg = np.take(seg,
                                  indices=range(min_crop, max_crop),
                                  axis=ii)

            batch_data.append(data)
            batch_segs.append(seg)

        data = np.array(batch_data)
        seg = np.array(batch_segs).astype(np.uint8)
        class_target = np.array(batch_targets)
        return {
            'data': data,
            'seg': seg,
            'pid': batch_pids,
            'class_target': class_target
        }
Beispiel #4
0
    def generate_train_batch(self):

        pid = self.dataset_pids[self.patient_ix]
        patient = self._data[pid]
        data = np.transpose(np.load(patient['data'], mmap_mode='r'),
                            axes=(1, 2, 0))[np.newaxis]  # (c, y, x, z)
        seg = np.transpose(np.load(patient['seg'], mmap_mode='r'),
                           axes=(1, 2, 0))
        print('patient', patient)
        print('data', data.shape)
        batch_class_targets = np.array([patient['class_target']])

        # pad data if smaller than patch_size seen during training.
        if np.any([
                data.shape[dim + 1] < ps
                for dim, ps in enumerate(self.patch_size)
        ]):
            new_shape = [data.shape[0]] + [
                np.max([data.shape[dim + 1], self.patch_size[dim]])
                for dim, ps in enumerate(self.patch_size)
            ]
            data = dutils.pad_nd_image(
                data, new_shape
            )  # use 'return_slicer' to crop image back to original shape.
            if len(new_shape) == 4:
                new_shape = new_shape[1:]
            seg = dutils.pad_nd_image(seg, new_shape)
        # get 3D targets for evaluation, even if network operates in 2D. 2D predictions will be merged to 3D in predictor.
        if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds:  #default True
            out_data = data[np.newaxis]
            out_seg = seg[np.newaxis, np.newaxis]
            out_targets = batch_class_targets

            batch_3D = {
                'data': out_data,
                'seg': out_seg,
                'class_target': out_targets,
                'pid': pid
            }
            converter = ConvertSegToBoundingBoxCoordinates(
                dim=3,
                get_rois_from_seg_flag=False,
                class_specific_seg_flag=False)  #default false
            batch_3D = converter(**batch_3D)
            batch_3D.update({
                'patient_bb_target': batch_3D['bb_target'],
                'patient_roi_labels': batch_3D['roi_labels'],
                'original_img_shape': out_data.shape
            })

        out_batch = batch_3D if self.cf.dim == 3 else batch_2D
        patient_batch = out_batch

        # crop patient-volume to patches of patch_size used during training. stack patches up in batch dimension.
        # in this case, 2D is treated as a special case of 3D with patch_size[z] = 1.
        if np.any(
            [data.shape[dim + 1] > self.patch_size[dim] for dim in range(3)]):
            patch_crop_coords_list = dutils.get_patch_crop_coords_stride(
                data[0], self.patch_size, self.testing_patch_stride)
            new_img_batch, new_seg_batch, new_class_targets_batch = [], [], []

            for cix, c in enumerate(patch_crop_coords_list):

                seg_patch = seg[c[0]:c[1], c[2]:c[3], c[4]:c[5]]
                new_seg_batch.append(seg_patch)

                # if set to not None, add neighbouring slices to each selected slice in channel dimension.
                # correct patch_crop coordinates by added slices of 3D context.
                if self.cf.dim == 2 and self.cf.n_3D_context is not None:
                    tmp_c_5 = c[5] + (self.cf.n_3D_context * 2)
                    if cix == 0:
                        data = np.pad(
                            data,
                            ((0, 0), (0, 0), (0, 0),
                             (self.cf.n_3D_context, self.cf.n_3D_context)),
                            'constant',
                            constant_values=0)
                else:
                    tmp_c_5 = c[5]

                new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3],
                                          c[4]:tmp_c_5])

            data = np.array(new_img_batch)  # (n_patches, c, x, y, z)
            seg = np.array(
                new_seg_batch)[:, np.newaxis]  # (n_patches, 1, x, y, z)
            batch_class_targets = np.repeat(batch_class_targets,
                                            len(patch_crop_coords_list),
                                            axis=0)

            patch_batch = {
                'data': data,
                'seg': seg,
                'class_target': batch_class_targets,
                'pid': pid
            }  #classtarget is len == cropsize
            patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list)
            patch_batch['patient_bb_target'] = patient_batch[
                'patient_bb_target']  #gt box
            patch_batch['patient_roi_labels'] = patient_batch[
                'patient_roi_labels']
            patch_batch['original_img_shape'] = patient_batch[
                'original_img_shape']

            converter = ConvertSegToBoundingBoxCoordinates(
                self.cf.dim,
                get_rois_from_seg_flag=False,
                class_specific_seg_flag=self.cf.class_specific_seg_flag)
            patch_batch = converter(**patch_batch)
            out_batch = patch_batch

        self.patient_ix += 1
        if self.patient_ix == len(self.dataset_pids):
            self.patient_ix = 0
        if out_batch['patient_roi_labels'][0][0] > 0:
            out_batch['patient_roi_labels'][0] = [1]

        return out_batch
Beispiel #5
0
    def generate_train_batch(self):
        #print(' --- start generate train batch ---')

        batch_data, batch_segs, batch_pids, batch_targets, batch_patient_labels = [], [], [], [], []
        class_targets_list = [
            v['class_target'] for (k, v) in self._data.items()
        ]
        #print('class_targets_list',np.array(class_targets_list))
        #print('head_classes',self.cf.head_classes)
        if self.cf.head_classes > 2:
            # samples patients towards equilibrium of foreground classes on a roi-level (after randomly sampling the ratio "batch_sample_slack).
            batch_ixs = dutils.get_class_balanced_patients(
                class_targets_list,
                self.batch_size,
                self.cf.head_classes - 1,
                slack_factor=self.cf.batch_sample_slack)  #0.2
        else:
            batch_ixs = np.random.choice(len(class_targets_list),
                                         self.batch_size)

        #print('batch_idx in generator: ', batch_ids)
        patients = list(self._data.items())
        #print('len(patients): ', len(patients))

        for b in batch_ixs:
            patient = patients[b][1]
            data = np.transpose(np.load(patient['data'], mmap_mode='r'),
                                axes=(1, 2, 0))[np.newaxis]  # (c, y, x, z)
            seg = np.transpose(np.load(patient['seg'], mmap_mode='r'),
                               axes=(1, 2, 0))
            batch_pids.append(patient['pid'])
            batch_targets.append(patient['class_target'])

            # pad data if smaller than pre_crop_size.
            if np.any([
                    data.shape[dim + 1] < ps
                    for dim, ps in enumerate(self.cf.pre_crop_size)
            ]):
                #print(patient['pid'])
                new_shape = [
                    np.max([data.shape[dim + 1], ps])
                    for dim, ps in enumerate(self.cf.pre_crop_size)
                ]
                #print('new_shape',new_shape)
                data = dutils.pad_nd_image(data, new_shape, mode='constant')
                seg = dutils.pad_nd_image(seg, new_shape, mode='constant')

            # crop patches of size pre_crop_size, while sampling patches containing foreground with p_fg.
            crop_dims = [
                dim for dim, ps in enumerate(self.cf.pre_crop_size)
                if data.shape[dim + 1] > ps
            ]
            if len(crop_dims) > 0:
                fg_prob_sample = np.random.rand(1)
                # with p_fg(0.5): sample random pixel from random ROI and shift center by random value.
                if fg_prob_sample < self.p_fg and np.sum(seg) > 0:
                    #_ = np.unique(seg)[1:]
                    #print('unique seg',_)
                    seg_ixs = np.argwhere(seg == np.random.choice(
                        np.unique(seg)[1:], 1))  #location of segmap == 1
                    roi_anchor_pixel = seg_ixs[np.random.choice(
                        seg_ixs.shape[0], 1)][0]
                    assert seg[tuple(roi_anchor_pixel)] > 0
                    # sample the patch center coords. constrained by edges of images - pre_crop_size /2. And by
                    # distance to the desired ROI < patch_size /2.
                    # (here final patch size to account for center_crop after data augmentation).
                    sample_seg_center = {}
                    for ii in crop_dims:
                        low = np.max((self.cf.pre_crop_size[ii] // 2,
                                      roi_anchor_pixel[ii] -
                                      (self.cf.patch_size[ii] // 2 -
                                       self.crop_margin[ii])))
                        high = np.min((data.shape[ii + 1] -
                                       self.cf.pre_crop_size[ii] // 2,
                                       roi_anchor_pixel[ii] +
                                       (self.cf.patch_size[ii] // 2 -
                                        self.crop_margin[ii])))
                        # happens if lesion on the edge of the image. dont care about roi anymore,
                        # just make sure pre-crop is inside image.
                        if low >= high:
                            low = data.shape[ii + 1] // 2 - (
                                data.shape[ii + 1] // 2 -
                                self.cf.pre_crop_size[ii] // 2)
                            high = data.shape[ii + 1] // 2 + (
                                data.shape[ii + 1] // 2 -
                                self.cf.pre_crop_size[ii] // 2)
                        sample_seg_center[ii] = np.random.randint(low=low,
                                                                  high=high)

                else:
                    # not guaranteed to be empty. probability of emptiness depends on the data.
                    sample_seg_center = {
                        ii:
                        np.random.randint(low=self.cf.pre_crop_size[ii] // 2,
                                          high=data.shape[ii + 1] -
                                          self.cf.pre_crop_size[ii] // 2)
                        for ii in crop_dims
                    }

                for ii in crop_dims:
                    min_crop = int(sample_seg_center[ii] -
                                   self.cf.pre_crop_size[ii] // 2)
                    max_crop = int(sample_seg_center[ii] +
                                   self.cf.pre_crop_size[ii] // 2)
                    data = np.take(data,
                                   indices=range(min_crop, max_crop),
                                   axis=ii + 1)
                    seg = np.take(seg,
                                  indices=range(min_crop, max_crop),
                                  axis=ii)

            batch_data.append(data)
            batch_segs.append(seg[np.newaxis])
        data = np.array(batch_data)
        seg = np.array(batch_segs).astype(np.uint8)
        class_target = np.array(batch_targets)
        return {
            'data': data,
            'seg': seg,
            'pid': batch_pids,
            'class_target': class_target
        }
Beispiel #6
0
    def generate_train_batch(self, pid=None):

        if pid is None:
            pid = self.dataset_pids[self.patient_ix]
        patient = self._data[pid]

        # already swapped dimensions in pp from (c,)z,y,x to c,y,x,z or h,w,d to ease 2D/3D-case handling
        data = np.load(patient['data'], mmap_mode='r').astype('float16')[np.newaxis]
        seg =  np.load(patient[self.gt_prefix+'seg']).astype('uint8')[np.newaxis]

        data_shp_raw = data.shape
        plot_bg = data[self.cf.plot_bg_chan] if self.cf.plot_bg_chan not in self.chans else None
        data = data[self.chans]
        discarded_chans = len(
            [c for c in np.setdiff1d(np.arange(data_shp_raw[0]), self.chans) if c < self.cf.plot_bg_chan])
        spatial_shp = data[0].shape  # spatial dims need to be in order x,y,z
        assert spatial_shp == seg[0].shape, "spatial shape incongruence betw. data and seg"

        if np.any([spatial_shp[i] < ps for i, ps in enumerate(self.patch_size)]):
            new_shape = [np.max([spatial_shp[i], self.patch_size[i]]) for i in range(len(self.patch_size))]
            data = dutils.pad_nd_image(data, new_shape)  # use 'return_slicer' to crop image back to original shape.
            seg = dutils.pad_nd_image(seg, new_shape)
            if plot_bg is not None:
                plot_bg = dutils.pad_nd_image(plot_bg, new_shape)

        if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds:
            # adds the batch dim here bc won't go through MTaugmenter
            out_data = data[np.newaxis]
            out_seg = seg[np.newaxis]
            if plot_bg is not None:
               out_plot_bg = plot_bg[np.newaxis]
            # data and seg shape: (1,c,x,y,z), where c=1 for seg

            batch_3D = {'data': out_data, 'seg': out_seg}
            for o in self.cf.roi_items:
                batch_3D[o] = np.array([patient[self.gt_prefix+o]])
            converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg)
            batch_3D = converter(**batch_3D)
            batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape})
            for o in self.cf.roi_items:
                batch_3D["patient_" + o] = batch_3D[o]

        if self.cf.dim == 2:
            out_data = np.transpose(data, axes=(3, 0, 1, 2)).astype('float32')  # (c,y,x,z) to (b=z,c,x,y), use z=b as batchdim
            out_seg = np.transpose(seg, axes=(3, 0, 1, 2)).astype('uint8')  # (c,y,x,z) to (b=z,c,x,y)

            batch_2D = {'data': out_data, 'seg': out_seg}
            for o in self.cf.roi_items:
                batch_2D[o] = np.repeat(np.array([patient[self.gt_prefix+o]]), len(out_data), axis=0)
            converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg)
            batch_2D = converter(**batch_2D)

            if plot_bg is not None:
                out_plot_bg = np.transpose(plot_bg, axes=(2, 0, 1)).astype('float32')

            if self.cf.merge_2D_to_3D_preds:
                batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'],
                                 'original_img_shape': out_data.shape})
                for o in self.cf.roi_items:
                    batch_2D["patient_" + o] = batch_3D[o]
            else:
                batch_2D.update({'patient_bb_target': batch_2D['bb_target'],
                                 'original_img_shape': out_data.shape})
                for o in self.cf.roi_items:
                    batch_2D["patient_" + o] = batch_2D[o]

        out_batch = batch_3D if self.cf.dim == 3 else batch_2D
        out_batch.update({'pid': np.array([patient['pid']] * len(out_data))})

        if self.cf.plot_bg_chan in self.chans and discarded_chans > 0:  # len(self.chans[:self.cf.plot_bg_chan])<data_shp_raw[0]:
            assert plot_bg is None
            plot_bg = int(self.cf.plot_bg_chan - discarded_chans)
            out_plot_bg = plot_bg
        if plot_bg is not None:
            out_batch['plot_bg'] = out_plot_bg

        # eventual tiling into patches
        spatial_shp = out_batch["data"].shape[2:]
        if np.any([spatial_shp[ix] > self.patch_size[ix] for ix in range(len(spatial_shp))]):
            patient_batch = out_batch
            print("patientiterator produced patched batch!")
            patch_crop_coords_list = dutils.get_patch_crop_coords(data[0], self.patch_size)
            new_img_batch, new_seg_batch = [], []

            for c in patch_crop_coords_list:
                new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3], c[4]:c[5]])
                seg_patch = seg[:, c[0]:c[1], c[2]: c[3], c[4]:c[5]]
                new_seg_batch.append(seg_patch)
            shps = []
            for arr in new_img_batch:
                shps.append(arr.shape)

            data = np.array(new_img_batch)  # (patches, c, x, y, z)
            seg = np.array(new_seg_batch)
            if self.cf.dim == 2:
                # all patches have z dimension 1 (slices). discard dimension
                data = data[..., 0]
                seg = seg[..., 0]
            patch_batch = {'data': data.astype('float32'), 'seg': seg.astype('uint8'),
                           'pid': np.array([patient['pid']] * data.shape[0])}
            for o in self.cf.roi_items:
                patch_batch[o] = np.repeat(np.array([patient[self.gt_prefix+o]]), len(patch_crop_coords_list), axis=0)
            #patient-wise (orig) batch info for putting the patches back together after prediction
            for o in self.cf.roi_items:
                patch_batch["patient_"+o] = patient_batch["patient_"+o]
                if self.cf.dim == 2:
                    # this could also be named "unpatched_2d_roi_items"
                    patch_batch["patient_" + o + "_2d"] = patient_batch[o]
            patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list)
            patch_batch['patient_bb_target'] = patient_batch['patient_bb_target']
            if self.cf.dim == 2:
                patch_batch['patient_bb_target_2d'] = patient_batch['bb_target']
            patch_batch['patient_data'] = patient_batch['data']
            patch_batch['patient_seg'] = patient_batch['seg']
            patch_batch['original_img_shape'] = patient_batch['original_img_shape']
            if plot_bg is not None:
                patch_batch['patient_plot_bg'] = patient_batch['plot_bg']

            converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, get_rois_from_seg=False,
                                                           class_specific_seg=self.cf.class_specific_seg)

            patch_batch = converter(**patch_batch)
            out_batch = patch_batch

        self.patient_ix += 1
        if self.patient_ix == len(self.dataset_pids):
            self.patient_ix = 0

        return out_batch
Beispiel #7
0
    def generate_train_batch(self):
        # everything done in here is per batch
        # print statements in here get confusing due to multithreading

        batch_pids = self.get_batch_pids()

        batch_data, batch_segs, batch_patient_targets = [], [], []
        batch_roi_items = {name: [] for name in self.cf.roi_items}
        # record roi count and empty count of classes in batch
        # empty count for no presence of resp. class in whole sample (empty slices in 2D/patients in 3D)
        batch_roi_counts = np.zeros((len(self.unique_ts),), dtype='uint32')
        batch_empty_counts = np.zeros((len(self.unique_ts),), dtype='uint32')

        for b in range(len(batch_pids)):
            patient = self._data[batch_pids[b]]

            data = np.load(patient['data'], mmap_mode='r').astype('float16')[np.newaxis]
            seg =  np.load(patient['seg'], mmap_mode='r').astype('uint8')

            (c, y, x, z) = data.shape
            if self.cf.dim == 2:
                elig_slices, choose_fg = [], False
                if len(patient['fg_slices']) > 0:
                    if np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio) or np.random.rand(
                            1) <= self.p_fg:
                        # fg is to be picked
                        for tix in np.argsort(batch_roi_counts):
                            # pick slices of patient that have roi of sought-for target
                            # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix
                            elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero(
                                patient[self.balance_target][np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0]) - 1] ==
                                self.unique_ts[tix]) > 0]
                            if len(elig_slices) > 0:
                                choose_fg = True
                                break
                    else:
                        # pick bg
                        elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices'])
                if len(elig_slices) > 0:
                    sl_pick_ix = np.random.choice(elig_slices, size=None)
                else:
                    sl_pick_ix = np.random.choice(z, size=None)
                data = data[..., sl_pick_ix]
                seg = seg[..., sl_pick_ix]

            spatial_shp = data[0].shape
            assert spatial_shp == seg.shape, "spatial shape incongruence betw. data and seg"
            if np.any([spatial_shp[ix] < self.cf.pre_crop_size[ix] for ix in range(len(spatial_shp))]):
                new_shape = [np.max([spatial_shp[ix], self.cf.pre_crop_size[ix]]) for ix in range(len(spatial_shp))]
                data = dutils.pad_nd_image(data, (len(data), *new_shape))
                seg = dutils.pad_nd_image(seg, new_shape)

            # eventual cropping to pre_crop_size: sample pixel from random ROI and shift center,
            # if possible, to that pixel, so that img still contains ROI after pre-cropping
            dim_cropflags = [spatial_shp[i] > self.cf.pre_crop_size[i] for i in range(len(spatial_shp))]
            if np.any(dim_cropflags):
                # sample pixel from random ROI and shift center, if possible, to that pixel
                if self.cf.dim==3:
                    choose_fg = np.any(batch_empty_counts/self.batch_size>=self.empty_samples_max_ratio) or \
                                np.random.rand(1) <= self.p_fg
                if choose_fg and np.any(seg):
                    available_roi_ids = np.unique(seg)[1:]
                    for tix in np.argsort(batch_roi_counts):
                        elig_roi_ids = available_roi_ids[patient[self.balance_target][available_roi_ids-1] == self.unique_ts[tix]]
                        if len(elig_roi_ids)>0:
                            seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None))
                            break
                    roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)]
                    assert seg[tuple(roi_anchor_pixel)] > 0

                    # sample the patch center coords. constrained by edges of image - pre_crop_size /2 and
                    # distance to the selected ROI < patch_size /2
                    def get_cropped_centercoords(dim):
                        low = np.max((self.cf.pre_crop_size[dim] // 2,
                                      roi_anchor_pixel[dim] - (
                                                  self.cf.patch_size[dim] // 2 - self.cf.crop_margin[dim])))
                        high = np.min((spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2,
                                       roi_anchor_pixel[dim] + (
                                                   self.cf.patch_size[dim] // 2 - self.cf.crop_margin[dim])))
                        if low >= high:  # happens if lesion on the edge of the image.
                            low = self.cf.pre_crop_size[dim] // 2
                            high = spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2

                        assert low < high, 'low greater equal high, data dimension {} too small, shp {}, patient {}, low {}, high {}'.format(
                            dim,
                            spatial_shp, patient['pid'], low, high)
                        return np.random.randint(low=low, high=high)
                else:
                    # sample crop center regardless of ROIs, not guaranteed to be empty
                    def get_cropped_centercoords(dim):
                        return np.random.randint(low=self.cf.pre_crop_size[dim] // 2,
                                                 high=spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2)

                sample_seg_center = {}
                for dim in np.where(dim_cropflags)[0]:
                    sample_seg_center[dim] = get_cropped_centercoords(dim)
                    min_ = int(sample_seg_center[dim] - self.cf.pre_crop_size[dim] // 2)
                    max_ = int(sample_seg_center[dim] + self.cf.pre_crop_size[dim] // 2)
                    data = np.take(data, indices=range(min_, max_), axis=dim + 1)  # +1 for channeldim
                    seg = np.take(seg, indices=range(min_, max_), axis=dim)

            batch_data.append(data)
            batch_segs.append(seg[np.newaxis])

            for o in batch_roi_items: #after loop, holds every entry of every batchpatient per observable
                    batch_roi_items[o].append(patient[o])

            if self.cf.dim == 3:
                for tix in range(len(self.unique_ts)):
                    non_zero = np.count_nonzero(patient[self.balance_target] == self.unique_ts[tix])
                    batch_roi_counts[tix] += non_zero
                    batch_empty_counts[tix] += int(non_zero==0)
                    # todo remove assert when checked
                    if not np.any(seg):
                        assert non_zero==0
            elif self.cf.dim == 2:
                for tix in range(len(self.unique_ts)):
                    non_zero = np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix])
                    batch_roi_counts[tix] += non_zero
                    batch_empty_counts[tix] += int(non_zero == 0)
                    # todo remove assert when checked
                    if not np.any(seg):
                        assert non_zero==0

        batch = {'data': np.array(batch_data), 'seg': np.array(batch_segs).astype('uint8'),
                 'pid': batch_pids,
                 'roi_counts': batch_roi_counts, 'empty_counts': batch_empty_counts}
        for key,val in batch_roi_items.items(): #extend batch dic by entries of observables dic
            batch[key] = np.array(val)

        return batch