Beispiel #1
0
def sample(logdir, subset):
    """Executes the sampling loop."""
    logging.info('Beginning sampling loop...')
    config = FLAGS.config
    batch_size = config.sample.get('batch_size', 1)
    # used to parallelize sampling jobs.
    skip_batches = config.sample.get('skip_batches', 0)
    gen_data_dir = config.sample.get('gen_data_dir', None)
    is_gen = gen_data_dir is not None

    model_name = config.model.get('name')
    if not is_gen and 'upsampler' in model_name:
        logging.info('Generated low resolution not provided, using ground '
                     'truth input.')

    # Get ground truth dataset for grayscale image.
    tf_dataset = datasets.get_dataset(name=config.dataset,
                                      config=config,
                                      batch_size=batch_size,
                                      subset=subset)
    tf_dataset = tf_dataset.skip(skip_batches)
    data_iter = iter(tf_dataset)

    # Creates dataset from generated TFRecords.
    # This is used as low resolution input to the upsamplers.
    gen_iter = None
    if is_gen:
        gen_tf_dataset = datasets.get_gen_dataset(data_dir=gen_data_dir,
                                                  batch_size=batch_size)
        gen_tf_dataset = gen_tf_dataset.skip(skip_batches)
        gen_iter = iter(gen_tf_dataset)

    store_samples(data_iter, config, logdir, gen_iter)
Beispiel #2
0
    def input_fn(input_context=None):
        read_config = None
        if input_context is not None:
            read_config = tfds.ReadConfig(input_context=input_context)

        dataset = datasets.get_dataset(name=config.dataset,
                                       config=config,
                                       batch_size=config.batch_size,
                                       subset='train',
                                       read_config=read_config)
        return dataset
Beispiel #3
0
 def input_fn(_=None):
     return datasets.get_dataset(name=config.dataset,
                                 config=config,
                                 batch_size=config.eval_batch_size,
                                 subset=subset)
Beispiel #4
0
 def input_fn(_=None):
     dataset = datasets.get_dataset(name=config.dataset,
                                    config=config,
                                    batch_size=config.batch_size,
                                    subset='train')
     return dataset