def generate_train_batch(self):

        subjects = self._data[0]
        subject_idx = int(random.uniform(0, len(subjects)))

        data, seg = load_training_data(self.Config, subjects[subject_idx])

        # Convert peaks to tensors if tensor model
        if self.Config.NR_OF_GRADIENTS == 18*self.Config.NR_SLICES:
            data = peak_utils.peaks_to_tensors(data)

        slice_direction = data_utils.slice_dir_to_int(self.Config.TRAINING_SLICE_DIRECTION)
        if data.shape[slice_direction] <= self.batch_size:
            print("INFO: Batch size bigger than nr of slices. Therefore sampling with replacement.")
            slice_idxs = np.random.choice(data.shape[slice_direction], self.batch_size, True, None)
        else:
            slice_idxs = np.random.choice(data.shape[slice_direction], self.batch_size, False, None)

        if self.Config.NR_SLICES > 1:
            x, y = data_utils.sample_Xslices(data, seg, slice_idxs, slice_direction=slice_direction,
                                             labels_type=self.Config.LABELS_TYPE, slice_window=self.Config.NR_SLICES)
        else:
            x, y = data_utils.sample_slices(data, seg, slice_idxs, slice_direction=slice_direction,
                                            labels_type=self.Config.LABELS_TYPE)


        # Can be replaced by crop
        # x = pad_nd_image(x, self.Config.INPUT_DIM, mode='constant', kwargs={'constant_values': 0})
        # y = pad_nd_image(y, self.Config.INPUT_DIM, mode='constant', kwargs={'constant_values': 0})
        # x = center_crop_2D_image_batched(x, self.Config.INPUT_DIM)
        # y = center_crop_2D_image_batched(y, self.Config.INPUT_DIM)

        # If want to convert e.g. 1.25mm (HCP) image to 2mm image (bb)
        # x, y = self._zoom_x_and_y(x, y, 0.67)  # very slow -> try spatial_transform, should be fast

        if self.Config.PAD_TO_SQUARE:
            #Crop and pad to input size
            x, y = crop(x, y, crop_size=self.Config.INPUT_DIM)  # does not work with img with batches and channels
        else:
            # Works -> results as good?
            # Will pad each axis to be multiple of 16. (Each sample can end up having different dimensions. Also x and y
            # can be different)
            # This is needed for Schizo dataset
            x = pad_nd_image(x, shape_must_be_divisible_by=(16, 16), mode='constant', kwargs={'constant_values': 0})
            y = pad_nd_image(y, shape_must_be_divisible_by=(16, 16), mode='constant', kwargs={'constant_values': 0})

        # Does not make it slower
        x = x.astype(np.float32)
        y = y.astype(np.float32)

        # possible optimization: sample slices from different patients and pad all to same size (size of biggest)

        data_dict = {"data": x,  # (batch_size, channels, x, y, [z])
                     "seg": y,
                     "slice_dir": slice_direction}  # (batch_size, channels, x, y, [z])
        return data_dict
 def process_bundle(idx):
     print("idx: {}".format(idx))
     dirs_per_bundle = []
     for jdx in range(3):  # 3 orientations
         peak = img[:, :, :, idx * 3:idx * 3 + 3, jdx]
         tensor = peak_utils.peaks_to_tensors(peak)
         dirs_per_bundle.append(tensor)
     merged_tensor = np.array(dirs_per_bundle).mean(axis=0)
     merged_peak = peak_utils.tensors_to_peaks(merged_tensor)
     return merged_peak
Exemple #3
0
    def get_batch_generator(self, batch_size=1):

        if self.data is not None:
            exp_utils.print_verbose(
                self.Config.VERBOSE,
                "Loading data from PREDICT_IMG input file")
            data = np.nan_to_num(self.data)
            # Use dummy mask in case we only want to predict on some data (where we do not have ground truth))
            seg = np.zeros(
                (self.Config.INPUT_DIM[0], self.Config.INPUT_DIM[0],
                 self.Config.INPUT_DIM[0],
                 self.Config.NR_OF_CLASSES)).astype(self.Config.LABELS_TYPE)
        elif self.subject is not None:
            if self.Config.TYPE == "combined":
                # Load from npy file for Fusion
                data = np.load(join(C.DATA_PATH, self.Config.DATASET_FOLDER,
                                    self.subject,
                                    self.Config.FEATURES_FILENAME + ".npy"),
                               mmap_mode="r")
                seg = np.load(join(C.DATA_PATH, self.Config.DATASET_FOLDER,
                                   self.subject,
                                   self.Config.LABELS_FILENAME + ".npy"),
                              mmap_mode="r")
                data = np.nan_to_num(data)
                seg = np.nan_to_num(seg)
                data = np.reshape(data,
                                  (data.shape[0], data.shape[1], data.shape[2],
                                   data.shape[3] * data.shape[4]))
            else:
                from tractseg.data.data_loader_training import load_training_data
                data, seg = load_training_data(self.Config, self.subject)

                # Convert peaks to tensors if tensor model
                if self.Config.NR_OF_GRADIENTS == 18 * self.Config.NR_SLICES:
                    data = peak_utils.peaks_to_tensors(data)

                data, transformation = data_utils.pad_and_scale_img_to_square_img(
                    data, target_size=self.Config.INPUT_DIM[0], nr_cpus=1)
                seg, transformation = data_utils.pad_and_scale_img_to_square_img(
                    seg, target_size=self.Config.INPUT_DIM[0], nr_cpus=1)
        else:
            raise ValueError("Neither 'data' nor 'subject' set.")

        if self.Config.DIM == "2D":
            batch_gen = BatchGenerator2D_data_ordered_standalone(
                (data, seg), batch_size=batch_size)
        else:
            batch_gen = BatchGenerator3D_data_ordered_standalone(
                (data, seg), batch_size=batch_size)
        batch_gen.Config = self.Config

        batch_gen = self._augment_data(batch_gen, type=type)
        return batch_gen
    def generate_train_batch(self):

        # np.random.seed(1234)

        subjects = self._data[0]
        subject_idx = int(random.uniform(0, len(subjects)))     # len(subjects)-1 not needed because int always rounds to floor

        data, seg = load_training_data(self.Config, subjects[subject_idx])

        #Convert peaks to tensors if tensor model
        if self.Config.NR_OF_GRADIENTS == 18*self.Config.NR_SLICES:
            data = peak_utils.peaks_to_tensors(data)

        slice_direction = data_utils.slice_dir_to_int(self.Config.TRAINING_SLICE_DIRECTION)
        slice_idxs = np.random.choice(data.shape[slice_direction], self.batch_size, False, None)

        if self.Config.NR_SLICES > 1:
            x, y = data_utils.sample_Xslices(data, seg, slice_idxs, slice_direction=slice_direction,
                                             labels_type=self.Config.LABELS_TYPE, slice_window=self.Config.NR_SLICES)
        else:
            x, y = data_utils.sample_slices(data, seg, slice_idxs, slice_direction=slice_direction,
                                            labels_type=self.Config.LABELS_TYPE)


        # Can be replaced by crop
        # x = pad_nd_image(x, self.Config.INPUT_DIM, mode='constant', kwargs={'constant_values': 0})
        # y = pad_nd_image(y, self.Config.INPUT_DIM, mode='constant', kwargs={'constant_values': 0})
        # x = center_crop_2D_image_batched(x, self.Config.INPUT_DIM)
        # y = center_crop_2D_image_batched(y, self.Config.INPUT_DIM)

        #Crop and pad to input size
        x, y = crop(x, y, crop_size=self.Config.INPUT_DIM)  # does not work with img with batches and channels

        # Works -> results as good? -> todo: make the same way for inference!
        # This is needed for Schizo dataset
        # x = pad_nd_image(x, shape_must_be_divisible_by=(16, 16), mode='constant', kwargs={'constant_values': 0})
        # y = pad_nd_image(y, shape_must_be_divisible_by=(16, 16), mode='constant', kwargs={'constant_values': 0})

        # Does not make it slower
        x = x.astype(np.float32)
        y = y.astype(np.float32)  # if not doing this: during validation: ConnectionResetError: [Errno 104] Connection
                                  # reset by peer

        #possible optimization: sample slices from different patients and pad all to same size (size of biggest)

        data_dict = {"data": x,     # (batch_size, channels, x, y, [z])
                     "seg": y,
                     "slice_dir": slice_direction}      # (batch_size, channels, x, y, [z])
        return data_dict