Esempio n. 1
0
    def model_fn(features, labels, mode, params):
        """Model function defining an inpainting estimator."""
        batch_size = params['batch_size']
        z_shape = [batch_size] + params['z_shape']
        add_summaries = params['add_summaries']
        input_clip = params['input_clip']

        z = tf.compat.v1.get_variable(
            name=INPUT_NAME,
            initializer=tf.random.truncated_normal(z_shape),
            constraint=lambda x: tf.clip_by_value(x, -input_clip, input_clip),
            use_resource=False)

        generator = functools.partial(generator_fn, mode=mode)
        discriminator = functools.partial(discriminator_fn, mode=mode)
        gan_model = tfgan_train.gan_model(generator_fn=generator,
                                          discriminator_fn=discriminator,
                                          real_data=labels,
                                          generator_inputs=z,
                                          check_shapes=False)

        loss = loss_fn(gan_model, features, labels, add_summaries)

        # Use a variable scope to make sure that estimator variables dont cause
        # save/load problems when restoring from ckpts.
        with tf.compat.v1.variable_scope(OPTIMIZER_NAME):
            opt = optimizer(learning_rate=params['learning_rate'],
                            **params['opt_kwargs'])
            train_op = opt.minimize(
                loss=loss,
                global_step=tf.compat.v1.train.get_or_create_global_step(),
                var_list=[z])

        if add_summaries:
            z_grads = tf.gradients(ys=loss, xs=z)
            tf.compat.v1.summary.scalar('z_loss/z_grads',
                                        tf.linalg.global_norm(z_grads))
            tf.compat.v1.summary.scalar('z_loss/loss', loss)

        return tf.estimator.EstimatorSpec(mode=mode,
                                          predictions=gan_model.generated_data,
                                          loss=loss,
                                          train_op=train_op)
Esempio n. 2
0
def _make_gan_model(generator_fn, discriminator_fn, real_data,
                    generator_inputs, generator_scope, discriminator_scope,
                    add_summaries, mode):
    """Construct a `GANModel`, and optionally pass in `mode`."""
    # If network functions have an argument `mode`, pass mode to it.
    if 'mode' in inspect.getargspec(generator_fn).args:
        generator_fn = functools.partial(generator_fn, mode=mode)
    if 'mode' in inspect.getargspec(discriminator_fn).args:
        discriminator_fn = functools.partial(discriminator_fn, mode=mode)
    gan_model = tfgan_train.gan_model(generator_fn,
                                      discriminator_fn,
                                      real_data,
                                      generator_inputs,
                                      generator_scope=generator_scope,
                                      discriminator_scope=discriminator_scope,
                                      check_shapes=False)
    if add_summaries:
        if not isinstance(add_summaries, (tuple, list)):
            add_summaries = [add_summaries]
        with tf.compat.v1.name_scope(''):  # Clear name scope.
            for summary_type in add_summaries:
                summary_type_map[summary_type](gan_model)

    return gan_model