Exemplo n.º 1
0
  def test_read_eval_metrics_when_no_events(self):
    eval_dir = tempfile.mkdtemp()
    self.assertTrue(os.path.exists(eval_dir))

    # No error should be raised when eval directory exists with no event files.
    self.assertEqual({}, early_stopping.read_eval_metrics(eval_dir))

    os.rmdir(eval_dir)
    self.assertFalse(os.path.exists(eval_dir))

    # No error should be raised when eval directory does not exist.
    self.assertEqual({}, early_stopping.read_eval_metrics(eval_dir))
Exemplo n.º 2
0
 def test_read_eval_metrics(self):
   eval_dir = tempfile.mkdtemp()
   _write_events(
       eval_dir,
       [
           # steps, loss, accuracy
           (1000, 1, 2),
           (2000, 3, 4),
           (3000, 5, 6),
       ])
   self.assertEqual(
       {
           1000: {
               'loss': 1,
               'accuracy': 2
           },
           2000: {
               'loss': 3,
               'accuracy': 4
           },
           3000: {
               'loss': 5,
               'accuracy': 6
           },
       }, early_stopping.read_eval_metrics(eval_dir))
Exemplo n.º 3
0
    def test_data_loss_error_ignored(self):
        eval_dir = tempfile.mkdtemp()
        _write_events(
            eval_dir,
            [
                # steps, loss, accuracy
                (1000, 1, 2),
                (2000, 3, 4),
                (3000, 5, 6),
            ])

        orig_tf_train_summary_iterator = tf.compat.v1.train.summary_iterator

        def _summary_iterator(*args, **kwargs):
            for event in orig_tf_train_summary_iterator(*args, **kwargs):
                yield event
                # Raise an error for one of the files after yielding a summary event.
                if event.HasField('summary'):
                    raise tf.errors.DataLossError(None, None,
                                                  'testing data loss')

        with mock.patch.object(tf.compat.v1.train,
                               'summary_iterator') as mock_summary_iterator:
            mock_summary_iterator.side_effect = _summary_iterator
            eval_results = early_stopping.read_eval_metrics(eval_dir)

        self.assertEqual({1000: {'loss': 1, 'accuracy': 2}}, eval_results)
Exemplo n.º 4
0
    def after_run(
        self,
        run_context: "tf.estimator.SessionRunContext",
        run_values: "tf.estimator.SessionRunValues",
    ) -> None:

        global_step = run_values.results
        # Get eval metrics every n steps.
        if self._timer.should_trigger_for_step(global_step):
            self._timer.update_last_triggered_step(global_step)
            eval_metrics = read_eval_metrics(self._estimator.eval_dir())
        else:
            eval_metrics = None
        if eval_metrics:
            summary_step = next(reversed(eval_metrics))
            latest_eval_metrics = eval_metrics[summary_step]
            # If there exists a new evaluation summary.
            if summary_step > self._current_summary_step:
                current_score = latest_eval_metrics[self._metric]
                if current_score is None:
                    current_score = float("nan")
                self._trial.report(float(current_score), step=summary_step)
                self._current_summary_step = summary_step
            if self._trial.should_prune():
                message = "Trial was pruned at iteration {}.".format(self._current_summary_step)
                raise optuna.TrialPruned(message)
Exemplo n.º 5
0
    def after_run(self, run_context, run_values):
        # type: (tf.train.SessionRunContext, tf.train.SessionRunValues) -> None

        global_step = run_values.results
        # Get eval metrics every n steps.
        if self.timer.should_trigger_for_step(global_step):
            self.timer.update_last_triggered_step(global_step)
            eval_metrics = read_eval_metrics(self.estimator.eval_dir())
        else:
            eval_metrics = None
        if eval_metrics:
            summary_step = next(reversed(eval_metrics))
            latest_eval_metrics = eval_metrics[summary_step]
            # If there exists a new evaluation summary.
            if summary_step > self.current_summary_step:
                current_score = latest_eval_metrics[self.metric]
                self.trial.report(current_score, step=summary_step)
                self.current_summary_step = summary_step
            if self.trial.should_prune():
                message = "Trial was pruned at iteration {}.".format(
                    self.current_summary_step)
                raise optuna.structs.TrialPruned(message)