예제 #1
0
def main(_):
  config, store_dir, img_dir = FLAGS.config, FLAGS.store_dir, FLAGS.img_dir
  assert store_dir is not None
  assert img_dir is not None
  model_name, gen_data_dir = config.model.name, FLAGS.gen_data_dir
  needs_gen = model_name in ['color_upsampler', 'spatial_upsampler']

  batch_size = get_batch_size(model_name)
  store_dir = get_store_dir(model_name, store_dir)
  num_files = len(tf.io.gfile.listdir(img_dir))

  if needs_gen:
    assert gen_data_dir is not None
    gen_dataset = create_gen_dataset_from_images(gen_data_dir, batch_size)
    gen_dataset_iter = iter(gen_dataset)

  dataset = create_grayscale_dataset_from_images(FLAGS.img_dir, batch_size)
  dataset_iter = iter(dataset)

  model, optimizer, ema = build_model(config)
  checkpoints = train_utils.create_checkpoint(model, optimizer=optimizer,
                                              ema=ema)
  train_utils.restore(model, checkpoints, FLAGS.logdir, ema)
  num_steps_v = optimizer.iterations.numpy()
  logging.info('Producing sample after %d training steps.', num_steps_v)

  num_epochs = int(np.ceil(num_files / batch_size))
  logging.info(num_epochs)

  for _ in range(num_epochs):
    gray, gray_64, child_paths = next(dataset_iter)

    if needs_gen:
      prev_gen = next(gen_dataset_iter)

    if model_name == 'coltran_core':
      out = model.sample(gray_64, mode='sample')
      samples = out['auto_sample']
    elif model_name == 'color_upsampler':
      prev_gen = base_utils.convert_bits(prev_gen, n_bits_in=8, n_bits_out=3)
      out = model.sample(bit_cond=prev_gen, gray_cond=gray_64)
      samples = out['bit_up_argmax']
    else:
      prev_gen = datasets_utils.change_resolution(prev_gen, 256)
      out = model.sample(gray_cond=gray, inputs=prev_gen, mode='argmax')
      samples = out['high_res_argmax']

    child_paths = child_paths.numpy()
    child_paths = [child_path.decode('utf-8') for child_path in child_paths]
    logging.info(child_paths)

    for sample, child_path in zip(samples, child_paths):
      write_path = os.path.join(store_dir, child_path)
      logging.info(write_path)
      sample = sample.numpy().astype(np.uint8)
      logging.info(sample.shape)
      with tf.io.gfile.GFile(write_path, 'wb') as f:
        plt.imsave(f, sample)
예제 #2
0
    def load_and_preprocess_image(path, child_path):
        image_str = tf.io.read_file(path)
        num_channels = 1 if FLAGS.mode == 'colorize' else 3
        image = tf.image.decode_image(image_str, channels=num_channels)

        # Central crop to square and resize to 256x256.
        image = datasets.resize_to_square(image, resolution=256, train=False)

        # Resize to a low resolution image.
        image_64 = datasets_utils.change_resolution(image, res=64)
        if FLAGS.mode == 'recolorize':
            image = tf.image.rgb_to_grayscale(image)
            image_64 = tf.image.rgb_to_grayscale(image_64)
        return image, image_64, child_path
예제 #3
0
def resize_to_square(image, resolution=32, train=True):
  """Preprocess the image in a way that is OK for generative modeling."""

  # Crop a square-shaped image by shortening the longer side.
  image_shape = tf.shape(image)
  height, width, channels = image_shape[0], image_shape[1], image_shape[2]
  side_size = tf.minimum(height, width)
  cropped_shape = tf.stack([side_size, side_size, channels])
  if train:
    image = tf.image.random_crop(image, cropped_shape)
  else:
    image = tf.image.resize_with_crop_or_pad(
        image, target_height=side_size, target_width=side_size)

  image = datasets_utils.change_resolution(image, res=resolution, method='area')
  return image
예제 #4
0
def store_samples(data, config, logdir, gen_dataset=None):
    """Stores the generated samples."""
    downsample_res = config.get('downsample_res', 64)
    num_samples = config.sample.num_samples
    num_outputs = config.sample.num_outputs
    batch_size = config.sample.get('batch_size', 1)
    sample_mode = config.sample.get('mode', 'argmax')
    gen_file = config.sample.get('gen_file', 'gen')

    model, optimizer, ema = build(config, 1, False)
    checkpoints = train_utils.create_checkpoint(model, optimizer, ema)
    sample_dir = create_sample_dir(logdir, config)
    record_path = os.path.join(sample_dir, '%s.tfrecords' % gen_file)
    writer = tf.io.TFRecordWriter(record_path)

    train_utils.restore(model, checkpoints, logdir, ema)
    num_steps_v = optimizer.iterations.numpy()
    logging.info('Producing sample after %d training steps.', num_steps_v)

    logging.info(gen_dataset)
    for batch_ind in range(num_outputs // batch_size):
        next_data = data.next()

        if gen_dataset is not None:
            next_gen_data = gen_dataset.next()

        # Gets grayscale image based on the model.
        curr_gray = get_grayscale_at_sample_time(next_data, downsample_res,
                                                 config.model.name)

        curr_output = collections.defaultdict(list)
        for sample_ind in range(num_samples):
            logging.info('Batch no: %d, Sample no: %d', batch_ind, sample_ind)

            if config.model.name == 'color_upsampler':

                if gen_dataset is not None:
                    # Provide generated coarse color inputs.
                    scaled_rgb = next_gen_data['targets']
                else:
                    # Provide coarse color ground truth inputs.
                    scaled_rgb = next_data['targets_%d' % downsample_res]
                bit_rgb = base_utils.convert_bits(scaled_rgb,
                                                  n_bits_in=8,
                                                  n_bits_out=3)
                output = model.sample(gray_cond=curr_gray,
                                      bit_cond=bit_rgb,
                                      mode=sample_mode)

            elif config.model.name == 'spatial_upsampler':
                if gen_dataset is not None:
                    # Provide low resolution generated image.
                    low_res = next_gen_data['targets']
                    low_res = datasets_utils.change_resolution(low_res, 256)
                else:
                    # Provide low resolution ground truth image.
                    low_res = next_data['targets_%d_up_back' % downsample_res]
                output = model.sample(gray_cond=curr_gray,
                                      inputs=low_res,
                                      mode=sample_mode)
            else:
                output = model.sample(gray_cond=curr_gray, mode=sample_mode)
            logging.info('Done sampling')

            for out_key, out_item in output.items():
                curr_output[out_key].append(out_item.numpy())

        # concatenate samples across width.
        for out_key, out_val in curr_output.items():
            curr_out_val = np.concatenate(out_val, axis=2)
            curr_output[out_key] = curr_out_val

            if ('sample' in out_key or 'argmax' in out_key):
                save_str = f'Saving {(batch_ind + 1) * batch_size} samples'
                logging.info(save_str)
                for single_ex in curr_out_val:
                    serialized = array_to_tf_example(single_ex)
                    writer.write(serialized)

    writer.close()