def testWithEpochLimit(self): predictions_limited = input.limit_epochs(self._predictions, num_epochs=1) labels_limited = input.limit_epochs(self._labels, num_epochs=1) value_op, update_op = metrics.accuracy(labels=labels_limited, predictions=predictions_limited) init_op = control_flow_ops.group( variables.global_variables_initializer(), variables.local_variables_initializer()) # Create checkpoint and log directories: chkpt_dir = os.path.join(self.get_temp_dir(), 'tmp_logs/') gfile.MakeDirs(chkpt_dir) logdir = os.path.join(self.get_temp_dir(), 'tmp_logs2/') gfile.MakeDirs(logdir) # Save initialized variables to a checkpoint directory: saver = saver_lib.Saver() with self.cached_session() as sess: init_op.run() saver.save(sess, os.path.join(chkpt_dir, 'chkpt')) # Now, run the evaluation loop: accuracy_value = evaluation.evaluation_loop( '', chkpt_dir, logdir, eval_op=update_op, final_op=value_op, max_number_of_evaluations=1, num_evals=10000) self.assertAlmostEqual(accuracy_value, self._expected_accuracy)
def testTimeoutFnOnEvaluationLoop(self): # We require a mutable object (e.g. list but not an int) to maintain state # across calls of a nested function. timeout_fn_calls = [0] def _TimeoutFn(): timeout_fn_calls[0] += 1 return timeout_fn_calls[0] >= 3 # Need not do any evaluation, but should just call timeout_fn repeatedly. evaluation.evaluation_loop('', '', '', timeout=0, timeout_fn=_TimeoutFn) self.assertEqual(timeout_fn_calls[0], 3)
def testFinalOpsOnEvaluationLoop(self): value_op, update_op = metrics.accuracy(labels=self._labels, predictions=self._predictions) init_op = control_flow_ops.group( variables.global_variables_initializer(), variables.local_variables_initializer()) # Create checkpoint and log directories: chkpt_dir = os.path.join(self.get_temp_dir(), 'tmp_logs/') gfile.MakeDirs(chkpt_dir) logdir = os.path.join(self.get_temp_dir(), 'tmp_logs2/') gfile.MakeDirs(logdir) # Save initialized variables to a checkpoint directory: saver = saver_lib.Saver() with self.cached_session() as sess: init_op.run() saver.save(sess, os.path.join(chkpt_dir, 'chkpt')) class Object(object): def __init__(self): self.hook_was_run = False obj = Object() # Create a custom session run hook. class CustomHook(session_run_hook.SessionRunHook): def __init__(self, obj): self.obj = obj def end(self, session): self.obj.hook_was_run = True # Now, run the evaluation loop: accuracy_value = evaluation.evaluation_loop( '', chkpt_dir, logdir, eval_op=update_op, final_op=value_op, hooks=[CustomHook(obj)], max_number_of_evaluations=1) self.assertAlmostEqual(accuracy_value, self._expected_accuracy) # Validate that custom hook ran. self.assertTrue(obj.hook_was_run)