예제 #1
0
 def post_process_image(self, image):
     """Post process image of size (H, W, 512) to a coarse RGB image."""
     image = base_utils.bins_to_labels(
         image, num_symbols_per_channel=self.num_symbols_per_channel)
     image = base_utils.convert_bits(image, n_bits_in=3, n_bits_out=8)
     image = tf.cast(image, dtype=tf.uint8)
     return image
예제 #2
0
    def loss(self, targets, logits, train_config, training, aux_output=None):
        """Converts targets to coarse colors and computes log-likelihood."""
        downsample = train_config.get('downsample', False)
        downsample_res = train_config.get('downsample_res', 64)
        if downsample:
            labels = targets['targets_%d' % downsample_res]
        else:
            labels = targets['targets']

        if aux_output is None:
            aux_output = {}

        # quantize labels.
        labels = base_utils.convert_bits(labels, n_bits_in=8, n_bits_out=3)

        # bin each channel triplet.
        labels = base_utils.labels_to_bins(labels,
                                           self.num_symbols_per_channel)

        loss = self.image_loss(logits, labels)
        enc_logits = aux_output.get('encoder_logits')
        if enc_logits is None:
            return loss, {}

        enc_loss = self.image_loss(enc_logits, labels)
        return loss, {'encoder': enc_loss}
예제 #3
0
    def call(self, inputs, inputs_slice, channel_index=None, training=True):
        """Upsamples the coarsely colored input into a RGB image.

    Args:
      inputs: size (B, 64, 64, 3).
      inputs_slice: batch of randomly sliced channels, i.e (B, 64, 64, 1)
                    each element of the batch is either a R, G or B channel.
      channel_index: size (B,) Each element is (0, 1, or 2) denoting a
                     R, G or B channel.
      training: used only for dropout.
    Returns:
      logits: size (B, 64, 64, 3, 256) during training or
              size (B, 64, 64, 1, 256) during evaluation or sampling.
    """
        grayscale = tf.image.rgb_to_grayscale(inputs)
        # convert inputs to a coarse image.
        inputs_slice = base_utils.convert_bits(inputs_slice,
                                               n_bits_in=8,
                                               n_bits_out=3)

        logits = self.upsampler(inputs_slice,
                                grayscale,
                                training=training,
                                channel_index=channel_index)
        return logits, {}
예제 #4
0
    def test_color_upsampler_attention_num_channels_1(self):
        config = self.get_config()
        spatial_upsampler = upsampler.SpatialUpsampler(config=config)

        inputs = tf.random.uniform(shape=(8, 64, 64, 3),
                                   minval=0,
                                   maxval=256,
                                   dtype=tf.int32)
        inputs_slice = tf.random.uniform(shape=(8, 64, 64, 1),
                                         minval=0,
                                         maxval=256,
                                         dtype=tf.int32)
        grayscale = tf.image.rgb_to_grayscale(inputs)
        channel_index = tf.random.uniform(shape=[
            8,
        ],
                                          minval=0,
                                          maxval=3,
                                          dtype=tf.int32)

        logits = spatial_upsampler(inputs=inputs,
                                   inputs_slice=inputs_slice,
                                   channel_index=channel_index)
        logits = logits[0]
        self.assertEqual(logits.shape, (8, 64, 64, 1, 256))

        inputs = base_utils.convert_bits(inputs, n_bits_in=8, n_bits_out=3)
        output = spatial_upsampler.sample(gray_cond=grayscale, inputs=inputs)
        self.assertEqual(output['high_res_argmax'].shape, (8, 64, 64, 3))
예제 #5
0
def main(_):
  config, store_dir, img_dir = FLAGS.config, FLAGS.store_dir, FLAGS.img_dir
  assert store_dir is not None
  assert img_dir is not None
  model_name, gen_data_dir = config.model.name, FLAGS.gen_data_dir
  needs_gen = model_name in ['color_upsampler', 'spatial_upsampler']

  batch_size = get_batch_size(model_name)
  store_dir = get_store_dir(model_name, store_dir)
  num_files = len(tf.io.gfile.listdir(img_dir))

  if needs_gen:
    assert gen_data_dir is not None
    gen_dataset = create_gen_dataset_from_images(gen_data_dir, batch_size)
    gen_dataset_iter = iter(gen_dataset)

  dataset = create_grayscale_dataset_from_images(FLAGS.img_dir, batch_size)
  dataset_iter = iter(dataset)

  model, optimizer, ema = build_model(config)
  checkpoints = train_utils.create_checkpoint(model, optimizer=optimizer,
                                              ema=ema)
  train_utils.restore(model, checkpoints, FLAGS.logdir, ema)
  num_steps_v = optimizer.iterations.numpy()
  logging.info('Producing sample after %d training steps.', num_steps_v)

  num_epochs = int(np.ceil(num_files / batch_size))
  logging.info(num_epochs)

  for _ in range(num_epochs):
    gray, gray_64, child_paths = next(dataset_iter)

    if needs_gen:
      prev_gen = next(gen_dataset_iter)

    if model_name == 'coltran_core':
      out = model.sample(gray_64, mode='sample')
      samples = out['auto_sample']
    elif model_name == 'color_upsampler':
      prev_gen = base_utils.convert_bits(prev_gen, n_bits_in=8, n_bits_out=3)
      out = model.sample(bit_cond=prev_gen, gray_cond=gray_64)
      samples = out['bit_up_argmax']
    else:
      prev_gen = datasets_utils.change_resolution(prev_gen, 256)
      out = model.sample(gray_cond=gray, inputs=prev_gen, mode='argmax')
      samples = out['high_res_argmax']

    child_paths = child_paths.numpy()
    child_paths = [child_path.decode('utf-8') for child_path in child_paths]
    logging.info(child_paths)

    for sample, child_path in zip(samples, child_paths):
      write_path = os.path.join(store_dir, child_path)
      logging.info(write_path)
      sample = sample.numpy().astype(np.uint8)
      logging.info(sample.shape)
      with tf.io.gfile.GFile(write_path, 'wb') as f:
        plt.imsave(f, sample)
  def test_bit_upsampler_attention_num_channels_3(self):
    config = self.get_config()
    bit_upsampler = upsampler.ColorUpsampler(config=config)

    inputs = tf.random.uniform(shape=(8, 32, 32, 3), minval=0, maxval=256,
                               dtype=tf.int32)
    grayscale = tf.image.rgb_to_grayscale(inputs)

    logits = bit_upsampler(inputs=inputs, inputs_slice=inputs)[0]
    self.assertEqual(logits.shape, (8, 32, 32, 3, 256))

    inputs = base_utils.convert_bits(inputs, n_bits_in=8, n_bits_out=3)
    output = bit_upsampler.sample(gray_cond=grayscale, bit_cond=inputs)
    self.assertEqual(output['bit_up_argmax'].shape, (8, 32, 32, 3))
예제 #7
0
  def sample(self, gray_cond, bit_cond, mode='argmax'):
    output = dict()
    bit_cond_viz = base_utils.convert_bits(bit_cond, n_bits_in=3, n_bits_out=8)
    output['bit_cond'] = tf.cast(bit_cond_viz, dtype=tf.uint8)

    logits = self.upsampler(bit_cond, gray_cond, training=False)

    if mode == 'argmax':
      samples = tf.argmax(logits, axis=-1)
    elif mode == 'sample':
      batch_size, height, width, channels = logits.shape[:-1]
      logits = tf.reshape(logits, (batch_size*height*width*channels, -1))
      samples = tf.random.categorical(logits, num_samples=1,
                                      dtype=tf.int32)[:, 0]
      samples = tf.reshape(samples, (batch_size, height, width, channels))

    samples = tf.cast(samples, dtype=tf.uint8)
    output[f'bit_up_{mode}'] = samples
    return output
예제 #8
0
    def decoder(self, inputs, z, training):
        """Decodes grayscale representation and masked colors into logits."""
        # (H, W, 512) preprocessing.
        # quantize to 3 bits.
        labels = base_utils.convert_bits(inputs, n_bits_in=8, n_bits_out=3)

        # bin each channel triplet -> (H, W, 3) with 8 possible symbols
        # (H, W, 512)
        labels = base_utils.labels_to_bins(labels,
                                           self.num_symbols_per_channel)

        # (H, W) with 512 symbols to (H, W, 512)
        labels = tf.one_hot(labels, depth=self.num_symbols)

        h_dec = self.pixel_embed_layer(labels)
        h_upper = self.outer_decoder((h_dec, z), training=training)
        h_inner = self.inner_decoder((h_dec, h_upper, z), training=training)

        activations = self.final_norm(h_inner)
        logits = self.final_dense(activations)
        return tf.expand_dims(logits, axis=-2)
예제 #9
0
 def test_dequantize(self):
   x = tf.range(0, 32, dtype=tf.int32)
   actual = base_utils.convert_bits(x, n_bits_in=5, n_bits_out=8).numpy()
   expected = np.arange(0, 256, 8)
   self.assertTrue(np.allclose(expected, actual))
예제 #10
0
def store_samples(data, config, logdir, gen_dataset=None):
    """Stores the generated samples."""
    downsample_res = config.get('downsample_res', 64)
    num_samples = config.sample.num_samples
    num_outputs = config.sample.num_outputs
    batch_size = config.sample.get('batch_size', 1)
    sample_mode = config.sample.get('mode', 'argmax')
    gen_file = config.sample.get('gen_file', 'gen')

    model, optimizer, ema = build(config, 1, False)
    checkpoints = train_utils.create_checkpoint(model, optimizer, ema)
    sample_dir = create_sample_dir(logdir, config)
    record_path = os.path.join(sample_dir, '%s.tfrecords' % gen_file)
    writer = tf.io.TFRecordWriter(record_path)

    train_utils.restore(model, checkpoints, logdir, ema)
    num_steps_v = optimizer.iterations.numpy()
    logging.info('Producing sample after %d training steps.', num_steps_v)

    logging.info(gen_dataset)
    for batch_ind in range(num_outputs // batch_size):
        next_data = data.next()

        if gen_dataset is not None:
            next_gen_data = gen_dataset.next()

        # Gets grayscale image based on the model.
        curr_gray = get_grayscale_at_sample_time(next_data, downsample_res,
                                                 config.model.name)

        curr_output = collections.defaultdict(list)
        for sample_ind in range(num_samples):
            logging.info('Batch no: %d, Sample no: %d', batch_ind, sample_ind)

            if config.model.name == 'color_upsampler':

                if gen_dataset is not None:
                    # Provide generated coarse color inputs.
                    scaled_rgb = next_gen_data['targets']
                else:
                    # Provide coarse color ground truth inputs.
                    scaled_rgb = next_data['targets_%d' % downsample_res]
                bit_rgb = base_utils.convert_bits(scaled_rgb,
                                                  n_bits_in=8,
                                                  n_bits_out=3)
                output = model.sample(gray_cond=curr_gray,
                                      bit_cond=bit_rgb,
                                      mode=sample_mode)

            elif config.model.name == 'spatial_upsampler':
                if gen_dataset is not None:
                    # Provide low resolution generated image.
                    low_res = next_gen_data['targets']
                    low_res = datasets_utils.change_resolution(low_res, 256)
                else:
                    # Provide low resolution ground truth image.
                    low_res = next_data['targets_%d_up_back' % downsample_res]
                output = model.sample(gray_cond=curr_gray,
                                      inputs=low_res,
                                      mode=sample_mode)
            else:
                output = model.sample(gray_cond=curr_gray, mode=sample_mode)
            logging.info('Done sampling')

            for out_key, out_item in output.items():
                curr_output[out_key].append(out_item.numpy())

        # concatenate samples across width.
        for out_key, out_val in curr_output.items():
            curr_out_val = np.concatenate(out_val, axis=2)
            curr_output[out_key] = curr_out_val

            if ('sample' in out_key or 'argmax' in out_key):
                save_str = f'Saving {(batch_ind + 1) * batch_size} samples'
                logging.info(save_str)
                for single_ex in curr_out_val:
                    serialized = array_to_tf_example(single_ex)
                    writer.write(serialized)

    writer.close()