def model_fn(features, labels, mode):
   _ = labels
   step = training.get_global_step()
   w = variable_scope.get_variable(
       'w',
       shape=[],
       initializer=init_ops.zeros_initializer(),
       dtype=dtypes.int64)
   if estimator_lib.ModeKeys.TRAIN == mode:
     # to consume features, we have control dependency
     with ops.control_dependencies([features]):
       step_inc = state_ops.assign_add(training.get_global_step(), 1)
     with ops.control_dependencies([step_inc]):
       assign_w_to_step_plus_2 = w.assign(step + 2)
     return estimator_lib.EstimatorSpec(
         mode,
         loss=constant_op.constant(3.),
         train_op=assign_w_to_step_plus_2)
   if estimator_lib.ModeKeys.EVAL == mode:
     # to consume features, we have control dependency
     with ops.control_dependencies([features]):
       loss = constant_op.constant(5.)
     return estimator_lib.EstimatorSpec(
         mode,
         loss=loss,
         # w is constant in each step, so the mean.
         # w = 0 if step==0 else step+2
         eval_metric_ops={'mean_of_const': metrics_lib.mean(w)})
Beispiel #2
0
 def model_fn(features, labels, mode):
     _ = labels
     step = training.get_global_step()
     w = variable_scope.get_variable(
         'w',
         shape=[],
         initializer=init_ops.zeros_initializer(),
         dtype=dtypes.int64)
     if estimator_lib.ModeKeys.TRAIN == mode:
         # to consume features, we have control dependency
         with ops.control_dependencies([features]):
             step_inc = state_ops.assign_add(training.get_global_step(),
                                             1)
         with ops.control_dependencies([step_inc]):
             assign_w_to_step_plus_2 = w.assign(step + 2)
         return estimator_lib.EstimatorSpec(
             mode,
             loss=constant_op.constant(3.),
             train_op=assign_w_to_step_plus_2)
     if estimator_lib.ModeKeys.EVAL == mode:
         # to consume features, we have control dependency
         with ops.control_dependencies([features]):
             loss = constant_op.constant(5.)
         return estimator_lib.EstimatorSpec(
             mode,
             loss=loss,
             # w is constant in each step, so the mean.
             # w = 0 if step==0 else step+2
             eval_metric_ops={'mean_of_const': metrics_lib.mean(w)})
Beispiel #3
0
def model_fn_global_step_incrementer(features, labels, mode):
  _, _ = features, labels
  global_step = training.get_global_step()
  return model_fn_lib.EstimatorSpec(
      mode,
      loss=constant_op.constant(1.),
      train_op=state_ops.assign_add(global_step, 1))
    def _create_global_step(self, graph):
        """Creates a global step suitable for TPUs.

    Args:
      graph: The graph in which to create the global step.

    Returns:
      A global step `Tensor`.

    Raises:
      ValueError: if the global step tensor is already defined.
    """
        graph = graph or ops.get_default_graph()
        if training.get_global_step(graph) is not None:
            raise ValueError('"global_step" already exists.')
        # Create in proper graph and base name_scope.
        with graph.as_default() as g, g.name_scope(None):
            return variable_scope.get_variable(
                ops.GraphKeys.GLOBAL_STEP,
                shape=[],
                dtype=dtypes.int32,
                initializer=init_ops.zeros_initializer(),
                trainable=False,
                use_resource=True,
                collections=[
                    ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP
                ])
    def _create_head_with_eval_metric_ops(self, mode, loss, eval_metric_ops):
        """Creates a head returning `TPUEstimatorSpec` based on mode.

    This version contains eval that will not run on TPUs, where eval_metric_ops
    has not been split into a metrics_fn that runs on CPUs. The intent is to
    test the entire eval (model_fn forward pass) and metrics output on CPU.

    Args:
      mode: The mode such as TRAIN, EVAL.
      loss: Training loss `Tensor`. Must be either scalar, or with shape `[1]`.
      eval_metric_ops: Dict of metric results keyed by name.

    Returns:
      An EstimatorSpec for EVAL or TPUEstimatorSpec otherwise.
    """
        if mode == _EVAL:
            return model_fn_lib.EstimatorSpec(mode=mode,
                                              eval_metric_ops=eval_metric_ops,
                                              loss=loss)
        # Train
        optimizer = tf.compat.v1.tpu.CrossShardOptimizer(
            training.GradientDescentOptimizer(learning_rate=0.5))
        train_op = optimizer.minimize(loss,
                                      global_step=training.get_global_step())
        return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                              train_op=train_op,
                                              loss=loss)
    def _model_fn(features, labels, mode):
      predictions = layers.dense(
          features['x'], 1, kernel_initializer=init_ops.zeros_initializer())
      export_outputs = {
          'predictions': export_output.RegressionOutput(predictions)
      }

      if mode == model_fn_lib.ModeKeys.PREDICT:
        return model_fn_lib.EstimatorSpec(
            mode, predictions=predictions, export_outputs=export_outputs)

      loss = losses.mean_squared_error(labels, predictions)
      train_op = training.GradientDescentOptimizer(learning_rate=0.5).minimize(
          loss, training.get_global_step())
      eval_metric_ops = {
          'absolute_error': metrics_lib.mean_absolute_error(
              labels, predictions)
      }

      return model_fn_lib.EstimatorSpec(
          mode,
          predictions=predictions,
          loss=loss,
          train_op=train_op,
          eval_metric_ops=eval_metric_ops,
          export_outputs=export_outputs)
Beispiel #7
0
    def _model_fn(features, labels, mode):
      predictions = layers.dense(
          features['x'], 1, kernel_initializer=init_ops.zeros_initializer())
      export_outputs = {
          'predictions': export.RegressionOutput(predictions)
      }

      if mode == model_fn_lib.ModeKeys.PREDICT:
        return model_fn_lib.EstimatorSpec(
            mode, predictions=predictions, export_outputs=export_outputs)

      loss = losses.mean_squared_error(labels, predictions)
      train_op = training.GradientDescentOptimizer(learning_rate=0.5).minimize(
          loss, training.get_global_step())
      eval_metric_ops = {
          'absolute_error': metrics_lib.mean_absolute_error(
              labels, predictions)
      }

      return model_fn_lib.EstimatorSpec(
          mode,
          predictions=predictions,
          loss=loss,
          train_op=train_op,
          eval_metric_ops=eval_metric_ops,
          export_outputs=export_outputs)
def model_fn_diff_modes(features, labels, mode):
  _, _ = features, labels
  v = variables.Variable(21, name='some_var')
  train_op = None
  loss = constant_op.constant(104)
  if mode == model_fn_lib.ModeKeys.TRAIN:
    loss = constant_op.constant(105)
    predictions = constant_op.constant([501])
    train_op = control_flow_ops.group(
        state_ops.assign_add(training.get_global_step(), 1),
        state_ops.assign_add(v, 3))
  elif mode == model_fn_lib.ModeKeys.EVAL:
    loss = constant_op.constant(106)
    predictions = constant_op.constant([502])
  else:
    loss = constant_op.constant(107)
    predictions = constant_op.constant([503])
  return model_fn_lib.EstimatorSpec(
      mode,
      loss=loss,
      train_op=train_op,
      eval_metric_ops={
          'abs_err': metrics_lib.mean_absolute_error(
              constant_op.constant(0), predictions)},
      predictions=predictions)
Beispiel #9
0
def model_fn_global_step_incrementer(features, labels, mode):
  _, _ = features, labels
  global_step = training.get_global_step()
  return model_fn_lib.EstimatorSpec(
      mode,
      loss=constant_op.constant(1.),
      train_op=state_ops.assign_add(global_step, 1))
Beispiel #10
0
  def build_subnetwork(self,
                       features,
                       logits_dimension,
                       training,
                       iteration_step,
                       summary,
                       previous_ensemble=None):
    assert features is not None
    assert training is not None
    assert iteration_step is not None
    assert summary is not None

    # Trainable variables collection should always be empty when
    # build_subnetwork is called.
    assert not tf_compat.v1.get_collection(
        tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES)

    # Subnetworks get iteration steps instead of global steps.
    step_name = "subnetwork_test/iteration_step"
    assert step_name == tf_compat.tensor_name(
        tf_compat.v1.train.get_global_step())
    assert step_name == tf_compat.tensor_name(train.get_global_step())
    assert step_name == tf_compat.tensor_name(training_util.get_global_step())
    assert step_name == tf_compat.tensor_name(tf_v1.train.get_global_step())
    assert step_name == tf_compat.tensor_name(
        tf_compat.v1.train.get_or_create_global_step())
    assert step_name == tf_compat.tensor_name(train.get_or_create_global_step())
    assert step_name == tf_compat.tensor_name(
        training_util.get_or_create_global_step())
    assert step_name == tf_compat.tensor_name(
        tf_v1.train.get_or_create_global_step())

    # Subnetworks get scoped summaries.
    assert "fake_scalar" == tf_compat.v1.summary.scalar("scalar", 1.)
    assert "fake_image" == tf_compat.v1.summary.image("image", 1.)
    assert "fake_histogram" == tf_compat.v1.summary.histogram("histogram", 1.)
    assert "fake_audio" == tf_compat.v1.summary.audio("audio", 1., 1.)
    last_layer = tu.dummy_tensor(shape=(2, 3))

    def logits_fn(logits_dim):
      return tf_compat.v1.layers.dense(
          last_layer,
          units=logits_dim,
          kernel_initializer=tf_compat.v1.glorot_uniform_initializer(
              seed=self._seed))

    if self._multi_head:
      logits = {
          "head1": logits_fn(logits_dimension / 2),
          "head2": logits_fn(logits_dimension / 2)
      }
      last_layer = {"head1": last_layer, "head2": last_layer}
    else:
      logits = logits_fn(logits_dimension)

    return Subnetwork(
        last_layer=logits if self._use_logits_last_layer else last_layer,
        logits=logits,
        complexity=2,
        persisted_tensors={})
def _model_fn_train_only(features, labels):
    v = variables.Variable(constant_op.constant(23), name='v')
    return model_fn_lib.EstimatorSpec(ModeKeys.TRAIN,
                                      predictions=features * labels,
                                      loss=constant_op.constant(5) + v,
                                      train_op=state_ops.assign_add(
                                          training.get_global_step(), 1))
Beispiel #12
0
def model_fn_diff_modes(features, labels, mode):
    _, _ = features, labels
    v = variables.Variable(21, name='some_var')
    train_op = None
    loss = constant_op.constant(104)
    if mode == model_fn_lib.ModeKeys.TRAIN:
        loss = constant_op.constant(105)
        predictions = constant_op.constant([501])
        train_op = control_flow_ops.group(
            state_ops.assign_add(training.get_global_step(), 1),
            state_ops.assign_add(v, 3))
    elif mode == model_fn_lib.ModeKeys.EVAL:
        loss = constant_op.constant(106)
        predictions = constant_op.constant([502])
    else:
        loss = constant_op.constant(107)
        predictions = constant_op.constant([503])
    return model_fn_lib.EstimatorSpec(mode,
                                      loss=loss,
                                      train_op=train_op,
                                      eval_metric_ops={
                                          'abs_err':
                                          metrics_lib.mean_absolute_error(
                                              constant_op.constant(0),
                                              predictions)
                                      },
                                      predictions=predictions)
def model_fn(features, labels, mode, params):
    del params  # unused
    with variable_scope.variable_scope('m', reuse=variable_scope.AUTO_REUSE):
        w = variable_scope.get_variable('W', shape=[1000, 10])
    logits = math_ops.matmul(features, w)
    loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

    if mode == model_fn_lib.ModeKeys.TRAIN:
        optimizer = training.RMSPropOptimizer(learning_rate=0.01)
        optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)
        train_op = optimizer.minimize(loss, training.get_global_step())
        return tpu_estimator.TPUEstimatorSpec(
            mode=model_fn_lib.ModeKeys.TRAIN,
            loss=loss,
            train_op=train_op,
        )
    elif mode == model_fn_lib.ModeKeys.EVAL:

        def metric_fn(labels, logits):
            labels = math_ops.cast(labels, dtypes.int64)
            logging.info('LABELS %s %s', labels, logits)
            return {
                'recall@1': metrics_lib.recall_at_k(labels, logits, 1),
                'recall@5': metrics_lib.recall_at_k(labels, logits, 5),
            }

        loss = losses.sparse_softmax_cross_entropy(labels=labels,
                                                   logits=logits)
        eval_metrics = (metric_fn, [labels, logits])
        return tpu_estimator.TPUEstimatorSpec(mode=model_fn_lib.ModeKeys.EVAL,
                                              loss=loss,
                                              eval_metrics=eval_metrics)
Beispiel #14
0
def _model_fn_with_x_y(features, labels, mode):
    _ = labels
    variables.Variable(1., name='weight')
    scores = constant_op.constant([3.])
    classes = constant_op.constant(['wumpus'])
    if mode == model_fn_lib.ModeKeys.PREDICT:
        variables.Variable(36., name='name_collision')
        return model_fn_lib.EstimatorSpec(
            mode,
            predictions=constant_op.constant(10.),
            export_outputs={
                'test': export_output.ClassificationOutput(scores, classes)
            })
    else:
        prefix = 'eval_' if mode == model_fn_lib.ModeKeys.EVAL else ''

        multiplied = math_ops.multiply(features['x'],
                                       features['y'],
                                       name='{}multiplied'.format(prefix))
        metrics = {
            'mean':
            metrics_lib.mean(features['x'] - features['y'],
                             name='{}mean'.format(prefix))
        }
        variables.Variable(1., name='later_var')
        variables.Variable(3., name='name_collision')
        return model_fn_lib.EstimatorSpec(mode,
                                          predictions=multiplied,
                                          loss=constant_op.constant(1.),
                                          train_op=state_ops.assign_add(
                                              training.get_global_step(), 1),
                                          eval_metric_ops=metrics)
Beispiel #15
0
  def _create_global_step(self, graph):
    """Creates a global step suitable for TPUs.

    Args:
      graph: The graph in which to create the global step.

    Returns:
      A global step `Tensor`.

    Raises:
      ValueError: if the global step tensor is already defined.
    """
    graph = graph or ops.get_default_graph()
    if training.get_global_step(graph) is not None:
      raise ValueError('"global_step" already exists.')
    # Create in proper graph and base name_scope.
    with graph.as_default() as g, g.name_scope(None):
      return variable_scope.get_variable(
          ops.GraphKeys.GLOBAL_STEP,
          shape=[],
          dtype=dtypes.int32,
          initializer=init_ops.zeros_initializer(),
          trainable=False,
          use_resource=True,
          collections=[ops.GraphKeys.GLOBAL_VARIABLES,
                       ops.GraphKeys.GLOBAL_STEP])
Beispiel #16
0
 def model_fn(features, mode):
     del features
     global_step = training.get_global_step()
     return estimator_lib.EstimatorSpec(
         mode,
         loss=constant_op.constant([5.]),
         predictions={'x': constant_op.constant([5.])},
         train_op=global_step.assign_add(1))
Beispiel #17
0
 def model_fn(features, mode):
   del features
   global_step = training.get_global_step()
   return estimator_lib.EstimatorSpec(
       mode,
       loss=constant_op.constant([5.]),
       predictions={'x': constant_op.constant([5.])},
       train_op=global_step.assign_add(1))
Beispiel #18
0
 def model_fn(features, labels, mode):
   _ = labels
   with ops.control_dependencies([features['x']]):
     loss = features['x'][1][0]
   return model_fn_lib.EstimatorSpec(
       mode,
       loss=loss,
       train_op=state_ops.assign_add(training.get_global_step(), 1))
 def model_fn(features, labels, mode):
   _ = labels
   with ops.control_dependencies([features['x']]):
     loss = features['x'][1][0]
   return model_fn_lib.EstimatorSpec(
       mode,
       loss=loss,
       train_op=state_ops.assign_add(training.get_global_step(), 1))
Beispiel #20
0
 def model_fn(features, labels, mode):
   _, _ = features, labels
   return model_fn_lib.EstimatorSpec(
       mode,
       loss=constant_op.constant([103]),
       train_op=state_ops.assign_add(training.get_global_step(), 1),
       predictions=constant_op.constant([502]),
       export_outputs={'test': export_output.ClassificationOutput(
           constant_op.constant([[32.]]))})
 def model_fn(features, labels, mode):
   _, _ = features, labels
   return model_fn_lib.EstimatorSpec(
       mode,
       loss=constant_op.constant([103]),
       train_op=state_ops.assign_add(training.get_global_step(), 1),
       predictions=constant_op.constant([502]),
       export_outputs={'test': export_output.ClassificationOutput(
           constant_op.constant([[32.]]))})
 def _model_fn(features, labels, mode):
   _, _, _ = features, labels, mode
   self.assertIsNotNone(training.get_global_step())
   self.assertEqual(expected_random_seed, ops.get_default_graph().seed)
   return model_fn_lib.EstimatorSpec(
       mode=mode,
       loss=constant_op.constant(0.),
       train_op=constant_op.constant(0.),
       predictions=constant_op.constant([[0.]]))
Beispiel #23
0
 def model_fn(features, labels, mode):
   loss = None
   if labels is not None:
     loss = labels[0][0] + labels[1][0]
   return model_fn_lib.EstimatorSpec(
       mode,
       loss=loss,
       train_op=state_ops.assign_add(training.get_global_step(), 1),
       predictions={'features_0': array_ops.identity([features['x'][0][0]]),
                    'features_1': array_ops.identity([features['x'][1][0]])})
Beispiel #24
0
 def _model_fn_with_incremental_loss(features, labels, mode):
   _, _ = features, labels
   local_weight = variables.Variable(
       0., name='local_weight', collections=[ops.GraphKeys.LOCAL_VARIABLES])
   # Loss will be 2, 4, 6, ...
   loss = 2 * state_ops.assign_add(local_weight, 1.)
   return model_fn_lib.EstimatorSpec(
       mode,
       loss=loss,
       train_op=state_ops.assign_add(training.get_global_step(), 1))
Beispiel #25
0
 def _model_fn_with_incremental_loss(features, labels, mode):
   _, _ = features, labels
   local_weight = variables.Variable(
       0., name='local_weight', collections=[ops.GraphKeys.LOCAL_VARIABLES])
   # Loss will be 2, 4, 6, ...
   loss = 2 * state_ops.assign_add(local_weight, 1.)
   return model_fn_lib.EstimatorSpec(
       mode,
       loss=loss,
       train_op=state_ops.assign_add(training.get_global_step(), 1))
 def model_fn(features, labels, mode):
   loss = None
   if labels is not None:
     loss = labels[0][0] + labels[1][0]
   return model_fn_lib.EstimatorSpec(
       mode,
       loss=loss,
       train_op=state_ops.assign_add(training.get_global_step(), 1),
       predictions={'features_0': array_ops.identity([features['x'][0][0]]),
                    'features_1': array_ops.identity([features['x'][1][0]])})
 def model_fn(features, labels, mode):
   _, _ = features, labels
   v = variables.Variable(21, name='some_var')
   scaffold = monitored_session.Scaffold(
       local_init_op=state_ops.assign_add(v, -3).op
   )
   return model_fn_lib.EstimatorSpec(
       mode,
       scaffold=scaffold,
       train_op=state_ops.assign_add(training.get_global_step(), 1),
       loss=array_ops.identity(v))
 def _create_and_assert_global_step(self, graph):
   """Creates and asserts properties of the global step.
   Args:
     graph: The graph in which to create the global step tensor.
   Returns:
     The global step `Tensor`.
   """
   step = self._create_global_step(graph)
   assert step == training.get_global_step()
   assert step.dtype.is_integer
   return step
Beispiel #29
0
 def _create_head(self, mode, loss, eval_metrics):
   """Creates a head returning `TPUEstimatorSpec` based on mode."""
   if mode == _EVAL:
     return tpu_estimator.TPUEstimatorSpec(
         mode=mode, eval_metrics=eval_metrics, loss=loss)
   # Train
   optimizer = tf.tpu.CrossShardOptimizer(
       training.GradientDescentOptimizer(learning_rate=0.5))
   train_op = optimizer.minimize(loss, global_step=training.get_global_step())
   return tpu_estimator.TPUEstimatorSpec(
       mode=mode, train_op=train_op, loss=loss)
Beispiel #30
0
 def model_fn(features, labels, mode):
     _, _ = features, labels
     v = variables.Variable(21, name='some_var')
     scaffold = monitored_session.Scaffold(
         local_init_op=state_ops.assign_add(v, -3).op)
     return model_fn_lib.EstimatorSpec(mode,
                                       scaffold=scaffold,
                                       train_op=state_ops.assign_add(
                                           training.get_global_step(),
                                           1),
                                       loss=array_ops.identity(v))
Beispiel #31
0
 def model_fn(features, labels, mode):
   _ = labels
   if estimator_lib.ModeKeys.TRAIN == mode:
     with ops.control_dependencies([features]):
       train_op = state_ops.assign_add(training.get_global_step(), 1)
     return estimator_lib.EstimatorSpec(
         mode, loss=constant_op.constant(3.), train_op=train_op)
   if estimator_lib.ModeKeys.EVAL == mode:
     return estimator_lib.EstimatorSpec(
         mode,
         loss=constant_op.constant(5.),
         eval_metric_ops={'mean_of_features': metrics_lib.mean(features)})
 def model_fn(features, labels, mode):
   _ = labels
   if estimator_lib.ModeKeys.TRAIN == mode:
     with ops.control_dependencies([features]):
       train_op = state_ops.assign_add(training.get_global_step(), 1)
     return estimator_lib.EstimatorSpec(
         mode, loss=constant_op.constant(3.), train_op=train_op)
   if estimator_lib.ModeKeys.EVAL == mode:
     return estimator_lib.EstimatorSpec(
         mode,
         loss=constant_op.constant(5.),
         eval_metric_ops={'mean_of_features': metrics_lib.mean(features)})
Beispiel #33
0
  def _create_and_assert_global_step(self, graph):
    """Creates and asserts properties of the global step.

    Args:
      graph: The graph in which to create the global step tensor.

    Returns:
      The global step `Tensor`.
    """
    step = self._create_global_step(graph)
    assert step == training.get_global_step()
    assert step.dtype.is_integer
    return step
Beispiel #34
0
 def model_fn(features, labels, mode):
     tb = lookup_ops.MutableHashTable(key_dtype=dtypes.int32,
                                      value_dtype=dtypes.int32,
                                      default_value=-1)
     predictions = tb.lookup(features['x'])
     train_op = None
     if mode == ModeKeys.TRAIN:
         train_op = control_flow_ops.group(
             tb.insert(features['x'], labels),
             state_ops.assign_add(training.get_global_step(), 1))
     return model_fn_lib.EstimatorSpec(mode,
                                       loss=constant_op.constant(0),
                                       predictions=predictions,
                                       train_op=train_op)
    def model_fn(features, labels, mode, params):
        loss = None
        train_op = None
        export_outputs = None

        # This could be some pre-processing on CPU like calls to input layer with
        # embedding columns.
        x2 = features['x'] * 2

        def computation(input_tensor):
            return layers.dense(
                input_tensor,
                1,
                kernel_initializer=init_ops.zeros_initializer())

        if mode != _PREDICT:
            predictions = computation(x2)
            loss = losses.mean_squared_error(labels, predictions)
            optimizer = tf.tpu.CrossShardOptimizer(
                training.GradientDescentOptimizer(learning_rate=0.5))
            train_op = optimizer.minimize(loss, training.get_global_step())
        else:
            inputs = [x2]
            if params['use_tpu']:
                predictions = array_ops.identity(
                    tpu_estimator.inference_on_tpu(computation,
                                                   inputs,
                                                   num_batch_threads=1,
                                                   max_batch_size=2,
                                                   batch_timeout_micros=100),
                    name='predictions')
            else:
                predictions = array_ops.identity(computation(*inputs),
                                                 name='predictions')
            key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
            export_outputs = {
                key: export_lib.PredictOutput({'prediction': predictions})
            }

            classes = string_ops.as_string(predictions, name='classes')
            classification_output = export_lib.ClassificationOutput(
                classes=classes)
            export_outputs['classification'] = classification_output

        return tpu_estimator.TPUEstimatorSpec(
            mode,
            loss=loss,
            train_op=train_op,
            predictions={'predictions': predictions},
            export_outputs=export_outputs)
    def _model_fn(features, labels, mode, config, params=None):
        """model_fn."""

        # TODO(jhseu): Move to EVAL and PREDICT to TPU.
        if mode != model_fn_lib.ModeKeys.TRAIN:
            return _call_model_fn_without_tpu(model_fn, features, labels, mode,
                                              config, params)

        # Now for TPU training.
        if params is not None and _BATCH_SIZE_KEY in params:
            params[_BATCH_SIZE_KEY] //= config.tpu_config.num_shards

        assert isinstance(features, _PerShardOutput)
        features = features.as_list()
        if labels is not None:
            assert isinstance(labels, _PerShardOutput)
            labels = labels.as_list()

        dequeue_fn, enqueue_fn = (_create_infeed_enqueue_ops_and_dequeue_fn(
            config, features, labels))

        loss = _train_on_tpu_shards(config,
                                    train_step=_convert_model_fn_to_train_step(
                                        model_fn, dequeue_fn, mode, config,
                                        params))

        # Gets the variables back from TPU nodes. This means the variables updated
        # by TPU will now be *synced* to host memory.
        update_ops = [
            array_ops.check_numerics(v.read_value(),
                                     'Gradient for %s is NaN' % v.name).op
            for v in variables.trainable_variables()
        ]

        hooks = [
            TpuInfeedSessionHook(config, enqueue_fn),
            training.LoggingTensorHook(
                {
                    'loss': array_ops.identity(loss),
                    'step': training.get_global_step()
                },
                every_n_secs=30)
        ]

        return model_fn_lib.EstimatorSpec(
            mode,
            loss=array_ops.identity(loss),
            training_hooks=hooks,
            train_op=control_flow_ops.group(*update_ops))
def _model_fn(features, labels, mode):
    v = variables.Variable(constant_op.constant(23), name='v')
    if mode == ModeKeys.PREDICT:
        return model_fn_lib.EstimatorSpec(ModeKeys.PREDICT,
                                          predictions=features + 1)
    elif mode == ModeKeys.EVAL:
        return model_fn_lib.EstimatorSpec(ModeKeys.EVAL,
                                          loss=constant_op.constant(5) + v,
                                          predictions=features + labels)
    elif mode == ModeKeys.TRAIN:
        return model_fn_lib.EstimatorSpec(ModeKeys.TRAIN,
                                          predictions=features * labels,
                                          loss=constant_op.constant(5) + v,
                                          train_op=state_ops.assign_add(
                                              training.get_global_step(), 1))
Beispiel #38
0
def _model_fn_with_eval_metric_ops(features, labels, mode, params):
  _, _ = features, labels
  metric_name = params.get('metric_name') or 'metric'
  metric_value = params.get('metric_value') or 2.
  global_step = training.get_global_step()
  loss = constant_op.constant(1.)
  metric_update_op = loss.op
  metric_tensor = control_flow_ops.with_dependencies(
      [metric_update_op], constant_op.constant(metric_value))
  return model_fn_lib.EstimatorSpec(
      mode,
      loss=loss,
      predictions={'predictions': constant_op.constant(1.)},
      train_op=state_ops.assign_add(global_step, 1),
      eval_metric_ops={metric_name: (metric_tensor, metric_update_op)})
Beispiel #39
0
def _model_fn_with_eval_metric_ops(features, labels, mode, params):
  _, _ = features, labels
  metric_name = params.get('metric_name') or 'metric'
  metric_value = params.get('metric_value') or 2.
  global_step = training.get_global_step()
  loss = constant_op.constant(1.)
  metric_update_op = loss.op
  metric_tensor = control_flow_ops.with_dependencies(
      [metric_update_op], constant_op.constant(metric_value))
  return model_fn_lib.EstimatorSpec(
      mode,
      loss=loss,
      predictions={'predictions': constant_op.constant(1.)},
      train_op=state_ops.assign_add(global_step, 1),
      eval_metric_ops={metric_name: (metric_tensor, metric_update_op)})
Beispiel #40
0
def _model_fn_for_export_tests(features, labels, mode):
  _, _ = features, labels
  variables.Variable(1., name='weight')
  scores = constant_op.constant([3.])
  classes = constant_op.constant(['wumpus'])
  update_global_step = state_ops.assign_add(training.get_global_step(), 1)
  with ops.control_dependencies([update_global_step]):
    train_op = constant_op.constant(2.)
  return model_fn_lib.EstimatorSpec(
      mode,
      predictions=constant_op.constant(10.),
      loss=constant_op.constant(1.),
      train_op=train_op,
      export_outputs={
          'test': export_output.ClassificationOutput(scores, classes)})
 def model_fn(features, labels, mode):
   _, _ = features, labels
   w = variables.Variable(
       initial_value=[0.],
       trainable=False,
       collections=[ops.GraphKeys.SAVEABLE_OBJECTS])
   init_op = control_flow_ops.group(
       [w.initializer, training.get_global_step().initializer])
   return estimator_lib.EstimatorSpec(
       mode,
       loss=constant_op.constant(3.),
       scaffold=training.Scaffold(init_op=init_op),
       train_op=constant_op.constant(5.),
       eval_metric_ops={
           'mean_of_features': metrics_lib.mean(constant_op.constant(2.))
       })
Beispiel #42
0
 def model_fn(features, labels, mode):
   _, _ = features, labels
   w = variables.VariableV1(
       initial_value=[0.],
       trainable=False,
       collections=[ops.GraphKeys.SAVEABLE_OBJECTS])
   init_op = control_flow_ops.group(
       [w.initializer, training.get_global_step().initializer])
   return estimator_lib.EstimatorSpec(
       mode,
       loss=constant_op.constant(3.),
       scaffold=training.Scaffold(init_op=init_op),
       train_op=constant_op.constant(5.),
       eval_metric_ops={
           'mean_of_features': metrics_lib.mean(constant_op.constant(2.))
       })
Beispiel #43
0
  def _model_fn(features, labels, mode, config, params=None):
    """model_fn."""

    # TODO(jhseu): Move to EVAL and PREDICT to TPU.
    if mode != model_fn_lib.ModeKeys.TRAIN:
      return _call_model_fn_without_tpu(
          model_fn, features, labels, mode, config, params)

    # Now for TPU training.
    if params is not None and _BATCH_SIZE_KEY in params:
      params[_BATCH_SIZE_KEY] //= config.tpu_config.num_shards

    assert isinstance(features, _PerShardOutput)
    features = features.as_list()
    if labels is not None:
      assert isinstance(labels, _PerShardOutput)
      labels = labels.as_list()

    dequeue_fn, enqueue_fn = (
        _create_infeed_enqueue_ops_and_dequeue_fn(config, features, labels))

    loss = _train_on_tpu_shards(
        config,
        train_step=_convert_model_fn_to_train_step(
            model_fn, dequeue_fn, mode, config, params))

    # Gets the variables back from TPU nodes. This means the variables updated
    # by TPU will now be *synced* to host memory.
    update_ops = [
        array_ops.check_numerics(v.read_value(),
                                 'Gradient for %s is NaN' % v.name).op
        for v in variables.trainable_variables()
    ]

    hooks = [
        TpuInfeedSessionHook(config, enqueue_fn),
        training.LoggingTensorHook(
            {'loss': array_ops.identity(loss),
             'step': training.get_global_step()},
            every_n_secs=30)
    ]

    return model_fn_lib.EstimatorSpec(
        mode,
        loss=array_ops.identity(loss),
        training_hooks=hooks,
        train_op=control_flow_ops.group(*update_ops))
        def _model_fn(features, labels, mode, params):
            if not self._export_mode:
                # Always check batch size in params
                self.assertEqual(batch_size_dict[mode], params['batch_size'])
            else:
                self.assertNotIn('batch_size', params)

            # Check the input feeds correct shape for train and eval. When eval on CPU
            # or predict, it is allowed to have dynamic shape. So, here only validates
            # the fully known shape (which covers the TPU train).
            if features['x'].shape.is_fully_defined():
                self.assertEqual(batch_size_dict[mode], features['x'].shape[0])

            predictions = layers.dense(
                features['x'],
                1,
                kernel_initializer=init_ops.ones_initializer())
            export_outputs = {
                'predictions': export_output.RegressionOutput(predictions)
            }

            if mode == _PREDICT:
                return _create_estimator_spec(
                    mode=mode,
                    predictions={'predictions': predictions},
                    export_outputs=export_outputs)

            loss = losses.mean_squared_error(labels, predictions)

            optimizer = tf.tpu.CrossShardOptimizer(
                training.GradientDescentOptimizer(learning_rate=0.5))
            train_op = optimizer.minimize(
                loss, global_step=training.get_global_step())

            eval_metrics = (
                lambda labels, predictions: {  # pylint: disable=g-long-lambda
                    'absolute_error':
                    metrics_lib.mean_absolute_error(labels, predictions)
                },
                [labels, predictions])
            return _create_estimator_spec(
                mode=mode,
                loss=loss,
                predictions={'predictions': predictions},
                export_outputs=export_outputs,
                train_op=train_op,
                eval_metrics=eval_metrics)
def _model_fn_callable_variable_initializers(features, labels, mode):
    """Model_fn with callable variable initializers (for WrappedGraph tests)."""
    _ = features, labels
    v = variables.Variable(lambda: constant_op.constant(23), name='v')
    if mode == ModeKeys.PREDICT:
        return model_fn_lib.EstimatorSpec(ModeKeys.PREDICT,
                                          predictions=features + 1)
    elif mode == ModeKeys.EVAL:
        return model_fn_lib.EstimatorSpec(ModeKeys.EVAL,
                                          loss=constant_op.constant(5) + v,
                                          predictions=features + labels)
    elif mode == ModeKeys.TRAIN:
        return model_fn_lib.EstimatorSpec(ModeKeys.TRAIN,
                                          predictions=features * labels,
                                          loss=constant_op.constant(5) + v,
                                          train_op=state_ops.assign_add(
                                              training.get_global_step(), 1))
def model_fn_global_step_incrementer(features, labels, mode, params):
    del params
    loss = None
    train_op = None
    predictions = dense_computation(features)
    if mode != _PREDICT:
        loss = losses.mean_squared_error(labels, predictions)
        optimizer = tf.tpu.CrossShardOptimizer(
            training.GradientDescentOptimizer(learning_rate=0.5))
        train_op = optimizer.minimize(loss, training.get_global_step())
    return tpu_estimator.TPUEstimatorSpec(
        mode,
        loss=loss,
        train_op=train_op,
        predictions={'predictions': predictions},
        export_outputs={
            'test': export_output.PredictOutput({'prediction': predictions})
        })
Beispiel #47
0
    def _build_train_op(self, loss):
        """Creates the training operation,

        In case of use_target_network == True, we append also the update op
        while taking into account the update_frequency.
        """
        train_op = super(BaseQModel, self)._build_train_op(loss)

        # check if we need to update the target graph
        if self.use_target_graph:
            update_op = tf.cond(
                tf.equal(tf.mod(training.get_global_step(), self.target_update_frequency), 0),
                self._build_update_target_graph,
                lambda: tf.no_op(name='no_op_copy_target'))

            # append the target update op to the train op.
            train_op = tf.group(*[train_op, update_op], name='train_and_update_target')

        return train_op
Beispiel #48
0
  def _model_fn(features, labels, mode, config, params):
    """A Estimator `model_fn` for TPUEstimator."""
    model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, mode,
                                       train_batch_size)

    # TODO(jhseu): Move to EVAL and PREDICT to TPU.
    if not use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
      return model_fn_wrapper.call_without_tpu(features, labels)

    inputs = _InputsHolder(features=features, labels=labels,
                           num_shards=config.tpu_config.num_shards)

    dequeue_fn, enqueue_fn = _create_infeed_enqueue_ops_and_dequeue_fn(
        inputs, config)

    loss = _train_on_tpu_system(model_fn_wrapper, dequeue_fn)

    # Gets the variables back from TPU nodes. This means the variables updated
    # by TPU will now be *synced* to host memory.
    update_ops = [
        array_ops.check_numerics(v.read_value(),
                                 'Gradient for %s is NaN' % v.name).op
        for v in variables.trainable_variables()
    ]

    hooks = [
        TPUInfeedSessionHook(config, enqueue_fn),
        training.LoggingTensorHook(
            {'loss': array_ops.identity(loss),
             'step': training.get_global_step()},
            every_n_secs=30)
    ]

    return model_fn_lib.EstimatorSpec(
        mode,
        loss=array_ops.identity(loss),
        training_hooks=hooks,
        train_op=control_flow_ops.group(*update_ops))
 def get_grad_multiplier(self):
   if self._grad_multiplier_fn:
     return ops.convert_to_tensor(
         self._grad_multiplier_fn(training.get_global_step()),
         dtype=dtypes.float32)
Beispiel #50
0
def optimize_loss(loss,
                  global_step,
                  learning_rate,
                  optimizer,
                  gradient_noise_scale=None,
                  gradient_multipliers=None,
                  clip_gradients=None,
                  learning_rate_decay_fn=None,
                  update_ops=None,
                  variables=None,
                  name=None,
                  summaries=None,
                  colocate_gradients_with_ops=False,
                  increment_global_step=True):
  """Given loss and parameters for optimizer, returns a training op.

  Various ways of passing optimizers include:

  - by string specifying the name of the optimizer. See OPTIMIZER_CLS_NAMES
      for full list. E.g. `optimize_loss(..., optimizer='Adam')`.
  - by function taking learning rate `Tensor` as argument and returning an
      `Optimizer` instance. E.g. `optimize_loss(...,
      optimizer=lambda lr: tf.train.MomentumOptimizer(lr, momentum=0.5))`.
    Alternatively, if `learning_rate` is `None`, the function takes no
    arguments. E.g. `optimize_loss(..., learning_rate=None,
      optimizer=lambda: tf.train.MomentumOptimizer(0.5, momentum=0.5))`.
  - by a subclass of `Optimizer` having a single-argument constructor
      (the argument is the learning rate), such as AdamOptimizer or
      AdagradOptimizer. E.g. `optimize_loss(...,
      optimizer=tf.train.AdagradOptimizer)`.
  - by an instance of a subclass of `Optimizer`.
      E.g., `optimize_loss(..., optimizer=tf.train.AdagradOptimizer(0.5))`.

  Args:
    loss: Scalar `Tensor`.
    global_step: Scalar int `Tensor`, step counter to update on each step
                 unless `increment_global_step` is `False`. If not supplied,
                 it will be fetched from the default graph (see
                 `tf.train.get_global_step` for details). If it has
                 not been created, no step will be incremented with each weight
                 update. `learning_rate_decay_fn` requires `global_step`.
    learning_rate: float or `Tensor`, magnitude of update per each training
                   step. Can be `None`.
    optimizer: string, class or optimizer instance, used as trainer.
               string should be name of optimizer, like 'SGD',
                 'Adam', 'Adagrad'. Full list in OPTIMIZER_CLS_NAMES constant.
               class should be sub-class of `tf.Optimizer` that implements
                 `compute_gradients` and `apply_gradients` functions.
               optimizer instance should be instantiation of `tf.Optimizer`
                 sub-class and have `compute_gradients` and `apply_gradients`
                 functions.
    gradient_noise_scale: float or None, adds 0-mean normal noise scaled by this
                          value.
    gradient_multipliers: dict of variables or variable names to floats.
                          If present, gradients for specified
                          variables will be multiplied by given constant.
    clip_gradients: float, callable or `None`. If float, is provided, a global
      clipping is applied to prevent the norm of the gradient to exceed this
      value. Alternatively, a callable can be provided e.g.: adaptive_clipping.
      This callable takes a `list` of `(gradients, variables)` `tuple`s and
      returns the same thing with the gradients modified.
    learning_rate_decay_fn: function, takes `learning_rate` and `global_step`
                            `Tensor`s, returns `Tensor`.
                            Can be used to implement any learning rate decay
                            functions.
                            For example: `tf.train.exponential_decay`.
                            Ignored if `learning_rate` is not supplied.
    update_ops: list of update `Operation`s to execute at each step. If `None`,
                uses elements of UPDATE_OPS collection. The order of execution
                between `update_ops` and `loss` is non-deterministic.
    variables: list of variables to optimize or
               `None` to use all trainable variables.
    name: The name for this operation is used to scope operations and summaries.
    summaries: List of internal quantities to visualize on tensorboard. If not
               set, the loss, the learning rate, and the global norm of the
               gradients will be reported. The complete list of possible values
               is in OPTIMIZER_SUMMARIES.
    colocate_gradients_with_ops: If True, try colocating gradients with the
                                 corresponding op.
    increment_global_step: Whether to increment `global_step`. If your model
      calls `optimize_loss` multiple times per training step (e.g. to optimize
      different parts of the model), use this arg to avoid incrementing
      `global_step` more times than necessary.

  Returns:
    Training op.

  Raises:
    ValueError: if:
        * `loss` is an invalid type or shape.
        * `global_step` is an invalid type or shape.
        * `learning_rate` is an invalid type or value.
        * `optimizer` has the wrong type.
        * `clip_gradients` is neither float nor callable.
        * `learning_rate` and `learning_rate_decay_fn` are supplied, but no
          `global_step` is available.
        * `gradients` is empty.
  """
  loss = ops.convert_to_tensor(loss)
  contrib_framework.assert_scalar(loss)
  if global_step is None:
    global_step = train.get_global_step()
  else:
    train.assert_global_step(global_step)
  with vs.variable_scope(name, "OptimizeLoss", [loss, global_step]):
    # Update ops take UPDATE_OPS collection if not provided.
    if update_ops is None:
      update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
    # Make sure update ops are ran before computing loss.
    if update_ops:
      loss = control_flow_ops.with_dependencies(list(update_ops), loss)

    # Learning rate variable, with possible decay.
    lr = None
    if learning_rate is not None:
      if (isinstance(learning_rate, ops.Tensor) and
          learning_rate.get_shape().ndims == 0):
        lr = learning_rate
      elif isinstance(learning_rate, float):
        if learning_rate < 0.0:
          raise ValueError("Invalid learning_rate %s.", learning_rate)
        lr = vs.get_variable(
            "learning_rate", [],
            trainable=False,
            initializer=init_ops.constant_initializer(learning_rate))
      else:
        raise ValueError("Learning rate should be 0d Tensor or float. "
                         "Got %s of type %s" % (str(learning_rate),
                                                str(type(learning_rate))))
    if summaries is None:
      summaries = ["loss", "learning_rate", "global_gradient_norm"]
    else:
      for summ in summaries:
        if summ not in OPTIMIZER_SUMMARIES:
          raise ValueError("Summaries should be one of [%s], you provided %s." %
                           (", ".join(OPTIMIZER_SUMMARIES), summ))
    if learning_rate is not None and learning_rate_decay_fn is not None:
      if global_step is None:
        raise ValueError("global_step is required for learning_rate_decay_fn.")
      lr = learning_rate_decay_fn(lr, global_step)
      if "learning_rate" in summaries:
        summary.scalar("learning_rate", lr)

    # Create optimizer, given specified parameters.
    if isinstance(optimizer, six.string_types):
      if lr is None:
        raise ValueError("Learning rate is None, but should be specified if "
                         "optimizer is string (%s)." % optimizer)
      if optimizer not in OPTIMIZER_CLS_NAMES:
        raise ValueError(
            "Optimizer name should be one of [%s], you provided %s." %
            (", ".join(OPTIMIZER_CLS_NAMES), optimizer))
      opt = OPTIMIZER_CLS_NAMES[optimizer](learning_rate=lr)
    elif (isinstance(optimizer, type) and
          issubclass(optimizer, optimizer_.Optimizer)):
      if lr is None:
        raise ValueError("Learning rate is None, but should be specified if "
                         "optimizer is class (%s)." % optimizer)
      opt = optimizer(learning_rate=lr)
    elif isinstance(optimizer, optimizer_.Optimizer):
      opt = optimizer
    elif callable(optimizer):
      if learning_rate is not None:
        opt = optimizer(lr)
      else:
        opt = optimizer()
      if not isinstance(opt, optimizer_.Optimizer):
        raise ValueError("Unrecognized optimizer: function should return "
                         "subclass of Optimizer. Got %s." % str(opt))
    else:
      raise ValueError("Unrecognized optimizer: should be string, "
                       "subclass of Optimizer, instance of "
                       "subclass of Optimizer or function with one argument. "
                       "Got %s." % str(optimizer))

    # All trainable variables, if specific variables are not specified.
    if variables is None:
      variables = vars_.trainable_variables()

    # Compute gradients.
    gradients = opt.compute_gradients(
        loss,
        variables,
        colocate_gradients_with_ops=colocate_gradients_with_ops)

    # Optionally add gradient noise.
    if gradient_noise_scale is not None:
      gradients = _add_scaled_noise_to_gradients(gradients,
                                                 gradient_noise_scale)

    # Multiply some gradients.
    if gradient_multipliers is not None:
      gradients = _multiply_gradients(gradients, gradient_multipliers)
      if not gradients:
        raise ValueError(
            "Empty list of (gradient, var) pairs encountered. This is most "
            "likely to be caused by an improper value of gradient_multipliers.")

    if "global_gradient_norm" in summaries or "gradient_norm" in summaries:
      summary.scalar("global_norm/gradient_norm",
                     clip_ops.global_norm(list(zip(*gradients))[0]))

    # Optionally clip gradients by global norm.
    if isinstance(clip_gradients, float):
      gradients = _clip_gradients_by_norm(gradients, clip_gradients)
    elif callable(clip_gradients):
      gradients = clip_gradients(gradients)
    elif clip_gradients is not None:
      raise ValueError(
          "Unknown type %s for clip_gradients" % type(clip_gradients))

    # Add scalar summary for loss.
    if "loss" in summaries:
      summary.scalar("loss", loss)

    # Add histograms for variables, gradients and gradient norms.
    for gradient, variable in gradients:
      if isinstance(gradient, ops.IndexedSlices):
        grad_values = gradient.values
      else:
        grad_values = gradient

      if grad_values is not None:
        var_name = variable.name.replace(":", "_")
        if "gradients" in summaries:
          summary.histogram("gradients/%s" % var_name, grad_values)
        if "gradient_norm" in summaries:
          summary.scalar("gradient_norm/%s" % var_name,
                         clip_ops.global_norm([grad_values]))

    if clip_gradients is not None and ("global_gradient_norm" in summaries or
                                       "gradient_norm" in summaries):
      summary.scalar("global_norm/clipped_gradient_norm",
                     clip_ops.global_norm(list(zip(*gradients))[0]))

    # Create gradient updates.
    grad_updates = opt.apply_gradients(
        gradients,
        global_step=global_step if increment_global_step else None,
        name="train")

    # Ensure the train_tensor computes grad_updates.
    train_tensor = control_flow_ops.with_dependencies([grad_updates], loss)

    return train_tensor
  def testSaveAndLoadSavedModelExport(
      self, model_builder, uses_learning_phase, optimizer, train_before_export):
    saved_model_path = self._save_model_dir()
    with self.session(graph=ops.Graph()):
      np.random.seed(130)
      input_arr = np.random.random((1, 3))
      target_arr = np.random.random((1, 3))

      model = model_builder(uses_learning_phase)
      if optimizer is not None:
        model.compile(
            loss='mse',
            optimizer=optimizer,
            metrics=['mae'])
        if train_before_export:
          model.train_on_batch(input_arr, target_arr)

        ref_loss, ref_mae = model.evaluate(input_arr, target_arr)

      ref_predict = model.predict(input_arr)

      # Export SavedModel
      output_path = keras_saved_model.save_keras_model(model, saved_model_path)

    input_name = model.input_names[0]
    output_name = model.output_names[0]
    target_name = output_name + '_target'

    # Load predict graph, and test predictions
    with session.Session(graph=ops.Graph()) as sess:
      inputs, outputs, _ = load_model(sess, output_path,
                                      model_fn_lib.ModeKeys.PREDICT)

      predictions = sess.run(outputs[output_name],
                             {inputs[input_name]: input_arr})
      self.assertAllClose(ref_predict, predictions, atol=1e-05)

    if optimizer:
      # Load eval graph, and test predictions, loss and metric values
      with session.Session(graph=ops.Graph()) as sess:
        inputs, outputs, _ = load_model(sess, output_path,
                                        model_fn_lib.ModeKeys.EVAL)

        # First obtain the loss and predictions, and run the metric update op by
        # feeding in the inputs and targets.
        loss, predictions, _ = sess.run(
            (outputs['loss'], outputs['predictions/' + output_name],
             outputs['metrics/mean_absolute_error/update_op']), {
                 inputs[input_name]: input_arr,
                 inputs[target_name]: target_arr
             })

        # The metric value should be run after the update op, to ensure that it
        # reflects the correct value.
        metric_value = sess.run(outputs['metrics/mean_absolute_error/value'])

        self.assertEqual(int(train_before_export),
                         sess.run(training_module.get_global_step()))
        self.assertAllClose(ref_loss, loss, atol=1e-05)
        self.assertAllClose(ref_mae, metric_value, atol=1e-05)
        self.assertAllClose(ref_predict, predictions, atol=1e-05)

      # Load train graph, and check for the train op, and prediction values
      with session.Session(graph=ops.Graph()) as sess:
        inputs, outputs, meta_graph_def = load_model(
            sess, output_path, model_fn_lib.ModeKeys.TRAIN)
        self.assertEqual(int(train_before_export),
                         sess.run(training_module.get_global_step()))
        self.assertIn('loss', outputs)
        self.assertIn('metrics/mean_absolute_error/update_op', outputs)
        self.assertIn('metrics/mean_absolute_error/value', outputs)
        self.assertIn('predictions/' + output_name, outputs)

        # Train for a step
        train_op = loader_impl.get_train_op(meta_graph_def)
        train_outputs, _ = sess.run(
            [outputs, train_op], {inputs[input_name]: input_arr,
                                  inputs[target_name]: target_arr})
        self.assertEqual(int(train_before_export) + 1,
                         sess.run(training_module.get_global_step()))

        if uses_learning_phase:
          self.assertAllClose(
              [[0, 0, 0]], train_outputs['predictions/' + output_name],
              atol=1e-05)
        else:
          self.assertNotAllClose(
              [[0, 0, 0]], train_outputs['predictions/' + output_name],
              atol=1e-05)
  def testSaveAndLoadSavedModelExport(
      self, model_builder, uses_learning_phase, optimizer, train_before_export):
    saved_model_path = self._save_model_dir()
    with self.session(graph=ops.Graph()):
      input_arr = np.random.random((1, 3))
      target_arr = np.random.random((1, 3))

      model = model_builder(uses_learning_phase)
      if optimizer is not None:
        model.compile(
            loss='mse',
            optimizer=optimizer,
            metrics=['mae'])
        if train_before_export:
          model.train_on_batch(input_arr, target_arr)

        ref_loss, ref_mae = model.evaluate(input_arr, target_arr)

      ref_predict = model.predict(input_arr)

      # Export SavedModel
      output_path = keras_saved_model.save_keras_model(model, saved_model_path)

    input_name = model.input_names[0]
    output_name = model.output_names[0]
    target_name = output_name + '_target'

    # Load predict graph, and test predictions
    with session.Session(graph=ops.Graph()) as sess:
      inputs, outputs = load_model(sess, output_path,
                                   model_fn_lib.ModeKeys.PREDICT)

      predictions = sess.run(outputs[output_name],
                             {inputs[input_name]: input_arr})
      self.assertAllClose(ref_predict, predictions, atol=1e-05)

    if optimizer:
      # Load eval graph, and test predictions, loss and metric values
      with session.Session(graph=ops.Graph()) as sess:
        inputs, outputs = load_model(sess, output_path,
                                     model_fn_lib.ModeKeys.EVAL)

        eval_results = sess.run(outputs, {inputs[input_name]: input_arr,
                                          inputs[target_name]: target_arr})

        self.assertEqual(int(train_before_export),
                         sess.run(training_module.get_global_step()))
        self.assertAllClose(ref_loss, eval_results['loss'], atol=1e-05)
        self.assertAllClose(
            ref_mae, eval_results['metrics/mae/update_op'], atol=1e-05)
        self.assertAllClose(
            ref_predict, eval_results['predictions/' + output_name], atol=1e-05)

      # Load train graph, and check for the train op, and prediction values
      with session.Session(graph=ops.Graph()) as sess:
        inputs, outputs = load_model(sess, output_path,
                                     model_fn_lib.ModeKeys.TRAIN)
        self.assertEqual(int(train_before_export),
                         sess.run(training_module.get_global_step()))
        self.assertIn('loss', outputs)
        self.assertIn('metrics/mae/update_op', outputs)
        self.assertIn('metrics/mae/value', outputs)
        self.assertIn('predictions/' + output_name, outputs)

        # Train for a step
        train_op = ops.get_collection(constants.TRAIN_OP_KEY)
        train_outputs, _ = sess.run(
            [outputs, train_op], {inputs[input_name]: input_arr,
                                  inputs[target_name]: target_arr})
        self.assertEqual(int(train_before_export) + 1,
                         sess.run(training_module.get_global_step()))

        if uses_learning_phase:
          self.assertAllClose(
              [[0, 0, 0]], train_outputs['predictions/' + output_name],
              atol=1e-05)
        else:
          self.assertNotAllClose(
              [[0, 0, 0]], train_outputs['predictions/' + output_name],
              atol=1e-05)