Ejemplo n.º 1
0
  def testEvaluatePerfectModel(self):
    if tf.executing_eagerly():
      # tf.metrics.accuracy is not supported when eager execution is enabled.
      return
    checkpoint_dir = os.path.join(self.get_temp_dir(),
                                  'evaluate_perfect_model_once')

    # Train a Model to completion:
    self._train_model(checkpoint_dir, num_steps=300)

    # Run
    inputs = tf.constant(self._inputs, dtype=tf.float32)
    labels = tf.constant(self._labels, dtype=tf.float32)
    logits = logistic_classifier(inputs)
    predictions = tf.round(logits)

    accuracy, update_op = tf.compat.v1.metrics.accuracy(
        predictions=predictions, labels=labels)

    checkpoint_path = evaluation.wait_for_new_checkpoint(checkpoint_dir)

    final_ops_values = evaluation.evaluate_once(
        checkpoint_path=checkpoint_path,
        eval_ops=update_op,
        final_ops={'accuracy': accuracy},
        hooks=[
            evaluation.StopAfterNEvalsHook(1),
        ])
    self.assertTrue(final_ops_values['accuracy'] > .99)
Ejemplo n.º 2
0
  def testEvalOpAndFinalOp(self):
    if tf.executing_eagerly():
      return
    checkpoint_dir = os.path.join(self.get_temp_dir(), 'eval_ops_and_final_ops')

    # Train a model for a single step to get a checkpoint.
    self._train_model(checkpoint_dir, num_steps=1)
    checkpoint_path = evaluation.wait_for_new_checkpoint(checkpoint_dir)

    # Create the model so we have something to restore.
    inputs = tf.constant(self._inputs, dtype=tf.float32)
    logistic_classifier(inputs)

    num_evals = 5
    final_increment = 9.0

    try:
      my_var = _local_variable(0.0, name='MyVar')
    except TypeError:  # `collections` doesn't exist in TF2.
      return
    eval_ops = tf.compat.v1.assign_add(my_var, 1.0)
    final_ops = tf.identity(my_var) + final_increment

    final_ops_values = evaluation.evaluate_once(
        checkpoint_path=checkpoint_path,
        eval_ops=eval_ops,
        final_ops={'value': final_ops},
        hooks=[
            evaluation.StopAfterNEvalsHook(num_evals),
        ])
    self.assertEqual(final_ops_values['value'], num_evals + final_increment)
Ejemplo n.º 3
0
  def testOnlyFinalOp(self):
    checkpoint_dir = os.path.join(self.get_temp_dir(), 'only_final_ops')

    # Train a model for a single step to get a checkpoint.
    self._train_model(checkpoint_dir, num_steps=1)
    checkpoint_path = evaluation.wait_for_new_checkpoint(checkpoint_dir)

    # Create the model so we have something to restore.
    inputs = tf.constant(self._inputs, dtype=tf.float32)
    logistic_classifier(inputs)

    final_increment = 9.0

    try:
      my_var = _local_variable(0.0, name='MyVar')
    except TypeError:  # `collections` doesn't exist in TF2.
      return
    final_ops = tf.identity(my_var) + final_increment

    final_ops_values = evaluation.evaluate_once(
        checkpoint_path=checkpoint_path, final_ops={'value': final_ops})
    self.assertEqual(final_ops_values['value'], final_increment)