def generate_train_batch(self):

        batch_data, batch_segs, batch_pids, batch_targets = [], [], [], []
        class_targets_list = [
            v['class_target'] for (k, v) in self._data.items()
        ]

        #samples patients towards equilibrium of foreground classes on a roi-level (after randomly sampling the ratio "batch_sample_slack).
        batch_ixs = dutils.get_class_balanced_patients(
            class_targets_list,
            self.batch_size,
            self.cf.head_classes - 1,
            slack_factor=self.cf.batch_sample_slack)
        patients = list(self._data.items())

        for b in batch_ixs:

            patient = patients[b][1]
            all_data = np.load(patient['data'], mmap_mode='r')
            data = all_data[0]
            seg = all_data[1].astype('uint8')
            batch_pids.append(patient['pid'])
            batch_targets.append(patient['class_target'])
            batch_data.append(data[np.newaxis])
            batch_segs.append(seg[np.newaxis])

        data = np.array(batch_data)
        seg = np.array(batch_segs).astype(np.uint8)
        class_target = np.array(batch_targets)
        return {
            'data': data,
            'seg': seg,
            'pid': batch_pids,
            'class_target': class_target
        }
Esempio n. 2
0
    def generate_train_batch(self):

        batch_data, batch_segs, batch_pids, batch_targets, batch_patient_labels = [], [], [], [], []
        class_targets_list = [
            v['class_target'] for (k, v) in self._data.items()
        ]

        #I am turning this off, because it is problematic with my class 20
        if False:  #self.cf.head_classes > 2:
            # samples patients towards equilibrium of foreground classes on a roi-level (after randomly sampling the ratio "batch_sample_slack).
            batch_ixs = dutils.get_class_balanced_patients(
                class_targets_list,
                self.batch_size,
                self.cf.head_classes - 1,
                slack_factor=self.cf.batch_sample_slack)
        else:
            batch_ixs = np.random.choice(len(class_targets_list),
                                         self.batch_size)

        patients = list(self._data.items())

        for b in batch_ixs:
            patient = patients[b][1]

            # data shape: from (c, z, y, x) to (c, y, x, z).
            data = np.transpose(np.load(patient['data'], mmap_mode='r'),
                                axes=(3, 1, 2, 0))
            seg = np.transpose(np.load(patient['seg'], mmap_mode='r'),
                               axes=(3, 1, 2, 0))
            batch_pids.append(patient['pid'])
            batch_targets.append(patient['class_target'])

            if self.cf.dim == 2:
                # draw random slice from patient while oversampling slices containing foreground objects with p_fg.
                if len(patient['fg_slices']) > 0:
                    fg_prob = self.p_fg / len(patient['fg_slices'])
                    bg_prob = (1 - self.p_fg) / (data.shape[3] -
                                                 len(patient['fg_slices']))
                    slices_prob = [
                        fg_prob if ix in patient['fg_slices'] else bg_prob
                        for ix in range(data.shape[3])
                    ]
                    slice_id = np.random.choice(data.shape[3], p=slices_prob)
                else:
                    slice_id = np.random.choice(data.shape[3])

                # if set to not None, add neighbouring slices to each selected slice in channel dimension.
                if self.cf.n_3D_context is not None:
                    padded_data = dutils.pad_nd_image(
                        data[0],
                        [(data.shape[-1] + (self.cf.n_3D_context * 2))],
                        mode='constant')
                    padded_slice_id = slice_id + self.cf.n_3D_context
                    data = (np.concatenate([
                        padded_data[..., ii][np.newaxis] for ii in range(
                            padded_slice_id -
                            self.cf.n_3D_context, padded_slice_id +
                            self.cf.n_3D_context + 1)
                    ],
                                           axis=0))
                else:
                    data = data[..., slice_id]
                seg = seg[..., slice_id]

            # pad data if smaller than pre_crop_size.
            if np.any([
                    data.shape[dim + 1] < ps
                    for dim, ps in enumerate(self.cf.pre_crop_size)
            ]):
                new_shape = [
                    np.max([data.shape[dim + 1], ps])
                    for dim, ps in enumerate(self.cf.pre_crop_size)
                ]
                data = dutils.pad_nd_image(data, new_shape, mode='constant')
                seg = dutils.pad_nd_image(seg, new_shape, mode='constant')

            # crop patches of size pre_crop_size, while sampling patches containing foreground with p_fg.
            crop_dims = [
                dim for dim, ps in enumerate(self.cf.pre_crop_size)
                if data.shape[dim + 1] > ps
            ]
            if len(crop_dims) > 0:
                fg_prob_sample = np.random.rand(1)
                # with p_fg: sample random pixel from random ROI and shift center by random value.
                if fg_prob_sample < self.p_fg and np.sum(seg) > 0:
                    seg_ixs = np.argwhere(
                        seg == np.random.choice(np.unique(seg)[1:], 1))
                    roi_anchor_pixel = seg_ixs[np.random.choice(
                        seg_ixs.shape[0], 1)][0]
                    assert seg[tuple(roi_anchor_pixel)] > 0
                    # sample the patch center coords. constrained by edges of images - pre_crop_size /2. And by
                    # distance to the desired ROI < patch_size /2.
                    # (here final patch size to account for center_crop after data augmentation).
                    sample_seg_center = {}
                    for ii in crop_dims:
                        low = np.max((self.cf.pre_crop_size[ii] // 2,
                                      roi_anchor_pixel[ii] -
                                      (self.cf.patch_size[ii] // 2 -
                                       self.crop_margin[ii])))
                        high = np.min((data.shape[ii + 1] -
                                       self.cf.pre_crop_size[ii] // 2,
                                       roi_anchor_pixel[ii] +
                                       (self.cf.patch_size[ii] // 2 -
                                        self.crop_margin[ii])))
                        # happens if lesion on the edge of the image. dont care about roi anymore,
                        # just make sure pre-crop is inside image.
                        if low >= high:
                            low = data.shape[ii + 1] // 2 - (
                                data.shape[ii + 1] // 2 -
                                self.cf.pre_crop_size[ii] // 2)
                            high = data.shape[ii + 1] // 2 + (
                                data.shape[ii + 1] // 2 -
                                self.cf.pre_crop_size[ii] // 2)
                        sample_seg_center[ii] = np.random.randint(low=low,
                                                                  high=high)

                else:
                    # not guaranteed to be empty. probability of emptiness depends on the data.
                    sample_seg_center = {
                        ii:
                        np.random.randint(low=self.cf.pre_crop_size[ii] // 2,
                                          high=data.shape[ii + 1] -
                                          self.cf.pre_crop_size[ii] // 2)
                        for ii in crop_dims
                    }

                for ii in crop_dims:
                    min_crop = int(sample_seg_center[ii] -
                                   self.cf.pre_crop_size[ii] // 2)
                    max_crop = int(sample_seg_center[ii] +
                                   self.cf.pre_crop_size[ii] // 2)
                    data = np.take(data,
                                   indices=range(min_crop, max_crop),
                                   axis=ii + 1)
                    seg = np.take(seg,
                                  indices=range(min_crop, max_crop),
                                  axis=ii)

            batch_data.append(data)
            batch_segs.append(seg)

        data = np.array(batch_data)
        seg = np.array(batch_segs).astype(np.uint8)
        class_target = np.array(batch_targets)
        return {
            'data': data,
            'seg': seg,
            'pid': batch_pids,
            'class_target': class_target
        }
Esempio n. 3
0
    def generate_train_batch(self):
        #print(' --- start generate train batch ---')

        batch_data, batch_segs, batch_pids, batch_targets, batch_patient_labels = [], [], [], [], []
        class_targets_list = [
            v['class_target'] for (k, v) in self._data.items()
        ]
        #print('class_targets_list',np.array(class_targets_list))
        #print('head_classes',self.cf.head_classes)
        if self.cf.head_classes > 2:
            # samples patients towards equilibrium of foreground classes on a roi-level (after randomly sampling the ratio "batch_sample_slack).
            batch_ixs = dutils.get_class_balanced_patients(
                class_targets_list,
                self.batch_size,
                self.cf.head_classes - 1,
                slack_factor=self.cf.batch_sample_slack)  #0.2
        else:
            batch_ixs = np.random.choice(len(class_targets_list),
                                         self.batch_size)

        #print('batch_idx in generator: ', batch_ids)
        patients = list(self._data.items())
        #print('len(patients): ', len(patients))

        for b in batch_ixs:
            patient = patients[b][1]
            data = np.transpose(np.load(patient['data'], mmap_mode='r'),
                                axes=(1, 2, 0))[np.newaxis]  # (c, y, x, z)
            seg = np.transpose(np.load(patient['seg'], mmap_mode='r'),
                               axes=(1, 2, 0))
            batch_pids.append(patient['pid'])
            batch_targets.append(patient['class_target'])

            # pad data if smaller than pre_crop_size.
            if np.any([
                    data.shape[dim + 1] < ps
                    for dim, ps in enumerate(self.cf.pre_crop_size)
            ]):
                #print(patient['pid'])
                new_shape = [
                    np.max([data.shape[dim + 1], ps])
                    for dim, ps in enumerate(self.cf.pre_crop_size)
                ]
                #print('new_shape',new_shape)
                data = dutils.pad_nd_image(data, new_shape, mode='constant')
                seg = dutils.pad_nd_image(seg, new_shape, mode='constant')

            # crop patches of size pre_crop_size, while sampling patches containing foreground with p_fg.
            crop_dims = [
                dim for dim, ps in enumerate(self.cf.pre_crop_size)
                if data.shape[dim + 1] > ps
            ]
            if len(crop_dims) > 0:
                fg_prob_sample = np.random.rand(1)
                # with p_fg(0.5): sample random pixel from random ROI and shift center by random value.
                if fg_prob_sample < self.p_fg and np.sum(seg) > 0:
                    #_ = np.unique(seg)[1:]
                    #print('unique seg',_)
                    seg_ixs = np.argwhere(seg == np.random.choice(
                        np.unique(seg)[1:], 1))  #location of segmap == 1
                    roi_anchor_pixel = seg_ixs[np.random.choice(
                        seg_ixs.shape[0], 1)][0]
                    assert seg[tuple(roi_anchor_pixel)] > 0
                    # sample the patch center coords. constrained by edges of images - pre_crop_size /2. And by
                    # distance to the desired ROI < patch_size /2.
                    # (here final patch size to account for center_crop after data augmentation).
                    sample_seg_center = {}
                    for ii in crop_dims:
                        low = np.max((self.cf.pre_crop_size[ii] // 2,
                                      roi_anchor_pixel[ii] -
                                      (self.cf.patch_size[ii] // 2 -
                                       self.crop_margin[ii])))
                        high = np.min((data.shape[ii + 1] -
                                       self.cf.pre_crop_size[ii] // 2,
                                       roi_anchor_pixel[ii] +
                                       (self.cf.patch_size[ii] // 2 -
                                        self.crop_margin[ii])))
                        # happens if lesion on the edge of the image. dont care about roi anymore,
                        # just make sure pre-crop is inside image.
                        if low >= high:
                            low = data.shape[ii + 1] // 2 - (
                                data.shape[ii + 1] // 2 -
                                self.cf.pre_crop_size[ii] // 2)
                            high = data.shape[ii + 1] // 2 + (
                                data.shape[ii + 1] // 2 -
                                self.cf.pre_crop_size[ii] // 2)
                        sample_seg_center[ii] = np.random.randint(low=low,
                                                                  high=high)

                else:
                    # not guaranteed to be empty. probability of emptiness depends on the data.
                    sample_seg_center = {
                        ii:
                        np.random.randint(low=self.cf.pre_crop_size[ii] // 2,
                                          high=data.shape[ii + 1] -
                                          self.cf.pre_crop_size[ii] // 2)
                        for ii in crop_dims
                    }

                for ii in crop_dims:
                    min_crop = int(sample_seg_center[ii] -
                                   self.cf.pre_crop_size[ii] // 2)
                    max_crop = int(sample_seg_center[ii] +
                                   self.cf.pre_crop_size[ii] // 2)
                    data = np.take(data,
                                   indices=range(min_crop, max_crop),
                                   axis=ii + 1)
                    seg = np.take(seg,
                                  indices=range(min_crop, max_crop),
                                  axis=ii)

            batch_data.append(data)
            batch_segs.append(seg[np.newaxis])
        data = np.array(batch_data)
        seg = np.array(batch_segs).astype(np.uint8)
        class_target = np.array(batch_targets)
        return {
            'data': data,
            'seg': seg,
            'pid': batch_pids,
            'class_target': class_target
        }