def test_discriminator_only_sees_pool(self): """Checks that discriminator only sees pooled values.""" if tf.executing_eagerly(): # This test involves '.op', which doesn't work in eager. return def checker_gen_fn(_): return tf.constant(0.0) model = tfgan.gan_model( checker_gen_fn, discriminator_model, real_data=tf.zeros([]), generator_inputs=tf.random.normal([])) def tensor_pool_fn(_): return (tf.random.uniform([]), tf.random.uniform([])) def checker_dis_fn(inputs, _): """Discriminator that checks that it only sees pooled Tensors.""" def _is_constant(tensor): """Returns `True` if the Tensor is a constant.""" return tensor.op.type == 'Const' self.assertFalse(_is_constant(inputs)) return inputs model = model._replace(discriminator_fn=checker_dis_fn) tfgan.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
def define_model(self, images_x, images_y): """Defines a CycleGAN model that maps between images_x and images_y. Args: images_x: A 4D float `Tensor` of NHWC format. Images in set X. images_y: A 4D float `Tensor` of NHWC format. Images in set Y. use_identity_loss: Whether to use identity loss or not Returns: A `CycleGANModel` namedtuple. """ if self._swap_inputs: generator_inputs = images_y real_data = images_x else: generator_inputs = images_x real_data = images_y gan_model = tfgan.gan_model( generator_fn=_shadowdata_generator_model, discriminator_fn=_shadowdata_discriminator_model, generator_inputs=generator_inputs, real_data=real_data) # Add summaries for generated images. # tfgan.eval.add_cyclegan_image_summaries(gan_model) return gan_model
def test_no_shape_check(self): if tf.executing_eagerly(): # None of the usual utilities work in eager. return def dummy_generator_model(_): return (None, None) def dummy_discriminator_model(data, conditioning): # pylint: disable=unused-argument return 1 with self.assertRaisesRegexp(AttributeError, 'object has no attribute'): tfgan.gan_model( dummy_generator_model, dummy_discriminator_model, real_data=tf.zeros([1, 2]), generator_inputs=tf.zeros([1]), check_shapes=True) tfgan.gan_model( dummy_generator_model, dummy_discriminator_model, real_data=tf.zeros([1, 2]), generator_inputs=tf.zeros([1]), check_shapes=False)
def test_doesnt_crash_when_in_nested_scope(self): if tf.executing_eagerly(): # None of the usual utilities work in eager. return with tf.compat.v1.variable_scope('outer_scope'): gan_model = tfgan.gan_model( generator_model, discriminator_model, real_data=tf.zeros([1, 2]), generator_inputs=tf.random.normal([1, 2])) # This should work inside a scope. tfgan.gan_loss(gan_model, gradient_penalty_weight=1.0) # This should also work outside a scope. tfgan.gan_loss(gan_model, gradient_penalty_weight=1.0)
def build_model(stage_id, batch_size, real_images, **kwargs): """Builds progressive GAN model. Args: stage_id: An integer of training stage index. batch_size: Number of training images in each minibatch. real_images: A 4D `Tensor` of NHWC format. **kwargs: A dictionary of 'start_height': An integer of start image height. 'start_width': An integer of start image width. 'scale_base': An integer of resolution multiplier. 'num_resolutions': An integer of number of progressive resolutions. 'stable_stage_num_images': An integer of number of training images in the stable stage. 'transition_stage_num_images': An integer of number of training images in the transition stage. 'total_num_images': An integer of total number of training images. 'kernel_size': Convolution kernel size. 'colors': Number of image channels. 'to_rgb_use_tanh_activation': Whether to apply tanh activation when output rgb. 'fmap_base': Base number of filters. 'fmap_decay': Decay of number of filters. 'fmap_max': Max number of filters. 'latent_vector_size': An integer of latent vector size. 'gradient_penalty_weight': A float of gradient norm target for wasserstein loss. 'gradient_penalty_target': A float of gradient penalty weight for wasserstein loss. 'real_score_penalty_weight': A float of Additional penalty to keep the scores from drifting too far from zero. 'adam_beta1': A float of Adam optimizer beta1. 'adam_beta2': A float of Adam optimizer beta2. 'generator_learning_rate': A float of generator learning rate. 'discriminator_learning_rate': A float of discriminator learning rate. Returns: An inernal object that wraps all information about the model. """ kernel_size = kwargs['kernel_size'] colors = kwargs['colors'] resolution_schedule = make_resolution_schedule(**kwargs) num_blocks, num_images = get_stage_info(stage_id, **kwargs) current_image_id = tf.compat.v1.train.get_or_create_global_step() current_image_id_inc_op = current_image_id.assign_add(batch_size) tf.compat.v1.summary.scalar('current_image_id', current_image_id) progress = networks.compute_progress(current_image_id, kwargs['stable_stage_num_images'], kwargs['transition_stage_num_images'], num_blocks) tf.compat.v1.summary.scalar('progress', progress) real_images = networks.blend_images(real_images, progress, resolution_schedule, num_blocks=num_blocks) def _num_filters_fn(block_id): """Computes number of filters of block `block_id`.""" return networks.num_filters(block_id, kwargs['fmap_base'], kwargs['fmap_decay'], kwargs['fmap_max']) def _generator_fn(z): """Builds generator network.""" to_rgb_act = tf.tanh if kwargs['to_rgb_use_tanh_activation'] else None return networks.generator(z, progress, _num_filters_fn, resolution_schedule, num_blocks=num_blocks, kernel_size=kernel_size, colors=colors, to_rgb_activation=to_rgb_act) def _discriminator_fn(x): """Builds discriminator network.""" return networks.discriminator(x, progress, _num_filters_fn, resolution_schedule, num_blocks=num_blocks, kernel_size=kernel_size) ########## Define model. z = make_latent_vectors(batch_size, **kwargs) gan_model = tfgan.gan_model( generator_fn=lambda z: _generator_fn(z)[0], discriminator_fn=lambda x, unused_z: _discriminator_fn(x)[0], real_data=real_images, generator_inputs=z) ########## Define loss. gan_loss = define_loss(gan_model, **kwargs) ########## Define train ops. gan_train_ops, optimizer_var_list = define_train_ops( gan_model, gan_loss, **kwargs) gan_train_ops = gan_train_ops._replace( global_step_inc_op=current_image_id_inc_op) ########## Generator smoothing. generator_ema = tf.train.ExponentialMovingAverage(decay=0.999) gan_train_ops, generator_vars_to_restore = add_generator_smoothing_ops( generator_ema, gan_model, gan_train_ops) class Model(object): pass model = Model() model.stage_id = stage_id model.batch_size = batch_size model.resolution_schedule = resolution_schedule model.num_images = num_images model.num_blocks = num_blocks model.current_image_id = current_image_id model.progress = progress model.num_filters_fn = _num_filters_fn model.generator_fn = _generator_fn model.discriminator_fn = _discriminator_fn model.gan_model = gan_model model.gan_loss = gan_loss model.gan_train_ops = gan_train_ops model.optimizer_var_list = optimizer_var_list model.generator_ema = generator_ema model.generator_vars_to_restore = generator_vars_to_restore return model
def __init__(self, stage_id, batch_size, config): """Build graph stage from config dictionary. Stage_id and batch_size change during training so they are kept separate from the global config. This function is also called by 'load_from_path()'. Args: stage_id: (int) Build generator/discriminator with this many stages. batch_size: (int) Build graph with fixed batch size. config: (dict) All the global state. """ data_helper = data_helpers.registry[config['data_type']](config) real_images, real_one_hot_labels = data_helper.provide_data(batch_size) # gen_one_hot_labels = real_one_hot_labels gen_one_hot_labels = data_helper.provide_one_hot_labels(batch_size) num_tokens = real_one_hot_labels.shape[1].value current_image_id = tf.train.get_or_create_global_step() current_image_id_inc_op = current_image_id.assign_add(batch_size) tf.summary.scalar('current_image_id', current_image_id) train_time = tf.Variable(0., dtype=tf.float32, trainable=False) tf.summary.scalar('train_time', train_time) resolution_schedule = train_util.make_resolution_schedule(**config) num_blocks, num_images = train_util.get_stage_info(stage_id, **config) num_stages = (2 * config['num_resolutions']) - 1 if config['train_time_limit'] is not None: stage_times = np.zeros(num_stages, dtype='float32') stage_times[0] = 1. for i in range(1, num_stages): stage_times[i] = (stage_times[i - 1] * config['train_time_stage_multiplier']) stage_times *= config['train_time_limit'] / np.sum(stage_times) stage_times = np.cumsum(stage_times) print('Stage times:') for t in stage_times: print('\t{}'.format(t)) if config['train_progressive']: if config['train_time_limit'] is not None: progress = networks.compute_progress_from_time( train_time, config['num_resolutions'], num_blocks, stage_times) else: progress = networks.compute_progress( current_image_id, config['stable_stage_num_images'], config['transition_stage_num_images'], num_blocks) else: progress = num_blocks - 1. # Maximum value, must be float. num_images = 0 for stage_id_idx in train_util.get_stage_ids(**config): _, n = train_util.get_stage_info(stage_id_idx, **config) num_images += n # Add to config config['resolution_schedule'] = resolution_schedule config['num_blocks'] = num_blocks config['num_images'] = num_images config['progress'] = progress config['num_tokens'] = num_tokens tf.summary.scalar('progress', progress) real_images = networks.blend_images(real_images, progress, resolution_schedule, num_blocks=num_blocks) ########## Define model. noises = train_util.make_latent_vectors(batch_size, **config) # Get network functions and wrap with hparams g_fn = lambda x: net_fns.g_fn_registry[config['g_fn']](x, **config) d_fn = lambda x: net_fns.d_fn_registry[config['d_fn']](x, **config) # Extra lambda functions to conform to tfgan.gan_model interface gan_model = tfgan.gan_model( generator_fn=lambda inputs: g_fn(inputs)[0], discriminator_fn=lambda images, unused_cond: d_fn(images)[0], real_data=real_images, generator_inputs=(noises, gen_one_hot_labels)) ########## Define loss. gan_loss = train_util.define_loss(gan_model, **config) ########## Auxiliary loss functions def _compute_ac_loss(images, target_one_hot_labels): with tf.variable_scope(gan_model.discriminator_scope, reuse=True): _, end_points = d_fn(images) return tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2( labels=tf.stop_gradient(target_one_hot_labels), logits=end_points['classification_logits'])) def _compute_gl_consistency_loss(data): """G&L consistency loss.""" sh = data_helper.specgrams_helper is_mel = isinstance(data_helper, data_helpers.DataMelHelper) if is_mel: stfts = sh.melspecgrams_to_stfts(data) else: stfts = sh.specgrams_to_stfts(data) waves = sh.stfts_to_waves(stfts) new_stfts = sh.waves_to_stfts(waves) # Magnitude loss mag = tf.abs(stfts) new_mag = tf.abs(new_stfts) # Normalize loss to max get_max = lambda x: tf.reduce_max(x, axis=(1, 2), keepdims=True) mag_max = get_max(mag) new_mag_max = get_max(new_mag) mag_scale = tf.maximum(1.0, tf.maximum(mag_max, new_mag_max)) mag_diff = (mag - new_mag) / mag_scale mag_loss = tf.reduce_mean(tf.square(mag_diff)) return mag_loss with tf.name_scope('losses'): # Loss weights gen_ac_loss_weight = config['generator_ac_loss_weight'] dis_ac_loss_weight = config['discriminator_ac_loss_weight'] gen_gl_consistency_loss_weight = config[ 'gen_gl_consistency_loss_weight'] # AC losses. fake_ac_loss = _compute_ac_loss(gan_model.generated_data, gen_one_hot_labels) real_ac_loss = _compute_ac_loss(gan_model.real_data, real_one_hot_labels) # GL losses. is_fourier = isinstance(data_helper, (data_helpers.DataSTFTHelper, data_helpers.DataSTFTNoIFreqHelper, data_helpers.DataMelHelper)) if isinstance(data_helper, data_helpers.DataWaveHelper): is_fourier = False if is_fourier: fake_gl_loss = _compute_gl_consistency_loss( gan_model.generated_data) real_gl_loss = _compute_gl_consistency_loss( gan_model.real_data) # Total losses. wx_fake_ac_loss = gen_ac_loss_weight * fake_ac_loss wx_real_ac_loss = dis_ac_loss_weight * real_ac_loss wx_fake_gl_loss = 0.0 if (is_fourier and gen_gl_consistency_loss_weight > 0 and stage_id == train_util.get_total_num_stages(**config) - 1): wx_fake_gl_loss = fake_gl_loss * gen_gl_consistency_loss_weight # Update the loss functions gan_loss = gan_loss._replace( generator_loss=(gan_loss.generator_loss + wx_fake_ac_loss + wx_fake_gl_loss), discriminator_loss=(gan_loss.discriminator_loss + wx_real_ac_loss)) tf.summary.scalar('fake_ac_loss', fake_ac_loss) tf.summary.scalar('real_ac_loss', real_ac_loss) tf.summary.scalar('wx_fake_ac_loss', wx_fake_ac_loss) tf.summary.scalar('wx_real_ac_loss', wx_real_ac_loss) tf.summary.scalar('total_gen_loss', gan_loss.generator_loss) tf.summary.scalar('total_dis_loss', gan_loss.discriminator_loss) if is_fourier: tf.summary.scalar('fake_gl_loss', fake_gl_loss) tf.summary.scalar('real_gl_loss', real_gl_loss) tf.summary.scalar('wx_fake_gl_loss', wx_fake_gl_loss) ########## Define train ops. gan_train_ops, optimizer_var_list = train_util.define_train_ops( gan_model, gan_loss, **config) gan_train_ops = gan_train_ops._replace( global_step_inc_op=current_image_id_inc_op) ########## Generator smoothing. generator_ema = tf.train.ExponentialMovingAverage(decay=0.999) gan_train_ops, generator_vars_to_restore = \ train_util.add_generator_smoothing_ops(generator_ema, gan_model, gan_train_ops) load_scope = tf.variable_scope( gan_model.generator_scope, reuse=True, custom_getter=train_util.make_var_scope_custom_getter_for_ema( generator_ema)) ########## Separate path for generating samples with a placeholder (ph) # Mapping of pitches to one-hot labels pitch_counts = data_helper.get_pitch_counts() pitch_to_label_dict = {} for i, pitch in enumerate(sorted(pitch_counts.keys())): pitch_to_label_dict[pitch] = i # (label_ph, noise_ph) -> fake_wave_ph labels_ph = tf.placeholder(tf.int32, [batch_size]) noises_ph = tf.placeholder(tf.float32, [batch_size, config['latent_vector_size']]) num_pitches = len(pitch_counts) one_hot_labels_ph = tf.one_hot(labels_ph, num_pitches) with load_scope: fake_data_ph, _ = g_fn((noises_ph, one_hot_labels_ph)) fake_waves_ph = data_helper.data_to_waves(fake_data_ph) if config['train_time_limit'] is not None: stage_train_time_limit = stage_times[stage_id] # config['train_time_limit'] * \ # (float(stage_id+1) / ((2*config['num_resolutions'])-1)) else: stage_train_time_limit = None ########## Add variables as properties self.stage_id = stage_id self.batch_size = batch_size self.config = config self.data_helper = data_helper self.resolution_schedule = resolution_schedule self.num_images = num_images self.num_blocks = num_blocks self.current_image_id = current_image_id self.progress = progress self.generator_fn = g_fn self.discriminator_fn = d_fn self.gan_model = gan_model self.fake_ac_loss = fake_ac_loss self.real_ac_loss = real_ac_loss self.gan_loss = gan_loss self.gan_train_ops = gan_train_ops self.optimizer_var_list = optimizer_var_list self.generator_ema = generator_ema self.generator_vars_to_restore = generator_vars_to_restore self.real_images = real_images self.real_one_hot_labels = real_one_hot_labels self.load_scope = load_scope self.pitch_counts = pitch_counts self.pitch_to_label_dict = pitch_to_label_dict self.labels_ph = labels_ph self.noises_ph = noises_ph self.fake_waves_ph = fake_waves_ph self.saver = tf.train.Saver() self.sess = tf.Session() self.train_time = train_time self.stage_train_time_limit = stage_train_time_limit
def train(hparams): """Trains an MNIST GAN. Args: hparams: An HParams instance containing the hyperparameters for training. """ if not tf.io.gfile.exists(hparams.train_log_dir): tf.io.gfile.makedirs(hparams.train_log_dir) # Force all input processing onto CPU in order to reserve the GPU for # the forward inference and back-propagation. with tf.name_scope('inputs'), tf.device('/cpu:0'): images, one_hot_labels = data_provider.provide_data( 'train', hparams.batch_size, num_parallel_calls=4) # Define the GANModel tuple. Optionally, condition the GAN on the label or # use an InfoGAN to learn a latent representation. if hparams.gan_type == 'unconditional': gan_model = tfgan.gan_model( generator_fn=networks.unconditional_generator, discriminator_fn=networks.unconditional_discriminator, real_data=images, generator_inputs=tf.random.normal( [hparams.batch_size, hparams.noise_dims])) elif hparams.gan_type == 'conditional': noise = tf.random.normal([hparams.batch_size, hparams.noise_dims]) gan_model = tfgan.gan_model( generator_fn=networks.conditional_generator, discriminator_fn=networks.conditional_discriminator, real_data=images, generator_inputs=(noise, one_hot_labels)) elif hparams.gan_type == 'infogan': cat_dim, cont_dim = 10, 2 generator_fn = functools.partial(networks.infogan_generator, categorical_dim=cat_dim) discriminator_fn = functools.partial(networks.infogan_discriminator, categorical_dim=cat_dim, continuous_dim=cont_dim) unstructured_inputs, structured_inputs = util.get_infogan_noise( hparams.batch_size, cat_dim, cont_dim, hparams.noise_dims) gan_model = tfgan.infogan_model( generator_fn=generator_fn, discriminator_fn=discriminator_fn, real_data=images, unstructured_generator_inputs=unstructured_inputs, structured_generator_inputs=structured_inputs) tfgan.eval.add_gan_model_image_summaries(gan_model, hparams.grid_size) # Get the GANLoss tuple. You can pass a custom function, use one of the # already-implemented losses from the losses library, or use the defaults. with tf.name_scope('loss'): if hparams.gan_type == 'infogan': gan_loss = tfgan.gan_loss( gan_model, generator_loss_fn=tfgan.losses.modified_generator_loss, discriminator_loss_fn=tfgan.losses.modified_discriminator_loss, mutual_information_penalty_weight=1.0, add_summaries=True) else: gan_loss = tfgan.gan_loss(gan_model, add_summaries=True) tfgan.eval.add_regularization_loss_summaries(gan_model) # Get the GANTrain ops using custom optimizers. with tf.name_scope('train'): gen_lr, dis_lr = _learning_rate(hparams.gan_type) train_ops = tfgan.gan_train_ops( gan_model, gan_loss, generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5), discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5), summarize_gradients=True, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) # Run the alternating training loop. Skip it if no steps should be taken # (used for graph construction tests). status_message = tf.strings.join([ 'Starting train step: ', tf.as_string(tf.train.get_or_create_global_step()) ], name='status_message') if hparams.max_number_of_steps == 0: return tfgan.gan_train( train_ops, hooks=[ tf.estimator.StopAtStepHook(num_steps=hparams.max_number_of_steps), tf.estimator.LoggingTensorHook([status_message], every_n_iter=10) ], logdir=hparams.train_log_dir, get_hooks_fn=tfgan.get_joint_train_hooks(), save_checkpoint_secs=60)
def create_callable_gan_model(): return tfgan.gan_model( Generator(), Discriminator(), real_data=tf.zeros([1, 2]), generator_inputs=tf.random.normal([1, 2]))
def create_gan_model(): return tfgan.gan_model( generator_model, discriminator_model, real_data=tf.zeros([1, 2]), generator_inputs=tf.random.normal([1, 2]))
def train(hparams): """Trains a CIFAR10 GAN. Args: hparams: An HParams instance containing the hyperparameters for training. """ if not tf.io.gfile.exists(hparams.train_log_dir): tf.io.gfile.makedirs(hparams.train_log_dir) with tf.device( tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)): # Force all input processing onto CPU in order to reserve the GPU for # the forward inference and back-propagation. with tf.compat.v1.name_scope('inputs'): with tf.device('/cpu:0'): images, _ = data_provider.provide_data('train', hparams.batch_size, num_parallel_calls=4) # Define the GANModel tuple. generator_fn = networks.generator discriminator_fn = networks.discriminator generator_inputs = tf.random.normal([hparams.batch_size, 64]) gan_model = tfgan.gan_model(generator_fn, discriminator_fn, real_data=images, generator_inputs=generator_inputs) tfgan.eval.add_gan_model_image_summaries(gan_model) # Get the GANLoss tuple. Use the selected GAN loss functions. with tf.compat.v1.name_scope('loss'): gan_loss = tfgan.gan_loss(gan_model, gradient_penalty_weight=1.0, add_summaries=True) # Get the GANTrain ops using the custom optimizers and optional # discriminator weight clipping. with tf.compat.v1.name_scope('train'): gen_opt, dis_opt = _get_optimizers(hparams) train_ops = tfgan.gan_train_ops(gan_model, gan_loss, generator_optimizer=gen_opt, discriminator_optimizer=dis_opt, summarize_gradients=True) # Run the alternating training loop. Skip it if no steps should be taken # (used for graph construction tests). status_message = tf.strings.join([ 'Starting train step: ', tf.as_string(tf.compat.v1.train.get_or_create_global_step()) ], name='status_message') if hparams.max_number_of_steps == 0: return tfgan.gan_train(train_ops, hooks=([ tf.estimator.StopAtStepHook( num_steps=hparams.max_number_of_steps), tf.estimator.LoggingTensorHook([status_message], every_n_iter=10) ]), logdir=hparams.train_log_dir, master=hparams.master, is_chief=hparams.task == 0)