コード例 #1
0
ファイル: svm.py プロジェクト: zxie/tensorflow
  def evaluate(self, x=None, y=None, input_fn=None, feed_fn=None,
               batch_size=None, steps=None, metrics=None, name=None):
    """See evaluable.Evaluable."""
    if not metrics:
      metrics = {}
      metrics["accuracy"] = metric_spec.MetricSpec(
          metric_fn=metrics_lib.streaming_accuracy,
          prediction_key=linear._CLASSES)
    additional_metrics = (
        target_column.get_default_binary_metrics_for_eval([0.5]))
    additional_metrics = {
        name: metric_spec.MetricSpec(metric_fn=metric,
                                     prediction_key=linear._LOGISTIC)
        for name, metric in additional_metrics.items()
    }
    metrics.update(additional_metrics)

    # TODO(b/31229024): Remove this loop
    for metric_name, metric in metrics.items():
      if isinstance(metric, metric_spec.MetricSpec):
        continue

      if isinstance(metric_name, tuple):
        if len(metric_name) != 2:
          raise ValueError("Ignoring metric %s. It returned a tuple with len  "
                           "%s, expected 2." % (metric_name, len(metric_name)))

        valid_keys = {linear._CLASSES, linear._LOGISTIC, linear._PROBABILITIES}
        if metric_name[1] not in valid_keys:
          raise ValueError("Ignoring metric %s. The 2nd element of its name "
                           "should be in %s" % (metric_name, valid_keys))
      metrics[metric_name] = linear._wrap_metric(metric)
    return self._estimator.evaluate(x=x, y=y, input_fn=input_fn,
                                    feed_fn=feed_fn, batch_size=batch_size,
                                    steps=steps, metrics=metrics, name=name)
コード例 #2
0
ファイル: linear.py プロジェクト: MrRabbit0o0/tensorflow
  def evaluate(self, x=None, y=None, input_fn=None, feed_fn=None,
               batch_size=None, steps=None, metrics=None, name=None):
    """See evaluable.Evaluable."""
    if not metrics:
      metrics = {}
      metrics[("accuracy", _CLASSES)] = metrics_lib.streaming_accuracy
    if self._n_classes == 2:
      additional_metrics = (
          target_column.get_default_binary_metrics_for_eval([0.5]))
      additional_metrics = {(name, _LOGISTIC): metric
                            for name, metric in additional_metrics.items()}
      metrics.update(additional_metrics)
    for metric_name, metric in metrics.items():
      if isinstance(metric_name, tuple):
        if len(metric_name) != 2:
          raise ValueError("Ignoring metric %s. It returned a tuple with len  "
                           "%s, expected 2." % (metric_name, len(metric_name)))

        valid_keys = {_CLASSES, _LOGISTIC, _PROBABILITIES}
        if metric_name[1] not in valid_keys:
          raise ValueError("Ignoring metric %s. The 2nd element of its name "
                           "should be in %s" % (metric_name, valid_keys))
      elif isinstance(metric_name, str):
        metrics.pop(metric_name)
        metric_name = (metric_name, _CLASSES)
      else:
        raise ValueError("Ignoring metric %s. Its name is not in the correct "
                         "form." % metric_name)
      metrics[metric_name] = _wrap_metric(metric)

    return self._estimator.evaluate(x=x, y=y, input_fn=input_fn,
                                    feed_fn=feed_fn, batch_size=batch_size,
                                    steps=steps, metrics=metrics, name=name)
コード例 #3
0
ファイル: linear.py プロジェクト: zjy-ucas/tensorflow
    def evaluate(self,
                 x=None,
                 y=None,
                 input_fn=None,
                 feed_fn=None,
                 batch_size=None,
                 steps=None,
                 metrics=None,
                 name=None):
        """See evaluable.Evaluable."""
        if not metrics:
            metrics = {}
            metrics[("accuracy", _CLASSES)] = metrics_lib.streaming_accuracy
        if self._n_classes == 2:
            additional_metrics = (
                target_column.get_default_binary_metrics_for_eval([0.5]))
            additional_metrics = {
                (name, _LOGISTIC): metric
                for name, metric in additional_metrics.items()
            }
            metrics.update(additional_metrics)
        for metric_name, metric in metrics.items():
            if isinstance(metric_name, tuple):
                if len(metric_name) != 2:
                    raise ValueError(
                        "Ignoring metric %s. It returned a tuple with len  "
                        "%s, expected 2." % (metric_name, len(metric_name)))

                valid_keys = {_CLASSES, _LOGISTIC, _PROBABILITIES}
                if metric_name[1] not in valid_keys:
                    raise ValueError(
                        "Ignoring metric %s. The 2nd element of its name "
                        "should be in %s" % (metric_name, valid_keys))
            elif isinstance(metric_name, str):
                metrics.pop(metric_name)
                metric_name = (metric_name, _CLASSES)
            else:
                raise ValueError(
                    "Ignoring metric %s. Its name is not in the correct "
                    "form." % metric_name)
            metrics[metric_name] = _wrap_metric(metric)

        return self._estimator.evaluate(x=x,
                                        y=y,
                                        input_fn=input_fn,
                                        feed_fn=feed_fn,
                                        batch_size=batch_size,
                                        steps=steps,
                                        metrics=metrics,
                                        name=name)