def _get_loss_for_train(gan_model, loss_fns, gan_loss_kwargs, add_summaries): kwargs = gan_loss_kwargs or {} return tfgan_train.gan_loss(gan_model, loss_fns.g_loss_fn, loss_fns.d_loss_fn, add_summaries=add_summaries, **kwargs)
def _model_fn(features, labels, mode, params): """GANEstimator model function.""" if mode not in [ tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT ]: raise ValueError('Mode not recognized: %s' % mode) real_data = labels # rename inputs for clarity generator_inputs = features # rename inputs for clarity # Make GANModel, which encapsulates the GAN model architectures. gan_model = get_gan_model(mode, generator_fn, discriminator_fn, real_data, generator_inputs, add_summaries) # Make GANLoss, which encapsulates the losses. if mode in [ tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL ]: gan_loss_kwargs = extract_gan_loss_args_from_params( params) or {} gan_loss = tfgan_train.gan_loss( gan_model, generator_loss_fn, discriminator_loss_fn, add_summaries=use_loss_summaries, **gan_loss_kwargs) # Make the EstimatorSpec, which incorporates the GANModel, losses, eval # metrics, and optimizers (if required). if mode == tf.estimator.ModeKeys.TRAIN: estimator_spec = get_train_estimator_spec(gan_model, gan_loss, optimizers, get_hooks_fn, is_chief=is_chief) elif mode == tf.estimator.ModeKeys.EVAL: estimator_spec = get_eval_estimator_spec( gan_model, gan_loss, get_eval_metric_ops_fn) else: # tf.estimator.ModeKeys.PREDICT estimator_spec = get_predict_estimator_spec(gan_model) return estimator_spec
def get_eval_estimator_spec(gan_model_fns, loss_fns, gan_loss_kwargs, get_eval_metric_ops_fn, add_summaries): """Estimator spec for eval case.""" assert len(gan_model_fns) == 1, ( '`gan_models` must be length 1 in eval mode. Got length %d' % len(gan_model_fns)) gan_model = gan_model_fns[0]() _maybe_add_summaries(gan_model, add_summaries) # Eval losses for metrics must preserve batch dimension. kwargs = gan_loss_kwargs or {} gan_loss_no_reduction = tfgan_train.gan_loss( gan_model, loss_fns.g_loss_fn, loss_fns.d_loss_fn, add_summaries=add_summaries, reduction=tf.compat.v1.losses.Reduction.NONE, **kwargs) # Make the metric function and tensor names. if get_eval_metric_ops_fn is not None: metric_fn = _make_custom_metric_fn(get_eval_metric_ops_fn) tensors_for_metric_fn = _make_custom_metric_tensors( gan_model, gan_loss_no_reduction) else: metric_fn = _make_default_metric_fn() tensors_for_metric_fn = _make_default_metric_tensors(gan_loss_no_reduction) scalar_loss = tf.compat.v1.losses.compute_weighted_loss( gan_loss_no_reduction.discriminator_loss, loss_collection=None, reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS) return contrib.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.EVAL, predictions=_predictions_from_generator_output(gan_model.generated_data), loss=scalar_loss, eval_metrics=(metric_fn, tensors_for_metric_fn))
def estimator_model_fn(cls, hparams, features, labels, mode, config=None, params=None, decode_hparams=None, use_tpu=False): if mode not in [ model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL, model_fn_lib.ModeKeys.PREDICT ]: raise ValueError('Mode not recognized: %s' % mode) if mode is model_fn_lib.ModeKeys.TRAIN: is_training = True else: is_training = False hparams = hparams_lib.copy_hparams(hparams) # Instantiate model data_parallelism = None if not use_tpu and config: data_parallelism = config.data_parallelism reuse = tf.get_variable_scope().reuse # Instantiate model self = cls(hparams, mode, data_parallelism=data_parallelism, decode_hparams=decode_hparams, _reuse=reuse) generator_inputs = self.sample_noise() # rename inputs for clarity real_data = features['inputs'] img_shape = common_layers.shape_list(real_data)[1:4] real_data.set_shape([hparams.batch_size] + img_shape) # To satify the TFGAN API setting real data to none on predict if mode == tf.estimator.ModeKeys.PREDICT: real_data = None optimizers = Optimizers( tf.compat.v1.train.AdamOptimizer(hparams.generator_lr, hparams.beta1), tf.compat.v1.train.AdamOptimizer(hparams.discriminator_lr, hparams.beta1)) # Creates tfhub modules for both generator and discriminator def make_discriminator_spec(): input_layer = tf.placeholder(tf.float32, shape=[None] + img_shape) disc_output = self.discriminator(input_layer, None, mode) hub.add_signature(inputs=input_layer, outputs=disc_output) disc_spec = hub.create_module_spec(make_discriminator_spec) def make_generator_spec(): input_layer = tf.placeholder( tf.float32, shape=[None] + common_layers.shape_list(generator_inputs)[1:]) gen_output = self.generator(input_layer, mode) hub.add_signature(inputs=input_layer, outputs=gen_output) gen_spec = hub.create_module_spec(make_generator_spec) # Create the modules discriminator_module = hub.Module(disc_spec, name="Discriminator_Module", trainable=True) generator_module = hub.Module(gen_spec, name="Generator_Module", trainable=True) # Wraps the modules into functions expected by TF-GAN def generator(code, mode): p = hparams out = generator_module(code) shape = common_layers.shape_list(out) # Applying convolution by PSF convolution if p.apply_psf and 'psf' in features: out = convolve(out, tf.cast(features['psf'][..., 0], tf.complex64)) # Adds noise according to the provided power spectrum noise = tf.spectral.rfft2d(tf.random_normal(out.get_shape()[:3])) thresholded_ps = tf.where(features['ps'] >= 9, tf.zeros_like(features['ps']), tf.sqrt(tf.exp(features['ps']))) noise = noise * tf.cast(thresholded_ps, tf.complex64) out = out + tf.expand_dims(tf.spectral.irfft2d(noise), axis=-1) return out discriminator = lambda image, conditioning, mode: discriminator_module( image) # Make GANModel, which encapsulates the GAN model architectures. gan_model = get_gan_model(mode, generator, discriminator, real_data, generator_inputs, add_summaries=self.summaries) # Make GANLoss, which encapsulates the losses. if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]: gan_loss = tfgan_train.gan_loss(gan_model, self.generator_loss, self.discriminator_loss, add_summaries=True) # Make the EstimatorSpec, which incorporates the GANModel, losses, eval # metrics, and optimizers (if required). if mode == tf.estimator.ModeKeys.TRAIN: get_hooks_fn = tfgan_train.get_sequential_train_hooks( namedtuples.GANTrainSteps(hparams.gen_steps, hparams.disc_steps)) estimator_spec = get_train_estimator_spec(gan_model, gan_loss, optimizers, get_hooks_fn, is_chief=True) elif mode == tf.estimator.ModeKeys.EVAL: estimator_spec = get_eval_estimator_spec(gan_model, gan_loss) else: # tf.estimator.ModeKeys.PREDICT # Register hub modules for export hub.register_module_for_export(generator_module, "generator") hub.register_module_for_export(discriminator_module, "discriminator") estimator_spec = get_predict_estimator_spec(gan_model) return estimator_spec
def get_eval_estimator_spec(gan_model_fns, loss_fns, gan_loss_kwargs, prepare_arguments_for_eval_metric_fn, get_eval_metric_ops_fn, add_summaries): """Estimator spec for eval case.""" assert len(gan_model_fns) == 1, ( '`gan_models` must be length 1 in eval mode. Got length %d' % len(gan_model_fns)) gan_model = gan_model_fns[0]() _maybe_add_summaries(gan_model, add_summaries) # Eval losses for metrics must preserve batch dimension. kwargs = gan_loss_kwargs or {} gan_loss_no_reduction = tfgan_train.gan_loss( gan_model, loss_fns.g_loss_fn, loss_fns.d_loss_fn, add_summaries=add_summaries, reduction=tf.compat.v1.losses.Reduction.NONE, **kwargs) if prepare_arguments_for_eval_metric_fn is None: # Set the default prepare_arguments_for_eval_metric_fn value: a function # that returns its arguments in a dict. prepare_arguments_for_eval_metric_fn = lambda **kwargs: kwargs default_metric_fn = _make_default_metric_fn() # Prepare tensors needed for calculating the metrics: the first element in # `tensors_for_metric_fn` holds a dict containing the arguments for # `default_metric_fn`, and the second element holds a dict for arguments for # `get_eval_metric_ops_fn` (if it is not None). tensors_for_metric_fn = [ _make_default_metric_tensors(gan_loss_no_reduction) ] if get_eval_metric_ops_fn is not None: tensors_for_metric_fn.append( prepare_arguments_for_eval_metric_fn( **_make_custom_metric_tensors(gan_model))) scalar_loss = tf.compat.v1.losses.compute_weighted_loss( gan_loss_no_reduction.discriminator_loss, loss_collection=None, reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS) # TPUEstimatorSpec.eval_metrics expects a function and a list of tensors, # however, some sturctures in tensors_for_metric_fn might be dictionaries # (e.g., generator_inputs and real_data). We therefore need to flatten # tensors_for_metric_fn before passing them to the function and then restoring # the original structure inside the function. def _metric_fn_wrapper(*args): """Unflattens the arguments and pass them to the metric functions.""" unpacked_arguments = tf.nest.pack_sequence_as(tensors_for_metric_fn, args) # Calculate default metrics. metrics = default_metric_fn(**unpacked_arguments[0]) if get_eval_metric_ops_fn is not None: # Append custom metrics. custom_eval_metric_ops = get_eval_metric_ops_fn( **unpacked_arguments[1]) if not isinstance(custom_eval_metric_ops, dict): raise TypeError('`get_eval_metric_ops_fn` must return a dict, ' 'received: {}'.format(custom_eval_metric_ops)) metrics.update(custom_eval_metric_ops) return metrics flat_tensors = tf.nest.flatten(tensors_for_metric_fn) if not all(isinstance(t, tf.Tensor) for t in flat_tensors): raise ValueError('All objects nested within the TF-GAN model must be ' 'tensors. Instead, types are: %s.' % str([type(v) for v in flat_tensors])) return tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.EVAL, predictions=_predictions_from_generator_output( gan_model.generated_data), loss=scalar_loss, eval_metrics=(_metric_fn_wrapper, flat_tensors))