コード例 #1
0
    def eval_child_model(self,
                         model,
                         data_loader,
                         mode,
                         only_noise_class=False):
        """Evaluate the child model.

    Args:
      model: image model that will be evaluated.
      data_loader: dataset object to extract eval data from.
      mode: will the model be evalled on train, val or test.
      only_noise_class: If True, evaluate the model only on examples from the
      noised class.

    Returns:
      Accuracy of the model on the specified dataset.
    """
        tf.logging.info('Evaluating child model in mode %s', mode)
        while True:
            try:
                with self._new_session(model):
                    accuracy, logit_norm_val, hidden_norm_val, cost = helper_utils.eval_child_model(
                        self.session, model, data_loader, mode,
                        only_noise_class)
                    tf.logging.info(
                        'Eval child model accuracy: {}'.format(accuracy))
                    # If epoch trained without raising the below errors, break
                    # from loop.
                    break
            except (tf.errors.AbortedError, tf.errors.UnavailableError) as e:
                tf.logging.info('Retryable error caught: %s.  Retrying.', e)

        return accuracy, logit_norm_val, hidden_norm_val, cost
コード例 #2
0
ファイル: train_cifar.py プロジェクト: 812864539/models
  def eval_child_model(self, model, data_loader, mode):
    """Evaluate the child model.

    Args:
      model: image model that will be evaluated.
      data_loader: dataset object to extract eval data from.
      mode: will the model be evalled on train, val or test.

    Returns:
      Accuracy of the model on the specified dataset.
    """
    tf.logging.info('Evaluating child model in mode %s', mode)
    while True:
      try:
        with self._new_session(model):
          accuracy = helper_utils.eval_child_model(
              self.session,
              model,
              data_loader,
              mode)
          tf.logging.info('Eval child model accuracy: {}'.format(accuracy))
          # If epoch trained without raising the below errors, break
          # from loop.
          break
      except (tf.errors.AbortedError, tf.errors.UnavailableError) as e:
        tf.logging.info('Retryable error caught: %s.  Retrying.', e)

    return accuracy