Exemple #1
0
    def testAdditionalHooks(self):
        checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt')
        log_dir = os.path.join(self.get_temp_dir(), 'log_dir1/')

        # First, save out the current model to a checkpoint:
        self._prepareCheckpoint(checkpoint_path)

        # Next, determine the metric to evaluate:
        value_op, update_op = metrics.accuracy(labels=self._labels,
                                               predictions=self._predictions)

        dumping_root = os.path.join(self.get_temp_dir(), 'tfdbg_dump_dir')
        dumping_hook = hooks.DumpingDebugHook(dumping_root, log_usage=False)
        try:
            # Run the evaluation and verify the results:
            accuracy_value = evaluation.evaluate_once('',
                                                      checkpoint_path,
                                                      log_dir,
                                                      eval_op=update_op,
                                                      final_op=value_op,
                                                      hooks=[dumping_hook])
            self.assertAlmostEqual(accuracy_value, self._expected_accuracy)

            dump = debug_data.DebugDumpDir(
                glob.glob(os.path.join(dumping_root, 'run_*'))[0])
            # Here we simply assert that the dumped data has been loaded and is
            # non-empty. We do not care about the detailed model-internal tensors or
            # their values.
            self.assertTrue(dump.dumped_tensor_data)
        finally:
            if os.path.isdir(dumping_root):
                shutil.rmtree(dumping_root)
Exemple #2
0
    def testRestoredModelPerformance(self):
        checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt')
        log_dir = os.path.join(self.get_temp_dir(), 'log_dir1/')

        # First, save out the current model to a checkpoint:
        self._prepareCheckpoint(checkpoint_path)

        # Next, determine the metric to evaluate:
        value_op, update_op = metrics.accuracy(labels=self._labels,
                                               predictions=self._predictions)

        # Run the evaluation and verify the results:
        accuracy_value = evaluation.evaluate_once('',
                                                  checkpoint_path,
                                                  log_dir,
                                                  eval_op=update_op,
                                                  final_op=value_op)
        self.assertAlmostEqual(accuracy_value, self._expected_accuracy)
Exemple #3
0
 def testErrorRaisedIfCheckpointDoesntExist(self):
     checkpoint_path = os.path.join(self.get_temp_dir(),
                                    'this_file_doesnt_exist')
     log_dir = os.path.join(self.get_temp_dir(), 'error_raised')
     with self.assertRaises(ValueError):
         evaluation.evaluate_once('', checkpoint_path, log_dir)