def get_estimator_spec(mode, gan_model, loss_fn, get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, get_hooks_fn=None): """Get the EstimatorSpec for the current mode.""" if mode == tf.estimator.ModeKeys.PREDICT: estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions=gan_model.generated_data) else: gan_loss = loss_fn(gan_model) if mode == tf.estimator.ModeKeys.EVAL: estimator_spec = _get_eval_estimator_spec(gan_model, gan_loss, get_eval_metric_ops_fn) else: # tf.estimator.ModeKeys.TRAIN: def _maybe_callable(x): return x() if callable(x) else x gopt = _maybe_callable(generator_optimizer) dopt = _maybe_callable(discriminator_optimizer) get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks() estimator_spec = _get_train_estimator_spec(gan_model, gan_loss, gopt, dopt, get_hooks_fn) return estimator_spec
def get_train_estimator_spec( gan_model, gan_loss, optimizers, get_hooks_fn, train_op_fn=tfgan_train.gan_train_ops, is_chief=True): """Return an EstimatorSpec for the train case.""" get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks() optimizers = _maybe_construct_optimizers(optimizers) train_ops = train_op_fn(gan_model, gan_loss, optimizers.gopt, optimizers.dopt, is_chief=is_chief) training_hooks = get_hooks_fn(train_ops) return tf.estimator.EstimatorSpec( loss=gan_loss.discriminator_loss, mode=tf.estimator.ModeKeys.TRAIN, train_op=train_ops.global_step_inc_op, training_hooks=training_hooks)
def estimator_model_fn(cls, hparams, features, labels, mode, config=None, params=None, decode_hparams=None, use_tpu=False): if mode not in [ model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL, model_fn_lib.ModeKeys.PREDICT ]: raise ValueError('Mode not recognized: %s' % mode) if mode is model_fn_lib.ModeKeys.TRAIN: is_training = True else: is_training = False hparams = hparams_lib.copy_hparams(hparams) # Instantiate model data_parallelism = None if not use_tpu and config: data_parallelism = config.data_parallelism reuse = tf.get_variable_scope().reuse # Instantiate model self = cls(hparams, mode, data_parallelism=data_parallelism, decode_hparams=decode_hparams, _reuse=reuse) generator_inputs = self.sample_noise() # rename inputs for clarity real_data = features['inputs'] img_shape = common_layers.shape_list(real_data)[1:4] real_data.set_shape([hparams.batch_size] + img_shape) # To satify the TFGAN API setting real data to none on predict if mode == tf.estimator.ModeKeys.PREDICT: real_data = None optimizers = Optimizers( tf.compat.v1.train.AdamOptimizer(hparams.generator_lr, hparams.beta1), tf.compat.v1.train.AdamOptimizer(hparams.discriminator_lr, hparams.beta1)) # Creates tfhub modules for both generator and discriminator def make_discriminator_spec(): input_layer = tf.placeholder(tf.float32, shape=[None] + img_shape) disc_output = self.discriminator(input_layer, None, mode) hub.add_signature(inputs=input_layer, outputs=disc_output) disc_spec = hub.create_module_spec(make_discriminator_spec) def make_generator_spec(): input_layer = tf.placeholder( tf.float32, shape=[None] + common_layers.shape_list(generator_inputs)[1:]) gen_output = self.generator(input_layer, mode) hub.add_signature(inputs=input_layer, outputs=gen_output) gen_spec = hub.create_module_spec(make_generator_spec) # Create the modules discriminator_module = hub.Module(disc_spec, name="Discriminator_Module", trainable=True) generator_module = hub.Module(gen_spec, name="Generator_Module", trainable=True) # Wraps the modules into functions expected by TF-GAN def generator(code, mode): p = hparams out = generator_module(code) shape = common_layers.shape_list(out) # Applying convolution by PSF convolution if p.apply_psf and 'psf' in features: out = convolve(out, tf.cast(features['psf'][..., 0], tf.complex64)) # Adds noise according to the provided power spectrum noise = tf.spectral.rfft2d(tf.random_normal(out.get_shape()[:3])) thresholded_ps = tf.where(features['ps'] >= 9, tf.zeros_like(features['ps']), tf.sqrt(tf.exp(features['ps']))) noise = noise * tf.cast(thresholded_ps, tf.complex64) out = out + tf.expand_dims(tf.spectral.irfft2d(noise), axis=-1) return out discriminator = lambda image, conditioning, mode: discriminator_module( image) # Make GANModel, which encapsulates the GAN model architectures. gan_model = get_gan_model(mode, generator, discriminator, real_data, generator_inputs, add_summaries=self.summaries) # Make GANLoss, which encapsulates the losses. if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]: gan_loss = tfgan_train.gan_loss(gan_model, self.generator_loss, self.discriminator_loss, add_summaries=True) # Make the EstimatorSpec, which incorporates the GANModel, losses, eval # metrics, and optimizers (if required). if mode == tf.estimator.ModeKeys.TRAIN: get_hooks_fn = tfgan_train.get_sequential_train_hooks( namedtuples.GANTrainSteps(hparams.gen_steps, hparams.disc_steps)) estimator_spec = get_train_estimator_spec(gan_model, gan_loss, optimizers, get_hooks_fn, is_chief=True) elif mode == tf.estimator.ModeKeys.EVAL: estimator_spec = get_eval_estimator_spec(gan_model, gan_loss) else: # tf.estimator.ModeKeys.PREDICT # Register hub modules for export hub.register_module_for_export(generator_module, "generator") hub.register_module_for_export(discriminator_module, "discriminator") estimator_spec = get_predict_estimator_spec(gan_model) return estimator_spec
def get_estimator_spec(mode, gan_model, loss_fn, get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, get_hooks_fn=None, cls_model=None, cls_checkpoint=None): """Get the EstimatorSpec for the current mode.""" if PREDICT_FLAG: predictions = { 'input_data': gan_model.input_data, 'input_data_domain_label': gan_model.input_data_domain_label, 'generated_data': gan_model.generated_data, 'generated_data_domain_target': gan_model.generated_data_domain_target, 'reconstructed_data': gan_model.reconstructed_data, 'discriminator_input_data_source_predication': gan_model.discriminator_input_data_source_predication, 'discriminator_generated_data_source_predication': gan_model.discriminator_generated_data_source_predication, 'discriminator_input_data_domain_predication': gan_model.discriminator_input_data_domain_predication, 'discriminator_generated_data_domain_predication': gan_model.discriminator_generated_data_domain_predication, } else: predictions = gan_model.generated_data if mode == tf.estimator.ModeKeys.PREDICT: estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions=predictions # predictions=gan_model.generated_data # predictions={ # 'input_data': gan_model.input_data, # 'input_data_domain_label': gan_model.input_data_domain_label, # 'generated_data': gan_model.generated_data, # 'generated_data_domain_target': gan_model.generated_data_domain_target, # 'reconstructed_data': gan_model.reconstructed_data, # 'discriminator_input_data_source_predication': gan_model.discriminator_input_data_source_predication, # 'discriminator_generated_data_source_predication': gan_model.discriminator_generated_data_source_predication, # 'discriminator_input_data_domain_predication': gan_model.discriminator_input_data_domain_predication, # 'discriminator_generated_data_domain_predication': gan_model.discriminator_generated_data_domain_predication, # } ) else: gan_loss = loss_fn(gan_model) if mode == tf.estimator.ModeKeys.EVAL: estimator_spec = _get_eval_estimator_spec(gan_model, gan_loss, get_eval_metric_ops_fn) else: # tf.estimator.ModeKeys.TRAIN: def _maybe_callable(x): return x() if callable(x) else x gopt = _maybe_callable(generator_optimizer) dopt = _maybe_callable(discriminator_optimizer) get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks( ) estimator_spec = _get_train_estimator_spec( gan_model, gan_loss, gopt, dopt, get_hooks_fn, cls_model=cls_model, cls_checkpoint=cls_checkpoint) return estimator_spec
def gan_train(train_ops, logdir, get_hooks_fn=get_sequential_train_hooks(), master='', is_chief=True, scaffold=None, hooks=None, chief_only_hooks=None, save_checkpoint_secs=USE_DEFAULT, save_summaries_steps=USE_DEFAULT, save_checkpoint_steps=USE_DEFAULT, max_wait_secs=7200, config=None): """A wrapper around `contrib.training.train` that uses GAN hooks. Args: save_checkpoint_steps: Checkpoint steps to train_ops: A GANTrainOps named tuple. logdir: The directory where the graph and checkpoints are saved. get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list of hooks. master: The URL of the master. is_chief: Specifies whether or not the training is being run by the primary replica during replica training. scaffold: An tf.train.Scaffold instance. hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the training loop. chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run inside the training loop for the chief trainer only. save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved using a default checkpoint saver. If `save_checkpoint_secs` is set to `None`, then the default checkpoint saver isn't used. save_summaries_steps: The frequency, in number of global steps, that the summaries are written to disk using a default summary saver. If `save_summaries_steps` is set to `None`, then the default summary saver isn't used. max_wait_secs: Maximum time workers should wait for the session to become available. This should be kept relatively short to help detect incorrect code, but sometimes may need to be increased if the chief takes a while to start up. config: An instance of `tf.ConfigProto`. Returns: Output of the call to `training.train`. """ _validate_gan_train_inputs(logdir, is_chief, save_summaries_steps, save_checkpoint_secs) new_hooks = get_hooks_fn(train_ops) if hooks is not None: hooks = list(hooks) + list(new_hooks) else: hooks = new_hooks with tf.compat.v1.train.MonitoredTrainingSession( master=master, is_chief=is_chief, checkpoint_dir=logdir, scaffold=scaffold, hooks=hooks, chief_only_hooks=chief_only_hooks, save_checkpoint_secs=save_checkpoint_secs, save_summaries_steps=save_summaries_steps, save_checkpoint_steps=save_checkpoint_steps, config=config, max_wait_secs=max_wait_secs) as session: gstep = None while not session.should_stop(): gstep = session.run(train_ops.global_step_inc_op) return gstep