コード例 #1
0
 def model_fn(features, labels, mode, params):
   del features, labels, params
   return tpu_estimator.TPUEstimatorSpec(
       mode=mode,
       loss=constant_op.constant(_EXPECTED_LOSS),
       train_op=control_flow_ops.no_op(),
       scaffold_fn=self._make_scaffold_fn(mode))
コード例 #2
0
 def model_fn_with_summary(features, labels, mode, params):
   del features, labels, params
   loss = constant_op.constant(_EXPECTED_LOSS)
   summary.scalar('loss_scalar_summary', loss)
   summary.histogram('loss_histogram_summary', loss)
   summary.image('loss_image_summary', loss)
   return tpu_estimator.TPUEstimatorSpec(mode=mode, loss=loss)
コード例 #3
0
 def eval_model_fn(features, labels, mode, params):
   del params
   dummy_eval_metric_fn_tensors = [features + labels, features - labels]
   return tpu_estimator.TPUEstimatorSpec(
       mode=mode,
       loss=constant_op.constant(_EXPECTED_LOSS),
       eval_metrics=(_test_eval_metric_fn, dummy_eval_metric_fn_tensors))
コード例 #4
0
 def eval_model_fn(features, labels, mode, params):
   del params
   dummy_eval_metric_fn_tensors_dict = {
       'eval_tensor_1': features + labels,
       'eval_tensor_2': features - labels,
       'extra_tensor': features * 2 - labels,
   }
   return tpu_estimator.TPUEstimatorSpec(
       mode=mode,
       loss=constant_op.constant(_EXPECTED_LOSS),
       eval_metrics=(_test_eval_metric_fn,
                     dummy_eval_metric_fn_tensors_dict))
コード例 #5
0
def _get_eval_estimator_spec(gan_model, gan_loss, gan_loss_no_reduction,
                             get_eval_metric_ops_fn):
    """Return an TPUEstimatorSpec for the eval case."""
    # Make the metric function and tensor names.
    if get_eval_metric_ops_fn is not None:

        def metric_fn(generator_inputs, generated_data, real_data,
                      discriminator_real_outputs, discriminator_gen_outputs,
                      generator_loss, discriminator_loss):
            """`metric_fn` used in TPUEstimator to calculate metrics."""
            eval_metric_ops = {
                'generator_loss': metrics_lib.mean(generator_loss),
                'discriminator_loss': metrics_lib.mean(discriminator_loss),
            }
            custom_eval_metric_ops = get_eval_metric_ops_fn(
                generator_inputs, generated_data, real_data,
                discriminator_real_outputs, discriminator_gen_outputs)
            if not isinstance(custom_eval_metric_ops, dict):
                raise TypeError('`get_eval_metric_ops_fn` must return a dict, '
                                'received: {}'.format(custom_eval_metric_ops))
            eval_metric_ops.update(custom_eval_metric_ops)
            return eval_metric_ops

        tensors = {
            'generator_loss': gan_loss_no_reduction.generator_loss,
            'discriminator_loss': gan_loss_no_reduction.discriminator_loss,
            'generator_inputs': gan_model.generator_inputs,
            'generated_data': gan_model.generated_data,
            'real_data': gan_model.real_data,
            'discriminator_real_outputs': gan_model.discriminator_real_outputs,
            'discriminator_gen_outputs': gan_model.discriminator_gen_outputs,
        }
    else:

        def metric_fn(generator_loss, discriminator_loss):
            return {
                'generator_loss': metrics_lib.mean(generator_loss),
                'discriminator_loss': metrics_lib.mean(discriminator_loss),
            }

        tensors = {
            'generator_loss': gan_loss_no_reduction.generator_loss,
            'discriminator_loss': gan_loss_no_reduction.discriminator_loss,
        }

    scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
    return tpu_estimator.TPUEstimatorSpec(mode=model_fn_lib.ModeKeys.EVAL,
                                          predictions=gan_model.generated_data,
                                          loss=scalar_loss,
                                          eval_metrics=(metric_fn, tensors))
コード例 #6
0
def _get_estimator_spec(mode, gan_model, generator_loss_fn,
                        discriminator_loss_fn, get_eval_metric_ops_fn,
                        generator_optimizer, discriminator_optimizer,
                        joint_train, is_on_tpu, gan_train_steps):
    """Get the TPUEstimatorSpec for the current mode."""
    if mode == model_fn_lib.ModeKeys.PREDICT:
        estimator_spec = tpu_estimator.TPUEstimatorSpec(
            mode=mode,
            predictions={'generated_data': gan_model.generated_data})
    elif mode == model_fn_lib.ModeKeys.EVAL:
        gan_loss = tfgan_tuples.GANLoss(
            generator_loss=generator_loss_fn(gan_model,
                                             add_summaries=not is_on_tpu),
            discriminator_loss=discriminator_loss_fn(
                gan_model, add_summaries=not is_on_tpu))
        # Eval losses for metrics must preserve batch dimension.
        gan_loss_no_reduction = tfgan_tuples.GANLoss(
            generator_loss=generator_loss_fn(gan_model,
                                             add_summaries=False,
                                             reduction=losses.Reduction.NONE),
            discriminator_loss=discriminator_loss_fn(
                gan_model,
                add_summaries=False,
                reduction=losses.Reduction.NONE))
        estimator_spec = _get_eval_estimator_spec(gan_model, gan_loss,
                                                  gan_loss_no_reduction,
                                                  get_eval_metric_ops_fn)
    else:  # model_fn_lib.ModeKeys.TRAIN:
        gan_loss = tfgan_tuples.GANLoss(
            generator_loss=generator_loss_fn(gan_model,
                                             add_summaries=not is_on_tpu),
            discriminator_loss=discriminator_loss_fn(
                gan_model, add_summaries=not is_on_tpu))

        # Construct optimizers if arguments were callable. For TPUs, they must be
        # `CrossShardOptimizer`.
        g_callable = callable(generator_optimizer)
        gopt = generator_optimizer() if g_callable else generator_optimizer
        d_callable = callable(discriminator_optimizer)
        dopt = discriminator_optimizer(
        ) if d_callable else discriminator_optimizer

        estimator_spec = _get_train_estimator_spec(gan_model, gan_loss, gopt,
                                                   dopt, joint_train,
                                                   gan_train_steps)

    return estimator_spec
コード例 #7
0
def _get_train_estimator_spec(gan_model, gan_loss, generator_optimizer,
                              discriminator_optimizer, joint_train,
                              gan_train_steps):
    """Return a TPUEstimatorSpec for the train case."""
    scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss

    # Get generator and discriminator update ops. We split them so that update
    # ops aren't accidentally run multiple times. For now, throw an error if
    # there are update ops that aren't associated with either the generator or
    # the discriminator. Might modify the `kwargs` dictionary.
    gen_update_ops, dis_update_ops = tfgan_train._get_update_ops(  # pylint:disable=protected-access
        {}, gan_model.generator_scope.name, gan_model.discriminator_scope.name)

    def gen_train_op():
        with ops.name_scope('generator_train'):
            return training.create_train_op(
                total_loss=gan_loss.generator_loss,
                optimizer=generator_optimizer,
                variables_to_train=gan_model.generator_variables,
                update_ops=gen_update_ops)

    def dis_train_op():
        with ops.name_scope('discriminator_train'):
            return training.create_train_op(
                total_loss=gan_loss.discriminator_loss,
                optimizer=discriminator_optimizer,
                variables_to_train=gan_model.discriminator_variables,
                update_ops=dis_update_ops)

    # Either optimize the generator and discriminator sequentially or jointly.
    tpu_train_op = _combine_train_ops(gen_train_op, dis_train_op, joint_train,
                                      gan_train_steps)

    return tpu_estimator.TPUEstimatorSpec(loss=scalar_loss,
                                          mode=model_fn_lib.ModeKeys.TRAIN,
                                          train_op=tpu_train_op)
コード例 #8
0
 def eval_model_fn_no_eval_metrics(features, labels, mode, params):
   del features, labels, params
   return tpu_estimator.TPUEstimatorSpec(
       mode=mode, loss=constant_op.constant(_EXPECTED_LOSS))