예제 #1
0
파일: loader.py 프로젝트: haelkemary/xy_net
def train_generator(ds, shape_aug=None, input_aug=None, label_aug=None, batch_size=16, nr_procs=8):
    ### augment both the input and label
    ds = ds if shape_aug is None else AugmentImageComponents(ds, shape_aug, (0, 1), copy=True)
    ### augment just the input i.e index 0 within each yield of DatasetSerial
    ds = ds if input_aug is None else AugmentImageComponent(ds, input_aug, index=0, copy=False)
    ### augment just the output i.e index 1 within each yield of DatasetSerial
    ds = ds if label_aug is None else AugmentImageComponent(ds, label_aug, index=1, copy=True)
    #
    ds = BatchDataByShape(ds, batch_size, idx=0)
    ds = PrefetchDataZMQ(ds, nr_procs)
    return ds
예제 #2
0
def train_generator_class(ds, shape_aug=None, input_aug=None, batch_size=16, nr_procs=8):
    ### augment the input
    ds = ds if shape_aug is None else AugmentImageComponent(
        ds, shape_aug, index=0, copy=True)
    ### augment the input i.e index 0 within each yield of DatasetSerial
    ds = ds if input_aug is None else AugmentImageComponent(
        ds, input_aug, index=0, copy=False)
    #
    ds = BatchDataByShape(ds, batch_size, idx=0)
    ds = PrefetchDataZMQ(ds, nr_procs)
    return ds