Пример #1
0
 def generate_train_batch(self):
     ids = np.random.choice(self._data.keys(), self.BATCH_SIZE)
     data = np.zeros((self.BATCH_SIZE, 5, self._patch_size[0],
                      self._patch_size[1], self._patch_size[2]),
                     dtype=np.float32)
     seg = np.zeros((self.BATCH_SIZE, 1, self._patch_size[0],
                     self._patch_size[1], self._patch_size[2]),
                    dtype=np.float32)
     types = []
     patient_names = []
     identifiers = []
     ages = []
     survivals = []
     for j, i in enumerate(ids):
         types.append(self._data[i]['type'])
         patient_names.append(self._data[i]['name'])
         identifiers.append(self._data[i]['idx'])
         # construct a batch, not very efficient
         data_all = self._data[i]['data'][None]
         if np.any(
                 np.array(data_all.shape[2:]) -
                 np.array(self._patch_size) < 0):
             new_shp = np.max(
                 np.vstack((np.array(data_all.shape[2:])[None],
                            np.array(self._patch_size)[None])), 0)
             data_all = resize_image_by_padding_batched(
                 data_all, new_shp, 0)
         data_all = random_crop_3D_image_batched(data_all, self._patch_size)
         data[j, :] = data_all[0, :5]
         if self.convert_labels:
             seg[j, 0] = convert_brats_seg(data_all[0, 5])
         else:
             seg[j, 0] = data_all[0, 5]
         if 'survival' in self._data[i].keys():
             survivals.append(self._data[i]['survival'])
         else:
             survivals.append(np.nan)
         if 'age' in self._data[i].keys():
             ages.append(self._data[i]['age'])
         else:
             ages.append(np.nan)
     return {
         'data': data,
         'seg': seg,
         "idx": ids,
         "grades": types,
         "identifiers": identifiers,
         "patient_names": patient_names,
         'survival': survivals,
         'age': ages
     }
 def generate_train_batch(self):
     data = np.zeros((self.BATCH_SIZE, 1, self.PATCH_SIZE[0],
                      self.PATCH_SIZE[1], self.PATCH_SIZE[2]),
                     dtype=np.float32)
     seg = np.zeros((self.BATCH_SIZE, 1, self.PATCH_SIZE[0],
                     self.PATCH_SIZE[1], self.PATCH_SIZE[2]),
                    dtype=np.float32)
     types = np.random.choice(['ed', 'es'], self.BATCH_SIZE, True)
     patients = np.random.choice(self._data.keys(), self.BATCH_SIZE, True)
     pathologies = []
     for nb in range(self.BATCH_SIZE):
         if np.any(
                 np.array(self._data[patients[nb]][types[nb] + '_data'].
                          shape) < np.array(self.PATCH_SIZE)):
             shp = self._data[patients[nb]][types[nb] + '_data'].shape
             tmp_data = resize_image_by_padding(
                 self._data[patients[nb]][types[nb] + '_data'],
                 (max(shp[0],
                      self.PATCH_SIZE[0]), max(shp[1], self.PATCH_SIZE[1]),
                  max(shp[2], self.PATCH_SIZE[2])),
                 pad_value=0)
             tmp_seg = resize_image_by_padding(
                 self._data[patients[nb]][types[nb] + '_gt'],
                 (max(shp[0],
                      self.PATCH_SIZE[0]), max(shp[1], self.PATCH_SIZE[1]),
                  max(shp[2], self.PATCH_SIZE[2])),
                 pad_value=0)
         else:
             tmp_data = self._data[patients[nb]][types[nb] + '_data']
             tmp_seg = self._data[patients[nb]][types[nb] + '_gt']
         # not the most efficient way but whatever...
         tmp = np.zeros((1, 2, tmp_data.shape[0], tmp_data.shape[1],
                         tmp_data.shape[2]))
         tmp[0, 0] = tmp_data
         tmp[0, 1] = tmp_seg
         if self._random_crop:
             tmp = random_crop_3D_image_batched(tmp, self.PATCH_SIZE)
         else:
             tmp = center_crop_3D_image_batched(tmp, self.PATCH_SIZE)
         data[nb, 0] = tmp[0, 0]
         seg[nb, 0] = tmp[0, 1]
         pathologies.append(self._data[patients[nb]]['pathology'])
     return {
         'data': data,
         'seg': seg,
         'types': types,
         'patient_ids': patients,
         'pathologies': pathologies
     }
Пример #3
0
    def generate_train_batch(self):
        selected_keys = np.random.choice(self.list_of_keys, self.batch_size,
                                         True, None)
        data = []
        seg = []
        case_properties = []
        for j, i in enumerate(selected_keys):
            properties = self._data[i]['properties']
            case_properties.append(properties)

            if self.get_do_oversample(j):
                force_fg = True
            else:
                force_fg = False

            if not isfile(self._data[i]['data_file'][:-4] + ".npy"):
                # lets hope you know what you're doing
                case_all_data = np.load(self._data[i]['data_file'][:-4] +
                                        ".npz")['data']
            else:
                case_all_data = np.load(
                    self._data[i]['data_file'][:-4] + ".npy", self.memmap_mode)

            # this is for when there is just a 2d slice in case_all_data (2d support)
            if len(case_all_data.shape) == 3:
                case_all_data = case_all_data[:, None]

            if self.transpose is not None:
                leading_axis = self.transpose[0]
            else:
                leading_axis = 0

            if not force_fg:
                random_slice = np.random.randint(
                    1, case_all_data.shape[leading_axis + 1])
            else:
                classes_in_slice_per_axis = properties.get(
                    "classes_in_slice_per_axis")
                possible_classes = np.unique(properties['classes'])
                possible_classes = possible_classes[possible_classes > 0]
                if len(possible_classes) > 0 and not (
                        0 in possible_classes.shape):
                    selected_class = np.random.choice(possible_classes)
                else:
                    selected_class = 0
                if classes_in_slice_per_axis is not None:
                    valid_slices = classes_in_slice_per_axis[leading_axis][
                        selected_class]
                else:
                    valid_slices = np.where(
                        np.sum(case_all_data[-1] == selected_class,
                               axis=[i for i in range(3)
                                     if i != leading_axis]))[0]
                if len(valid_slices) != 0:
                    random_slice = np.random.choice(valid_slices)
                else:
                    random_slice = np.random.choice(
                        case_all_data.shape[leading_axis + 1])

            if self.pseudo_3d_slices == 1:
                if leading_axis == 0:
                    case_all_data = case_all_data[:, random_slice -
                                                  1:random_slice + 2]
                elif leading_axis == 1:
                    case_all_data = case_all_data[:, :, random_slice]
                else:
                    case_all_data = case_all_data[:, :, :, random_slice]
                if self.transpose is not None and self.transpose[
                        1] > self.transpose[2]:
                    case_all_data = case_all_data.transpose(0, 2, 1)
            else:
                assert leading_axis == 0, "pseudo_3d_slices works only without transpose for now!"
                mn = random_slice - (self.pseudo_3d_slices - 1) // 2
                mx = random_slice + (self.pseudo_3d_slices - 1) // 2 + 1
                valid_mn = max(mn, 0)
                valid_mx = min(mx, case_all_data.shape[1])
                case_all_seg = case_all_data[-1:]
                case_all_data = case_all_data[:-1]
                case_all_data = case_all_data[:, valid_mn:valid_mx]
                case_all_seg = case_all_seg[:, random_slice]
                need_to_pad_below = valid_mn - mn
                need_to_pad_above = mx - valid_mx
                if need_to_pad_below > 0:
                    shp_for_pad = np.array(case_all_data.shape)
                    shp_for_pad[1] = need_to_pad_below
                    case_all_data = np.concatenate(
                        (np.zeros(shp_for_pad), case_all_data), 1)
                if need_to_pad_above > 0:
                    shp_for_pad = np.array(case_all_data.shape)
                    shp_for_pad[1] = need_to_pad_above
                    case_all_data = np.concatenate(
                        (case_all_data, np.zeros(shp_for_pad)), 1)
                case_all_data = case_all_data.reshape(
                    (-1, case_all_data.shape[-2], case_all_data.shape[-1]))
                case_all_data = np.concatenate((case_all_data, case_all_seg),
                                               0)
            num_seg = 1

            # why we need this is a little complicated. It has to do with downstream random cropping during data
            # augmentation. Basically we will rotate the patch and then to a center crop of size self.final_patch_size.
            # Depending on the rotation, scaling and elastic deformation parameters, self.patch_size has to be large
            # enough to prevent border artifacts from being present in the final patch
            new_shp = None
            # print(case_all_data.shape)
            if np.any(self.need_to_pad) > 0:
                self.need_to_pad = np.array(
                    [0, self.need_to_pad[0], self.need_to_pad[1]])
                new_shp = np.array(case_all_data.shape[1:] + self.need_to_pad)
                if np.any(new_shp -
                          np.array([3, self.patch_size[0], self.patch_size[1]])
                          < 0):
                    new_shp = np.max(
                        np.vstack(
                            (new_shp[None],
                             np.array(
                                 [3, self.patch_size[0],
                                  self.patch_size[1]])[None])), 0)
            else:
                if np.any(
                        np.array(case_all_data.shape[1:]) -
                        np.array([3, self.patch_size[0], self.patch_size[1]]) <
                        0):
                    new_shp = np.max(
                        np.vstack(
                            (np.array(case_all_data.shape[1:])[None],
                             np.array(
                                 [3, self.patch_size[0],
                                  self.patch_size[1]])[None])), 0)
            if new_shp is not None:
                case_all_data_donly = pad_nd_image(case_all_data[:-num_seg],
                                                   new_shp,
                                                   self.pad_mode,
                                                   kwargs=self.pad_kwargs_data)
                case_all_data_segnonly = pad_nd_image(
                    case_all_data[-num_seg:],
                    new_shp,
                    'constant',
                    kwargs={'constant_values': -1})
                case_all_data = np.vstack(
                    (case_all_data_donly, case_all_data_segnonly))[None]
            else:
                case_all_data = case_all_data[None]
            force_fg = True
            selected_class = 1
            if not force_fg:
                case_all_data = random_crop_3D_image_batched(
                    case_all_data,
                    tuple([3, self.patch_size[0], self.patch_size[1]]))
            else:
                case_all_data = crop_3D_image_force_fg(case_all_data[0],
                                                       tuple(self.patch_size),
                                                       selected_class)[None]
            data.append(case_all_data[0, :-num_seg])
            seg.append(case_all_data[0, -num_seg:])

        data = np.vstack(data)
        seg = np.vstack(seg)
        keys = selected_keys
        return {
            'data': data,
            'seg': seg,
            'properties': case_properties,
            "keys": keys
        }
Пример #4
0
    def generate_train_batch(self):
        indices = self.get_indices()
        imgs = []
        labels = []
        properties = []
        for idx in indices:
            item = self._data[idx]
            if self.cfg.is_test is False:
                img, lbl, _property = item[IMG], item[LABEL], item[PROPERTIES]
                stacked_volume = np.stack([img, lbl])  # (2, 512, 512)

                assert len(img.shape) == len(
                    self.cfg.patch_size
                ), "len(patch_size) must be equal to len(img.shape)"

                padded_stacked_volume = pad_nd_image(
                    stacked_volume, self.cfg.patch_size
                )  # in case the img_size is smaller than patch_size
                padded_stacked_volume = np.expand_dims(padded_stacked_volume,
                                                       axis=0)  # (1, 2, *size)
                if self.cfg.three_dim:
                    cropped_stacked_volume = random_crop_3D_image_batched(
                        padded_stacked_volume, self.cfg.patch_size)
                else:
                    cropped_stacked_volume = random_crop_2D_image_batched(
                        padded_stacked_volume, self.cfg.patch_size)
                cropped_stacked_volume = np.squeeze(
                    cropped_stacked_volume)  # (2, *patch_size)
                img, lbl = cropped_stacked_volume[0], cropped_stacked_volume[1]
                imgs.append(img)
                labels.append(lbl)
                properties.append(_property)
            else:
                img, _property = item[IMG], item[PROPERTIES]

                assert len(img.shape) == len(
                    self.cfg.patch_size
                ), "len(patch_size) must be equal to len(img.shape)"

                padded_stacked_volume = pad_nd_image(
                    img, self.cfg.patch_size
                )  # in case the img_size is smaller than patch_size
                if self.cfg.three_dim:
                    cropped_stacked_volume = random_crop_3D_image(
                        padded_stacked_volume, self.cfg.patch_size)
                else:
                    cropped_stacked_volume = random_crop_2D_image(
                        padded_stacked_volume, self.cfg.patch_size)
                imgs.append(cropped_stacked_volume)
                properties.append(_property)

        batch_img = np.expand_dims(np.stack(imgs),
                                   axis=1)  # (b, c, *patch_size)
        if self.cfg.is_test:
            return {
                IMG: batch_img,
                BATCH_KEYS: indices,
                PROPERTIES: properties
            }
        batch_label = np.expand_dims(np.stack(labels),
                                     axis=1)  # (b, c, *patch_size)
        return {
            IMG: batch_img,
            LABEL: batch_label,
            BATCH_KEYS: indices,
            PROPERTIES: properties
        }