def _make_prediction_gan_model(input_data, input_data_domain_label,
                               generator_fn, generator_scope):
  """Make a `StarGANModel` from just the generator."""
  # If `generator_fn` has an argument `mode`, pass mode to it.
  if 'mode' in inspect.getargspec(generator_fn).args:
    generator_fn = functools.partial(
        generator_fn, mode=model_fn_lib.ModeKeys.PREDICT)
  with variable_scope.variable_scope(generator_scope) as gen_scope:
    # pylint:disable=protected-access
    input_data = tfgan_train._convert_tensor_or_l_or_d(input_data)
    input_data_domain_label = tfgan_train._convert_tensor_or_l_or_d(
        input_data_domain_label)
    # pylint:enable=protected-access
    generated_data = generator_fn(input_data, input_data_domain_label)
  generator_variables = variable_lib.get_trainable_variables(gen_scope)

  return tfgan_tuples.StarGANModel(
      input_data=input_data,
      input_data_domain_label=None,
      generated_data=generated_data,
      generated_data_domain_target=input_data_domain_label,
      reconstructed_data=None,
      discriminator_input_data_source_predication=None,
      discriminator_generated_data_source_predication=None,
      discriminator_input_data_domain_predication=None,
      discriminator_generated_data_domain_predication=None,
      generator_variables=generator_variables,
      generator_scope=generator_scope,
      generator_fn=generator_fn,
      discriminator_variables=None,
      discriminator_scope=None,
      discriminator_fn=None)
예제 #2
0
def _make_prediction_gan_model(input_data, input_data_domain_label,
                               generator_fn, generator_scope):
    """Make a `StarGANModel` from just the generator."""
    # If `generator_fn` has an argument `mode`, pass mode to it.
    if 'mode' in inspect.getargspec(generator_fn).args:
        generator_fn = functools.partial(generator_fn,
                                         mode=model_fn_lib.ModeKeys.PREDICT)
    with variable_scope.variable_scope(generator_scope) as gen_scope:
        # pylint:disable=protected-access
        input_data = tfgan_train._convert_tensor_or_l_or_d(input_data)
        input_data_domain_label = tfgan_train._convert_tensor_or_l_or_d(
            input_data_domain_label)
        # pylint:enable=protected-access
        generated_data = generator_fn(input_data, input_data_domain_label)
    generator_variables = variable_lib.get_trainable_variables(gen_scope)

    return tfgan_tuples.StarGANModel(
        input_data=input_data,
        input_data_domain_label=None,
        generated_data=generated_data,
        generated_data_domain_target=input_data_domain_label,
        reconstructed_data=None,
        discriminator_input_data_source_predication=None,
        discriminator_generated_data_source_predication=None,
        discriminator_input_data_domain_predication=None,
        discriminator_generated_data_domain_predication=None,
        generator_variables=generator_variables,
        generator_scope=generator_scope,
        generator_fn=generator_fn,
        discriminator_variables=None,
        discriminator_scope=None,
        discriminator_fn=None)
예제 #3
0
def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope):
  """Make a `GANModel` from just the generator."""
  with variable_scope.variable_scope(generator_scope) as gen_scope:
    generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs)  # pylint:disable=protected-access
    generated_data = generator_fn(generator_inputs)
  generator_variables = variable_lib.get_trainable_variables(gen_scope)

  return tfgan_tuples.GANModel(
      generator_inputs,
      generated_data,
      generator_variables,
      gen_scope,
      generator_fn,
      real_data=None,
      discriminator_real_outputs=None,
      discriminator_gen_outputs=None,
      discriminator_variables=None,
      discriminator_scope=None,
      discriminator_fn=None)
def _make_prediction_gan_model(generator_inputs, generator_fn,
                               generator_scope):
    """Make a `GANModel` from just the generator."""
    with variable_scope.variable_scope(generator_scope) as gen_scope:
        generator_inputs = tfgan_train._convert_tensor_or_l_or_d(
            generator_inputs)  # pylint:disable=protected-access
        generated_data = generator_fn(generator_inputs)
    generator_variables = variable_lib.get_trainable_variables(gen_scope)

    return tfgan_tuples.GANModel(generator_inputs,
                                 generated_data,
                                 generator_variables,
                                 gen_scope,
                                 generator_fn,
                                 real_data=None,
                                 discriminator_real_outputs=None,
                                 discriminator_gen_outputs=None,
                                 discriminator_variables=None,
                                 discriminator_scope=None,
                                 discriminator_fn=None)
예제 #5
0
def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope):
  """Make a `GANModel` from just the generator."""
  # If `generator_fn` has an argument `mode`, pass mode to it.
  if 'mode' in inspect.getargspec(generator_fn).args:
    generator_fn = functools.partial(generator_fn,
                                     mode=model_fn_lib.ModeKeys.PREDICT)
  with variable_scope.variable_scope(generator_scope) as gen_scope:
    generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs)  # pylint:disable=protected-access
    generated_data = generator_fn(generator_inputs)
  generator_variables = variable_lib.get_trainable_variables(gen_scope)

  return tfgan_tuples.GANModel(
      generator_inputs,
      generated_data,
      generator_variables,
      gen_scope,
      generator_fn,
      real_data=None,
      discriminator_real_outputs=None,
      discriminator_gen_outputs=None,
      discriminator_variables=None,
      discriminator_scope=None,
      discriminator_fn=None)
예제 #6
0
def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope):
  """Make a `GANModel` from just the generator."""
  # If `generator_fn` has an argument `mode`, pass mode to it.
  if 'mode' in inspect.getargspec(generator_fn).args:
    generator_fn = functools.partial(generator_fn,
                                     mode=model_fn_lib.ModeKeys.PREDICT)
  with variable_scope.variable_scope(generator_scope) as gen_scope:
    generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs)  # pylint:disable=protected-access
    generated_data = generator_fn(generator_inputs)
  generator_variables = variable_lib.get_trainable_variables(gen_scope)

  return tfgan_tuples.GANModel(
      generator_inputs,
      generated_data,
      generator_variables,
      gen_scope,
      generator_fn,
      real_data=None,
      discriminator_real_outputs=None,
      discriminator_gen_outputs=None,
      discriminator_variables=None,
      discriminator_scope=None,
      discriminator_fn=None)