def main(_): config, store_dir, img_dir = FLAGS.config, FLAGS.store_dir, FLAGS.img_dir assert store_dir is not None assert img_dir is not None model_name, gen_data_dir = config.model.name, FLAGS.gen_data_dir needs_gen = model_name in ['color_upsampler', 'spatial_upsampler'] batch_size = get_batch_size(model_name) store_dir = get_store_dir(model_name, store_dir) num_files = len(tf.io.gfile.listdir(img_dir)) if needs_gen: assert gen_data_dir is not None gen_dataset = create_gen_dataset_from_images(gen_data_dir, batch_size) gen_dataset_iter = iter(gen_dataset) dataset = create_grayscale_dataset_from_images(FLAGS.img_dir, batch_size) dataset_iter = iter(dataset) model, optimizer, ema = build_model(config) checkpoints = train_utils.create_checkpoint(model, optimizer=optimizer, ema=ema) train_utils.restore(model, checkpoints, FLAGS.logdir, ema) num_steps_v = optimizer.iterations.numpy() logging.info('Producing sample after %d training steps.', num_steps_v) num_epochs = int(np.ceil(num_files / batch_size)) logging.info(num_epochs) for _ in range(num_epochs): gray, gray_64, child_paths = next(dataset_iter) if needs_gen: prev_gen = next(gen_dataset_iter) if model_name == 'coltran_core': out = model.sample(gray_64, mode='sample') samples = out['auto_sample'] elif model_name == 'color_upsampler': prev_gen = base_utils.convert_bits(prev_gen, n_bits_in=8, n_bits_out=3) out = model.sample(bit_cond=prev_gen, gray_cond=gray_64) samples = out['bit_up_argmax'] else: prev_gen = datasets_utils.change_resolution(prev_gen, 256) out = model.sample(gray_cond=gray, inputs=prev_gen, mode='argmax') samples = out['high_res_argmax'] child_paths = child_paths.numpy() child_paths = [child_path.decode('utf-8') for child_path in child_paths] logging.info(child_paths) for sample, child_path in zip(samples, child_paths): write_path = os.path.join(store_dir, child_path) logging.info(write_path) sample = sample.numpy().astype(np.uint8) logging.info(sample.shape) with tf.io.gfile.GFile(write_path, 'wb') as f: plt.imsave(f, sample)
def load_and_preprocess_image(path, child_path): image_str = tf.io.read_file(path) num_channels = 1 if FLAGS.mode == 'colorize' else 3 image = tf.image.decode_image(image_str, channels=num_channels) # Central crop to square and resize to 256x256. image = datasets.resize_to_square(image, resolution=256, train=False) # Resize to a low resolution image. image_64 = datasets_utils.change_resolution(image, res=64) if FLAGS.mode == 'recolorize': image = tf.image.rgb_to_grayscale(image) image_64 = tf.image.rgb_to_grayscale(image_64) return image, image_64, child_path
def resize_to_square(image, resolution=32, train=True): """Preprocess the image in a way that is OK for generative modeling.""" # Crop a square-shaped image by shortening the longer side. image_shape = tf.shape(image) height, width, channels = image_shape[0], image_shape[1], image_shape[2] side_size = tf.minimum(height, width) cropped_shape = tf.stack([side_size, side_size, channels]) if train: image = tf.image.random_crop(image, cropped_shape) else: image = tf.image.resize_with_crop_or_pad( image, target_height=side_size, target_width=side_size) image = datasets_utils.change_resolution(image, res=resolution, method='area') return image
def store_samples(data, config, logdir, gen_dataset=None): """Stores the generated samples.""" downsample_res = config.get('downsample_res', 64) num_samples = config.sample.num_samples num_outputs = config.sample.num_outputs batch_size = config.sample.get('batch_size', 1) sample_mode = config.sample.get('mode', 'argmax') gen_file = config.sample.get('gen_file', 'gen') model, optimizer, ema = build(config, 1, False) checkpoints = train_utils.create_checkpoint(model, optimizer, ema) sample_dir = create_sample_dir(logdir, config) record_path = os.path.join(sample_dir, '%s.tfrecords' % gen_file) writer = tf.io.TFRecordWriter(record_path) train_utils.restore(model, checkpoints, logdir, ema) num_steps_v = optimizer.iterations.numpy() logging.info('Producing sample after %d training steps.', num_steps_v) logging.info(gen_dataset) for batch_ind in range(num_outputs // batch_size): next_data = data.next() if gen_dataset is not None: next_gen_data = gen_dataset.next() # Gets grayscale image based on the model. curr_gray = get_grayscale_at_sample_time(next_data, downsample_res, config.model.name) curr_output = collections.defaultdict(list) for sample_ind in range(num_samples): logging.info('Batch no: %d, Sample no: %d', batch_ind, sample_ind) if config.model.name == 'color_upsampler': if gen_dataset is not None: # Provide generated coarse color inputs. scaled_rgb = next_gen_data['targets'] else: # Provide coarse color ground truth inputs. scaled_rgb = next_data['targets_%d' % downsample_res] bit_rgb = base_utils.convert_bits(scaled_rgb, n_bits_in=8, n_bits_out=3) output = model.sample(gray_cond=curr_gray, bit_cond=bit_rgb, mode=sample_mode) elif config.model.name == 'spatial_upsampler': if gen_dataset is not None: # Provide low resolution generated image. low_res = next_gen_data['targets'] low_res = datasets_utils.change_resolution(low_res, 256) else: # Provide low resolution ground truth image. low_res = next_data['targets_%d_up_back' % downsample_res] output = model.sample(gray_cond=curr_gray, inputs=low_res, mode=sample_mode) else: output = model.sample(gray_cond=curr_gray, mode=sample_mode) logging.info('Done sampling') for out_key, out_item in output.items(): curr_output[out_key].append(out_item.numpy()) # concatenate samples across width. for out_key, out_val in curr_output.items(): curr_out_val = np.concatenate(out_val, axis=2) curr_output[out_key] = curr_out_val if ('sample' in out_key or 'argmax' in out_key): save_str = f'Saving {(batch_ind + 1) * batch_size} samples' logging.info(save_str) for single_ex in curr_out_val: serialized = array_to_tf_example(single_ex) writer.write(serialized) writer.close()