Example #1
0
def main(_):
  config, store_dir, img_dir = FLAGS.config, FLAGS.store_dir, FLAGS.img_dir
  assert store_dir is not None
  assert img_dir is not None
  model_name, gen_data_dir = config.model.name, FLAGS.gen_data_dir
  needs_gen = model_name in ['color_upsampler', 'spatial_upsampler']

  batch_size = get_batch_size(model_name)
  store_dir = get_store_dir(model_name, store_dir)
  num_files = len(tf.io.gfile.listdir(img_dir))

  if needs_gen:
    assert gen_data_dir is not None
    gen_dataset = create_gen_dataset_from_images(gen_data_dir, batch_size)
    gen_dataset_iter = iter(gen_dataset)

  dataset = create_grayscale_dataset_from_images(FLAGS.img_dir, batch_size)
  dataset_iter = iter(dataset)

  model, optimizer, ema = build_model(config)
  checkpoints = train_utils.create_checkpoint(model, optimizer=optimizer,
                                              ema=ema)
  train_utils.restore(model, checkpoints, FLAGS.logdir, ema)
  num_steps_v = optimizer.iterations.numpy()
  logging.info('Producing sample after %d training steps.', num_steps_v)

  num_epochs = int(np.ceil(num_files / batch_size))
  logging.info(num_epochs)

  for _ in range(num_epochs):
    gray, gray_64, child_paths = next(dataset_iter)

    if needs_gen:
      prev_gen = next(gen_dataset_iter)

    if model_name == 'coltran_core':
      out = model.sample(gray_64, mode='sample')
      samples = out['auto_sample']
    elif model_name == 'color_upsampler':
      prev_gen = base_utils.convert_bits(prev_gen, n_bits_in=8, n_bits_out=3)
      out = model.sample(bit_cond=prev_gen, gray_cond=gray_64)
      samples = out['bit_up_argmax']
    else:
      prev_gen = datasets_utils.change_resolution(prev_gen, 256)
      out = model.sample(gray_cond=gray, inputs=prev_gen, mode='argmax')
      samples = out['high_res_argmax']

    child_paths = child_paths.numpy()
    child_paths = [child_path.decode('utf-8') for child_path in child_paths]
    logging.info(child_paths)

    for sample, child_path in zip(samples, child_paths):
      write_path = os.path.join(store_dir, child_path)
      logging.info(write_path)
      sample = sample.numpy().astype(np.uint8)
      logging.info(sample.shape)
      with tf.io.gfile.GFile(write_path, 'wb') as f:
        plt.imsave(f, sample)
Example #2
0
def evaluate(logdir, subset):
    """Executes the evaluation loop."""
    config = FLAGS.config
    strategy, batch_size = train_utils.setup_strategy(config, FLAGS.master,
                                                      FLAGS.devices_per_worker,
                                                      FLAGS.mode,
                                                      FLAGS.accelerator_type)

    def input_fn(_=None):
        return datasets.get_dataset(name=config.dataset,
                                    config=config,
                                    batch_size=config.eval_batch_size,
                                    subset=subset)

    model, optimizer, ema = train_utils.with_strategy(
        lambda: build(config, batch_size, False), strategy)

    metric_keys = ['loss', 'total_loss']
    # metric_keys += model.metric_keys
    metrics = {}
    for metric_key in metric_keys:
        func = functools.partial(tf.keras.metrics.Mean, metric_key)
        curr_metric = train_utils.with_strategy(func, strategy)
        metrics[metric_key] = curr_metric

    checkpoints = train_utils.with_strategy(
        lambda: train_utils.create_checkpoint(model, optimizer, ema), strategy)
    dataset = train_utils.dataset_with_strategy(input_fn, strategy)

    def step_fn(batch):
        _, extra = loss_on_batch(batch, model, config, training=False)

        for metric_key in metric_keys:
            curr_metric = metrics[metric_key]
            curr_scalar = extra['scalar'][metric_key]
            curr_metric.update_state(curr_scalar)

    num_examples = config.eval_num_examples
    eval_step = train_utils.step_with_strategy(step_fn, strategy)
    ckpt_path = None
    wait_max = config.get('eval_checkpoint_wait_secs',
                          config.save_checkpoint_secs * 100)
    is_ema = True if ema else False

    eval_summary_dir = os.path.join(
        logdir, 'eval_{}_summaries_pyk_{}'.format(subset, is_ema))
    writer = tf.summary.create_file_writer(eval_summary_dir)

    while True:
        ckpt_path = train_utils.wait_for_checkpoint(logdir, ckpt_path,
                                                    wait_max)
        logging.info(ckpt_path)
        if ckpt_path is None:
            logging.info('Timed out waiting for checkpoint.')
            break

        train_utils.with_strategy(
            lambda: train_utils.restore(model, checkpoints, logdir, ema),
            strategy)
        data_iterator = iter(dataset)
        num_steps = num_examples // batch_size

        for metric_key, metric in metrics.items():
            metric.reset_states()

        logging.info('Starting evaluation.')
        done = False
        for i in range(0, num_steps, FLAGS.steps_per_summaries):
            start_run = time.time()
            for k in range(min(num_steps - i, FLAGS.steps_per_summaries)):
                try:
                    if k % 10 == 0:
                        logging.info('Step: %d', (i + k + 1))
                    eval_step(data_iterator)
                except (StopIteration, tf.errors.OutOfRangeError):
                    done = True
                    break
            if done:
                break
            bits_per_dim = metrics['loss'].result()
            logging.info(
                'Bits/Dim: %.3f, Speed: %.3f seconds/step, Step: %d/%d',
                bits_per_dim,
                (time.time() - start_run) / FLAGS.steps_per_summaries,
                i + k + 1, num_steps)

        # logging.info('Final Bits/Dim: %.3f', bits_per_dim)
        with writer.as_default():
            for metric_key, metric in metrics.items():
                curr_scalar = metric.result().numpy()
                tf.summary.scalar(metric_key,
                                  curr_scalar,
                                  step=optimizer.iterations)
Example #3
0
def store_samples(data, config, logdir, gen_dataset=None):
    """Stores the generated samples."""
    downsample_res = config.get('downsample_res', 64)
    num_samples = config.sample.num_samples
    num_outputs = config.sample.num_outputs
    batch_size = config.sample.get('batch_size', 1)
    sample_mode = config.sample.get('mode', 'argmax')
    gen_file = config.sample.get('gen_file', 'gen')

    model, optimizer, ema = build(config, 1, False)
    checkpoints = train_utils.create_checkpoint(model, optimizer, ema)
    sample_dir = create_sample_dir(logdir, config)
    record_path = os.path.join(sample_dir, '%s.tfrecords' % gen_file)
    writer = tf.io.TFRecordWriter(record_path)

    train_utils.restore(model, checkpoints, logdir, ema)
    num_steps_v = optimizer.iterations.numpy()
    logging.info('Producing sample after %d training steps.', num_steps_v)

    logging.info(gen_dataset)
    for batch_ind in range(num_outputs // batch_size):
        next_data = data.next()

        if gen_dataset is not None:
            next_gen_data = gen_dataset.next()

        # Gets grayscale image based on the model.
        curr_gray = get_grayscale_at_sample_time(next_data, downsample_res,
                                                 config.model.name)

        curr_output = collections.defaultdict(list)
        for sample_ind in range(num_samples):
            logging.info('Batch no: %d, Sample no: %d', batch_ind, sample_ind)

            if config.model.name == 'color_upsampler':

                if gen_dataset is not None:
                    # Provide generated coarse color inputs.
                    scaled_rgb = next_gen_data['targets']
                else:
                    # Provide coarse color ground truth inputs.
                    scaled_rgb = next_data['targets_%d' % downsample_res]
                bit_rgb = base_utils.convert_bits(scaled_rgb,
                                                  n_bits_in=8,
                                                  n_bits_out=3)
                output = model.sample(gray_cond=curr_gray,
                                      bit_cond=bit_rgb,
                                      mode=sample_mode)

            elif config.model.name == 'spatial_upsampler':
                if gen_dataset is not None:
                    # Provide low resolution generated image.
                    low_res = next_gen_data['targets']
                    low_res = datasets_utils.change_resolution(low_res, 256)
                else:
                    # Provide low resolution ground truth image.
                    low_res = next_data['targets_%d_up_back' % downsample_res]
                output = model.sample(gray_cond=curr_gray,
                                      inputs=low_res,
                                      mode=sample_mode)
            else:
                output = model.sample(gray_cond=curr_gray, mode=sample_mode)
            logging.info('Done sampling')

            for out_key, out_item in output.items():
                curr_output[out_key].append(out_item.numpy())

        # concatenate samples across width.
        for out_key, out_val in curr_output.items():
            curr_out_val = np.concatenate(out_val, axis=2)
            curr_output[out_key] = curr_out_val

            if ('sample' in out_key or 'argmax' in out_key):
                save_str = f'Saving {(batch_ind + 1) * batch_size} samples'
                logging.info(save_str)
                for single_ex in curr_out_val:
                    serialized = array_to_tf_example(single_ex)
                    writer.write(serialized)

    writer.close()