def post_process_image(self, image): """Post process image of size (H, W, 512) to a coarse RGB image.""" image = base_utils.bins_to_labels( image, num_symbols_per_channel=self.num_symbols_per_channel) image = base_utils.convert_bits(image, n_bits_in=3, n_bits_out=8) image = tf.cast(image, dtype=tf.uint8) return image
def loss(self, targets, logits, train_config, training, aux_output=None): """Converts targets to coarse colors and computes log-likelihood.""" downsample = train_config.get('downsample', False) downsample_res = train_config.get('downsample_res', 64) if downsample: labels = targets['targets_%d' % downsample_res] else: labels = targets['targets'] if aux_output is None: aux_output = {} # quantize labels. labels = base_utils.convert_bits(labels, n_bits_in=8, n_bits_out=3) # bin each channel triplet. labels = base_utils.labels_to_bins(labels, self.num_symbols_per_channel) loss = self.image_loss(logits, labels) enc_logits = aux_output.get('encoder_logits') if enc_logits is None: return loss, {} enc_loss = self.image_loss(enc_logits, labels) return loss, {'encoder': enc_loss}
def call(self, inputs, inputs_slice, channel_index=None, training=True): """Upsamples the coarsely colored input into a RGB image. Args: inputs: size (B, 64, 64, 3). inputs_slice: batch of randomly sliced channels, i.e (B, 64, 64, 1) each element of the batch is either a R, G or B channel. channel_index: size (B,) Each element is (0, 1, or 2) denoting a R, G or B channel. training: used only for dropout. Returns: logits: size (B, 64, 64, 3, 256) during training or size (B, 64, 64, 1, 256) during evaluation or sampling. """ grayscale = tf.image.rgb_to_grayscale(inputs) # convert inputs to a coarse image. inputs_slice = base_utils.convert_bits(inputs_slice, n_bits_in=8, n_bits_out=3) logits = self.upsampler(inputs_slice, grayscale, training=training, channel_index=channel_index) return logits, {}
def test_color_upsampler_attention_num_channels_1(self): config = self.get_config() spatial_upsampler = upsampler.SpatialUpsampler(config=config) inputs = tf.random.uniform(shape=(8, 64, 64, 3), minval=0, maxval=256, dtype=tf.int32) inputs_slice = tf.random.uniform(shape=(8, 64, 64, 1), minval=0, maxval=256, dtype=tf.int32) grayscale = tf.image.rgb_to_grayscale(inputs) channel_index = tf.random.uniform(shape=[ 8, ], minval=0, maxval=3, dtype=tf.int32) logits = spatial_upsampler(inputs=inputs, inputs_slice=inputs_slice, channel_index=channel_index) logits = logits[0] self.assertEqual(logits.shape, (8, 64, 64, 1, 256)) inputs = base_utils.convert_bits(inputs, n_bits_in=8, n_bits_out=3) output = spatial_upsampler.sample(gray_cond=grayscale, inputs=inputs) self.assertEqual(output['high_res_argmax'].shape, (8, 64, 64, 3))
def main(_): config, store_dir, img_dir = FLAGS.config, FLAGS.store_dir, FLAGS.img_dir assert store_dir is not None assert img_dir is not None model_name, gen_data_dir = config.model.name, FLAGS.gen_data_dir needs_gen = model_name in ['color_upsampler', 'spatial_upsampler'] batch_size = get_batch_size(model_name) store_dir = get_store_dir(model_name, store_dir) num_files = len(tf.io.gfile.listdir(img_dir)) if needs_gen: assert gen_data_dir is not None gen_dataset = create_gen_dataset_from_images(gen_data_dir, batch_size) gen_dataset_iter = iter(gen_dataset) dataset = create_grayscale_dataset_from_images(FLAGS.img_dir, batch_size) dataset_iter = iter(dataset) model, optimizer, ema = build_model(config) checkpoints = train_utils.create_checkpoint(model, optimizer=optimizer, ema=ema) train_utils.restore(model, checkpoints, FLAGS.logdir, ema) num_steps_v = optimizer.iterations.numpy() logging.info('Producing sample after %d training steps.', num_steps_v) num_epochs = int(np.ceil(num_files / batch_size)) logging.info(num_epochs) for _ in range(num_epochs): gray, gray_64, child_paths = next(dataset_iter) if needs_gen: prev_gen = next(gen_dataset_iter) if model_name == 'coltran_core': out = model.sample(gray_64, mode='sample') samples = out['auto_sample'] elif model_name == 'color_upsampler': prev_gen = base_utils.convert_bits(prev_gen, n_bits_in=8, n_bits_out=3) out = model.sample(bit_cond=prev_gen, gray_cond=gray_64) samples = out['bit_up_argmax'] else: prev_gen = datasets_utils.change_resolution(prev_gen, 256) out = model.sample(gray_cond=gray, inputs=prev_gen, mode='argmax') samples = out['high_res_argmax'] child_paths = child_paths.numpy() child_paths = [child_path.decode('utf-8') for child_path in child_paths] logging.info(child_paths) for sample, child_path in zip(samples, child_paths): write_path = os.path.join(store_dir, child_path) logging.info(write_path) sample = sample.numpy().astype(np.uint8) logging.info(sample.shape) with tf.io.gfile.GFile(write_path, 'wb') as f: plt.imsave(f, sample)
def test_bit_upsampler_attention_num_channels_3(self): config = self.get_config() bit_upsampler = upsampler.ColorUpsampler(config=config) inputs = tf.random.uniform(shape=(8, 32, 32, 3), minval=0, maxval=256, dtype=tf.int32) grayscale = tf.image.rgb_to_grayscale(inputs) logits = bit_upsampler(inputs=inputs, inputs_slice=inputs)[0] self.assertEqual(logits.shape, (8, 32, 32, 3, 256)) inputs = base_utils.convert_bits(inputs, n_bits_in=8, n_bits_out=3) output = bit_upsampler.sample(gray_cond=grayscale, bit_cond=inputs) self.assertEqual(output['bit_up_argmax'].shape, (8, 32, 32, 3))
def sample(self, gray_cond, bit_cond, mode='argmax'): output = dict() bit_cond_viz = base_utils.convert_bits(bit_cond, n_bits_in=3, n_bits_out=8) output['bit_cond'] = tf.cast(bit_cond_viz, dtype=tf.uint8) logits = self.upsampler(bit_cond, gray_cond, training=False) if mode == 'argmax': samples = tf.argmax(logits, axis=-1) elif mode == 'sample': batch_size, height, width, channels = logits.shape[:-1] logits = tf.reshape(logits, (batch_size*height*width*channels, -1)) samples = tf.random.categorical(logits, num_samples=1, dtype=tf.int32)[:, 0] samples = tf.reshape(samples, (batch_size, height, width, channels)) samples = tf.cast(samples, dtype=tf.uint8) output[f'bit_up_{mode}'] = samples return output
def decoder(self, inputs, z, training): """Decodes grayscale representation and masked colors into logits.""" # (H, W, 512) preprocessing. # quantize to 3 bits. labels = base_utils.convert_bits(inputs, n_bits_in=8, n_bits_out=3) # bin each channel triplet -> (H, W, 3) with 8 possible symbols # (H, W, 512) labels = base_utils.labels_to_bins(labels, self.num_symbols_per_channel) # (H, W) with 512 symbols to (H, W, 512) labels = tf.one_hot(labels, depth=self.num_symbols) h_dec = self.pixel_embed_layer(labels) h_upper = self.outer_decoder((h_dec, z), training=training) h_inner = self.inner_decoder((h_dec, h_upper, z), training=training) activations = self.final_norm(h_inner) logits = self.final_dense(activations) return tf.expand_dims(logits, axis=-2)
def test_dequantize(self): x = tf.range(0, 32, dtype=tf.int32) actual = base_utils.convert_bits(x, n_bits_in=5, n_bits_out=8).numpy() expected = np.arange(0, 256, 8) self.assertTrue(np.allclose(expected, actual))
def store_samples(data, config, logdir, gen_dataset=None): """Stores the generated samples.""" downsample_res = config.get('downsample_res', 64) num_samples = config.sample.num_samples num_outputs = config.sample.num_outputs batch_size = config.sample.get('batch_size', 1) sample_mode = config.sample.get('mode', 'argmax') gen_file = config.sample.get('gen_file', 'gen') model, optimizer, ema = build(config, 1, False) checkpoints = train_utils.create_checkpoint(model, optimizer, ema) sample_dir = create_sample_dir(logdir, config) record_path = os.path.join(sample_dir, '%s.tfrecords' % gen_file) writer = tf.io.TFRecordWriter(record_path) train_utils.restore(model, checkpoints, logdir, ema) num_steps_v = optimizer.iterations.numpy() logging.info('Producing sample after %d training steps.', num_steps_v) logging.info(gen_dataset) for batch_ind in range(num_outputs // batch_size): next_data = data.next() if gen_dataset is not None: next_gen_data = gen_dataset.next() # Gets grayscale image based on the model. curr_gray = get_grayscale_at_sample_time(next_data, downsample_res, config.model.name) curr_output = collections.defaultdict(list) for sample_ind in range(num_samples): logging.info('Batch no: %d, Sample no: %d', batch_ind, sample_ind) if config.model.name == 'color_upsampler': if gen_dataset is not None: # Provide generated coarse color inputs. scaled_rgb = next_gen_data['targets'] else: # Provide coarse color ground truth inputs. scaled_rgb = next_data['targets_%d' % downsample_res] bit_rgb = base_utils.convert_bits(scaled_rgb, n_bits_in=8, n_bits_out=3) output = model.sample(gray_cond=curr_gray, bit_cond=bit_rgb, mode=sample_mode) elif config.model.name == 'spatial_upsampler': if gen_dataset is not None: # Provide low resolution generated image. low_res = next_gen_data['targets'] low_res = datasets_utils.change_resolution(low_res, 256) else: # Provide low resolution ground truth image. low_res = next_data['targets_%d_up_back' % downsample_res] output = model.sample(gray_cond=curr_gray, inputs=low_res, mode=sample_mode) else: output = model.sample(gray_cond=curr_gray, mode=sample_mode) logging.info('Done sampling') for out_key, out_item in output.items(): curr_output[out_key].append(out_item.numpy()) # concatenate samples across width. for out_key, out_val in curr_output.items(): curr_out_val = np.concatenate(out_val, axis=2) curr_output[out_key] = curr_out_val if ('sample' in out_key or 'argmax' in out_key): save_str = f'Saving {(batch_ind + 1) * batch_size} samples' logging.info(save_str) for single_ex in curr_out_val: serialized = array_to_tf_example(single_ex) writer.write(serialized) writer.close()