def _make_dataflow(batch_size=1, use_prefetch=False): img_shape = get_env('dataset.img_shape', (64, 64)) data_dir = get_env('dir.data') db_a = osp.join(data_dir, get_env('dataset.db_a')) db_b = osp.join(data_dir, get_env('dataset.db_b')) assert osp.exists(db_a) and osp.exists(db_b), ( 'Unknown database: {} and {}. If you haven\'t downloaded them,' 'please run scripts in TensorArtist/scripts/dataset-tools/pix2pix and put the generated dataset lmdb' 'in the corresponding position.'.format(db_a, db_b)) dfs = [] dfa = flow.KVStoreRandomSampleDataFlow(lambda: kvstore.LMDBKVStore(db_a)) dfb = flow.KVStoreRandomSampleDataFlow(lambda: kvstore.LMDBKVStore(db_b)) df = DiscoGANSplitDataFlow(dfa, dfb) df = flow.BatchDataFlow(df, batch_size, sample_dict={ 'img_a': np.empty(shape=(batch_size, img_shape[0], img_shape[1], 3), dtype='float32'), 'img_b': np.empty(shape=(batch_size, img_shape[0], img_shape[1], 3), dtype='float32'), }) if use_prefetch: df = flow.MPPrefetchDataFlow(df, nr_workers=2) return df df = gan.GANDataFlow(dfs[0], dfs[1], get_env('trainer.nr_g_per_iter', 1), get_env('trainer.nr_d_per_iter', 1))
def make_dataflow_train(env): batch_size = get_env('trainer.batch_size') df = flow.QueueDataFlow(env.data_queue) df = flow.BatchDataFlow(df, batch_size, sample_dict={ 'state': np.empty((batch_size, ) + get_input_shape(), dtype='float32'), 'action': np.empty((batch_size, ) + get_action_shape(), dtype='int64'), 'future_reward': np.empty((batch_size, ), dtype='float32') }) return df
def make_dataflow_train(env): ensure_load() batch_size = get_env('trainer.batch_size') df = _mnist[0] df = flow.DOARandomSampleDataFlow(df) df = flow.BatchDataFlow(df, batch_size, sample_dict={'img': np.empty(shape=(batch_size, 28, 28, 1), dtype='float32'), }) return df
def make_dataflow_inference(env): ensure_load() batch_size = get_env('inference.batch_size') epoch_size = get_env('inference.epoch_size') df = _mnist[1] # use validation set actually df = flow.DictOfArrayDataFlow(df) df = flow.tools.cycle(df) df = flow.BatchDataFlow(df, batch_size, sample_dict={'img': np.empty(shape=(batch_size, 28, 28, 1), dtype='float32'), }) df = flow.EpochDataFlow(df, epoch_size) return df
def make_dataflow_train(env): num_classes = get_env('dataset.nr_classes') ensure_load(num_classes) batch_size = get_env('trainer.batch_size') df = _cifar[0] df = flow.DOARandomSampleDataFlow(df) df = flow.BatchDataFlow(df, batch_size, sample_dict={ 'img': np.empty(shape=(batch_size, _cifar_img_dim, _cifar_img_dim, 3), dtype='float32'), 'label': np.empty(shape=(batch_size, ), dtype='int32') }) return df
def make_dataflow_inference(env): num_classes = get_env('dataset.nr_classes') ensure_load(num_classes) batch_size = get_env('inference.batch_size') epoch_size = get_env('inference.epoch_size') df = _cifar[1] # use validation set actually df = flow.DictOfArrayDataFlow(df) df = flow.tools.cycle(df) df = flow.BatchDataFlow(df, batch_size, sample_dict={ 'img': np.empty(shape=(batch_size, _cifar_img_dim, _cifar_img_dim, 3), dtype='float32'), 'label': np.empty(shape=(batch_size, ), dtype='int32') }) df = flow.EpochDataFlow(df, epoch_size) return df