Exemplo n.º 1
0
    def generate_train_batch(self):

        subjects = self._data[0]
        subject_idx = int(random.uniform(0, len(subjects)))

        data, seg = load_training_data(self.Config, subjects[subject_idx])

        # Convert peaks to tensors if tensor model
        if self.Config.NR_OF_GRADIENTS == 18*self.Config.NR_SLICES:
            data = peak_utils.peaks_to_tensors(data)

        slice_direction = data_utils.slice_dir_to_int(self.Config.TRAINING_SLICE_DIRECTION)
        if data.shape[slice_direction] <= self.batch_size:
            print("INFO: Batch size bigger than nr of slices. Therefore sampling with replacement.")
            slice_idxs = np.random.choice(data.shape[slice_direction], self.batch_size, True, None)
        else:
            slice_idxs = np.random.choice(data.shape[slice_direction], self.batch_size, False, None)

        if self.Config.NR_SLICES > 1:
            x, y = data_utils.sample_Xslices(data, seg, slice_idxs, slice_direction=slice_direction,
                                             labels_type=self.Config.LABELS_TYPE, slice_window=self.Config.NR_SLICES)
        else:
            x, y = data_utils.sample_slices(data, seg, slice_idxs, slice_direction=slice_direction,
                                            labels_type=self.Config.LABELS_TYPE)


        # Can be replaced by crop
        # x = pad_nd_image(x, self.Config.INPUT_DIM, mode='constant', kwargs={'constant_values': 0})
        # y = pad_nd_image(y, self.Config.INPUT_DIM, mode='constant', kwargs={'constant_values': 0})
        # x = center_crop_2D_image_batched(x, self.Config.INPUT_DIM)
        # y = center_crop_2D_image_batched(y, self.Config.INPUT_DIM)

        # If want to convert e.g. 1.25mm (HCP) image to 2mm image (bb)
        # x, y = self._zoom_x_and_y(x, y, 0.67)  # very slow -> try spatial_transform, should be fast

        if self.Config.PAD_TO_SQUARE:
            #Crop and pad to input size
            x, y = crop(x, y, crop_size=self.Config.INPUT_DIM)  # does not work with img with batches and channels
        else:
            # Works -> results as good?
            # Will pad each axis to be multiple of 16. (Each sample can end up having different dimensions. Also x and y
            # can be different)
            # This is needed for Schizo dataset
            x = pad_nd_image(x, shape_must_be_divisible_by=(16, 16), mode='constant', kwargs={'constant_values': 0})
            y = pad_nd_image(y, shape_must_be_divisible_by=(16, 16), mode='constant', kwargs={'constant_values': 0})

        # Does not make it slower
        x = x.astype(np.float32)
        y = y.astype(np.float32)

        # possible optimization: sample slices from different patients and pad all to same size (size of biggest)

        data_dict = {"data": x,  # (batch_size, channels, x, y, [z])
                     "seg": y,
                     "slice_dir": slice_direction}  # (batch_size, channels, x, y, [z])
        return data_dict
Exemplo n.º 2
0
    def generate_train_batch(self):
        data = self._data[0]
        seg = self._data[1]

        if self.Config.SLICE_DIRECTION == "x":
            end = data.shape[0]
        elif self.Config.SLICE_DIRECTION == "y":
            end = data.shape[1]
        elif self.Config.SLICE_DIRECTION == "z":
            end = data.shape[2]

        # Stop iterating if we reached end of data
        if self.global_idx >= end:
            # print("Stopped because end of file")
            self.global_idx = 0
            raise StopIteration

        new_global_idx = self.global_idx + self.batch_size

        # If we reach end, make last batch smaller, so it fits exactly into rest
        if new_global_idx >= end:
            new_global_idx = end  # not end-1, because this goes into range, and there automatically -1

        slice_idxs = list(range(self.global_idx, new_global_idx))
        slice_direction = data_utils.slice_dir_to_int(
            self.Config.SLICE_DIRECTION)

        if self.Config.NR_SLICES > 1:
            x, y = data_utils.sample_Xslices(
                data,
                seg,
                slice_idxs,
                slice_direction=slice_direction,
                labels_type=self.Config.LABELS_TYPE,
                slice_window=self.Config.NR_SLICES)
        else:
            x, y = data_utils.sample_slices(
                data,
                seg,
                slice_idxs,
                slice_direction=slice_direction,
                labels_type=self.Config.LABELS_TYPE)

        data_dict = {
            "data": x,  # (batch_size, channels, x, y, [z])
            "seg": y
        }  # (batch_size, channels, x, y, [z])
        self.global_idx = new_global_idx
        return data_dict
Exemplo n.º 3
0
    def generate_train_batch(self):

        # np.random.seed(1234)

        subjects = self._data[0]
        subject_idx = int(random.uniform(0, len(subjects)))     # len(subjects)-1 not needed because int always rounds to floor

        data, seg = load_training_data(self.Config, subjects[subject_idx])

        #Convert peaks to tensors if tensor model
        if self.Config.NR_OF_GRADIENTS == 18*self.Config.NR_SLICES:
            data = peak_utils.peaks_to_tensors(data)

        slice_direction = data_utils.slice_dir_to_int(self.Config.TRAINING_SLICE_DIRECTION)
        slice_idxs = np.random.choice(data.shape[slice_direction], self.batch_size, False, None)

        if self.Config.NR_SLICES > 1:
            x, y = data_utils.sample_Xslices(data, seg, slice_idxs, slice_direction=slice_direction,
                                             labels_type=self.Config.LABELS_TYPE, slice_window=self.Config.NR_SLICES)
        else:
            x, y = data_utils.sample_slices(data, seg, slice_idxs, slice_direction=slice_direction,
                                            labels_type=self.Config.LABELS_TYPE)


        # Can be replaced by crop
        # x = pad_nd_image(x, self.Config.INPUT_DIM, mode='constant', kwargs={'constant_values': 0})
        # y = pad_nd_image(y, self.Config.INPUT_DIM, mode='constant', kwargs={'constant_values': 0})
        # x = center_crop_2D_image_batched(x, self.Config.INPUT_DIM)
        # y = center_crop_2D_image_batched(y, self.Config.INPUT_DIM)

        #Crop and pad to input size
        x, y = crop(x, y, crop_size=self.Config.INPUT_DIM)  # does not work with img with batches and channels

        # Works -> results as good? -> todo: make the same way for inference!
        # This is needed for Schizo dataset
        # x = pad_nd_image(x, shape_must_be_divisible_by=(16, 16), mode='constant', kwargs={'constant_values': 0})
        # y = pad_nd_image(y, shape_must_be_divisible_by=(16, 16), mode='constant', kwargs={'constant_values': 0})

        # Does not make it slower
        x = x.astype(np.float32)
        y = y.astype(np.float32)  # if not doing this: during validation: ConnectionResetError: [Errno 104] Connection
                                  # reset by peer

        #possible optimization: sample slices from different patients and pad all to same size (size of biggest)

        data_dict = {"data": x,     # (batch_size, channels, x, y, [z])
                     "seg": y,
                     "slice_dir": slice_direction}      # (batch_size, channels, x, y, [z])
        return data_dict