def _test_sync_replicas_helper(self, create_gan_model_fn, create_global_step=False): model = create_gan_model_fn() loss = train.gan_loss(model) num_trainable_vars = len(variables_lib.get_trainable_variables()) if create_global_step: gstep = variable_scope.get_variable('custom_gstep', dtype=dtypes.int32, initializer=0, trainable=False) ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, gstep) g_opt = get_sync_optimizer() d_opt = get_sync_optimizer() train_ops = train.gan_train_ops(model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt) self.assertTrue(isinstance(train_ops, namedtuples.GANTrainOps)) # No new trainable variables should have been added. self.assertEqual(num_trainable_vars, len(variables_lib.get_trainable_variables())) g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1) d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1) # Check that update op is run properly. global_step = training_util.get_or_create_global_step() with self.test_session(use_gpu=True) as sess: variables.global_variables_initializer().run() variables.local_variables_initializer().run() g_opt.chief_init_op.run() d_opt.chief_init_op.run() gstep_before = global_step.eval() # Start required queue runner for SyncReplicasOptimizer. coord = coordinator.Coordinator() g_threads = g_opt.get_chief_queue_runner().create_threads( sess, coord) d_threads = d_opt.get_chief_queue_runner().create_threads( sess, coord) g_sync_init_op.run() d_sync_init_op.run() train_ops.generator_train_op.eval() # Check that global step wasn't incremented. self.assertEqual(gstep_before, global_step.eval()) train_ops.discriminator_train_op.eval() # Check that global step wasn't incremented. self.assertEqual(gstep_before, global_step.eval()) coord.request_stop() coord.join(g_threads + d_threads)
def test_sync_replicas(self, create_gan_model_fn, create_global_step): model = create_gan_model_fn() loss = train.gan_loss(model) num_trainable_vars = len(variables_lib.get_trainable_variables()) if create_global_step: gstep = variable_scope.get_variable( 'custom_gstep', dtype=dtypes.int32, initializer=0, trainable=False) ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, gstep) g_opt = get_sync_optimizer() d_opt = get_sync_optimizer() train_ops = train.gan_train_ops( model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt) self.assertIsInstance(train_ops, namedtuples.GANTrainOps) # No new trainable variables should have been added. self.assertLen(variables_lib.get_trainable_variables(), num_trainable_vars) # Sync hooks should be populated in the GANTrainOps. self.assertLen(train_ops.train_hooks, 2) for hook in train_ops.train_hooks: self.assertIsInstance( hook, sync_replicas_optimizer._SyncReplicasOptimizerHook) sync_opts = [hook._sync_optimizer for hook in train_ops.train_hooks] self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1) d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1) # Check that update op is run properly. global_step = training_util.get_or_create_global_step() with self.test_session(use_gpu=True) as sess: variables.global_variables_initializer().run() variables.local_variables_initializer().run() g_opt.chief_init_op.run() d_opt.chief_init_op.run() gstep_before = global_step.eval() # Start required queue runner for SyncReplicasOptimizer. coord = coordinator.Coordinator() g_threads = g_opt.get_chief_queue_runner().create_threads(sess, coord) d_threads = d_opt.get_chief_queue_runner().create_threads(sess, coord) g_sync_init_op.run() d_sync_init_op.run() train_ops.generator_train_op.eval() # Check that global step wasn't incremented. self.assertEqual(gstep_before, global_step.eval()) train_ops.discriminator_train_op.eval() # Check that global step wasn't incremented. self.assertEqual(gstep_before, global_step.eval()) coord.request_stop() coord.join(g_threads + d_threads)
def _make_prediction_gan_model(input_data, input_data_domain_label, generator_fn, generator_scope): """Make a `StarGANModel` from just the generator.""" # If `generator_fn` has an argument `mode`, pass mode to it. if 'mode' in inspect.getargspec(generator_fn).args: generator_fn = functools.partial(generator_fn, mode=model_fn_lib.ModeKeys.PREDICT) with variable_scope.variable_scope(generator_scope) as gen_scope: # pylint:disable=protected-access input_data = tfgan_train._convert_tensor_or_l_or_d(input_data) input_data_domain_label = tfgan_train._convert_tensor_or_l_or_d( input_data_domain_label) # pylint:enable=protected-access generated_data = generator_fn(input_data, input_data_domain_label) generator_variables = variable_lib.get_trainable_variables(gen_scope) return tfgan_tuples.StarGANModel( input_data=input_data, input_data_domain_label=None, generated_data=generated_data, generated_data_domain_target=input_data_domain_label, reconstructed_data=None, discriminator_input_data_source_predication=None, discriminator_generated_data_source_predication=None, discriminator_input_data_domain_predication=None, discriminator_generated_data_domain_predication=None, generator_variables=generator_variables, generator_scope=generator_scope, generator_fn=generator_fn, discriminator_variables=None, discriminator_scope=None, discriminator_fn=None)
def _make_prediction_gan_model(input_data, input_data_domain_label, generator_fn, generator_scope): """Make a `StarGANModel` from just the generator.""" # If `generator_fn` has an argument `mode`, pass mode to it. if 'mode' in inspect.getargspec(generator_fn).args: generator_fn = functools.partial( generator_fn, mode=model_fn_lib.ModeKeys.PREDICT) with variable_scope.variable_scope(generator_scope) as gen_scope: # pylint:disable=protected-access input_data = tfgan_train._convert_tensor_or_l_or_d(input_data) input_data_domain_label = tfgan_train._convert_tensor_or_l_or_d( input_data_domain_label) # pylint:enable=protected-access generated_data = generator_fn(input_data, input_data_domain_label) generator_variables = variable_lib.get_trainable_variables(gen_scope) return tfgan_tuples.StarGANModel( input_data=input_data, input_data_domain_label=None, generated_data=generated_data, generated_data_domain_target=input_data_domain_label, reconstructed_data=None, discriminator_input_data_source_predication=None, discriminator_generated_data_source_predication=None, discriminator_input_data_domain_predication=None, discriminator_generated_data_domain_predication=None, generator_variables=generator_variables, generator_scope=generator_scope, generator_fn=generator_fn, discriminator_variables=None, discriminator_scope=None, discriminator_fn=None)
def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope): """Make a `GANModel` from just the generator.""" with variable_scope.variable_scope(generator_scope) as gen_scope: generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs) # pylint:disable=protected-access generated_data = generator_fn(generator_inputs) generator_variables = variable_lib.get_trainable_variables(gen_scope) return tfgan_tuples.GANModel( generator_inputs, generated_data, generator_variables, gen_scope, generator_fn, real_data=None, discriminator_real_outputs=None, discriminator_gen_outputs=None, discriminator_variables=None, discriminator_scope=None, discriminator_fn=None)
def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope): """Make a `GANModel` from just the generator.""" with variable_scope.variable_scope(generator_scope) as gen_scope: generator_inputs = tfgan_train._convert_tensor_or_l_or_d( generator_inputs) # pylint:disable=protected-access generated_data = generator_fn(generator_inputs) generator_variables = variable_lib.get_trainable_variables(gen_scope) return tfgan_tuples.GANModel(generator_inputs, generated_data, generator_variables, gen_scope, generator_fn, real_data=None, discriminator_real_outputs=None, discriminator_gen_outputs=None, discriminator_variables=None, discriminator_scope=None, discriminator_fn=None)
def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope): """Make a `GANModel` from just the generator.""" # If `generator_fn` has an argument `mode`, pass mode to it. if 'mode' in inspect.getargspec(generator_fn).args: generator_fn = functools.partial(generator_fn, mode=model_fn_lib.ModeKeys.PREDICT) with variable_scope.variable_scope(generator_scope) as gen_scope: generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs) # pylint:disable=protected-access generated_data = generator_fn(generator_inputs) generator_variables = variable_lib.get_trainable_variables(gen_scope) return tfgan_tuples.GANModel( generator_inputs, generated_data, generator_variables, gen_scope, generator_fn, real_data=None, discriminator_real_outputs=None, discriminator_gen_outputs=None, discriminator_variables=None, discriminator_scope=None, discriminator_fn=None)
def infogan_model( # Lambdas defining models. generator_fn, discriminator_fn, # Real data and conditioning. real_data, unstructured_generator_inputs, structured_generator_inputs, # Optional scopes. generator_scope='Generator', discriminator_scope='Discriminator'): """Returns an InfoGAN model outputs and variables. See https://arxiv.org/abs/1606.03657 for more details. Args: generator_fn: A python lambda that takes a list of Tensors as inputs and returns the outputs of the GAN generator. discriminator_fn: A python lambda that takes `real_data`/`generated data` and `generator_inputs`. Outputs a 2-tuple of (logits, distribution_list). `logits` are in the range [-inf, inf], and `distribution_list` is a list of Tensorflow distributions representing the predicted noise distribution of the ith structure noise. real_data: A Tensor representing the real data. unstructured_generator_inputs: A list of Tensors to the generator. These tensors represent the unstructured noise or conditioning. structured_generator_inputs: A list of Tensors to the generator. These tensors must have high mutual information with the recognizer. generator_scope: Optional generator variable scope. Useful if you want to reuse a subgraph that has already been created. discriminator_scope: Optional discriminator variable scope. Useful if you want to reuse a subgraph that has already been created. Returns: An InfoGANModel namedtuple. Raises: ValueError: If the generator outputs a Tensor that isn't the same shape as `real_data`. ValueError: If the discriminator output is malformed. """ # Create models with variable_scope.variable_scope(generator_scope) as gen_scope: unstructured_generator_inputs = _convert_tensor_or_l_or_d( unstructured_generator_inputs) structured_generator_inputs = _convert_tensor_or_l_or_d( structured_generator_inputs) generator_inputs = (unstructured_generator_inputs + structured_generator_inputs) generated_data = generator_fn(generator_inputs) with variable_scope.variable_scope(discriminator_scope) as disc_scope: dis_gen_outputs, predicted_distributions = discriminator_fn( generated_data, generator_inputs) _validate_distributions(predicted_distributions, structured_generator_inputs) with variable_scope.variable_scope(disc_scope, reuse=True): real_data = ops.convert_to_tensor(real_data) dis_real_outputs, _ = discriminator_fn(real_data, generator_inputs) if not generated_data.get_shape().is_compatible_with( real_data.get_shape()): raise ValueError( 'Generator output shape (%s) must be the same shape as real data ' '(%s).' % (generated_data.get_shape(), real_data.get_shape())) # Get model-specific variables. generator_variables = variables_lib.get_trainable_variables(gen_scope) discriminator_variables = variables_lib.get_trainable_variables(disc_scope) return namedtuples.InfoGANModel( generator_inputs, generated_data, generator_variables, gen_scope, generator_fn, real_data, dis_real_outputs, dis_gen_outputs, discriminator_variables, disc_scope, lambda x, y: discriminator_fn(x, y)[0], # conform to non-InfoGAN API structured_generator_inputs, predicted_distributions)
def combine_adversarial_loss(main_loss, adversarial_loss, weight_factor=None, gradient_ratio=None, gradient_ratio_epsilon=1e-6, variables=None, scalar_summaries=True, gradient_summaries=True, scope=None): """Utility to combine main and adversarial losses. This utility combines the main and adversarial losses in one of two ways. 1) Fixed coefficient on adversarial loss. Use `weight_factor` in this case. 2) Fixed ratio of gradients. Use `gradient_ratio` in this case. This is often used to make sure both losses affect weights roughly equally, as in https://arxiv.org/pdf/1705.05823. One can optionally also visualize the scalar and gradient behavior of the losses. Args: main_loss: A floating scalar Tensor indicating the main loss. adversarial_loss: A floating scalar Tensor indication the adversarial loss. weight_factor: If not `None`, the coefficient by which to multiply the adversarial loss. Exactly one of this and `gradient_ratio` must be non-None. gradient_ratio: If not `None`, the ratio of the magnitude of the gradients. Specifically, gradient_ratio = grad_mag(main_loss) / grad_mag(adversarial_loss) Exactly one of this and `weight_factor` must be non-None. gradient_ratio_epsilon: An epsilon to add to the adversarial loss coefficient denominator, to avoid division-by-zero. variables: List of variables to calculate gradients with respect to. If not present, defaults to all trainable variables. scalar_summaries: Create scalar summaries of losses. gradient_summaries: Create gradient summaries of losses. scope: Optional name scope. Returns: A floating scalar Tensor indicating the desired combined loss. Raises: ValueError: Malformed input. """ _validate_args([main_loss, adversarial_loss], weight_factor, gradient_ratio) if variables is None: variables = contrib_variables_lib.get_trainable_variables() with ops.name_scope(scope, 'adversarial_loss', values=[main_loss, adversarial_loss]): # Compute gradients if we will need them. if gradient_summaries or gradient_ratio is not None: main_loss_grad_mag = _numerically_stable_global_norm( gradients_impl.gradients(main_loss, variables)) adv_loss_grad_mag = _numerically_stable_global_norm( gradients_impl.gradients(adversarial_loss, variables)) # Add summaries, if applicable. if scalar_summaries: summary.scalar('main_loss', main_loss) summary.scalar('adversarial_loss', adversarial_loss) if gradient_summaries: summary.scalar('main_loss_gradients', main_loss_grad_mag) summary.scalar('adversarial_loss_gradients', adv_loss_grad_mag) # Combine losses in the appropriate way. # If `weight_factor` is always `0`, avoid computing the adversarial loss # tensor entirely. if _used_weight((weight_factor, gradient_ratio)) == 0: final_loss = main_loss elif weight_factor is not None: final_loss = (main_loss + array_ops.stop_gradient(weight_factor) * adversarial_loss) elif gradient_ratio is not None: grad_mag_ratio = main_loss_grad_mag / ( adv_loss_grad_mag + gradient_ratio_epsilon) adv_coeff = grad_mag_ratio / gradient_ratio summary.scalar('adversarial_coefficient', adv_coeff) final_loss = (main_loss + array_ops.stop_gradient(adv_coeff) * adversarial_loss) return final_loss
def combine_adversarial_loss(main_loss, adversarial_loss, weight_factor=None, gradient_ratio=None, gradient_ratio_epsilon=1e-6, variables=None, scalar_summaries=True, gradient_summaries=True, scope=None): """Utility to combine main and adversarial losses. This utility combines the main and adversarial losses in one of two ways. 1) Fixed coefficient on adversarial loss. Use `weight_factor` in this case. 2) Fixed ratio of gradients. Use `gradient_ratio` in this case. This is often used to make sure both losses affect weights roughly equally, as in https://arxiv.org/pdf/1705.05823. One can optionally also visualize the scalar and gradient behavior of the losses. Args: main_loss: A floating scalar Tensor indicating the main loss. adversarial_loss: A floating scalar Tensor indication the adversarial loss. weight_factor: If not `None`, the coefficient by which to multiply the adversarial loss. Exactly one of this and `gradient_ratio` must be non-None. gradient_ratio: If not `None`, the ratio of the magnitude of the gradients. Specifically, gradient_ratio = grad_mag(main_loss) / grad_mag(adversarial_loss) Exactly one of this and `weight_factor` must be non-None. gradient_ratio_epsilon: An epsilon to add to the adversarial loss coefficient denominator, to avoid division-by-zero. variables: List of variables to calculate gradients with respect to. If not present, defaults to all trainable variables. scalar_summaries: Create scalar summaries of losses. gradient_summaries: Create gradient summaries of losses. scope: Optional name scope. Returns: A floating scalar Tensor indicating the desired combined loss. Raises: ValueError: Malformed input. """ _validate_args([main_loss, adversarial_loss], weight_factor, gradient_ratio) if variables is None: variables = contrib_variables_lib.get_trainable_variables() with ops.name_scope(scope, 'adversarial_loss', values=[main_loss, adversarial_loss]): # Compute gradients if we will need them. if gradient_summaries or gradient_ratio is not None: main_loss_grad_mag = _numerically_stable_global_norm( gradients_impl.gradients(main_loss, variables)) adv_loss_grad_mag = _numerically_stable_global_norm( gradients_impl.gradients(adversarial_loss, variables)) # Add summaries, if applicable. if scalar_summaries: summary.scalar('main_loss', main_loss) summary.scalar('adversarial_loss', adversarial_loss) if gradient_summaries: summary.scalar('main_loss_gradients', main_loss_grad_mag) summary.scalar('adversarial_loss_gradients', adv_loss_grad_mag) # Combine losses in the appropriate way. # If `weight_factor` is always `0`, avoid computing the adversarial loss # tensor entirely. if _used_weight((weight_factor, gradient_ratio)) == 0: final_loss = main_loss elif weight_factor is not None: final_loss = ( main_loss + array_ops.stop_gradient(weight_factor) * adversarial_loss) elif gradient_ratio is not None: grad_mag_ratio = main_loss_grad_mag / (adv_loss_grad_mag + gradient_ratio_epsilon) adv_coeff = grad_mag_ratio / gradient_ratio summary.scalar('adversarial_coefficient', adv_coeff) final_loss = ( main_loss + array_ops.stop_gradient(adv_coeff) * adversarial_loss) return final_loss
def get_estimator_spec(self, real_features, real_class_labels, mode): with tf.variable_scope(self.D_scope) as dscope: is_training = True if mode == 'train' or mode == 'infer' else False real_score, real_logits = self.discriminator(real_features, is_training) tf.summary.scalar('accuracy', accuracy(real_class_labels, real_logits)) if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec( mode=mode, predictions=real_logits) with tf.variable_scope('Latent'): noise = tf.random_normal( [tf.shape(real_class_labels)[0], self.nlatent], dtype=tf.float32, name='Z') with tf.name_scope('fake_class_labels'): shape = tf.shape(real_class_labels) random = tf.random_uniform(shape=[shape[0], shape[1]], maxval=1) fake_class_labels = tf.cast(random < 0.096, dtype=tf.float32) with tf.variable_scope(self.G_scope): fake_features = self.generator(noise, fake_class_labels) with tf.variable_scope(self.D_scope, reuse=True): fake_score, fake_logits = self.discriminator(fake_features, is_training) with ops.name_scope('losses'): loss_tuple = gan_loss( discriminator_fn = self.discriminator, discriminator_scope = self.D_scope, real_features = real_features, fake_features = fake_features, disc_real_score = real_score, disc_fake_score = fake_score, disc_real_logits = real_logits, disc_fake_logits = fake_logits, real_class_labels = real_class_labels, fake_class_labels = fake_class_labels) total_loss = loss_tuple.discriminator_loss + loss_tuple.generator_loss generator_variables = variables_lib.get_trainable_variables(self.G_scope) discriminator_variables = variables_lib.get_trainable_variables(self.D_scope) G_train_op = tf.train.AdamOptimizer( self.learning_rate, self.beta1, self.beta2, name='generator_optimizer').minimize( loss_tuple.generator_loss, var_list=generator_variables) D_train_op = tf.train.AdamOptimizer( self.learning_rate, self.beta1, self.beta2, name='discriminator_optimizer').minimize(\ loss_tuple.discriminator_loss, var_list=discriminator_variables) train_hook = PGTrainHook( G_train_op, D_train_op, self.alpha, self.res, self.stablize_increment, self.fade_increment, self.res_increment, self.reset_alpha) eval_metric_ops = get_eval_metric_ops(real_class_labels, real_logits) return tf.estimator.EstimatorSpec( loss=total_loss, mode=mode, train_op=self.global_step_inc, training_hooks = [train_hook], eval_metric_ops=None) """
def infogan_model( # Lambdas defining models. generator_fn, discriminator_fn, # Real data and conditioning. real_data, unstructured_generator_inputs, structured_generator_inputs, # Optional scopes. generator_scope='Generator', discriminator_scope='Discriminator'): """Returns an InfoGAN model outputs and variables. See https://arxiv.org/abs/1606.03657 for more details. Args: generator_fn: A python lambda that takes a list of Tensors as inputs and returns the outputs of the GAN generator. discriminator_fn: A python lambda that takes `real_data`/`generated data` and `generator_inputs`. Outputs a 2-tuple of (logits, distribution_list). `logits` are in the range [-inf, inf], and `distribution_list` is a list of Tensorflow distributions representing the predicted noise distribution of the ith structure noise. real_data: A Tensor representing the real data. unstructured_generator_inputs: A list of Tensors to the generator. These tensors represent the unstructured noise or conditioning. structured_generator_inputs: A list of Tensors to the generator. These tensors must have high mutual information with the recognizer. generator_scope: Optional generator variable scope. Useful if you want to reuse a subgraph that has already been created. discriminator_scope: Optional discriminator variable scope. Useful if you want to reuse a subgraph that has already been created. Returns: An InfoGANModel namedtuple. Raises: ValueError: If the generator outputs a Tensor that isn't the same shape as `real_data`. ValueError: If the discriminator output is malformed. """ # Create models with variable_scope.variable_scope(generator_scope) as gen_scope: unstructured_generator_inputs = _convert_tensor_or_l_or_d( unstructured_generator_inputs) structured_generator_inputs = _convert_tensor_or_l_or_d( structured_generator_inputs) generator_inputs = ( unstructured_generator_inputs + structured_generator_inputs) generated_data = generator_fn(generator_inputs) with variable_scope.variable_scope(discriminator_scope) as disc_scope: dis_gen_outputs, predicted_distributions = discriminator_fn( generated_data, generator_inputs) _validate_distributions(predicted_distributions, structured_generator_inputs) with variable_scope.variable_scope(disc_scope, reuse=True): real_data = ops.convert_to_tensor(real_data) dis_real_outputs, _ = discriminator_fn(real_data, generator_inputs) if not generated_data.get_shape().is_compatible_with(real_data.get_shape()): raise ValueError( 'Generator output shape (%s) must be the same shape as real data ' '(%s).' % (generated_data.get_shape(), real_data.get_shape())) # Get model-specific variables. generator_variables = variables_lib.get_trainable_variables(gen_scope) discriminator_variables = variables_lib.get_trainable_variables( disc_scope) return namedtuples.InfoGANModel( generator_inputs, generated_data, generator_variables, gen_scope, generator_fn, real_data, dis_real_outputs, dis_gen_outputs, discriminator_variables, disc_scope, lambda x, y: discriminator_fn(x, y)[0], # conform to non-InfoGAN API structured_generator_inputs, predicted_distributions, discriminator_fn)
def stargan_model(generator_fn, discriminator_fn, input_data, input_data_domain_label, generator_scope='Generator', discriminator_scope='Discriminator'): """Returns a StarGAN model outputs and variables. See https://arxiv.org/abs/1711.09020 for more details. Args: generator_fn: A python lambda that takes `inputs` and `targets` as inputs and returns 'generated_data' as the transformed version of `input` based on the `target`. `input` has shape (n, h, w, c), `targets` has shape (n, num_domains), and `generated_data` has the same shape as `input`. discriminator_fn: A python lambda that takes `inputs` and `num_domains` as inputs and returns a tuple (`source_prediction`, `domain_prediction`). `source_prediction` represents the source(real/generated) prediction by the discriminator, and `domain_prediction` represents the domain prediction/classification by the discriminator. `source_prediction` has shape (n) and `domain_prediction` has shape (n, num_domains). input_data: Tensor or a list of tensor of shape (n, h, w, c) representing the real input images. input_data_domain_label: Tensor or a list of tensor of shape (batch_size, num_domains) representing the domain label associated with the real images. generator_scope: Optional generator variable scope. Useful if you want to reuse a subgraph that has already been created. discriminator_scope: Optional discriminator variable scope. Useful if you want to reuse a subgraph that has already been created. Returns: StarGANModel nametuple return the tensor that are needed to compute the loss. Raises: ValueError: If the shape of `input_data_domain_label` is not rank 2 or fully defined in every dimensions. """ # Convert to tensor. input_data = _convert_tensor_or_l_or_d(input_data) input_data_domain_label = _convert_tensor_or_l_or_d( input_data_domain_label) # Convert list of tensor to a single tensor if applicable. if isinstance(input_data, (list, tuple)): input_data = array_ops.concat( [ops.convert_to_tensor(x) for x in input_data], 0) if isinstance(input_data_domain_label, (list, tuple)): input_data_domain_label = array_ops.concat( [ops.convert_to_tensor(x) for x in input_data_domain_label], 0) # Get batch_size, num_domains from the labels. input_data_domain_label.shape.assert_has_rank(2) input_data_domain_label.shape.assert_is_fully_defined() batch_size, num_domains = input_data_domain_label.shape.as_list() # Transform input_data to random target domains. with variable_scope.variable_scope(generator_scope) as generator_scope: generated_data_domain_target = _generate_stargan_random_domain_target( batch_size, num_domains) generated_data = generator_fn(input_data, generated_data_domain_target) # Transform generated_data back to the original input_data domain. with variable_scope.variable_scope(generator_scope, reuse=True): reconstructed_data = generator_fn(generated_data, input_data_domain_label) # Predict source and domain for the generated_data using the discriminator. with variable_scope.variable_scope( discriminator_scope) as discriminator_scope: disc_gen_data_source_pred, disc_gen_data_domain_pred = discriminator_fn( generated_data, num_domains) # Predict source and domain for the input_data using the discriminator. with variable_scope.variable_scope(discriminator_scope, reuse=True): disc_input_data_source_pred, disc_input_data_domain_pred = discriminator_fn( input_data, num_domains) # Collect trainable variables from the neural networks. generator_variables = variables_lib.get_trainable_variables( generator_scope) discriminator_variables = variables_lib.get_trainable_variables( discriminator_scope) # Create the StarGANModel namedtuple. return namedtuples.StarGANModel( input_data=input_data, input_data_domain_label=input_data_domain_label, generated_data=generated_data, generated_data_domain_target=generated_data_domain_target, reconstructed_data=reconstructed_data, discriminator_input_data_source_predication=disc_input_data_source_pred, discriminator_generated_data_source_predication= disc_gen_data_source_pred, discriminator_input_data_domain_predication=disc_input_data_domain_pred, discriminator_generated_data_domain_predication= disc_gen_data_domain_pred, generator_variables=generator_variables, generator_scope=generator_scope, generator_fn=generator_fn, discriminator_variables=discriminator_variables, discriminator_scope=discriminator_scope, discriminator_fn=discriminator_fn)
def acgan_model( # Lambdas defining models. generator_fn, discriminator_fn, # Real data and conditioning. real_data, generator_inputs, one_hot_labels, # Optional scopes. generator_scope='Generator', discriminator_scope='Discriminator', check_shapes=True): """Returns an ACGANModel contains all the pieces needed for ACGAN training. The `acgan_model` is the same as the `gan_model` with the only difference being that the discriminator additionally outputs logits to classify the input (real or generated). Therefore, an explicit field holding one_hot_labels is necessary, as well as a discriminator_fn that outputs a 2-tuple holding the logits for real/fake and classification. See https://arxiv.org/abs/1610.09585 for more details. Args: generator_fn: A python lambda that takes `generator_inputs` as inputs and returns the outputs of the GAN generator. discriminator_fn: A python lambda that takes `real_data`/`generated data` and `generator_inputs`. Outputs a tuple consisting of two Tensors: (1) real/fake logits in the range [-inf, inf] (2) classification logits in the range [-inf, inf] real_data: A Tensor representing the real data. generator_inputs: A Tensor or list of Tensors to the generator. In the vanilla GAN case, this might be a single noise Tensor. In the conditional GAN case, this might be the generator's conditioning. one_hot_labels: A Tensor holding one-hot-labels for the batch. Needed by acgan_loss. generator_scope: Optional generator variable scope. Useful if you want to reuse a subgraph that has already been created. discriminator_scope: Optional discriminator variable scope. Useful if you want to reuse a subgraph that has already been created. check_shapes: If `True`, check that generator produces Tensors that are the same shape as real data. Otherwise, skip this check. Returns: A ACGANModel namedtuple. Raises: ValueError: If the generator outputs a Tensor that isn't the same shape as `real_data`. TypeError: If the discriminator does not output a tuple consisting of (discrimination logits, classification logits). """ # Create models with variable_scope.variable_scope(generator_scope) as gen_scope: generator_inputs = _convert_tensor_or_l_or_d(generator_inputs) generated_data = generator_fn(generator_inputs) with variable_scope.variable_scope(discriminator_scope) as dis_scope: (discriminator_gen_outputs, discriminator_gen_classification_logits ) = _validate_acgan_discriminator_outputs( discriminator_fn(generated_data, generator_inputs)) with variable_scope.variable_scope(dis_scope, reuse=True): real_data = ops.convert_to_tensor(real_data) (discriminator_real_outputs, discriminator_real_classification_logits ) = _validate_acgan_discriminator_outputs( discriminator_fn(real_data, generator_inputs)) if check_shapes: if not generated_data.shape.is_compatible_with(real_data.shape): raise ValueError( 'Generator output shape (%s) must be the same shape as real data ' '(%s).' % (generated_data.shape, real_data.shape)) # Get model-specific variables. generator_variables = variables_lib.get_trainable_variables(gen_scope) discriminator_variables = variables_lib.get_trainable_variables(dis_scope) return namedtuples.ACGANModel(generator_inputs, generated_data, generator_variables, gen_scope, generator_fn, real_data, discriminator_real_outputs, discriminator_gen_outputs, discriminator_variables, dis_scope, discriminator_fn, one_hot_labels, discriminator_real_classification_logits, discriminator_gen_classification_logits)
def gan_model( # Lambdas defining models. generator_fn, discriminator_fn, # Real data and conditioning. real_data, generator_inputs, # Optional scopes. generator_scope='Generator', discriminator_scope='Discriminator', # Options. check_shapes=True): """Returns GAN model outputs and variables. Args: generator_fn: A python lambda that takes `generator_inputs` as inputs and returns the outputs of the GAN generator. discriminator_fn: A python lambda that takes `real_data`/`generated data` and `generator_inputs`. Outputs a Tensor in the range [-inf, inf]. real_data: A Tensor representing the real data. generator_inputs: A Tensor or list of Tensors to the generator. In the vanilla GAN case, this might be a single noise Tensor. In the conditional GAN case, this might be the generator's conditioning. generator_scope: Optional generator variable scope. Useful if you want to reuse a subgraph that has already been created. discriminator_scope: Optional discriminator variable scope. Useful if you want to reuse a subgraph that has already been created. check_shapes: If `True`, check that generator produces Tensors that are the same shape as real data. Otherwise, skip this check. Returns: A GANModel namedtuple. Raises: ValueError: If the generator outputs a Tensor that isn't the same shape as `real_data`. """ # Create models with variable_scope.variable_scope(generator_scope) as gen_scope: generator_inputs = _convert_tensor_or_l_or_d(generator_inputs) generated_data = generator_fn(generator_inputs) with variable_scope.variable_scope(discriminator_scope) as dis_scope: discriminator_gen_outputs = discriminator_fn(generated_data, generator_inputs) with variable_scope.variable_scope(dis_scope, reuse=True): real_data = ops.convert_to_tensor(real_data) discriminator_real_outputs = discriminator_fn(real_data, generator_inputs) if check_shapes: if not generated_data.shape.is_compatible_with(real_data.shape): raise ValueError( 'Generator output shape (%s) must be the same shape as real data ' '(%s).' % (generated_data.shape, real_data.shape)) # Get model-specific variables. generator_variables = variables_lib.get_trainable_variables(gen_scope) discriminator_variables = variables_lib.get_trainable_variables(dis_scope) return namedtuples.GANModel(generator_inputs, generated_data, generator_variables, gen_scope, generator_fn, real_data, discriminator_real_outputs, discriminator_gen_outputs, discriminator_variables, dis_scope, discriminator_fn)
def acgan_model( # Lambdas defining models. generator_fn, discriminator_fn, # Real data and conditioning. real_data, generator_inputs, one_hot_labels, # Optional scopes. generator_scope='Generator', discriminator_scope='Discriminator', # Options. check_shapes=True): """Returns an ACGANModel contains all the pieces needed for ACGAN training. The `acgan_model` is the same as the `gan_model` with the only difference being that the discriminator additionally outputs logits to classify the input (real or generated). Therefore, an explicit field holding one_hot_labels is necessary, as well as a discriminator_fn that outputs a 2-tuple holding the logits for real/fake and classification. See https://arxiv.org/abs/1610.09585 for more details. Args: generator_fn: A python lambda that takes `generator_inputs` as inputs and returns the outputs of the GAN generator. discriminator_fn: A python lambda that takes `real_data`/`generated data` and `generator_inputs`. Outputs a tuple consisting of two Tensors: (1) real/fake logits in the range [-inf, inf] (2) classification logits in the range [-inf, inf] real_data: A Tensor representing the real data. generator_inputs: A Tensor or list of Tensors to the generator. In the vanilla GAN case, this might be a single noise Tensor. In the conditional GAN case, this might be the generator's conditioning. one_hot_labels: A Tensor holding one-hot-labels for the batch. Needed by acgan_loss. generator_scope: Optional generator variable scope. Useful if you want to reuse a subgraph that has already been created. discriminator_scope: Optional discriminator variable scope. Useful if you want to reuse a subgraph that has already been created. check_shapes: If `True`, check that generator produces Tensors that are the same shape as real data. Otherwise, skip this check. Returns: A ACGANModel namedtuple. Raises: ValueError: If the generator outputs a Tensor that isn't the same shape as `real_data`. TypeError: If the discriminator does not output a tuple consisting of (discrimination logits, classification logits). """ # Create models with variable_scope.variable_scope(generator_scope) as gen_scope: generator_inputs = _convert_tensor_or_l_or_d(generator_inputs) generated_data = generator_fn(generator_inputs) with variable_scope.variable_scope(discriminator_scope) as dis_scope: (discriminator_gen_outputs, discriminator_gen_classification_logits ) = _validate_acgan_discriminator_outputs( discriminator_fn(generated_data, generator_inputs)) with variable_scope.variable_scope(dis_scope, reuse=True): real_data = ops.convert_to_tensor(real_data) (discriminator_real_outputs, discriminator_real_classification_logits ) = _validate_acgan_discriminator_outputs( discriminator_fn(real_data, generator_inputs)) if check_shapes: if not generated_data.shape.is_compatible_with(real_data.shape): raise ValueError( 'Generator output shape (%s) must be the same shape as real data ' '(%s).' % (generated_data.shape, real_data.shape)) # Get model-specific variables. generator_variables = variables_lib.get_trainable_variables(gen_scope) discriminator_variables = variables_lib.get_trainable_variables( dis_scope) return namedtuples.ACGANModel( generator_inputs, generated_data, generator_variables, gen_scope, generator_fn, real_data, discriminator_real_outputs, discriminator_gen_outputs, discriminator_variables, dis_scope, discriminator_fn, one_hot_labels, discriminator_real_classification_logits, discriminator_gen_classification_logits)
def gan_model( # Lambdas defining models. generator_fn, discriminator_fn, # Real data and conditioning. real_data, generator_inputs, # Optional scopes. generator_scope='Generator', discriminator_scope='Discriminator', # Options. check_shapes=True): """Returns GAN model outputs and variables. Args: generator_fn: A python lambda that takes `generator_inputs` as inputs and returns the outputs of the GAN generator. discriminator_fn: A python lambda that takes `real_data`/`generated data` and `generator_inputs`. Outputs a Tensor in the range [-inf, inf]. real_data: A Tensor representing the real data. generator_inputs: A Tensor or list of Tensors to the generator. In the vanilla GAN case, this might be a single noise Tensor. In the conditional GAN case, this might be the generator's conditioning. generator_scope: Optional generator variable scope. Useful if you want to reuse a subgraph that has already been created. discriminator_scope: Optional discriminator variable scope. Useful if you want to reuse a subgraph that has already been created. check_shapes: If `True`, check that generator produces Tensors that are the same shape as real data. Otherwise, skip this check. Returns: A GANModel namedtuple. Raises: ValueError: If the generator outputs a Tensor that isn't the same shape as `real_data`. """ # Create models with variable_scope.variable_scope(generator_scope) as gen_scope: generator_inputs = _convert_tensor_or_l_or_d(generator_inputs) generated_data = generator_fn(generator_inputs) with variable_scope.variable_scope(discriminator_scope) as dis_scope: discriminator_gen_outputs = discriminator_fn(generated_data, generator_inputs) with variable_scope.variable_scope(dis_scope, reuse=True): real_data = ops.convert_to_tensor(real_data) discriminator_real_outputs = discriminator_fn(real_data, generator_inputs) if check_shapes: if not generated_data.shape.is_compatible_with(real_data.shape): raise ValueError( 'Generator output shape (%s) must be the same shape as real data ' '(%s).' % (generated_data.shape, real_data.shape)) # Get model-specific variables. generator_variables = variables_lib.get_trainable_variables(gen_scope) discriminator_variables = variables_lib.get_trainable_variables(dis_scope) return namedtuples.GANModel( generator_inputs, generated_data, generator_variables, gen_scope, generator_fn, real_data, discriminator_real_outputs, discriminator_gen_outputs, discriminator_variables, dis_scope, discriminator_fn)
def stargan_model(generator_fn, discriminator_fn, input_data, input_data_domain_label, generator_scope='Generator', discriminator_scope='Discriminator'): """Returns a StarGAN model outputs and variables. See https://arxiv.org/abs/1711.09020 for more details. Args: generator_fn: A python lambda that takes `inputs` and `targets` as inputs and returns 'generated_data' as the transformed version of `input` based on the `target`. `input` has shape (n, h, w, c), `targets` has shape (n, num_domains), and `generated_data` has the same shape as `input`. discriminator_fn: A python lambda that takes `inputs` and `num_domains` as inputs and returns a tuple (`source_prediction`, `domain_prediction`). `source_prediction` represents the source(real/generated) prediction by the discriminator, and `domain_prediction` represents the domain prediction/classification by the discriminator. `source_prediction` has shape (n) and `domain_prediction` has shape (n, num_domains). input_data: Tensor or a list of tensor of shape (n, h, w, c) representing the real input images. input_data_domain_label: Tensor or a list of tensor of shape (batch_size, num_domains) representing the domain label associated with the real images. generator_scope: Optional generator variable scope. Useful if you want to reuse a subgraph that has already been created. discriminator_scope: Optional discriminator variable scope. Useful if you want to reuse a subgraph that has already been created. Returns: StarGANModel nametuple return the tensor that are needed to compute the loss. Raises: ValueError: If the shape of `input_data_domain_label` is not rank 2 or fully defined in every dimensions. """ # Convert to tensor. input_data = _convert_tensor_or_l_or_d(input_data) input_data_domain_label = _convert_tensor_or_l_or_d(input_data_domain_label) # Convert list of tensor to a single tensor if applicable. if isinstance(input_data, (list, tuple)): input_data = array_ops.concat( [ops.convert_to_tensor(x) for x in input_data], 0) if isinstance(input_data_domain_label, (list, tuple)): input_data_domain_label = array_ops.concat( [ops.convert_to_tensor(x) for x in input_data_domain_label], 0) # Get batch_size, num_domains from the labels. input_data_domain_label.shape.assert_has_rank(2) input_data_domain_label.shape.assert_is_fully_defined() batch_size, num_domains = input_data_domain_label.shape.as_list() # Transform input_data to random target domains. with variable_scope.variable_scope(generator_scope) as generator_scope: generated_data_domain_target = _generate_stargan_random_domain_target( batch_size, num_domains) generated_data = generator_fn(input_data, generated_data_domain_target) # Transform generated_data back to the original input_data domain. with variable_scope.variable_scope(generator_scope, reuse=True): reconstructed_data = generator_fn(generated_data, input_data_domain_label) # Predict source and domain for the generated_data using the discriminator. with variable_scope.variable_scope( discriminator_scope) as discriminator_scope: disc_gen_data_source_pred, disc_gen_data_domain_pred = discriminator_fn( generated_data, num_domains) # Predict source and domain for the input_data using the discriminator. with variable_scope.variable_scope(discriminator_scope, reuse=True): disc_input_data_source_pred, disc_input_data_domain_pred = discriminator_fn( input_data, num_domains) # Collect trainable variables from the neural networks. generator_variables = variables_lib.get_trainable_variables(generator_scope) discriminator_variables = variables_lib.get_trainable_variables( discriminator_scope) # Create the StarGANModel namedtuple. return namedtuples.StarGANModel( input_data=input_data, input_data_domain_label=input_data_domain_label, generated_data=generated_data, generated_data_domain_target=generated_data_domain_target, reconstructed_data=reconstructed_data, discriminator_input_data_source_predication=disc_input_data_source_pred, discriminator_generated_data_source_predication=disc_gen_data_source_pred, discriminator_input_data_domain_predication=disc_input_data_domain_pred, discriminator_generated_data_domain_predication=disc_gen_data_domain_pred, generator_variables=generator_variables, generator_scope=generator_scope, generator_fn=generator_fn, discriminator_variables=discriminator_variables, discriminator_scope=discriminator_scope, discriminator_fn=discriminator_fn)