def get_train_dataflow(add_mask=True): """ """ if config.CROSS_VALIDATION: imgs = BRATS_SEG.load_from_file(config.BASEDIR, config.TRAIN_DATASET) else: imgs = BRATS_SEG.load_many(config.BASEDIR, config.TRAIN_DATASET, add_gt=False, add_mask=add_mask) # no filter for training imgs = list(imgs) ds = DataFromList(imgs, shuffle=True) def preprocess(data): if config.NO_CACHE: fname, gt, im = data['file_name'], data['gt'], data['image_data'] volume_list, label, weight, _, _ = crop_brain_region(im, gt) batch = sampler3d(volume_list, label, weight) else: volume_list, label, weight, _, _ = data['preprocessed'] batch = sampler3d(volume_list, label, weight) return [batch['images'], batch['weights'], batch['labels']] ds = BatchData(MapData(ds, preprocess), config.BATCH_SIZE) ds = PrefetchDataZMQ(ds, 6) return ds
def get_eval_dataflow(): #if config.CROSS_VALIDATION: imgs = BRATS_SEG.load_from_file(config.BASEDIR, config.VAL_DATASET) # no filter for training ds = DataFromListOfDict(imgs, ['file_name', 'id', 'preprocessed']) def f(data): volume_list, label, weight, original_shape, bbox = data batch = sampler3d_whole(volume_list, label, weight, original_shape, bbox) return batch ds = MapDataComponent(ds, f, 2) ds = PrefetchDataZMQ(ds, 1) return ds