def test_compute_progress(self): if tf.executing_eagerly(): progress_output = [] for current_image_id in [0, 3, 6, 7, 8, 10, 15, 29, 100]: progress = networks.compute_progress( current_image_id, stable_stage_num_images=7, transition_stage_num_images=8, num_blocks=2) with self.cached_session(use_gpu=True) as sess: progress_output.append(sess.run(progress)) else: current_image_id_ph = tf.compat.v1.placeholder(tf.int32, []) progress = networks.compute_progress( current_image_id_ph, stable_stage_num_images=7, transition_stage_num_images=8, num_blocks=2) with self.cached_session(use_gpu=True) as sess: progress_output = [ sess.run(progress, feed_dict={current_image_id_ph: cur_image_id}) for cur_image_id in [0, 3, 6, 7, 8, 10, 15, 29, 100] ] self.assertArrayNear(progress_output, [0.0, 0.0, 0.0, 0.0, 0.125, 0.375, 1.0, 1.0, 1.0], 1.0e-6)
def test_generator_grad_norm_progress(self): if tf.executing_eagerly(): # tf.placeholder() is not compatible with eager execution. return stable_stage_num_images = 2 transition_stage_num_images = 3 current_image_id_ph = tf.compat.v1.placeholder(tf.int32, []) progress = networks.compute_progress( current_image_id_ph, stable_stage_num_images, transition_stage_num_images, num_blocks=3) z = tf.random.normal([2, 10], dtype=tf.float32) x, _ = networks.generator( z, progress, _num_filters_stub, networks.ResolutionSchedule( start_resolutions=(4, 4), scale_base=2, num_resolutions=3)) fake_loss = tf.reduce_sum(input_tensor=tf.square(x)) grad_norms = [ _get_grad_norm( fake_loss, tf.compat.v1.trainable_variables('.*/progressive_gan_block_1/.*')), _get_grad_norm( fake_loss, tf.compat.v1.trainable_variables('.*/progressive_gan_block_2/.*')), _get_grad_norm( fake_loss, tf.compat.v1.trainable_variables('.*/progressive_gan_block_3/.*')) ] grad_norms_output = None with self.cached_session(use_gpu=True) as sess: sess.run(tf.compat.v1.global_variables_initializer()) x1_np = sess.run(x, feed_dict={current_image_id_ph: 0.12}) x2_np = sess.run(x, feed_dict={current_image_id_ph: 1.8}) grad_norms_output = np.array([ sess.run(grad_norms, feed_dict={current_image_id_ph: i}) for i in range(15) # total num of images ]) self.assertEqual((2, 16, 16, 3), x1_np.shape) self.assertEqual((2, 16, 16, 3), x2_np.shape) # The gradient of block_1 is always on. self.assertEqual( np.argmax(grad_norms_output[:, 0] > 0), 0, 'gradient norms {} for block 1 is not always on'.format( grad_norms_output[:, 0])) # The gradient of block_2 is on after 1 stable stage. self.assertEqual( np.argmax(grad_norms_output[:, 1] > 0), 3, 'gradient norms {} for block 2 is not on at step 3'.format( grad_norms_output[:, 1])) # The gradient of block_3 is on after 2 stable stage + 1 transition stage. self.assertEqual( np.argmax(grad_norms_output[:, 2] > 0), 8, 'gradient norms {} for block 3 is not on at step 8'.format( grad_norms_output[:, 2]))
def test_discriminator_grad_norm_progress(self): stable_stage_num_images = 2 transition_stage_num_images = 3 current_image_id_ph = tf.compat.v1.placeholder(tf.int32, []) progress = networks.compute_progress(current_image_id_ph, stable_stage_num_images, transition_stage_num_images, num_blocks=3) x = tf.random.normal([2, 16, 16, 3]) logits, _ = networks.discriminator( x, progress, _num_filters_stub, networks.ResolutionSchedule(start_resolutions=(4, 4), scale_base=2, num_resolutions=3)) fake_loss = tf.reduce_sum(input_tensor=tf.square(logits)) grad_norms = [ _get_grad_norm( fake_loss, tf.compat.v1.trainable_variables( '.*/progressive_gan_block_1/.*')), _get_grad_norm( fake_loss, tf.compat.v1.trainable_variables( '.*/progressive_gan_block_2/.*')), _get_grad_norm( fake_loss, tf.compat.v1.trainable_variables( '.*/progressive_gan_block_3/.*')) ] grad_norms_output = None with self.cached_session(use_gpu=True) as sess: sess.run(tf.compat.v1.global_variables_initializer()) grad_norms_output = np.array([ sess.run(grad_norms, feed_dict={current_image_id_ph: i}) for i in range(15) # total num of images ]) # The gradient of block_1 is always on. self.assertEqual( np.argmax(grad_norms_output[:, 0] > 0), 0, 'gradient norms {} for block 1 is not always on'.format( grad_norms_output[:, 0])) # The gradient of block_2 is on after 1 stable stage. self.assertEqual( np.argmax(grad_norms_output[:, 1] > 0), 3, 'gradient norms {} for block 2 is not on at step 3'.format( grad_norms_output[:, 1])) # The gradient of block_3 is on after 2 stable stage + 1 transition stage. self.assertEqual( np.argmax(grad_norms_output[:, 2] > 0), 8, 'gradient norms {} for block 3 is not on at step 8'.format( grad_norms_output[:, 2]))
def build_model(stage_id, batch_size, real_images, **kwargs): """Builds progressive GAN model. Args: stage_id: An integer of training stage index. batch_size: Number of training images in each minibatch. real_images: A 4D `Tensor` of NHWC format. **kwargs: A dictionary of 'start_height': An integer of start image height. 'start_width': An integer of start image width. 'scale_base': An integer of resolution multiplier. 'num_resolutions': An integer of number of progressive resolutions. 'stable_stage_num_images': An integer of number of training images in the stable stage. 'transition_stage_num_images': An integer of number of training images in the transition stage. 'total_num_images': An integer of total number of training images. 'kernel_size': Convolution kernel size. 'colors': Number of image channels. 'to_rgb_use_tanh_activation': Whether to apply tanh activation when output rgb. 'fmap_base': Base number of filters. 'fmap_decay': Decay of number of filters. 'fmap_max': Max number of filters. 'latent_vector_size': An integer of latent vector size. 'gradient_penalty_weight': A float of gradient norm target for wasserstein loss. 'gradient_penalty_target': A float of gradient penalty weight for wasserstein loss. 'real_score_penalty_weight': A float of Additional penalty to keep the scores from drifting too far from zero. 'adam_beta1': A float of Adam optimizer beta1. 'adam_beta2': A float of Adam optimizer beta2. 'generator_learning_rate': A float of generator learning rate. 'discriminator_learning_rate': A float of discriminator learning rate. Returns: An inernal object that wraps all information about the model. """ kernel_size = kwargs['kernel_size'] colors = kwargs['colors'] resolution_schedule = make_resolution_schedule(**kwargs) num_blocks, num_images = get_stage_info(stage_id, **kwargs) current_image_id = tf.compat.v1.train.get_or_create_global_step() current_image_id_inc_op = current_image_id.assign_add(batch_size) tf.compat.v1.summary.scalar('current_image_id', current_image_id) progress = networks.compute_progress(current_image_id, kwargs['stable_stage_num_images'], kwargs['transition_stage_num_images'], num_blocks) tf.compat.v1.summary.scalar('progress', progress) real_images = networks.blend_images(real_images, progress, resolution_schedule, num_blocks=num_blocks) def _num_filters_fn(block_id): """Computes number of filters of block `block_id`.""" return networks.num_filters(block_id, kwargs['fmap_base'], kwargs['fmap_decay'], kwargs['fmap_max']) def _generator_fn(z): """Builds generator network.""" to_rgb_act = tf.tanh if kwargs['to_rgb_use_tanh_activation'] else None return networks.generator(z, progress, _num_filters_fn, resolution_schedule, num_blocks=num_blocks, kernel_size=kernel_size, colors=colors, to_rgb_activation=to_rgb_act) def _discriminator_fn(x): """Builds discriminator network.""" return networks.discriminator(x, progress, _num_filters_fn, resolution_schedule, num_blocks=num_blocks, kernel_size=kernel_size) ########## Define model. z = make_latent_vectors(batch_size, **kwargs) gan_model = tfgan.gan_model( generator_fn=lambda z: _generator_fn(z)[0], discriminator_fn=lambda x, unused_z: _discriminator_fn(x)[0], real_data=real_images, generator_inputs=z) ########## Define loss. gan_loss = define_loss(gan_model, **kwargs) ########## Define train ops. gan_train_ops, optimizer_var_list = define_train_ops( gan_model, gan_loss, **kwargs) gan_train_ops = gan_train_ops._replace( global_step_inc_op=current_image_id_inc_op) ########## Generator smoothing. generator_ema = tf.train.ExponentialMovingAverage(decay=0.999) gan_train_ops, generator_vars_to_restore = add_generator_smoothing_ops( generator_ema, gan_model, gan_train_ops) class Model(object): pass model = Model() model.stage_id = stage_id model.batch_size = batch_size model.resolution_schedule = resolution_schedule model.num_images = num_images model.num_blocks = num_blocks model.current_image_id = current_image_id model.progress = progress model.num_filters_fn = _num_filters_fn model.generator_fn = _generator_fn model.discriminator_fn = _discriminator_fn model.gan_model = gan_model model.gan_loss = gan_loss model.gan_train_ops = gan_train_ops model.optimizer_var_list = optimizer_var_list model.generator_ema = generator_ema model.generator_vars_to_restore = generator_vars_to_restore return model