def sample(logdir, subset): """Executes the sampling loop.""" logging.info('Beginning sampling loop...') config = FLAGS.config batch_size = config.sample.get('batch_size', 1) # used to parallelize sampling jobs. skip_batches = config.sample.get('skip_batches', 0) gen_data_dir = config.sample.get('gen_data_dir', None) is_gen = gen_data_dir is not None model_name = config.model.get('name') if not is_gen and 'upsampler' in model_name: logging.info('Generated low resolution not provided, using ground ' 'truth input.') # Get ground truth dataset for grayscale image. tf_dataset = datasets.get_dataset(name=config.dataset, config=config, batch_size=batch_size, subset=subset) tf_dataset = tf_dataset.skip(skip_batches) data_iter = iter(tf_dataset) # Creates dataset from generated TFRecords. # This is used as low resolution input to the upsamplers. gen_iter = None if is_gen: gen_tf_dataset = datasets.get_gen_dataset(data_dir=gen_data_dir, batch_size=batch_size) gen_tf_dataset = gen_tf_dataset.skip(skip_batches) gen_iter = iter(gen_tf_dataset) store_samples(data_iter, config, logdir, gen_iter)
def input_fn(input_context=None): read_config = None if input_context is not None: read_config = tfds.ReadConfig(input_context=input_context) dataset = datasets.get_dataset(name=config.dataset, config=config, batch_size=config.batch_size, subset='train', read_config=read_config) return dataset
def input_fn(_=None): return datasets.get_dataset(name=config.dataset, config=config, batch_size=config.eval_batch_size, subset=subset)
def input_fn(_=None): dataset = datasets.get_dataset(name=config.dataset, config=config, batch_size=config.batch_size, subset='train') return dataset