def test_blend_images_in_stable_stage(self): x_np = np.random.normal(size=[2, 8, 8, 3]) x = tf.constant(x_np, tf.float32) x_blend = networks.blend_images( x, progress=tf.constant(0.0), resolution_schedule=networks.ResolutionSchedule( scale_base=2, num_resolutions=2), num_blocks=2) with self.cached_session(use_gpu=True) as sess: x_blend_np = sess.run(x_blend) x_blend_expected_np = sess.run(layers.upscale(layers.downscale(x, 2), 2)) self.assertNDArrayNear(x_blend_np, x_blend_expected_np, 1.0e-6)
def build_model(stage_id, batch_size, real_images, **kwargs): """Builds progressive GAN model. Args: stage_id: An integer of training stage index. batch_size: Number of training images in each minibatch. real_images: A 4D `Tensor` of NHWC format. **kwargs: A dictionary of 'start_height': An integer of start image height. 'start_width': An integer of start image width. 'scale_base': An integer of resolution multiplier. 'num_resolutions': An integer of number of progressive resolutions. 'stable_stage_num_images': An integer of number of training images in the stable stage. 'transition_stage_num_images': An integer of number of training images in the transition stage. 'total_num_images': An integer of total number of training images. 'kernel_size': Convolution kernel size. 'colors': Number of image channels. 'to_rgb_use_tanh_activation': Whether to apply tanh activation when output rgb. 'fmap_base': Base number of filters. 'fmap_decay': Decay of number of filters. 'fmap_max': Max number of filters. 'latent_vector_size': An integer of latent vector size. 'gradient_penalty_weight': A float of gradient norm target for wasserstein loss. 'gradient_penalty_target': A float of gradient penalty weight for wasserstein loss. 'real_score_penalty_weight': A float of Additional penalty to keep the scores from drifting too far from zero. 'adam_beta1': A float of Adam optimizer beta1. 'adam_beta2': A float of Adam optimizer beta2. 'generator_learning_rate': A float of generator learning rate. 'discriminator_learning_rate': A float of discriminator learning rate. Returns: An inernal object that wraps all information about the model. """ kernel_size = kwargs['kernel_size'] colors = kwargs['colors'] resolution_schedule = make_resolution_schedule(**kwargs) num_blocks, num_images = get_stage_info(stage_id, **kwargs) current_image_id = tf.compat.v1.train.get_or_create_global_step() current_image_id_inc_op = current_image_id.assign_add(batch_size) tf.compat.v1.summary.scalar('current_image_id', current_image_id) progress = networks.compute_progress(current_image_id, kwargs['stable_stage_num_images'], kwargs['transition_stage_num_images'], num_blocks) tf.compat.v1.summary.scalar('progress', progress) real_images = networks.blend_images(real_images, progress, resolution_schedule, num_blocks=num_blocks) def _num_filters_fn(block_id): """Computes number of filters of block `block_id`.""" return networks.num_filters(block_id, kwargs['fmap_base'], kwargs['fmap_decay'], kwargs['fmap_max']) def _generator_fn(z): """Builds generator network.""" to_rgb_act = tf.tanh if kwargs['to_rgb_use_tanh_activation'] else None return networks.generator(z, progress, _num_filters_fn, resolution_schedule, num_blocks=num_blocks, kernel_size=kernel_size, colors=colors, to_rgb_activation=to_rgb_act) def _discriminator_fn(x): """Builds discriminator network.""" return networks.discriminator(x, progress, _num_filters_fn, resolution_schedule, num_blocks=num_blocks, kernel_size=kernel_size) ########## Define model. z = make_latent_vectors(batch_size, **kwargs) gan_model = tfgan.gan_model( generator_fn=lambda z: _generator_fn(z)[0], discriminator_fn=lambda x, unused_z: _discriminator_fn(x)[0], real_data=real_images, generator_inputs=z) ########## Define loss. gan_loss = define_loss(gan_model, **kwargs) ########## Define train ops. gan_train_ops, optimizer_var_list = define_train_ops( gan_model, gan_loss, **kwargs) gan_train_ops = gan_train_ops._replace( global_step_inc_op=current_image_id_inc_op) ########## Generator smoothing. generator_ema = tf.train.ExponentialMovingAverage(decay=0.999) gan_train_ops, generator_vars_to_restore = add_generator_smoothing_ops( generator_ema, gan_model, gan_train_ops) class Model(object): pass model = Model() model.stage_id = stage_id model.batch_size = batch_size model.resolution_schedule = resolution_schedule model.num_images = num_images model.num_blocks = num_blocks model.current_image_id = current_image_id model.progress = progress model.num_filters_fn = _num_filters_fn model.generator_fn = _generator_fn model.discriminator_fn = _discriminator_fn model.gan_model = gan_model model.gan_loss = gan_loss model.gan_train_ops = gan_train_ops model.optimizer_var_list = optimizer_var_list model.generator_ema = generator_ema model.generator_vars_to_restore = generator_vars_to_restore return model