예제 #1
0
    def _test_sync_replicas_helper(self,
                                   create_gan_model_fn,
                                   create_global_step=False):
        model = create_gan_model_fn()
        loss = train.gan_loss(model)
        num_trainable_vars = len(variables_lib.get_trainable_variables())

        if create_global_step:
            gstep = variable_scope.get_variable('custom_gstep',
                                                dtype=dtypes.int32,
                                                initializer=0,
                                                trainable=False)
            ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, gstep)

        g_opt = get_sync_optimizer()
        d_opt = get_sync_optimizer()
        train_ops = train.gan_train_ops(model,
                                        loss,
                                        generator_optimizer=g_opt,
                                        discriminator_optimizer=d_opt)
        self.assertTrue(isinstance(train_ops, namedtuples.GANTrainOps))
        # No new trainable variables should have been added.
        self.assertEqual(num_trainable_vars,
                         len(variables_lib.get_trainable_variables()))

        g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1)
        d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1)

        # Check that update op is run properly.
        global_step = training_util.get_or_create_global_step()
        with self.test_session(use_gpu=True) as sess:
            variables.global_variables_initializer().run()
            variables.local_variables_initializer().run()

            g_opt.chief_init_op.run()
            d_opt.chief_init_op.run()

            gstep_before = global_step.eval()

            # Start required queue runner for SyncReplicasOptimizer.
            coord = coordinator.Coordinator()
            g_threads = g_opt.get_chief_queue_runner().create_threads(
                sess, coord)
            d_threads = d_opt.get_chief_queue_runner().create_threads(
                sess, coord)

            g_sync_init_op.run()
            d_sync_init_op.run()

            train_ops.generator_train_op.eval()
            # Check that global step wasn't incremented.
            self.assertEqual(gstep_before, global_step.eval())

            train_ops.discriminator_train_op.eval()
            # Check that global step wasn't incremented.
            self.assertEqual(gstep_before, global_step.eval())

            coord.request_stop()
            coord.join(g_threads + d_threads)
예제 #2
0
  def test_sync_replicas(self, create_gan_model_fn, create_global_step):
    model = create_gan_model_fn()
    loss = train.gan_loss(model)
    num_trainable_vars = len(variables_lib.get_trainable_variables())

    if create_global_step:
      gstep = variable_scope.get_variable(
          'custom_gstep', dtype=dtypes.int32, initializer=0, trainable=False)
      ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, gstep)

    g_opt = get_sync_optimizer()
    d_opt = get_sync_optimizer()
    train_ops = train.gan_train_ops(
        model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt)
    self.assertIsInstance(train_ops, namedtuples.GANTrainOps)
    # No new trainable variables should have been added.
    self.assertLen(variables_lib.get_trainable_variables(), num_trainable_vars)

    # Sync hooks should be populated in the GANTrainOps.
    self.assertLen(train_ops.train_hooks, 2)
    for hook in train_ops.train_hooks:
      self.assertIsInstance(
          hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)
    sync_opts = [hook._sync_optimizer for hook in train_ops.train_hooks]
    self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))

    g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1)
    d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1)

    # Check that update op is run properly.
    global_step = training_util.get_or_create_global_step()
    with self.test_session(use_gpu=True) as sess:
      variables.global_variables_initializer().run()
      variables.local_variables_initializer().run()

      g_opt.chief_init_op.run()
      d_opt.chief_init_op.run()

      gstep_before = global_step.eval()

      # Start required queue runner for SyncReplicasOptimizer.
      coord = coordinator.Coordinator()
      g_threads = g_opt.get_chief_queue_runner().create_threads(sess, coord)
      d_threads = d_opt.get_chief_queue_runner().create_threads(sess, coord)

      g_sync_init_op.run()
      d_sync_init_op.run()

      train_ops.generator_train_op.eval()
      # Check that global step wasn't incremented.
      self.assertEqual(gstep_before, global_step.eval())

      train_ops.discriminator_train_op.eval()
      # Check that global step wasn't incremented.
      self.assertEqual(gstep_before, global_step.eval())

      coord.request_stop()
      coord.join(g_threads + d_threads)
예제 #3
0
def _make_prediction_gan_model(input_data, input_data_domain_label,
                               generator_fn, generator_scope):
    """Make a `StarGANModel` from just the generator."""
    # If `generator_fn` has an argument `mode`, pass mode to it.
    if 'mode' in inspect.getargspec(generator_fn).args:
        generator_fn = functools.partial(generator_fn,
                                         mode=model_fn_lib.ModeKeys.PREDICT)
    with variable_scope.variable_scope(generator_scope) as gen_scope:
        # pylint:disable=protected-access
        input_data = tfgan_train._convert_tensor_or_l_or_d(input_data)
        input_data_domain_label = tfgan_train._convert_tensor_or_l_or_d(
            input_data_domain_label)
        # pylint:enable=protected-access
        generated_data = generator_fn(input_data, input_data_domain_label)
    generator_variables = variable_lib.get_trainable_variables(gen_scope)

    return tfgan_tuples.StarGANModel(
        input_data=input_data,
        input_data_domain_label=None,
        generated_data=generated_data,
        generated_data_domain_target=input_data_domain_label,
        reconstructed_data=None,
        discriminator_input_data_source_predication=None,
        discriminator_generated_data_source_predication=None,
        discriminator_input_data_domain_predication=None,
        discriminator_generated_data_domain_predication=None,
        generator_variables=generator_variables,
        generator_scope=generator_scope,
        generator_fn=generator_fn,
        discriminator_variables=None,
        discriminator_scope=None,
        discriminator_fn=None)
def _make_prediction_gan_model(input_data, input_data_domain_label,
                               generator_fn, generator_scope):
  """Make a `StarGANModel` from just the generator."""
  # If `generator_fn` has an argument `mode`, pass mode to it.
  if 'mode' in inspect.getargspec(generator_fn).args:
    generator_fn = functools.partial(
        generator_fn, mode=model_fn_lib.ModeKeys.PREDICT)
  with variable_scope.variable_scope(generator_scope) as gen_scope:
    # pylint:disable=protected-access
    input_data = tfgan_train._convert_tensor_or_l_or_d(input_data)
    input_data_domain_label = tfgan_train._convert_tensor_or_l_or_d(
        input_data_domain_label)
    # pylint:enable=protected-access
    generated_data = generator_fn(input_data, input_data_domain_label)
  generator_variables = variable_lib.get_trainable_variables(gen_scope)

  return tfgan_tuples.StarGANModel(
      input_data=input_data,
      input_data_domain_label=None,
      generated_data=generated_data,
      generated_data_domain_target=input_data_domain_label,
      reconstructed_data=None,
      discriminator_input_data_source_predication=None,
      discriminator_generated_data_source_predication=None,
      discriminator_input_data_domain_predication=None,
      discriminator_generated_data_domain_predication=None,
      generator_variables=generator_variables,
      generator_scope=generator_scope,
      generator_fn=generator_fn,
      discriminator_variables=None,
      discriminator_scope=None,
      discriminator_fn=None)
예제 #5
0
def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope):
  """Make a `GANModel` from just the generator."""
  with variable_scope.variable_scope(generator_scope) as gen_scope:
    generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs)  # pylint:disable=protected-access
    generated_data = generator_fn(generator_inputs)
  generator_variables = variable_lib.get_trainable_variables(gen_scope)

  return tfgan_tuples.GANModel(
      generator_inputs,
      generated_data,
      generator_variables,
      gen_scope,
      generator_fn,
      real_data=None,
      discriminator_real_outputs=None,
      discriminator_gen_outputs=None,
      discriminator_variables=None,
      discriminator_scope=None,
      discriminator_fn=None)
def _make_prediction_gan_model(generator_inputs, generator_fn,
                               generator_scope):
    """Make a `GANModel` from just the generator."""
    with variable_scope.variable_scope(generator_scope) as gen_scope:
        generator_inputs = tfgan_train._convert_tensor_or_l_or_d(
            generator_inputs)  # pylint:disable=protected-access
        generated_data = generator_fn(generator_inputs)
    generator_variables = variable_lib.get_trainable_variables(gen_scope)

    return tfgan_tuples.GANModel(generator_inputs,
                                 generated_data,
                                 generator_variables,
                                 gen_scope,
                                 generator_fn,
                                 real_data=None,
                                 discriminator_real_outputs=None,
                                 discriminator_gen_outputs=None,
                                 discriminator_variables=None,
                                 discriminator_scope=None,
                                 discriminator_fn=None)
예제 #7
0
def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope):
  """Make a `GANModel` from just the generator."""
  # If `generator_fn` has an argument `mode`, pass mode to it.
  if 'mode' in inspect.getargspec(generator_fn).args:
    generator_fn = functools.partial(generator_fn,
                                     mode=model_fn_lib.ModeKeys.PREDICT)
  with variable_scope.variable_scope(generator_scope) as gen_scope:
    generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs)  # pylint:disable=protected-access
    generated_data = generator_fn(generator_inputs)
  generator_variables = variable_lib.get_trainable_variables(gen_scope)

  return tfgan_tuples.GANModel(
      generator_inputs,
      generated_data,
      generator_variables,
      gen_scope,
      generator_fn,
      real_data=None,
      discriminator_real_outputs=None,
      discriminator_gen_outputs=None,
      discriminator_variables=None,
      discriminator_scope=None,
      discriminator_fn=None)
예제 #8
0
def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope):
  """Make a `GANModel` from just the generator."""
  # If `generator_fn` has an argument `mode`, pass mode to it.
  if 'mode' in inspect.getargspec(generator_fn).args:
    generator_fn = functools.partial(generator_fn,
                                     mode=model_fn_lib.ModeKeys.PREDICT)
  with variable_scope.variable_scope(generator_scope) as gen_scope:
    generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs)  # pylint:disable=protected-access
    generated_data = generator_fn(generator_inputs)
  generator_variables = variable_lib.get_trainable_variables(gen_scope)

  return tfgan_tuples.GANModel(
      generator_inputs,
      generated_data,
      generator_variables,
      gen_scope,
      generator_fn,
      real_data=None,
      discriminator_real_outputs=None,
      discriminator_gen_outputs=None,
      discriminator_variables=None,
      discriminator_scope=None,
      discriminator_fn=None)
예제 #9
0
def infogan_model(
        # Lambdas defining models.
        generator_fn,
        discriminator_fn,
        # Real data and conditioning.
        real_data,
        unstructured_generator_inputs,
        structured_generator_inputs,
        # Optional scopes.
        generator_scope='Generator',
        discriminator_scope='Discriminator'):
    """Returns an InfoGAN model outputs and variables.

  See https://arxiv.org/abs/1606.03657 for more details.

  Args:
    generator_fn: A python lambda that takes a list of Tensors as inputs and
      returns the outputs of the GAN generator.
    discriminator_fn: A python lambda that takes `real_data`/`generated data`
      and `generator_inputs`. Outputs a 2-tuple of (logits, distribution_list).
      `logits` are in the range [-inf, inf], and `distribution_list` is a list
      of Tensorflow distributions representing the predicted noise distribution
      of the ith structure noise.
    real_data: A Tensor representing the real data.
    unstructured_generator_inputs: A list of Tensors to the generator.
      These tensors represent the unstructured noise or conditioning.
    structured_generator_inputs: A list of Tensors to the generator.
      These tensors must have high mutual information with the recognizer.
    generator_scope: Optional generator variable scope. Useful if you want to
      reuse a subgraph that has already been created.
    discriminator_scope: Optional discriminator variable scope. Useful if you
      want to reuse a subgraph that has already been created.

  Returns:
    An InfoGANModel namedtuple.

  Raises:
    ValueError: If the generator outputs a Tensor that isn't the same shape as
      `real_data`.
    ValueError: If the discriminator output is malformed.
  """
    # Create models
    with variable_scope.variable_scope(generator_scope) as gen_scope:
        unstructured_generator_inputs = _convert_tensor_or_l_or_d(
            unstructured_generator_inputs)
        structured_generator_inputs = _convert_tensor_or_l_or_d(
            structured_generator_inputs)
        generator_inputs = (unstructured_generator_inputs +
                            structured_generator_inputs)
        generated_data = generator_fn(generator_inputs)
    with variable_scope.variable_scope(discriminator_scope) as disc_scope:
        dis_gen_outputs, predicted_distributions = discriminator_fn(
            generated_data, generator_inputs)
    _validate_distributions(predicted_distributions,
                            structured_generator_inputs)
    with variable_scope.variable_scope(disc_scope, reuse=True):
        real_data = ops.convert_to_tensor(real_data)
        dis_real_outputs, _ = discriminator_fn(real_data, generator_inputs)

    if not generated_data.get_shape().is_compatible_with(
            real_data.get_shape()):
        raise ValueError(
            'Generator output shape (%s) must be the same shape as real data '
            '(%s).' % (generated_data.get_shape(), real_data.get_shape()))

    # Get model-specific variables.
    generator_variables = variables_lib.get_trainable_variables(gen_scope)
    discriminator_variables = variables_lib.get_trainable_variables(disc_scope)

    return namedtuples.InfoGANModel(
        generator_inputs,
        generated_data,
        generator_variables,
        gen_scope,
        generator_fn,
        real_data,
        dis_real_outputs,
        dis_gen_outputs,
        discriminator_variables,
        disc_scope,
        lambda x, y: discriminator_fn(x, y)[0],  # conform to non-InfoGAN API
        structured_generator_inputs,
        predicted_distributions)
예제 #10
0
def combine_adversarial_loss(main_loss,
                             adversarial_loss,
                             weight_factor=None,
                             gradient_ratio=None,
                             gradient_ratio_epsilon=1e-6,
                             variables=None,
                             scalar_summaries=True,
                             gradient_summaries=True,
                             scope=None):
  """Utility to combine main and adversarial losses.

  This utility combines the main and adversarial losses in one of two ways.
  1) Fixed coefficient on adversarial loss. Use `weight_factor` in this case.
  2) Fixed ratio of gradients. Use `gradient_ratio` in this case. This is often
    used to make sure both losses affect weights roughly equally, as in
    https://arxiv.org/pdf/1705.05823.

  One can optionally also visualize the scalar and gradient behavior of the
  losses.

  Args:
    main_loss: A floating scalar Tensor indicating the main loss.
    adversarial_loss: A floating scalar Tensor indication the adversarial loss.
    weight_factor: If not `None`, the coefficient by which to multiply the
      adversarial loss. Exactly one of this and `gradient_ratio` must be
      non-None.
    gradient_ratio: If not `None`, the ratio of the magnitude of the gradients.
      Specifically,
        gradient_ratio = grad_mag(main_loss) / grad_mag(adversarial_loss)
      Exactly one of this and `weight_factor` must be non-None.
    gradient_ratio_epsilon: An epsilon to add to the adversarial loss
      coefficient denominator, to avoid division-by-zero.
    variables: List of variables to calculate gradients with respect to. If not
      present, defaults to all trainable variables.
    scalar_summaries: Create scalar summaries of losses.
    gradient_summaries: Create gradient summaries of losses.
    scope: Optional name scope.

  Returns:
    A floating scalar Tensor indicating the desired combined loss.

  Raises:
    ValueError: Malformed input.
  """
  _validate_args([main_loss, adversarial_loss], weight_factor, gradient_ratio)
  if variables is None:
    variables = contrib_variables_lib.get_trainable_variables()

  with ops.name_scope(scope, 'adversarial_loss',
                      values=[main_loss, adversarial_loss]):
    # Compute gradients if we will need them.
    if gradient_summaries or gradient_ratio is not None:
      main_loss_grad_mag = _numerically_stable_global_norm(
          gradients_impl.gradients(main_loss, variables))
      adv_loss_grad_mag = _numerically_stable_global_norm(
          gradients_impl.gradients(adversarial_loss, variables))

    # Add summaries, if applicable.
    if scalar_summaries:
      summary.scalar('main_loss', main_loss)
      summary.scalar('adversarial_loss', adversarial_loss)
    if gradient_summaries:
      summary.scalar('main_loss_gradients', main_loss_grad_mag)
      summary.scalar('adversarial_loss_gradients', adv_loss_grad_mag)

    # Combine losses in the appropriate way.
    # If `weight_factor` is always `0`, avoid computing the adversarial loss
    # tensor entirely.
    if _used_weight((weight_factor, gradient_ratio)) == 0:
      final_loss = main_loss
    elif weight_factor is not None:
      final_loss = (main_loss +
                    array_ops.stop_gradient(weight_factor) * adversarial_loss)
    elif gradient_ratio is not None:
      grad_mag_ratio = main_loss_grad_mag / (
          adv_loss_grad_mag + gradient_ratio_epsilon)
      adv_coeff = grad_mag_ratio / gradient_ratio
      summary.scalar('adversarial_coefficient', adv_coeff)
      final_loss = (main_loss +
                    array_ops.stop_gradient(adv_coeff) * adversarial_loss)

  return final_loss
예제 #11
0
def combine_adversarial_loss(main_loss,
                             adversarial_loss,
                             weight_factor=None,
                             gradient_ratio=None,
                             gradient_ratio_epsilon=1e-6,
                             variables=None,
                             scalar_summaries=True,
                             gradient_summaries=True,
                             scope=None):
    """Utility to combine main and adversarial losses.

  This utility combines the main and adversarial losses in one of two ways.
  1) Fixed coefficient on adversarial loss. Use `weight_factor` in this case.
  2) Fixed ratio of gradients. Use `gradient_ratio` in this case. This is often
    used to make sure both losses affect weights roughly equally, as in
    https://arxiv.org/pdf/1705.05823.

  One can optionally also visualize the scalar and gradient behavior of the
  losses.

  Args:
    main_loss: A floating scalar Tensor indicating the main loss.
    adversarial_loss: A floating scalar Tensor indication the adversarial loss.
    weight_factor: If not `None`, the coefficient by which to multiply the
      adversarial loss. Exactly one of this and `gradient_ratio` must be
      non-None.
    gradient_ratio: If not `None`, the ratio of the magnitude of the gradients.
      Specifically, gradient_ratio = grad_mag(main_loss) /
      grad_mag(adversarial_loss) Exactly one of this and `weight_factor` must be
      non-None.
    gradient_ratio_epsilon: An epsilon to add to the adversarial loss
      coefficient denominator, to avoid division-by-zero.
    variables: List of variables to calculate gradients with respect to. If not
      present, defaults to all trainable variables.
    scalar_summaries: Create scalar summaries of losses.
    gradient_summaries: Create gradient summaries of losses.
    scope: Optional name scope.

  Returns:
    A floating scalar Tensor indicating the desired combined loss.

  Raises:
    ValueError: Malformed input.
  """
    _validate_args([main_loss, adversarial_loss], weight_factor,
                   gradient_ratio)
    if variables is None:
        variables = contrib_variables_lib.get_trainable_variables()

    with ops.name_scope(scope,
                        'adversarial_loss',
                        values=[main_loss, adversarial_loss]):
        # Compute gradients if we will need them.
        if gradient_summaries or gradient_ratio is not None:
            main_loss_grad_mag = _numerically_stable_global_norm(
                gradients_impl.gradients(main_loss, variables))
            adv_loss_grad_mag = _numerically_stable_global_norm(
                gradients_impl.gradients(adversarial_loss, variables))

        # Add summaries, if applicable.
        if scalar_summaries:
            summary.scalar('main_loss', main_loss)
            summary.scalar('adversarial_loss', adversarial_loss)
        if gradient_summaries:
            summary.scalar('main_loss_gradients', main_loss_grad_mag)
            summary.scalar('adversarial_loss_gradients', adv_loss_grad_mag)

        # Combine losses in the appropriate way.
        # If `weight_factor` is always `0`, avoid computing the adversarial loss
        # tensor entirely.
        if _used_weight((weight_factor, gradient_ratio)) == 0:
            final_loss = main_loss
        elif weight_factor is not None:
            final_loss = (
                main_loss +
                array_ops.stop_gradient(weight_factor) * adversarial_loss)
        elif gradient_ratio is not None:
            grad_mag_ratio = main_loss_grad_mag / (adv_loss_grad_mag +
                                                   gradient_ratio_epsilon)
            adv_coeff = grad_mag_ratio / gradient_ratio
            summary.scalar('adversarial_coefficient', adv_coeff)
            final_loss = (
                main_loss +
                array_ops.stop_gradient(adv_coeff) * adversarial_loss)

    return final_loss
예제 #12
0
파일: model.py 프로젝트: jenkspt/tsa-pggan
    def get_estimator_spec(self, real_features, real_class_labels, mode):

        with tf.variable_scope(self.D_scope) as dscope:
            is_training = True if mode == 'train' or mode == 'infer' else False
            real_score, real_logits = self.discriminator(real_features, is_training)
            tf.summary.scalar('accuracy', accuracy(real_class_labels, real_logits))

        if mode == tf.estimator.ModeKeys.PREDICT:

            return tf.estimator.EstimatorSpec(
                mode=mode,
                predictions=real_logits)

        with tf.variable_scope('Latent'):
            noise = tf.random_normal(
                    [tf.shape(real_class_labels)[0], self.nlatent], 
                    dtype=tf.float32,
                    name='Z')
            
            with tf.name_scope('fake_class_labels'):
                shape = tf.shape(real_class_labels)
                random = tf.random_uniform(shape=[shape[0], shape[1]], maxval=1)
                fake_class_labels = tf.cast(random < 0.096, dtype=tf.float32)


        with tf.variable_scope(self.G_scope):
            fake_features = self.generator(noise, fake_class_labels)

        with tf.variable_scope(self.D_scope, reuse=True):
            fake_score, fake_logits = self.discriminator(fake_features, is_training)
        
        with ops.name_scope('losses'):
            loss_tuple = gan_loss(
                discriminator_fn = self.discriminator,
                discriminator_scope = self.D_scope,
                real_features = real_features,
                fake_features = fake_features,
                disc_real_score = real_score,
                disc_fake_score = fake_score,
                disc_real_logits = real_logits,
                disc_fake_logits = fake_logits,
                real_class_labels = real_class_labels,
                fake_class_labels = fake_class_labels)
           
            total_loss = loss_tuple.discriminator_loss + loss_tuple.generator_loss

        generator_variables = variables_lib.get_trainable_variables(self.G_scope)
        discriminator_variables = variables_lib.get_trainable_variables(self.D_scope)

        G_train_op = tf.train.AdamOptimizer(
                self.learning_rate, 
                self.beta1, 
                self.beta2, 
                name='generator_optimizer').minimize(
                        loss_tuple.generator_loss, 
                        var_list=generator_variables)

        D_train_op = tf.train.AdamOptimizer(
                self.learning_rate, 
                self.beta1, 
                self.beta2,
                name='discriminator_optimizer').minimize(\
                        loss_tuple.discriminator_loss, 
                        var_list=discriminator_variables)
        
        train_hook = PGTrainHook(
                G_train_op, 
                D_train_op, 
                self.alpha, 
                self.res,
                self.stablize_increment,
                self.fade_increment,
                self.res_increment,
                self.reset_alpha)
        eval_metric_ops = get_eval_metric_ops(real_class_labels, real_logits)
        
        return tf.estimator.EstimatorSpec(
                loss=total_loss,
                mode=mode,
                train_op=self.global_step_inc,
                training_hooks = [train_hook],
                eval_metric_ops=None)
        """
예제 #13
0
def infogan_model(
    # Lambdas defining models.
    generator_fn,
    discriminator_fn,
    # Real data and conditioning.
    real_data,
    unstructured_generator_inputs,
    structured_generator_inputs,
    # Optional scopes.
    generator_scope='Generator',
    discriminator_scope='Discriminator'):
  """Returns an InfoGAN model outputs and variables.

  See https://arxiv.org/abs/1606.03657 for more details.

  Args:
    generator_fn: A python lambda that takes a list of Tensors as inputs and
      returns the outputs of the GAN generator.
    discriminator_fn: A python lambda that takes `real_data`/`generated data`
      and `generator_inputs`. Outputs a 2-tuple of (logits, distribution_list).
      `logits` are in the range [-inf, inf], and `distribution_list` is a list
      of Tensorflow distributions representing the predicted noise distribution
      of the ith structure noise.
    real_data: A Tensor representing the real data.
    unstructured_generator_inputs: A list of Tensors to the generator.
      These tensors represent the unstructured noise or conditioning.
    structured_generator_inputs: A list of Tensors to the generator.
      These tensors must have high mutual information with the recognizer.
    generator_scope: Optional generator variable scope. Useful if you want to
      reuse a subgraph that has already been created.
    discriminator_scope: Optional discriminator variable scope. Useful if you
      want to reuse a subgraph that has already been created.

  Returns:
    An InfoGANModel namedtuple.

  Raises:
    ValueError: If the generator outputs a Tensor that isn't the same shape as
      `real_data`.
    ValueError: If the discriminator output is malformed.
  """
  # Create models
  with variable_scope.variable_scope(generator_scope) as gen_scope:
    unstructured_generator_inputs = _convert_tensor_or_l_or_d(
        unstructured_generator_inputs)
    structured_generator_inputs = _convert_tensor_or_l_or_d(
        structured_generator_inputs)
    generator_inputs = (
        unstructured_generator_inputs + structured_generator_inputs)
    generated_data = generator_fn(generator_inputs)
  with variable_scope.variable_scope(discriminator_scope) as disc_scope:
    dis_gen_outputs, predicted_distributions = discriminator_fn(
        generated_data, generator_inputs)
  _validate_distributions(predicted_distributions, structured_generator_inputs)
  with variable_scope.variable_scope(disc_scope, reuse=True):
    real_data = ops.convert_to_tensor(real_data)
    dis_real_outputs, _ = discriminator_fn(real_data, generator_inputs)

  if not generated_data.get_shape().is_compatible_with(real_data.get_shape()):
    raise ValueError(
        'Generator output shape (%s) must be the same shape as real data '
        '(%s).' % (generated_data.get_shape(), real_data.get_shape()))

  # Get model-specific variables.
  generator_variables = variables_lib.get_trainable_variables(gen_scope)
  discriminator_variables = variables_lib.get_trainable_variables(
      disc_scope)

  return namedtuples.InfoGANModel(
      generator_inputs,
      generated_data,
      generator_variables,
      gen_scope,
      generator_fn,
      real_data,
      dis_real_outputs,
      dis_gen_outputs,
      discriminator_variables,
      disc_scope,
      lambda x, y: discriminator_fn(x, y)[0],  # conform to non-InfoGAN API
      structured_generator_inputs,
      predicted_distributions,
      discriminator_fn)
예제 #14
0
def stargan_model(generator_fn,
                  discriminator_fn,
                  input_data,
                  input_data_domain_label,
                  generator_scope='Generator',
                  discriminator_scope='Discriminator'):
    """Returns a StarGAN model outputs and variables.

  See https://arxiv.org/abs/1711.09020 for more details.

  Args:
    generator_fn: A python lambda that takes `inputs` and `targets` as inputs
      and returns 'generated_data' as the transformed version of `input` based
      on the `target`. `input` has shape (n, h, w, c), `targets` has shape (n,
      num_domains), and `generated_data` has the same shape as `input`.
    discriminator_fn: A python lambda that takes `inputs` and `num_domains` as
      inputs and returns a tuple (`source_prediction`, `domain_prediction`).
      `source_prediction` represents the source(real/generated) prediction by
      the discriminator, and `domain_prediction` represents the domain
      prediction/classification by the discriminator. `source_prediction` has
      shape (n) and `domain_prediction` has shape (n, num_domains).
    input_data: Tensor or a list of tensor of shape (n, h, w, c) representing
      the real input images.
    input_data_domain_label: Tensor or a list of tensor of shape (batch_size,
      num_domains) representing the domain label associated with the real
      images.
    generator_scope: Optional generator variable scope. Useful if you want to
      reuse a subgraph that has already been created.
    discriminator_scope: Optional discriminator variable scope. Useful if you
      want to reuse a subgraph that has already been created.

  Returns:
    StarGANModel nametuple return the tensor that are needed to compute the
    loss.

  Raises:
    ValueError: If the shape of `input_data_domain_label` is not rank 2 or fully
    defined in every dimensions.
  """

    # Convert to tensor.
    input_data = _convert_tensor_or_l_or_d(input_data)
    input_data_domain_label = _convert_tensor_or_l_or_d(
        input_data_domain_label)

    # Convert list of tensor to a single tensor if applicable.
    if isinstance(input_data, (list, tuple)):
        input_data = array_ops.concat(
            [ops.convert_to_tensor(x) for x in input_data], 0)
    if isinstance(input_data_domain_label, (list, tuple)):
        input_data_domain_label = array_ops.concat(
            [ops.convert_to_tensor(x) for x in input_data_domain_label], 0)

    # Get batch_size, num_domains from the labels.
    input_data_domain_label.shape.assert_has_rank(2)
    input_data_domain_label.shape.assert_is_fully_defined()
    batch_size, num_domains = input_data_domain_label.shape.as_list()

    # Transform input_data to random target domains.
    with variable_scope.variable_scope(generator_scope) as generator_scope:
        generated_data_domain_target = _generate_stargan_random_domain_target(
            batch_size, num_domains)
        generated_data = generator_fn(input_data, generated_data_domain_target)

    # Transform generated_data back to the original input_data domain.
    with variable_scope.variable_scope(generator_scope, reuse=True):
        reconstructed_data = generator_fn(generated_data,
                                          input_data_domain_label)

    # Predict source and domain for the generated_data using the discriminator.
    with variable_scope.variable_scope(
            discriminator_scope) as discriminator_scope:
        disc_gen_data_source_pred, disc_gen_data_domain_pred = discriminator_fn(
            generated_data, num_domains)

    # Predict source and domain for the input_data using the discriminator.
    with variable_scope.variable_scope(discriminator_scope, reuse=True):
        disc_input_data_source_pred, disc_input_data_domain_pred = discriminator_fn(
            input_data, num_domains)

    # Collect trainable variables from the neural networks.
    generator_variables = variables_lib.get_trainable_variables(
        generator_scope)
    discriminator_variables = variables_lib.get_trainable_variables(
        discriminator_scope)

    # Create the StarGANModel namedtuple.
    return namedtuples.StarGANModel(
        input_data=input_data,
        input_data_domain_label=input_data_domain_label,
        generated_data=generated_data,
        generated_data_domain_target=generated_data_domain_target,
        reconstructed_data=reconstructed_data,
        discriminator_input_data_source_predication=disc_input_data_source_pred,
        discriminator_generated_data_source_predication=
        disc_gen_data_source_pred,
        discriminator_input_data_domain_predication=disc_input_data_domain_pred,
        discriminator_generated_data_domain_predication=
        disc_gen_data_domain_pred,
        generator_variables=generator_variables,
        generator_scope=generator_scope,
        generator_fn=generator_fn,
        discriminator_variables=discriminator_variables,
        discriminator_scope=discriminator_scope,
        discriminator_fn=discriminator_fn)
예제 #15
0
def acgan_model(
        # Lambdas defining models.
        generator_fn,
        discriminator_fn,
        # Real data and conditioning.
        real_data,
        generator_inputs,
        one_hot_labels,
        # Optional scopes.
        generator_scope='Generator',
        discriminator_scope='Discriminator',
        check_shapes=True):
    """Returns an ACGANModel contains all the pieces needed for ACGAN training.

  The `acgan_model` is the same as the `gan_model` with the only difference
  being that the discriminator additionally outputs logits to classify the input
  (real or generated).
  Therefore, an explicit field holding one_hot_labels is necessary, as well as a
  discriminator_fn that outputs a 2-tuple holding the logits for real/fake and
  classification.

  See https://arxiv.org/abs/1610.09585 for more details.

  Args:
    generator_fn: A python lambda that takes `generator_inputs` as inputs and
      returns the outputs of the GAN generator.
    discriminator_fn: A python lambda that takes `real_data`/`generated data`
      and `generator_inputs`. Outputs a tuple consisting of two Tensors:
        (1) real/fake logits in the range [-inf, inf]
        (2) classification logits in the range [-inf, inf]
    real_data: A Tensor representing the real data.
    generator_inputs: A Tensor or list of Tensors to the generator. In the
      vanilla GAN case, this might be a single noise Tensor. In the conditional
      GAN case, this might be the generator's conditioning.
    one_hot_labels: A Tensor holding one-hot-labels for the batch. Needed by
      acgan_loss.
    generator_scope: Optional generator variable scope. Useful if you want to
      reuse a subgraph that has already been created.
    discriminator_scope: Optional discriminator variable scope. Useful if you
      want to reuse a subgraph that has already been created.
    check_shapes: If `True`, check that generator produces Tensors that are the
      same shape as real data. Otherwise, skip this check.

  Returns:
    A ACGANModel namedtuple.

  Raises:
    ValueError: If the generator outputs a Tensor that isn't the same shape as
      `real_data`.
    TypeError: If the discriminator does not output a tuple consisting of
    (discrimination logits, classification logits).
  """
    # Create models
    with variable_scope.variable_scope(generator_scope) as gen_scope:
        generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
        generated_data = generator_fn(generator_inputs)
    with variable_scope.variable_scope(discriminator_scope) as dis_scope:
        (discriminator_gen_outputs, discriminator_gen_classification_logits
         ) = _validate_acgan_discriminator_outputs(
             discriminator_fn(generated_data, generator_inputs))
    with variable_scope.variable_scope(dis_scope, reuse=True):
        real_data = ops.convert_to_tensor(real_data)
        (discriminator_real_outputs, discriminator_real_classification_logits
         ) = _validate_acgan_discriminator_outputs(
             discriminator_fn(real_data, generator_inputs))
    if check_shapes:
        if not generated_data.shape.is_compatible_with(real_data.shape):
            raise ValueError(
                'Generator output shape (%s) must be the same shape as real data '
                '(%s).' % (generated_data.shape, real_data.shape))

    # Get model-specific variables.
    generator_variables = variables_lib.get_trainable_variables(gen_scope)
    discriminator_variables = variables_lib.get_trainable_variables(dis_scope)

    return namedtuples.ACGANModel(generator_inputs, generated_data,
                                  generator_variables, gen_scope, generator_fn,
                                  real_data, discriminator_real_outputs,
                                  discriminator_gen_outputs,
                                  discriminator_variables, dis_scope,
                                  discriminator_fn, one_hot_labels,
                                  discriminator_real_classification_logits,
                                  discriminator_gen_classification_logits)
예제 #16
0
def gan_model(
        # Lambdas defining models.
        generator_fn,
        discriminator_fn,
        # Real data and conditioning.
        real_data,
        generator_inputs,
        # Optional scopes.
        generator_scope='Generator',
        discriminator_scope='Discriminator',
        # Options.
        check_shapes=True):
    """Returns GAN model outputs and variables.

  Args:
    generator_fn: A python lambda that takes `generator_inputs` as inputs and
      returns the outputs of the GAN generator.
    discriminator_fn: A python lambda that takes `real_data`/`generated data`
      and `generator_inputs`. Outputs a Tensor in the range [-inf, inf].
    real_data: A Tensor representing the real data.
    generator_inputs: A Tensor or list of Tensors to the generator. In the
      vanilla GAN case, this might be a single noise Tensor. In the conditional
      GAN case, this might be the generator's conditioning.
    generator_scope: Optional generator variable scope. Useful if you want to
      reuse a subgraph that has already been created.
    discriminator_scope: Optional discriminator variable scope. Useful if you
      want to reuse a subgraph that has already been created.
    check_shapes: If `True`, check that generator produces Tensors that are the
      same shape as real data. Otherwise, skip this check.

  Returns:
    A GANModel namedtuple.

  Raises:
    ValueError: If the generator outputs a Tensor that isn't the same shape as
      `real_data`.
  """
    # Create models
    with variable_scope.variable_scope(generator_scope) as gen_scope:
        generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
        generated_data = generator_fn(generator_inputs)
    with variable_scope.variable_scope(discriminator_scope) as dis_scope:
        discriminator_gen_outputs = discriminator_fn(generated_data,
                                                     generator_inputs)
    with variable_scope.variable_scope(dis_scope, reuse=True):
        real_data = ops.convert_to_tensor(real_data)
        discriminator_real_outputs = discriminator_fn(real_data,
                                                      generator_inputs)

    if check_shapes:
        if not generated_data.shape.is_compatible_with(real_data.shape):
            raise ValueError(
                'Generator output shape (%s) must be the same shape as real data '
                '(%s).' % (generated_data.shape, real_data.shape))

    # Get model-specific variables.
    generator_variables = variables_lib.get_trainable_variables(gen_scope)
    discriminator_variables = variables_lib.get_trainable_variables(dis_scope)

    return namedtuples.GANModel(generator_inputs, generated_data,
                                generator_variables, gen_scope, generator_fn,
                                real_data, discriminator_real_outputs,
                                discriminator_gen_outputs,
                                discriminator_variables, dis_scope,
                                discriminator_fn)
예제 #17
0
def acgan_model(
    # Lambdas defining models.
    generator_fn,
    discriminator_fn,
    # Real data and conditioning.
    real_data,
    generator_inputs,
    one_hot_labels,
    # Optional scopes.
    generator_scope='Generator',
    discriminator_scope='Discriminator',
    # Options.
    check_shapes=True):
  """Returns an ACGANModel contains all the pieces needed for ACGAN training.

  The `acgan_model` is the same as the `gan_model` with the only difference
  being that the discriminator additionally outputs logits to classify the input
  (real or generated).
  Therefore, an explicit field holding one_hot_labels is necessary, as well as a
  discriminator_fn that outputs a 2-tuple holding the logits for real/fake and
  classification.

  See https://arxiv.org/abs/1610.09585 for more details.

  Args:
    generator_fn: A python lambda that takes `generator_inputs` as inputs and
      returns the outputs of the GAN generator.
    discriminator_fn: A python lambda that takes `real_data`/`generated data`
      and `generator_inputs`. Outputs a tuple consisting of two Tensors:
        (1) real/fake logits in the range [-inf, inf]
        (2) classification logits in the range [-inf, inf]
    real_data: A Tensor representing the real data.
    generator_inputs: A Tensor or list of Tensors to the generator. In the
      vanilla GAN case, this might be a single noise Tensor. In the conditional
      GAN case, this might be the generator's conditioning.
    one_hot_labels: A Tensor holding one-hot-labels for the batch. Needed by
      acgan_loss.
    generator_scope: Optional generator variable scope. Useful if you want to
      reuse a subgraph that has already been created.
    discriminator_scope: Optional discriminator variable scope. Useful if you
      want to reuse a subgraph that has already been created.
    check_shapes: If `True`, check that generator produces Tensors that are the
      same shape as real data. Otherwise, skip this check.

  Returns:
    A ACGANModel namedtuple.

  Raises:
    ValueError: If the generator outputs a Tensor that isn't the same shape as
      `real_data`.
    TypeError: If the discriminator does not output a tuple consisting of
    (discrimination logits, classification logits).
  """
  # Create models
  with variable_scope.variable_scope(generator_scope) as gen_scope:
    generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
    generated_data = generator_fn(generator_inputs)
  with variable_scope.variable_scope(discriminator_scope) as dis_scope:
    (discriminator_gen_outputs, discriminator_gen_classification_logits
    ) = _validate_acgan_discriminator_outputs(
        discriminator_fn(generated_data, generator_inputs))
  with variable_scope.variable_scope(dis_scope, reuse=True):
    real_data = ops.convert_to_tensor(real_data)
    (discriminator_real_outputs, discriminator_real_classification_logits
    ) = _validate_acgan_discriminator_outputs(
        discriminator_fn(real_data, generator_inputs))
  if check_shapes:
    if not generated_data.shape.is_compatible_with(real_data.shape):
      raise ValueError(
          'Generator output shape (%s) must be the same shape as real data '
          '(%s).' % (generated_data.shape, real_data.shape))

  # Get model-specific variables.
  generator_variables = variables_lib.get_trainable_variables(gen_scope)
  discriminator_variables = variables_lib.get_trainable_variables(
      dis_scope)

  return namedtuples.ACGANModel(
      generator_inputs, generated_data, generator_variables, gen_scope,
      generator_fn, real_data, discriminator_real_outputs,
      discriminator_gen_outputs, discriminator_variables, dis_scope,
      discriminator_fn, one_hot_labels,
      discriminator_real_classification_logits,
      discriminator_gen_classification_logits)
예제 #18
0
def gan_model(
    # Lambdas defining models.
    generator_fn,
    discriminator_fn,
    # Real data and conditioning.
    real_data,
    generator_inputs,
    # Optional scopes.
    generator_scope='Generator',
    discriminator_scope='Discriminator',
    # Options.
    check_shapes=True):
  """Returns GAN model outputs and variables.

  Args:
    generator_fn: A python lambda that takes `generator_inputs` as inputs and
      returns the outputs of the GAN generator.
    discriminator_fn: A python lambda that takes `real_data`/`generated data`
      and `generator_inputs`. Outputs a Tensor in the range [-inf, inf].
    real_data: A Tensor representing the real data.
    generator_inputs: A Tensor or list of Tensors to the generator. In the
      vanilla GAN case, this might be a single noise Tensor. In the conditional
      GAN case, this might be the generator's conditioning.
    generator_scope: Optional generator variable scope. Useful if you want to
      reuse a subgraph that has already been created.
    discriminator_scope: Optional discriminator variable scope. Useful if you
      want to reuse a subgraph that has already been created.
    check_shapes: If `True`, check that generator produces Tensors that are the
      same shape as real data. Otherwise, skip this check.

  Returns:
    A GANModel namedtuple.

  Raises:
    ValueError: If the generator outputs a Tensor that isn't the same shape as
      `real_data`.
  """
  # Create models
  with variable_scope.variable_scope(generator_scope) as gen_scope:
    generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
    generated_data = generator_fn(generator_inputs)
  with variable_scope.variable_scope(discriminator_scope) as dis_scope:
    discriminator_gen_outputs = discriminator_fn(generated_data,
                                                 generator_inputs)
  with variable_scope.variable_scope(dis_scope, reuse=True):
    real_data = ops.convert_to_tensor(real_data)
    discriminator_real_outputs = discriminator_fn(real_data, generator_inputs)

  if check_shapes:
    if not generated_data.shape.is_compatible_with(real_data.shape):
      raise ValueError(
          'Generator output shape (%s) must be the same shape as real data '
          '(%s).' % (generated_data.shape, real_data.shape))

  # Get model-specific variables.
  generator_variables = variables_lib.get_trainable_variables(gen_scope)
  discriminator_variables = variables_lib.get_trainable_variables(dis_scope)

  return namedtuples.GANModel(
      generator_inputs,
      generated_data,
      generator_variables,
      gen_scope,
      generator_fn,
      real_data,
      discriminator_real_outputs,
      discriminator_gen_outputs,
      discriminator_variables,
      dis_scope,
      discriminator_fn)
예제 #19
0
def stargan_model(generator_fn,
                  discriminator_fn,
                  input_data,
                  input_data_domain_label,
                  generator_scope='Generator',
                  discriminator_scope='Discriminator'):
  """Returns a StarGAN model outputs and variables.

  See https://arxiv.org/abs/1711.09020 for more details.

  Args:
    generator_fn: A python lambda that takes `inputs` and `targets` as inputs
      and returns 'generated_data' as the transformed version of `input` based
      on the `target`. `input` has shape (n, h, w, c), `targets` has shape (n,
      num_domains), and `generated_data` has the same shape as `input`.
    discriminator_fn: A python lambda that takes `inputs` and `num_domains` as
      inputs and returns a tuple (`source_prediction`, `domain_prediction`).
      `source_prediction` represents the source(real/generated) prediction by
      the discriminator, and `domain_prediction` represents the domain
      prediction/classification by the discriminator. `source_prediction` has
      shape (n) and `domain_prediction` has shape (n, num_domains).
    input_data: Tensor or a list of tensor of shape (n, h, w, c) representing
      the real input images.
    input_data_domain_label: Tensor or a list of tensor of shape (batch_size,
      num_domains) representing the domain label associated with the real
      images.
    generator_scope: Optional generator variable scope. Useful if you want to
      reuse a subgraph that has already been created.
    discriminator_scope: Optional discriminator variable scope. Useful if you
      want to reuse a subgraph that has already been created.

  Returns:
    StarGANModel nametuple return the tensor that are needed to compute the
    loss.

  Raises:
    ValueError: If the shape of `input_data_domain_label` is not rank 2 or fully
    defined in every dimensions.
  """

  # Convert to tensor.
  input_data = _convert_tensor_or_l_or_d(input_data)
  input_data_domain_label = _convert_tensor_or_l_or_d(input_data_domain_label)

  # Convert list of tensor to a single tensor if applicable.
  if isinstance(input_data, (list, tuple)):
    input_data = array_ops.concat(
        [ops.convert_to_tensor(x) for x in input_data], 0)
  if isinstance(input_data_domain_label, (list, tuple)):
    input_data_domain_label = array_ops.concat(
        [ops.convert_to_tensor(x) for x in input_data_domain_label], 0)

  # Get batch_size, num_domains from the labels.
  input_data_domain_label.shape.assert_has_rank(2)
  input_data_domain_label.shape.assert_is_fully_defined()
  batch_size, num_domains = input_data_domain_label.shape.as_list()

  # Transform input_data to random target domains.
  with variable_scope.variable_scope(generator_scope) as generator_scope:
    generated_data_domain_target = _generate_stargan_random_domain_target(
        batch_size, num_domains)
    generated_data = generator_fn(input_data, generated_data_domain_target)

  # Transform generated_data back to the original input_data domain.
  with variable_scope.variable_scope(generator_scope, reuse=True):
    reconstructed_data = generator_fn(generated_data, input_data_domain_label)

  # Predict source and domain for the generated_data using the discriminator.
  with variable_scope.variable_scope(
      discriminator_scope) as discriminator_scope:
    disc_gen_data_source_pred, disc_gen_data_domain_pred = discriminator_fn(
        generated_data, num_domains)

  # Predict source and domain for the input_data using the discriminator.
  with variable_scope.variable_scope(discriminator_scope, reuse=True):
    disc_input_data_source_pred, disc_input_data_domain_pred = discriminator_fn(
        input_data, num_domains)

  # Collect trainable variables from the neural networks.
  generator_variables = variables_lib.get_trainable_variables(generator_scope)
  discriminator_variables = variables_lib.get_trainable_variables(
      discriminator_scope)

  # Create the StarGANModel namedtuple.
  return namedtuples.StarGANModel(
      input_data=input_data,
      input_data_domain_label=input_data_domain_label,
      generated_data=generated_data,
      generated_data_domain_target=generated_data_domain_target,
      reconstructed_data=reconstructed_data,
      discriminator_input_data_source_predication=disc_input_data_source_pred,
      discriminator_generated_data_source_predication=disc_gen_data_source_pred,
      discriminator_input_data_domain_predication=disc_input_data_domain_pred,
      discriminator_generated_data_domain_predication=disc_gen_data_domain_pred,
      generator_variables=generator_variables,
      generator_scope=generator_scope,
      generator_fn=generator_fn,
      discriminator_variables=discriminator_variables,
      discriminator_scope=discriminator_scope,
      discriminator_fn=discriminator_fn)