Пример #1
0
 def testPredictionsFromDifferentGraph(self):
   with tf.Graph().as_default():
     predictions = {'loss': tf.constant(1.)}
   with tf.Graph().as_default(), self.cached_session():
     with self.assertRaisesRegexp(ValueError,
                                  'must be from the default graph'):
       model_fn.EstimatorSpec(
           mode=ModeKeys.EVAL, predictions=predictions, loss=tf.constant(1.))
Пример #2
0
 def testPredictionsNumber(self):
   with tf.Graph().as_default(), self.cached_session():
     with self.assertRaisesRegexp(TypeError,
                                  r'predictions\[number\] must be Tensor'):
       model_fn.EstimatorSpec(
           mode=ModeKeys.EVAL,
           predictions={'number': 1.},
           loss=tf.constant(1.))
Пример #3
0
 def testLossSparseTensor(self):
   with tf.Graph().as_default(), self.cached_session():
     loss = tf.sparse.SparseTensor(indices=[[0]], values=[0.], dense_shape=[1])
     with self.assertRaisesRegexp(TypeError, 'loss must be Tensor'):
       model_fn.EstimatorSpec(
           mode=ModeKeys.EVAL,
           predictions={'prediction': tf.constant(1.)},
           loss=loss)
Пример #4
0
 def testTrainOpFromDifferentGraph(self):
   with tf.Graph().as_default():
     train_op = tf.no_op()
   with tf.Graph().as_default(), self.cached_session():
     with self.assertRaisesRegexp(ValueError,
                                  'must be from the default graph'):
       model_fn.EstimatorSpec(
           mode=ModeKeys.TRAIN, loss=tf.constant(1.), train_op=train_op)
Пример #5
0
 def testPredictionHookInvalid(self):
   with tf.Graph().as_default(), self.cached_session():
     with self.assertRaisesRegexp(
         TypeError, 'All hooks must be SessionRunHook instances'):
       model_fn.EstimatorSpec(
           mode=ModeKeys.PREDICT,
           predictions=tf.constant(1.),
           prediction_hooks=[_InvalidHook()])
Пример #6
0
    def _assertDefaultExportOutputForPredictions(self, predictions):
        spec = model_fn.EstimatorSpec(mode=ModeKeys.PREDICT,
                                      predictions=predictions)

        expected = export_output.PredictOutput(predictions).outputs
        serving_output = spec.export_outputs[
            signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
        self.assertEqual(serving_output.outputs, expected)
Пример #7
0
def _model_fn_train_only(features, labels):
    v = tf.Variable(tf.constant(23), name='v')
    return model_fn_lib.EstimatorSpec(ModeKeys.TRAIN,
                                      predictions=features * labels,
                                      loss=tf.constant(5) + v,
                                      train_op=tf.compat.v1.assign_add(
                                          tf.compat.v1.train.get_global_step(),
                                          1))
Пример #8
0
 def testReplaceDoesReplace(self):
     with ops.Graph().as_default(), self.cached_session():
         loss = constant_op.constant(1.)
         spec = model_fn.EstimatorSpec(mode=ModeKeys.EVAL,
                                       predictions={'loss': loss},
                                       loss=loss)
         new_spec = spec._replace(predictions={'m': loss})
         self.assertEqual(['m'], list(new_spec.predictions.keys()))
Пример #9
0
def _model_fn_callable_variable_initializers(features, labels, mode):
  """Model_fn with callable variable initializers (for WrappedGraph tests)."""
  _ = features, labels
  v = tf.Variable(lambda: tf.constant(23), name='v')
  if mode == ModeKeys.PREDICT:
    return model_fn_lib.EstimatorSpec(
        ModeKeys.PREDICT, predictions=features + 1)
  elif mode == ModeKeys.EVAL:
    return model_fn_lib.EstimatorSpec(
        ModeKeys.EVAL, loss=tf.constant(5) + v, predictions=features + labels)
  elif mode == ModeKeys.TRAIN:
    return model_fn_lib.EstimatorSpec(
        ModeKeys.TRAIN,
        predictions=features * labels,
        loss=tf.constant(5) + v,
        train_op=tf.compat.v1.assign_add(tf.compat.v1.train.get_global_step(),
                                         1))
Пример #10
0
 def testTrainOpNotOperationAndTensor(self):
   with tf.Graph().as_default(), self.cached_session():
     with self.assertRaisesRegexp(TypeError,
                                  'train_op must be Operation or Tensor'):
       model_fn.EstimatorSpec(
           mode=ModeKeys.TRAIN,
           loss=tf.constant(1.),
           train_op='Not an Operation or Tensor')
Пример #11
0
 def testLossNumber(self):
     """Tests that error is raised when loss is a number (not Tensor)."""
     with ops.Graph().as_default(), self.cached_session():
         with self.assertRaisesRegexp(TypeError, 'loss must be Tensor'):
             model_fn.EstimatorSpec(
                 mode=ModeKeys.EVAL,
                 predictions={'loss': constant_op.constant(1.)},
                 loss=1.)
Пример #12
0
 def testReplaceRaisesConstructorChecks(self):
     with ops.Graph().as_default(), self.cached_session():
         loss = constant_op.constant(1.)
         spec = model_fn.EstimatorSpec(mode=ModeKeys.EVAL,
                                       predictions={'loss': loss},
                                       loss=loss)
         with self.assertRaisesRegexp(ValueError, 'Loss must be scalar'):
             spec._replace(loss=constant_op.constant([1., 2.]))
Пример #13
0
 def testScaffoldInvalid(self):
     with ops.Graph().as_default(), self.cached_session():
         with self.assertRaisesRegexp(
                 TypeError, r'scaffold must be tf\.train\.Scaffold'):
             model_fn.EstimatorSpec(mode=ModeKeys.TRAIN,
                                    loss=constant_op.constant(1.),
                                    train_op=control_flow_ops.no_op(),
                                    scaffold=_InvalidScaffold())
Пример #14
0
 def testTrainingHookInvalid(self):
     with tf.Graph().as_default(), self.cached_session():
         with self.assertRaisesRegexp(
                 TypeError, 'All hooks must be SessionRunHook instances'):
             model_fn.EstimatorSpec(mode=ModeKeys.TRAIN,
                                    loss=tf.constant(1.),
                                    train_op=tf.no_op(),
                                    training_hooks=[_InvalidHook()])
Пример #15
0
 def model_fn(features, labels, mode):
   _ = labels
   with ops.control_dependencies([features['x']]):
     loss = features['x'][1][0]
   return model_fn_lib.EstimatorSpec(
       mode,
       loss=loss,
       train_op=state_ops.assign_add(training.get_global_step(), 1))
Пример #16
0
  def model_fn(features, labels, mode):
    """model_fn for keras Estimator."""
    # Raise an error when users use DistributionStrategy with native Keras
    # optimizers. Currently we only support native TensorFlow optimizers.
    if distribution_strategy_context.has_distribution_strategy() and \
        not isinstance(keras_model.optimizer,
                       (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
      raise ValueError('Only TensorFlow native optimizers are supported with '
                       'DistributionStrategy.')

    model = _clone_and_build_model(mode, keras_model, custom_objects, features,
                                   labels)
    model_output_names = []
    # We need to make sure that the output names of the last layer in the model
    # is the same for each of the cloned models. This is required for mirrored
    # strategy when we call regroup.
    if distribution_strategy_context.has_distribution_strategy():
      for name in model.output_names:
        name = re.compile(r'_\d$').sub('', name)
        model_output_names.append(name)
    else:
      model_output_names = model.output_names

    # Get inputs to EstimatorSpec
    predictions = dict(zip(model_output_names, model.outputs))

    loss = None
    train_op = None
    eval_metric_ops = None

    # Set loss and metric only during train and evaluate.
    if mode is not model_fn_lib.ModeKeys.PREDICT:
      if mode is model_fn_lib.ModeKeys.TRAIN:
        model._make_train_function()  # pylint: disable=protected-access
      else:
        model._make_test_function()  # pylint: disable=protected-access
      loss = model.total_loss

      eval_metric_ops = _convert_keras_metrics_to_estimator(model)

    # Set train_op only during train.
    if mode is model_fn_lib.ModeKeys.TRAIN:
      train_op = model.train_function.updates_op

    if not model._is_graph_network:
      # Reset model state to original state,
      # to avoid `model_fn` being destructive for the initial model argument.
      models.in_place_subclassed_model_state_restoration(keras_model)
    return model_fn_lib.EstimatorSpec(
        mode=mode,
        predictions=predictions,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        export_outputs={
            _DEFAULT_SERVING_KEY:
            export_lib.export_output.PredictOutput(predictions)
        })
Пример #17
0
 def testExportOutputsNoDict(self):
   with tf.Graph().as_default(), self.cached_session():
     predictions = {'loss': tf.constant(1.)}
     classes = tf.constant('hello')
     with self.assertRaisesRegexp(TypeError, 'export_outputs must be dict'):
       model_fn.EstimatorSpec(
           mode=ModeKeys.PREDICT,
           predictions=predictions,
           export_outputs=export_output.ClassificationOutput(classes=classes))
 def testEvalMetricOpsNoDict(self):
     with ops.Graph().as_default(), self.cached_session():
         loss = constant_op.constant(1.)
         with self.assertRaisesRegexp(TypeError,
                                      'eval_metric_ops must be a dict'):
             model_fn.EstimatorSpec(mode=ModeKeys.EVAL,
                                    predictions={'loss': loss},
                                    loss=loss,
                                    eval_metric_ops=loss)
 def testLossFromDifferentGraph(self):
     with ops.Graph().as_default():
         loss = constant_op.constant(1.)
     with ops.Graph().as_default(), self.cached_session():
         with self.assertRaisesRegexp(ValueError,
                                      'must be from the default graph'):
             model_fn.EstimatorSpec(mode=ModeKeys.TRAIN,
                                    loss=loss,
                                    train_op=control_flow_ops.no_op())
Пример #20
0
def _scale_tower_loss(tower_spec, loss_reduction, number_of_towers):
    """Produce an EstimatorSpec with appropriately scaled loss."""
    if tower_spec.loss is None:
        return tower_spec

    estimator_spec = _asdict(tower_spec)
    estimator_spec['loss'] = _scale_loss(tower_spec.loss, loss_reduction,
                                         number_of_towers)
    return model_fn_lib.EstimatorSpec(**estimator_spec)
 def testLossSparseTensor(self):
     with ops.Graph().as_default(), self.cached_session():
         loss = sparse_tensor.SparseTensor(indices=[[0]],
                                           values=[0.],
                                           dense_shape=[1])
         with self.assertRaisesRegexp(TypeError, 'loss must be Tensor'):
             model_fn.EstimatorSpec(mode=ModeKeys.TRAIN,
                                    loss=loss,
                                    train_op=control_flow_ops.no_op())
Пример #22
0
 def model_fn(features, labels, mode):
   _, _ = features, labels
   return model_fn_lib.EstimatorSpec(
       mode,
       loss=constant_op.constant([103]),
       train_op=state_ops.assign_add(training.get_global_step(), 1),
       predictions=constant_op.constant([502]),
       export_outputs={'test': export_output.ClassificationOutput(
           constant_op.constant([[32.]]))})
Пример #23
0
 def testReplaceNotAllowModeChange(self):
   with ops.Graph().as_default(), self.cached_session():
     loss = constant_op.constant(1.)
     spec = model_fn.EstimatorSpec(
         mode=model_fn.ModeKeys.EVAL, predictions={'loss': loss}, loss=loss)
     spec._replace(mode=model_fn.ModeKeys.EVAL)
     with self.assertRaisesRegexp(ValueError,
                                  'mode of EstimatorSpec cannot be changed'):
       spec._replace(mode=model_fn.ModeKeys.TRAIN)
Пример #24
0
 def testTupleMetric(self):
   """Tests that no errors are raised when a metric is tuple-valued."""
   with ops.Graph().as_default(), self.cached_session():
     loss = constant_op.constant(1.)
     model_fn.EstimatorSpec(
         mode=model_fn.ModeKeys.EVAL,
         loss=loss,
         eval_metric_ops={
             'some_metric': ((loss, loss, (constant_op.constant(2), loss)),
                             control_flow_ops.no_op())})
Пример #25
0
 def testEvalMetricOpsNoTensorOrOperation(self):
     with ops.Graph().as_default(), self.cached_session():
         loss = constant_op.constant(1.)
         with self.assertRaisesRegexp(TypeError,
                                      'must be Operation or Tensor'):
             model_fn.EstimatorSpec(
                 mode=model_fn.ModeKeys.EVAL,
                 predictions={'loss': loss},
                 loss=loss,
                 eval_metric_ops={'loss': ('NonTensor', loss)})
 def testLossFromDifferentGraph(self):
     with ops.Graph().as_default():
         loss = constant_op.constant(1.)
     with ops.Graph().as_default(), self.cached_session():
         with self.assertRaisesRegexp(ValueError,
                                      'must be from the default graph'):
             model_fn.EstimatorSpec(
                 mode=ModeKeys.EVAL,
                 predictions={'prediction': constant_op.constant(1.)},
                 loss=loss)
Пример #27
0
 def model_fn(features, labels, mode):
   loss = None
   if labels is not None:
     loss = labels[0][0] + labels[1][0]
   return model_fn_lib.EstimatorSpec(
       mode,
       loss=loss,
       train_op=state_ops.assign_add(training.get_global_step(), 1),
       predictions={'features_0': array_ops.identity([features['x'][0][0]]),
                    'features_1': array_ops.identity([features['x'][1][0]])})
Пример #28
0
 def testPredictionsSparseTensor(self):
   with tf.Graph().as_default(), self.cached_session():
     predictions = {
         'sparse':
             tf.sparse.SparseTensor(
                 indices=[[0]], values=[0.], dense_shape=[1])
     }
     with self.assertRaisesRegexp(TypeError,
                                  r'predictions\[sparse\] must be Tensor'):
       model_fn.EstimatorSpec(mode=ModeKeys.PREDICT, predictions=predictions)
Пример #29
0
 def model_fn(features, labels, mode):
     _, _ = features, labels
     v = variables.Variable(21, name='some_var')
     scaffold = monitored_session.Scaffold(
         local_init_op=state_ops.assign_add(v, -3).op)
     return model_fn_lib.EstimatorSpec(mode,
                                       scaffold=scaffold,
                                       train_op=state_ops.assign_add(
                                           training.get_global_step(),
                                           1),
                                       loss=array_ops.identity(v))
 def testEvalMetricOpsWithoutUpdates(self):
     with ops.Graph().as_default():
         eval_metric_ops = {'mean': metrics.Mean()}
     with ops.Graph().as_default(), self.cached_session():
         loss = constant_op.constant(1.)
         with self.assertRaisesRegexp(ValueError,
                                      'Please call update_state(...)'):
             model_fn.EstimatorSpec(mode=ModeKeys.EVAL,
                                    predictions={'loss': loss},
                                    loss=loss,
                                    eval_metric_ops=eval_metric_ops)