Example #1
0
    def testWithEpochLimit(self):
        predictions_limited = input.limit_epochs(self._predictions,
                                                 num_epochs=1)
        labels_limited = input.limit_epochs(self._labels, num_epochs=1)

        value_op, update_op = metrics.accuracy(labels=labels_limited,
                                               predictions=predictions_limited)

        init_op = control_flow_ops.group(
            variables.global_variables_initializer(),
            variables.local_variables_initializer())
        # Create checkpoint and log directories:
        chkpt_dir = os.path.join(self.get_temp_dir(), 'tmp_logs/')
        gfile.MakeDirs(chkpt_dir)
        logdir = os.path.join(self.get_temp_dir(), 'tmp_logs2/')
        gfile.MakeDirs(logdir)

        # Save initialized variables to a checkpoint directory:
        saver = saver_lib.Saver()
        with self.cached_session() as sess:
            init_op.run()
            saver.save(sess, os.path.join(chkpt_dir, 'chkpt'))

        # Now, run the evaluation loop:
        accuracy_value = evaluation.evaluation_loop(
            '',
            chkpt_dir,
            logdir,
            eval_op=update_op,
            final_op=value_op,
            max_number_of_evaluations=1,
            num_evals=10000)
        self.assertAlmostEqual(accuracy_value, self._expected_accuracy)
Example #2
0
    def testTimeoutFnOnEvaluationLoop(self):
        # We require a mutable object (e.g. list but not an int) to maintain state
        # across calls of a nested function.
        timeout_fn_calls = [0]

        def _TimeoutFn():
            timeout_fn_calls[0] += 1
            return timeout_fn_calls[0] >= 3

        # Need not do any evaluation, but should just call timeout_fn repeatedly.
        evaluation.evaluation_loop('',
                                   '',
                                   '',
                                   timeout=0,
                                   timeout_fn=_TimeoutFn)
        self.assertEqual(timeout_fn_calls[0], 3)
Example #3
0
    def testFinalOpsOnEvaluationLoop(self):
        value_op, update_op = metrics.accuracy(labels=self._labels,
                                               predictions=self._predictions)
        init_op = control_flow_ops.group(
            variables.global_variables_initializer(),
            variables.local_variables_initializer())
        # Create checkpoint and log directories:
        chkpt_dir = os.path.join(self.get_temp_dir(), 'tmp_logs/')
        gfile.MakeDirs(chkpt_dir)
        logdir = os.path.join(self.get_temp_dir(), 'tmp_logs2/')
        gfile.MakeDirs(logdir)

        # Save initialized variables to a checkpoint directory:
        saver = saver_lib.Saver()
        with self.cached_session() as sess:
            init_op.run()
            saver.save(sess, os.path.join(chkpt_dir, 'chkpt'))

        class Object(object):
            def __init__(self):
                self.hook_was_run = False

        obj = Object()

        # Create a custom session run hook.
        class CustomHook(session_run_hook.SessionRunHook):
            def __init__(self, obj):
                self.obj = obj

            def end(self, session):
                self.obj.hook_was_run = True

        # Now, run the evaluation loop:
        accuracy_value = evaluation.evaluation_loop(
            '',
            chkpt_dir,
            logdir,
            eval_op=update_op,
            final_op=value_op,
            hooks=[CustomHook(obj)],
            max_number_of_evaluations=1)
        self.assertAlmostEqual(accuracy_value, self._expected_accuracy)

        # Validate that custom hook ran.
        self.assertTrue(obj.hook_was_run)