示例#1
0
    def generate_train_batch(self):
        data = np.zeros(
            (self.BATCH_SIZE, 1, self.PATCH_SIZE[0], self.PATCH_SIZE[1]),
            dtype=np.float32)
        seg = np.zeros(
            (self.BATCH_SIZE, 1, self.PATCH_SIZE[0], self.PATCH_SIZE[1]),
            dtype=np.float32)
        types = np.random.choice(['ed', 'es'], self.BATCH_SIZE,
                                 True)  # randomly choose es or ed
        # print(self._data.keys())
        patients = np.random.choice(
            list(self._data.keys()), self.BATCH_SIZE,
            True)  # randomly choose 'batchsize' patients
        pathologies = []
        for nb in range(self.BATCH_SIZE):
            """
            data[i]['ed_data'] = a[0, :]
            data[i]['ed_gt'] = a[1, :]
            data[i]['es_data'] = a[2, :]
            data[i]['es_gt'] = a[3, :]
            """
            shp = self._data[patients[nb]][types[nb] + '_data'].shape
            slice_id = np.random.choice(
                shp[0]
            )  # randomly choose one slice from total slice, here is 10
            tmp_data = resize_image_by_padding(
                self._data[patients[nb]][types[nb] + '_data'][slice_id],
                (max(shp[1],
                     self.PATCH_SIZE[0]), max(shp[2], self.PATCH_SIZE[1])),
                pad_value=0)
            tmp_seg = resize_image_by_padding(
                self._data[patients[nb]][types[nb] + '_gt'][slice_id],
                (max(shp[1],
                     self.PATCH_SIZE[0]), max(shp[2], self.PATCH_SIZE[1])),
                pad_value=0)

            # not the most efficient way but whatever...
            tmp = np.zeros((1, 2, tmp_data.shape[0], tmp_data.shape[1]))
            tmp[0, 0] = tmp_data
            tmp[0, 1] = tmp_seg
            tmp = random_crop_2D_image_batched(tmp, self.PATCH_SIZE)
            data[nb, 0] = tmp[0, 0]
            seg[nb, 0] = tmp[0, 1]
            pathologies.append(self._data[patients[nb]]['pathology'])
            # print(data)
        return {
            'data': data,
            'seg': seg,
            'types': types,
            'patient_ids': patients,
            'pathologies': pathologies
        }
示例#2
0
    def generate_train_batch(self):
        data = np.zeros(
            (self.BATCH_SIZE, 1, self.PATCH_SIZE[0], self.PATCH_SIZE[1]),
            dtype=np.float32)
        seg = np.zeros(
            (self.BATCH_SIZE, 1, self.PATCH_SIZE[0], self.PATCH_SIZE[1]),
            dtype=np.float32)
        types = np.random.choice(['ed', 'es'], self.BATCH_SIZE, True)
        patients = np.random.choice(self._data.keys(), self.BATCH_SIZE, True)
        pathologies = []
        for nb in range(self.BATCH_SIZE):
            shp = self._data[patients[nb]][types[nb] + '_data'].shape
            slice_id = np.random.choice(shp[0])
            tmp_data = resize_image_by_padding(
                self._data[patients[nb]][types[nb] + '_data'][slice_id],
                (max(shp[1],
                     self.PATCH_SIZE[0]), max(shp[2], self.PATCH_SIZE[1])),
                pad_value=0)
            tmp_seg = resize_image_by_padding(
                self._data[patients[nb]][types[nb] + '_gt'][slice_id],
                (max(shp[1],
                     self.PATCH_SIZE[0]), max(shp[2], self.PATCH_SIZE[1])),
                pad_value=0)

            # not the most efficient way but whatever...
            tmp = np.zeros((1, 2, tmp_data.shape[0], tmp_data.shape[1]))
            tmp[0, 0] = tmp_data
            tmp[0, 1] = tmp_seg
            tmp = random_crop_2D_image_batched(tmp, self.PATCH_SIZE)
            data[nb, 0] = tmp[0, 0]
            seg[nb, 0] = tmp[0, 1]
            pathologies.append(self._data[patients[nb]]['pathology'])
        return {
            'data': data,
            'seg': seg,
            'types': types,
            'patient_ids': patients,
            'pathologies': pathologies
        }
示例#3
0
    def generate_train_batch(self):
        selected_keys = np.random.choice(self.list_of_keys, self.batch_size,
                                         True, None)
        data = []
        seg = []
        case_properties = []
        for j, i in enumerate(selected_keys):
            properties = self._data[i]['properties']
            case_properties.append(properties)

            if self.get_do_oversample(j):
                force_fg = True
            else:
                force_fg = False

            if not isfile(self._data[i]['data_file'][:-4] + ".npy"):
                # lets hope you know what you're doing
                case_all_data = np.load(self._data[i]['data_file'][:-4] +
                                        ".npz")['data']
            else:
                case_all_data = np.load(
                    self._data[i]['data_file'][:-4] + ".npy", self.memmap_mode)

            # this is for when there is just a 2d slice in case_all_data (2d support)
            if len(case_all_data.shape) == 3:
                case_all_data = case_all_data[:, None]

            if self.transpose is not None:
                leading_axis = self.transpose[0]
            else:
                leading_axis = 0

            if not force_fg:
                random_slice = np.random.choice(
                    case_all_data.shape[leading_axis + 1])
            else:
                # select one class, then select a slice that contains that class
                classes_in_slice_per_axis = properties.get(
                    "classes_in_slice_per_axis")
                possible_classes = np.unique(properties['classes'])
                possible_classes = possible_classes[possible_classes > 0]
                if len(possible_classes) > 0 and not (
                        0 in possible_classes.shape):
                    selected_class = np.random.choice(possible_classes)
                else:
                    selected_class = 0
                if classes_in_slice_per_axis is not None:
                    valid_slices = classes_in_slice_per_axis[leading_axis][
                        selected_class]
                else:
                    valid_slices = np.where(
                        np.sum(case_all_data[-1] == selected_class,
                               axis=[i for i in range(3)
                                     if i != leading_axis]))[0]
                if len(valid_slices) != 0:
                    random_slice = np.random.choice(valid_slices)
                else:
                    random_slice = np.random.choice(
                        case_all_data.shape[leading_axis + 1])

            if self.pseudo_3d_slices == 1:
                if leading_axis == 0:
                    case_all_data = case_all_data[:, random_slice]
                elif leading_axis == 1:
                    case_all_data = case_all_data[:, :, random_slice]
                else:
                    case_all_data = case_all_data[:, :, :, random_slice]
                if self.transpose is not None and self.transpose[
                        1] > self.transpose[2]:
                    case_all_data = case_all_data.transpose(0, 2, 1)
            else:
                assert leading_axis == 0, "pseudo_3d_slices works only without transpose for now!"
                mn = random_slice - (self.pseudo_3d_slices - 1) // 2
                mx = random_slice + (self.pseudo_3d_slices - 1) // 2 + 1
                valid_mn = max(mn, 0)
                valid_mx = min(mx, case_all_data.shape[1])
                case_all_seg = case_all_data[-1:]
                case_all_data = case_all_data[:-1]
                case_all_data = case_all_data[:, valid_mn:valid_mx]
                case_all_seg = case_all_seg[:, random_slice]
                need_to_pad_below = valid_mn - mn
                need_to_pad_above = mx - valid_mx
                if need_to_pad_below > 0:
                    shp_for_pad = np.array(case_all_data.shape)
                    shp_for_pad[1] = need_to_pad_below
                    case_all_data = np.concatenate(
                        (np.zeros(shp_for_pad), case_all_data), 1)
                if need_to_pad_above > 0:
                    shp_for_pad = np.array(case_all_data.shape)
                    shp_for_pad[1] = need_to_pad_above
                    case_all_data = np.concatenate(
                        (case_all_data, np.zeros(shp_for_pad)), 1)
                case_all_data = case_all_data.reshape(
                    (-1, case_all_data.shape[-2], case_all_data.shape[-1]))
                case_all_data = np.concatenate((case_all_data, case_all_seg),
                                               0)

            num_seg = 1

            # why we need this is a little complicated. It has to do with downstream random cropping during data
            # augmentation. Basically we will rotate the patch and then to a center crop of size self.final_patch_size.
            # Depending on the rotation, scaling and elastic deformation parameters, self.patch_size has to be large
            # enough to prevent border artifacts from being present in the final patch
            new_shp = None
            if np.any(self.need_to_pad) > 0:
                new_shp = np.array(case_all_data.shape[1:] + self.need_to_pad)
                if np.any(new_shp - np.array(self.patch_size) < 0):
                    new_shp = np.max(
                        np.vstack(
                            (new_shp[None], np.array(self.patch_size)[None])),
                        0)
            else:
                if np.any(
                        np.array(case_all_data.shape[1:]) -
                        np.array(self.patch_size) < 0):
                    new_shp = np.max(
                        np.vstack((np.array(case_all_data.shape[1:])[None],
                                   np.array(self.patch_size)[None])), 0)
            if new_shp is not None:
                case_all_data_donly = pad_nd_image(case_all_data[:-num_seg],
                                                   new_shp,
                                                   self.pad_mode,
                                                   kwargs=self.pad_kwargs_data)
                case_all_data_segnonly = pad_nd_image(
                    case_all_data[-num_seg:],
                    new_shp,
                    'constant',
                    kwargs={'constant_values': -1})
                case_all_data = np.vstack(
                    (case_all_data_donly, case_all_data_segnonly))[None]
            else:
                case_all_data = case_all_data[None]

            if not force_fg:
                case_all_data = random_crop_2D_image_batched(
                    case_all_data, tuple(self.patch_size))
            else:
                case_all_data = crop_2D_image_force_fg(case_all_data[0],
                                                       tuple(self.patch_size),
                                                       selected_class)[None]
            data.append(case_all_data[:, :-num_seg])
            seg.append(case_all_data[:, -num_seg:])
        data = np.vstack(data)
        seg = np.vstack(seg)
        keys = selected_keys
        return {
            'data': data,
            'seg': seg,
            'properties': case_properties,
            "keys": keys
        }
示例#4
0
    def generate_train_batch(self):
        indices = self.get_indices()
        imgs = []
        labels = []
        properties = []
        for idx in indices:
            item = self._data[idx]
            if self.cfg.is_test is False:
                img, lbl, _property = item[IMG], item[LABEL], item[PROPERTIES]
                stacked_volume = np.stack([img, lbl])  # (2, 512, 512)

                assert len(img.shape) == len(
                    self.cfg.patch_size
                ), "len(patch_size) must be equal to len(img.shape)"

                padded_stacked_volume = pad_nd_image(
                    stacked_volume, self.cfg.patch_size
                )  # in case the img_size is smaller than patch_size
                padded_stacked_volume = np.expand_dims(padded_stacked_volume,
                                                       axis=0)  # (1, 2, *size)
                if self.cfg.three_dim:
                    cropped_stacked_volume = random_crop_3D_image_batched(
                        padded_stacked_volume, self.cfg.patch_size)
                else:
                    cropped_stacked_volume = random_crop_2D_image_batched(
                        padded_stacked_volume, self.cfg.patch_size)
                cropped_stacked_volume = np.squeeze(
                    cropped_stacked_volume)  # (2, *patch_size)
                img, lbl = cropped_stacked_volume[0], cropped_stacked_volume[1]
                imgs.append(img)
                labels.append(lbl)
                properties.append(_property)
            else:
                img, _property = item[IMG], item[PROPERTIES]

                assert len(img.shape) == len(
                    self.cfg.patch_size
                ), "len(patch_size) must be equal to len(img.shape)"

                padded_stacked_volume = pad_nd_image(
                    img, self.cfg.patch_size
                )  # in case the img_size is smaller than patch_size
                if self.cfg.three_dim:
                    cropped_stacked_volume = random_crop_3D_image(
                        padded_stacked_volume, self.cfg.patch_size)
                else:
                    cropped_stacked_volume = random_crop_2D_image(
                        padded_stacked_volume, self.cfg.patch_size)
                imgs.append(cropped_stacked_volume)
                properties.append(_property)

        batch_img = np.expand_dims(np.stack(imgs),
                                   axis=1)  # (b, c, *patch_size)
        if self.cfg.is_test:
            return {
                IMG: batch_img,
                BATCH_KEYS: indices,
                PROPERTIES: properties
            }
        batch_label = np.expand_dims(np.stack(labels),
                                     axis=1)  # (b, c, *patch_size)
        return {
            IMG: batch_img,
            LABEL: batch_label,
            BATCH_KEYS: indices,
            PROPERTIES: properties
        }