def model_fn(features, labels, mode, params): del features, labels, params return tpu_estimator.TPUEstimatorSpec( mode=mode, loss=constant_op.constant(_EXPECTED_LOSS), train_op=control_flow_ops.no_op(), scaffold_fn=self._make_scaffold_fn(mode))
def model_fn_with_summary(features, labels, mode, params): del features, labels, params loss = constant_op.constant(_EXPECTED_LOSS) summary.scalar('loss_scalar_summary', loss) summary.histogram('loss_histogram_summary', loss) summary.image('loss_image_summary', loss) return tpu_estimator.TPUEstimatorSpec(mode=mode, loss=loss)
def eval_model_fn(features, labels, mode, params): del params dummy_eval_metric_fn_tensors = [features + labels, features - labels] return tpu_estimator.TPUEstimatorSpec( mode=mode, loss=constant_op.constant(_EXPECTED_LOSS), eval_metrics=(_test_eval_metric_fn, dummy_eval_metric_fn_tensors))
def eval_model_fn(features, labels, mode, params): del params dummy_eval_metric_fn_tensors_dict = { 'eval_tensor_1': features + labels, 'eval_tensor_2': features - labels, 'extra_tensor': features * 2 - labels, } return tpu_estimator.TPUEstimatorSpec( mode=mode, loss=constant_op.constant(_EXPECTED_LOSS), eval_metrics=(_test_eval_metric_fn, dummy_eval_metric_fn_tensors_dict))
def _get_eval_estimator_spec(gan_model, gan_loss, gan_loss_no_reduction, get_eval_metric_ops_fn): """Return an TPUEstimatorSpec for the eval case.""" # Make the metric function and tensor names. if get_eval_metric_ops_fn is not None: def metric_fn(generator_inputs, generated_data, real_data, discriminator_real_outputs, discriminator_gen_outputs, generator_loss, discriminator_loss): """`metric_fn` used in TPUEstimator to calculate metrics.""" eval_metric_ops = { 'generator_loss': metrics_lib.mean(generator_loss), 'discriminator_loss': metrics_lib.mean(discriminator_loss), } custom_eval_metric_ops = get_eval_metric_ops_fn( generator_inputs, generated_data, real_data, discriminator_real_outputs, discriminator_gen_outputs) if not isinstance(custom_eval_metric_ops, dict): raise TypeError('`get_eval_metric_ops_fn` must return a dict, ' 'received: {}'.format(custom_eval_metric_ops)) eval_metric_ops.update(custom_eval_metric_ops) return eval_metric_ops tensors = { 'generator_loss': gan_loss_no_reduction.generator_loss, 'discriminator_loss': gan_loss_no_reduction.discriminator_loss, 'generator_inputs': gan_model.generator_inputs, 'generated_data': gan_model.generated_data, 'real_data': gan_model.real_data, 'discriminator_real_outputs': gan_model.discriminator_real_outputs, 'discriminator_gen_outputs': gan_model.discriminator_gen_outputs, } else: def metric_fn(generator_loss, discriminator_loss): return { 'generator_loss': metrics_lib.mean(generator_loss), 'discriminator_loss': metrics_lib.mean(discriminator_loss), } tensors = { 'generator_loss': gan_loss_no_reduction.generator_loss, 'discriminator_loss': gan_loss_no_reduction.discriminator_loss, } scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss return tpu_estimator.TPUEstimatorSpec(mode=model_fn_lib.ModeKeys.EVAL, predictions=gan_model.generated_data, loss=scalar_loss, eval_metrics=(metric_fn, tensors))
def _get_estimator_spec(mode, gan_model, generator_loss_fn, discriminator_loss_fn, get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, joint_train, is_on_tpu, gan_train_steps): """Get the TPUEstimatorSpec for the current mode.""" if mode == model_fn_lib.ModeKeys.PREDICT: estimator_spec = tpu_estimator.TPUEstimatorSpec( mode=mode, predictions={'generated_data': gan_model.generated_data}) elif mode == model_fn_lib.ModeKeys.EVAL: gan_loss = tfgan_tuples.GANLoss( generator_loss=generator_loss_fn(gan_model, add_summaries=not is_on_tpu), discriminator_loss=discriminator_loss_fn( gan_model, add_summaries=not is_on_tpu)) # Eval losses for metrics must preserve batch dimension. gan_loss_no_reduction = tfgan_tuples.GANLoss( generator_loss=generator_loss_fn(gan_model, add_summaries=False, reduction=losses.Reduction.NONE), discriminator_loss=discriminator_loss_fn( gan_model, add_summaries=False, reduction=losses.Reduction.NONE)) estimator_spec = _get_eval_estimator_spec(gan_model, gan_loss, gan_loss_no_reduction, get_eval_metric_ops_fn) else: # model_fn_lib.ModeKeys.TRAIN: gan_loss = tfgan_tuples.GANLoss( generator_loss=generator_loss_fn(gan_model, add_summaries=not is_on_tpu), discriminator_loss=discriminator_loss_fn( gan_model, add_summaries=not is_on_tpu)) # Construct optimizers if arguments were callable. For TPUs, they must be # `CrossShardOptimizer`. g_callable = callable(generator_optimizer) gopt = generator_optimizer() if g_callable else generator_optimizer d_callable = callable(discriminator_optimizer) dopt = discriminator_optimizer( ) if d_callable else discriminator_optimizer estimator_spec = _get_train_estimator_spec(gan_model, gan_loss, gopt, dopt, joint_train, gan_train_steps) return estimator_spec
def _get_train_estimator_spec(gan_model, gan_loss, generator_optimizer, discriminator_optimizer, joint_train, gan_train_steps): """Return a TPUEstimatorSpec for the train case.""" scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss # Get generator and discriminator update ops. We split them so that update # ops aren't accidentally run multiple times. For now, throw an error if # there are update ops that aren't associated with either the generator or # the discriminator. Might modify the `kwargs` dictionary. gen_update_ops, dis_update_ops = tfgan_train._get_update_ops( # pylint:disable=protected-access {}, gan_model.generator_scope.name, gan_model.discriminator_scope.name) def gen_train_op(): with ops.name_scope('generator_train'): return training.create_train_op( total_loss=gan_loss.generator_loss, optimizer=generator_optimizer, variables_to_train=gan_model.generator_variables, update_ops=gen_update_ops) def dis_train_op(): with ops.name_scope('discriminator_train'): return training.create_train_op( total_loss=gan_loss.discriminator_loss, optimizer=discriminator_optimizer, variables_to_train=gan_model.discriminator_variables, update_ops=dis_update_ops) # Either optimize the generator and discriminator sequentially or jointly. tpu_train_op = _combine_train_ops(gen_train_op, dis_train_op, joint_train, gan_train_steps) return tpu_estimator.TPUEstimatorSpec(loss=scalar_loss, mode=model_fn_lib.ModeKeys.TRAIN, train_op=tpu_train_op)
def eval_model_fn_no_eval_metrics(features, labels, mode, params): del features, labels, params return tpu_estimator.TPUEstimatorSpec( mode=mode, loss=constant_op.constant(_EXPECTED_LOSS))