def _make_dataflow(batch_size=1, use_prefetch=False):
    img_shape = get_env('dataset.img_shape', (64, 64))

    data_dir = get_env('dir.data')
    db_a = osp.join(data_dir, get_env('dataset.db_a'))
    db_b = osp.join(data_dir, get_env('dataset.db_b'))

    assert osp.exists(db_a) and osp.exists(db_b), (
        'Unknown database: {} and {}. If you haven\'t downloaded them,'
        'please run scripts in TensorArtist/scripts/dataset-tools/pix2pix and put the generated dataset lmdb'
        'in the corresponding position.'.format(db_a, db_b))

    dfs = []
    dfa = flow.KVStoreRandomSampleDataFlow(lambda: kvstore.LMDBKVStore(db_a))
    dfb = flow.KVStoreRandomSampleDataFlow(lambda: kvstore.LMDBKVStore(db_b))
    df = DiscoGANSplitDataFlow(dfa, dfb)
    df = flow.BatchDataFlow(df,
                            batch_size,
                            sample_dict={
                                'img_a':
                                np.empty(shape=(batch_size, img_shape[0],
                                                img_shape[1], 3),
                                         dtype='float32'),
                                'img_b':
                                np.empty(shape=(batch_size, img_shape[0],
                                                img_shape[1], 3),
                                         dtype='float32'),
                            })
    if use_prefetch:
        df = flow.MPPrefetchDataFlow(df, nr_workers=2)
    return df

    df = gan.GANDataFlow(dfs[0], dfs[1], get_env('trainer.nr_g_per_iter', 1),
                         get_env('trainer.nr_d_per_iter', 1))
Exemplo n.º 2
0
def make_dataflow_train(env):
    batch_size = get_env('trainer.batch_size')

    df = flow.QueueDataFlow(env.data_queue)
    df = flow.BatchDataFlow(df, batch_size, sample_dict={
        'state': np.empty((batch_size, ) + get_input_shape(), dtype='float32'),
        'action': np.empty((batch_size, ) + get_action_shape(), dtype='int64'),
        'future_reward': np.empty((batch_size, ), dtype='float32')
    })
    return df
Exemplo n.º 3
0
def make_dataflow_train(env):
    ensure_load()
    batch_size = get_env('trainer.batch_size')

    df = _mnist[0]
    df = flow.DOARandomSampleDataFlow(df)
    df = flow.BatchDataFlow(df, batch_size,
                            sample_dict={'img': np.empty(shape=(batch_size, 28, 28, 1), dtype='float32'), })

    return df
Exemplo n.º 4
0
def make_dataflow_inference(env):
    ensure_load()
    batch_size = get_env('inference.batch_size')
    epoch_size = get_env('inference.epoch_size')

    df = _mnist[1]  # use validation set actually
    df = flow.DictOfArrayDataFlow(df)
    df = flow.tools.cycle(df)
    df = flow.BatchDataFlow(df, batch_size,
                            sample_dict={'img': np.empty(shape=(batch_size, 28, 28, 1), dtype='float32'), })
    df = flow.EpochDataFlow(df, epoch_size)

    return df
Exemplo n.º 5
0
def make_dataflow_train(env):
    num_classes = get_env('dataset.nr_classes')
    ensure_load(num_classes)
    batch_size = get_env('trainer.batch_size')

    df = _cifar[0]
    df = flow.DOARandomSampleDataFlow(df)
    df = flow.BatchDataFlow(df,
                            batch_size,
                            sample_dict={
                                'img':
                                np.empty(shape=(batch_size, _cifar_img_dim,
                                                _cifar_img_dim, 3),
                                         dtype='float32'),
                                'label':
                                np.empty(shape=(batch_size, ), dtype='int32')
                            })

    return df
Exemplo n.º 6
0
def make_dataflow_inference(env):
    num_classes = get_env('dataset.nr_classes')
    ensure_load(num_classes)
    batch_size = get_env('inference.batch_size')
    epoch_size = get_env('inference.epoch_size')

    df = _cifar[1]  # use validation set actually
    df = flow.DictOfArrayDataFlow(df)
    df = flow.tools.cycle(df)
    df = flow.BatchDataFlow(df,
                            batch_size,
                            sample_dict={
                                'img':
                                np.empty(shape=(batch_size, _cifar_img_dim,
                                                _cifar_img_dim, 3),
                                         dtype='float32'),
                                'label':
                                np.empty(shape=(batch_size, ), dtype='int32')
                            })
    df = flow.EpochDataFlow(df, epoch_size)

    return df