def get_train_dataflow(add_mask=True):
    """
    
    """
    if config.CROSS_VALIDATION:
        imgs = BRATS_SEG.load_from_file(config.BASEDIR, config.TRAIN_DATASET)
    else:
        imgs = BRATS_SEG.load_many(config.BASEDIR,
                                   config.TRAIN_DATASET,
                                   add_gt=False,
                                   add_mask=add_mask)
    # no filter for training
    imgs = list(imgs)

    ds = DataFromList(imgs, shuffle=True)

    def preprocess(data):
        if config.NO_CACHE:
            fname, gt, im = data['file_name'], data['gt'], data['image_data']
            volume_list, label, weight, _, _ = crop_brain_region(im, gt)
            batch = sampler3d(volume_list, label, weight)
        else:
            volume_list, label, weight, _, _ = data['preprocessed']
            batch = sampler3d(volume_list, label, weight)
        return [batch['images'], batch['weights'], batch['labels']]

    ds = BatchData(MapData(ds, preprocess), config.BATCH_SIZE)
    ds = PrefetchDataZMQ(ds, 6)
    return ds
示例#2
0
def get_eval_dataflow():
    #if config.CROSS_VALIDATION:
    imgs = BRATS_SEG.load_from_file(config.BASEDIR, config.VAL_DATASET)
    # no filter for training
    ds = DataFromListOfDict(imgs, ['file_name', 'id', 'preprocessed'])

    def f(data):
        volume_list, label, weight, original_shape, bbox = data
        batch = sampler3d_whole(volume_list, label, weight, original_shape, bbox)
        return batch
    ds = MapDataComponent(ds, f, 2)
    ds = PrefetchDataZMQ(ds, 1)
    return ds