def get_basic_generators(self):
        self.load_dataset()
        self.do_split()

        if self.threeD:
            dl_tr = DataLoader3D(self.dataset_tr,
                                 self.basic_generator_patch_size,
                                 self.patch_size,
                                 self.batch_size,
                                 True,
                                 oversample_foreground_percent=self.
                                 oversample_foreground_percent,
                                 pad_mode="constant",
                                 pad_sides=self.pad_all_sides)
            dl_val = DataLoader3D(self.dataset_val,
                                  self.patch_size,
                                  self.patch_size,
                                  self.batch_size,
                                  True,
                                  oversample_foreground_percent=self.
                                  oversample_foreground_percent,
                                  pad_mode="constant",
                                  pad_sides=self.pad_all_sides)
        else:
            raise NotImplementedError("2D has no cascade")

        return dl_tr, dl_val
Esempio n. 2
0
    def get_basic_generators(self):
        self.load_dataset()
        self.do_split()

        if self.threeD:
            dl_tr = DataLoader3D(self.dataset_tr,
                                 self.basic_generator_patch_size,
                                 self.patch_size,
                                 self.batch_size,
                                 False,
                                 oversample_foreground_percent=self.
                                 oversample_foreground_percent,
                                 pad_mode="constant",
                                 pad_sides=self.pad_all_sides,
                                 memmap_mode='r')
            dl_val = DataLoader3D(self.dataset_val,
                                  self.patch_size,
                                  self.patch_size,
                                  self.batch_size,
                                  False,
                                  oversample_foreground_percent=self.
                                  oversample_foreground_percent,
                                  pad_mode="constant",
                                  pad_sides=self.pad_all_sides,
                                  memmap_mode='r')
        else:
            dl_tr = DataLoader2D(self.dataset_tr,
                                 self.basic_generator_patch_size,
                                 self.patch_size,
                                 self.batch_size,
                                 oversample_foreground_percent=self.
                                 oversample_foreground_percent,
                                 pad_mode="constant",
                                 pad_sides=self.pad_all_sides,
                                 memmap_mode='r')
            dl_val = DataLoader2D(self.dataset_val,
                                  self.patch_size,
                                  self.patch_size,
                                  self.batch_size,
                                  oversample_foreground_percent=self.
                                  oversample_foreground_percent,
                                  pad_mode="constant",
                                  pad_sides=self.pad_all_sides,
                                  memmap_mode='r')
        return dl_tr, dl_val
 def get_basic_generators(self):
     self.load_dataset()
     self.do_split()
     if self.threeD:
         dl_tr = DataLoader3D(self.dataset_tr,
                              self.basic_generator_patch_size,
                              self.patch_size,
                              self.batch_size,
                              True,
                              oversample_foreground_percent=self.
                              oversample_foreground_percent)
         dl_val = DataLoader3D(self.dataset_val,
                               self.patch_size,
                               self.patch_size,
                               self.batch_size,
                               True,
                               oversample_foreground_percent=self.
                               oversample_foreground_percent)
     else:
         raise NotImplementedError
     return dl_tr, dl_val
Esempio n. 4
0
    def get_basic_generators(self):
        self.load_dataset()
        self.do_split()

        if self.threeD:
            dl_tr = DataLoader3D(self.dataset_tr, self.patch_size, self.patch_size, self.batch_size,
                                 False, oversample_foreground_percent=self.oversample_foreground_percent
                                 , pad_mode="constant", pad_sides=self.pad_all_sides)
            dl_val = DataLoader3D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, False,
                                  oversample_foreground_percent=self.oversample_foreground_percent,
                                  pad_mode="constant", pad_sides=self.pad_all_sides)
        else:
            dl_tr = DataLoader2D(self.dataset_tr, self.patch_size, self.patch_size, self.batch_size,
                                 transpose=self.plans.get('transpose_forward'),
                                 oversample_foreground_percent=self.oversample_foreground_percent
                                 , pad_mode="constant", pad_sides=self.pad_all_sides)
            dl_val = DataLoader2D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size,
                                  transpose=self.plans.get('transpose_forward'),
                                  oversample_foreground_percent=self.oversample_foreground_percent,
                                  pad_mode="constant", pad_sides=self.pad_all_sides)
        return dl_tr, dl_val
    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms)
    batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1),
                                                params.get("num_cached_per_thread"), seeds=seeds_val,
                                                pin_memory=pin_memory)
    return batchgenerator_train, batchgenerator_val


if __name__ == "__main__":
    from tuframework.training.dataloading.dataset_loading import DataLoader3D, load_dataset
    from tuframework.paths import preprocessing_output_dir
    import os
    import pickle

    t = "Task002_Heart"
    p = preprocessing_output_dir+"/"+ t
    dataset = load_dataset(p, 0)
    with open(p+"/"+ "plans.pkl", 'rb') as f:
        plans = pickle.load(f)

    basic_patch_size = get_patch_size(np.array(plans['stage_properties'][0].patch_size),
                                      default_3D_augmentation_params['rotation_x'],
                                      default_3D_augmentation_params['rotation_y'],
                                      default_3D_augmentation_params['rotation_z'],
                                      default_3D_augmentation_params['scale_range'])

    dl = DataLoader3D(dataset, basic_patch_size, np.array(plans['stage_properties'][0].patch_size).astype(int), 1)
    tr, val = get_default_augmentation(dl, dl, np.array(plans['stage_properties'][0].patch_size).astype(int))