Exemple #1
0
 def get_basic_generators(self):
     self.load_dataset()
     self.do_split()
     if self.threeD:
         dl_tr = DataLoader3D(self.dataset_tr,
                              self.basic_generator_patch_size,
                              self.patch_size,
                              self.batch_size,
                              True,
                              oversample_foreground_percent=self.
                              oversample_foreground_percent)
         dl_val = DataLoader3D(self.dataset_val,
                               self.patch_size,
                               self.patch_size,
                               self.batch_size,
                               True,
                               oversample_foreground_percent=self.
                               oversample_foreground_percent)
     else:
         raise NotImplementedError
     return dl_tr, dl_val
Exemple #2
0
    def get_basic_generators(self):
        self.load_dataset()
        self.do_split()

        if self.threeD:
            dl_tr = DataLoader3D(self.dataset_tr, self.patch_size, self.patch_size, self.batch_size,
                                 False, oversample_foreground_percent=self.oversample_foreground_percent
                                 , pad_mode="constant", pad_sides=self.pad_all_sides)
            dl_val = DataLoader3D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, False,
                                  oversample_foreground_percent=self.oversample_foreground_percent,
                                  pad_mode="constant", pad_sides=self.pad_all_sides)
        else:
            dl_tr = DataLoader2D(self.dataset_tr, self.patch_size, self.patch_size, self.batch_size,
                                 transpose=self.plans.get('transpose_forward'),
                                 oversample_foreground_percent=self.oversample_foreground_percent
                                 , pad_mode="constant", pad_sides=self.pad_all_sides)
            dl_val = DataLoader2D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size,
                                  transpose=self.plans.get('transpose_forward'),
                                  oversample_foreground_percent=self.oversample_foreground_percent,
                                  pad_mode="constant", pad_sides=self.pad_all_sides)
        return dl_tr, dl_val
Exemple #3
0
    def get_basic_generators(self):
        self.load_dataset()
        """
        def load_dataset(folder):
        # we don't load the actual data but instead return the filename to the np file. the properties are loaded though
        case_identifiers = get_case_identifiers(folder)
        case_identifiers.sort()
        dataset = OrderedDict()
        for c in case_identifiers:
            dataset[c] = OrderedDict()
            dataset[c]['data_file'] = join(folder, "%s.npz"%c)
            with open(join(folder, "%s.pkl"%c), 'rb') as f:
                dataset[c]['properties'] = pickle.load(f)
            if dataset[c].get('seg_from_prev_stage_file') is not None:
                dataset[c]['seg_from_prev_stage_file'] = join(folder, "%s_segs.npz"%c)
        return dataset
        """
        self.do_split()

        if self.threeD:
            dl_tr = DataLoader3D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size,
                                 False, oversample_foreground_percent=self.oversample_foreground_percent,
                                 pad_mode="constant", pad_sides=self.pad_all_sides)
            dl_val = DataLoader3D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, False,
                                  oversample_foreground_percent=self.oversample_foreground_percent,
                                  pad_mode="constant", pad_sides=self.pad_all_sides)
        else:
            dl_tr = DataLoader2D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size,
                                 transpose=self.plans.get('transpose_forward'),
                                 oversample_foreground_percent=self.oversample_foreground_percent,
                                 pad_mode="constant", pad_sides=self.pad_all_sides)
            dl_val = DataLoader2D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size,
                                  transpose=self.plans.get('transpose_forward'),
                                  oversample_foreground_percent=self.oversample_foreground_percent,
                                  pad_mode="constant", pad_sides=self.pad_all_sides)
        return dl_tr, dl_val
Exemple #4
0
        seeds=seeds_val,
        pin_memory=pin_memory)
    return batchgenerator_train, batchgenerator_val


if __name__ == "__main__":
    from nnunet.training.dataloading.dataset_loading import DataLoader3D, load_dataset
    from nnunet.paths import preprocessing_output_dir
    import os
    import pickle

    t = "Task002_Heart"
    p = os.path.join(preprocessing_output_dir, t)
    dataset = load_dataset(p, 0)
    with open(os.path.join(p, "plans.pkl"), 'rb') as f:
        plans = pickle.load(f)

    basic_patch_size = get_patch_size(
        np.array(plans['stage_properties'][0].patch_size),
        default_3D_augmentation_params['rotation_x'],
        default_3D_augmentation_params['rotation_y'],
        default_3D_augmentation_params['rotation_z'],
        default_3D_augmentation_params['scale_range'])

    dl = DataLoader3D(
        dataset, basic_patch_size,
        np.array(plans['stage_properties'][0].patch_size).astype(int), 1)
    tr, val = get_default_augmentation(
        dl, dl,
        np.array(plans['stage_properties'][0].patch_size).astype(int))