def train_generator(ds, shape_aug=None, input_aug=None, label_aug=None, batch_size=16, nr_procs=8): ### augment both the input and label ds = ds if shape_aug is None else AugmentImageComponents(ds, shape_aug, (0, 1), copy=True) ### augment just the input i.e index 0 within each yield of DatasetSerial ds = ds if input_aug is None else AugmentImageComponent(ds, input_aug, index=0, copy=False) ### augment just the output i.e index 1 within each yield of DatasetSerial ds = ds if label_aug is None else AugmentImageComponent(ds, label_aug, index=1, copy=True) # ds = BatchDataByShape(ds, batch_size, idx=0) ds = PrefetchDataZMQ(ds, nr_procs) return ds
def train_generator_class(ds, shape_aug=None, input_aug=None, batch_size=16, nr_procs=8): ### augment the input ds = ds if shape_aug is None else AugmentImageComponent( ds, shape_aug, index=0, copy=True) ### augment the input i.e index 0 within each yield of DatasetSerial ds = ds if input_aug is None else AugmentImageComponent( ds, input_aug, index=0, copy=False) # ds = BatchDataByShape(ds, batch_size, idx=0) ds = PrefetchDataZMQ(ds, nr_procs) return ds