def main(_): config, store_dir, img_dir = FLAGS.config, FLAGS.store_dir, FLAGS.img_dir assert store_dir is not None assert img_dir is not None model_name, gen_data_dir = config.model.name, FLAGS.gen_data_dir needs_gen = model_name in ['color_upsampler', 'spatial_upsampler'] batch_size = get_batch_size(model_name) store_dir = get_store_dir(model_name, store_dir) num_files = len(tf.io.gfile.listdir(img_dir)) if needs_gen: assert gen_data_dir is not None gen_dataset = create_gen_dataset_from_images(gen_data_dir, batch_size) gen_dataset_iter = iter(gen_dataset) dataset = create_grayscale_dataset_from_images(FLAGS.img_dir, batch_size) dataset_iter = iter(dataset) model, optimizer, ema = build_model(config) checkpoints = train_utils.create_checkpoint(model, optimizer=optimizer, ema=ema) train_utils.restore(model, checkpoints, FLAGS.logdir, ema) num_steps_v = optimizer.iterations.numpy() logging.info('Producing sample after %d training steps.', num_steps_v) num_epochs = int(np.ceil(num_files / batch_size)) logging.info(num_epochs) for _ in range(num_epochs): gray, gray_64, child_paths = next(dataset_iter) if needs_gen: prev_gen = next(gen_dataset_iter) if model_name == 'coltran_core': out = model.sample(gray_64, mode='sample') samples = out['auto_sample'] elif model_name == 'color_upsampler': prev_gen = base_utils.convert_bits(prev_gen, n_bits_in=8, n_bits_out=3) out = model.sample(bit_cond=prev_gen, gray_cond=gray_64) samples = out['bit_up_argmax'] else: prev_gen = datasets_utils.change_resolution(prev_gen, 256) out = model.sample(gray_cond=gray, inputs=prev_gen, mode='argmax') samples = out['high_res_argmax'] child_paths = child_paths.numpy() child_paths = [child_path.decode('utf-8') for child_path in child_paths] logging.info(child_paths) for sample, child_path in zip(samples, child_paths): write_path = os.path.join(store_dir, child_path) logging.info(write_path) sample = sample.numpy().astype(np.uint8) logging.info(sample.shape) with tf.io.gfile.GFile(write_path, 'wb') as f: plt.imsave(f, sample)
def evaluate(logdir, subset): """Executes the evaluation loop.""" config = FLAGS.config strategy, batch_size = train_utils.setup_strategy(config, FLAGS.master, FLAGS.devices_per_worker, FLAGS.mode, FLAGS.accelerator_type) def input_fn(_=None): return datasets.get_dataset(name=config.dataset, config=config, batch_size=config.eval_batch_size, subset=subset) model, optimizer, ema = train_utils.with_strategy( lambda: build(config, batch_size, False), strategy) metric_keys = ['loss', 'total_loss'] # metric_keys += model.metric_keys metrics = {} for metric_key in metric_keys: func = functools.partial(tf.keras.metrics.Mean, metric_key) curr_metric = train_utils.with_strategy(func, strategy) metrics[metric_key] = curr_metric checkpoints = train_utils.with_strategy( lambda: train_utils.create_checkpoint(model, optimizer, ema), strategy) dataset = train_utils.dataset_with_strategy(input_fn, strategy) def step_fn(batch): _, extra = loss_on_batch(batch, model, config, training=False) for metric_key in metric_keys: curr_metric = metrics[metric_key] curr_scalar = extra['scalar'][metric_key] curr_metric.update_state(curr_scalar) num_examples = config.eval_num_examples eval_step = train_utils.step_with_strategy(step_fn, strategy) ckpt_path = None wait_max = config.get('eval_checkpoint_wait_secs', config.save_checkpoint_secs * 100) is_ema = True if ema else False eval_summary_dir = os.path.join( logdir, 'eval_{}_summaries_pyk_{}'.format(subset, is_ema)) writer = tf.summary.create_file_writer(eval_summary_dir) while True: ckpt_path = train_utils.wait_for_checkpoint(logdir, ckpt_path, wait_max) logging.info(ckpt_path) if ckpt_path is None: logging.info('Timed out waiting for checkpoint.') break train_utils.with_strategy( lambda: train_utils.restore(model, checkpoints, logdir, ema), strategy) data_iterator = iter(dataset) num_steps = num_examples // batch_size for metric_key, metric in metrics.items(): metric.reset_states() logging.info('Starting evaluation.') done = False for i in range(0, num_steps, FLAGS.steps_per_summaries): start_run = time.time() for k in range(min(num_steps - i, FLAGS.steps_per_summaries)): try: if k % 10 == 0: logging.info('Step: %d', (i + k + 1)) eval_step(data_iterator) except (StopIteration, tf.errors.OutOfRangeError): done = True break if done: break bits_per_dim = metrics['loss'].result() logging.info( 'Bits/Dim: %.3f, Speed: %.3f seconds/step, Step: %d/%d', bits_per_dim, (time.time() - start_run) / FLAGS.steps_per_summaries, i + k + 1, num_steps) # logging.info('Final Bits/Dim: %.3f', bits_per_dim) with writer.as_default(): for metric_key, metric in metrics.items(): curr_scalar = metric.result().numpy() tf.summary.scalar(metric_key, curr_scalar, step=optimizer.iterations)
def store_samples(data, config, logdir, gen_dataset=None): """Stores the generated samples.""" downsample_res = config.get('downsample_res', 64) num_samples = config.sample.num_samples num_outputs = config.sample.num_outputs batch_size = config.sample.get('batch_size', 1) sample_mode = config.sample.get('mode', 'argmax') gen_file = config.sample.get('gen_file', 'gen') model, optimizer, ema = build(config, 1, False) checkpoints = train_utils.create_checkpoint(model, optimizer, ema) sample_dir = create_sample_dir(logdir, config) record_path = os.path.join(sample_dir, '%s.tfrecords' % gen_file) writer = tf.io.TFRecordWriter(record_path) train_utils.restore(model, checkpoints, logdir, ema) num_steps_v = optimizer.iterations.numpy() logging.info('Producing sample after %d training steps.', num_steps_v) logging.info(gen_dataset) for batch_ind in range(num_outputs // batch_size): next_data = data.next() if gen_dataset is not None: next_gen_data = gen_dataset.next() # Gets grayscale image based on the model. curr_gray = get_grayscale_at_sample_time(next_data, downsample_res, config.model.name) curr_output = collections.defaultdict(list) for sample_ind in range(num_samples): logging.info('Batch no: %d, Sample no: %d', batch_ind, sample_ind) if config.model.name == 'color_upsampler': if gen_dataset is not None: # Provide generated coarse color inputs. scaled_rgb = next_gen_data['targets'] else: # Provide coarse color ground truth inputs. scaled_rgb = next_data['targets_%d' % downsample_res] bit_rgb = base_utils.convert_bits(scaled_rgb, n_bits_in=8, n_bits_out=3) output = model.sample(gray_cond=curr_gray, bit_cond=bit_rgb, mode=sample_mode) elif config.model.name == 'spatial_upsampler': if gen_dataset is not None: # Provide low resolution generated image. low_res = next_gen_data['targets'] low_res = datasets_utils.change_resolution(low_res, 256) else: # Provide low resolution ground truth image. low_res = next_data['targets_%d_up_back' % downsample_res] output = model.sample(gray_cond=curr_gray, inputs=low_res, mode=sample_mode) else: output = model.sample(gray_cond=curr_gray, mode=sample_mode) logging.info('Done sampling') for out_key, out_item in output.items(): curr_output[out_key].append(out_item.numpy()) # concatenate samples across width. for out_key, out_val in curr_output.items(): curr_out_val = np.concatenate(out_val, axis=2) curr_output[out_key] = curr_out_val if ('sample' in out_key or 'argmax' in out_key): save_str = f'Saving {(batch_ind + 1) * batch_size} samples' logging.info(save_str) for single_ex in curr_out_val: serialized = array_to_tf_example(single_ex) writer.write(serialized) writer.close()