コード例 #1
0
  def get_metrics_from_events_dir(self):
    """Retrieves and aggregates metrics from Tensorboard Summary file.

    Returns:
      raw_metrics (dict): Key is Tensorboard Tag and value is a list of
        MetricPoint namedtuples.
      aggregated_metrics (dict): Key is metric name and value is a MetricPoint
        containing the aggregated value for that metric.
    """
    tags_to_ignore = set(
        self.metric_collection_config.get('tags_to_ignore', []))
    raw_metrics = metrics.read_metrics_from_events_dir(
        self.events_dir, tags_to_ignore)

    if not raw_metrics:
      self.logger.warning("No metrics found in {}".format(self.events_dir))
      return {}, {}

    default_aggregation_strategies = self.metric_collection_config.get(
        'default_aggregation_strategies')
    metric_to_aggregation_strategies = self.metric_collection_config.get(
        'metric_to_aggregation_strategies')
    try:
      final_metrics = metrics.aggregate_metrics(
          raw_metrics,
          default_aggregation_strategies,
          metric_to_aggregation_strategies
      )
    except ValueError as e:
      raise ValueError("Error during metric aggregation: {}".format(e))
    return raw_metrics, final_metrics
コード例 #2
0
    def get_metrics_from_events_dir(self, job_status_dict):
        """Retrieves and aggregates metrics from Tensorboard Summary file.

    Args:
      job_status_dict (dict): Should contain `job_status`, `start_time`,
        and `stop_time` as keys.

    Returns:
      final_metrics (dict): Key is metric name and value is a MetricPoint
        containing the aggregated value for that metric.
    """
        tags_to_ignore = set(
            self.metric_collection_config.get('tags_to_ignore', []))
        raw_metrics = metrics.read_metrics_from_events_dir(
            self.events_dir, tags_to_ignore)

        if not raw_metrics:
            self.logger.warning("No metrics found in {}".format(
                self.events_dir))
            return {}

        default_aggregation_strategies = self.metric_collection_config.get(
            'default_aggregation_strategies')
        metric_to_aggregation_strategies = self.metric_collection_config.get(
            'metric_to_aggregation_strategies')
        try:
            final_metrics = metrics.aggregate_metrics(
                raw_metrics, default_aggregation_strategies,
                metric_to_aggregation_strategies)
        except ValueError as e:
            raise ValueError("Error during metric aggregation: {}".format(e))

        start_time = job_status_dict['start_time']
        stop_time = job_status_dict['stop_time']
        final_metrics['total_wall_time'] = metrics.MetricPoint(
            stop_time - start_time, stop_time)

        tta_config = self.metric_collection_config.get('time_to_accuracy')
        # Compute time_to_accuracy if requested in the config.
        if tta_config:
            if 'accuracy_tag' not in tta_config or \
                'accuracy_threshold' not in tta_config:
                raise ValueError(
                    'Invalid `time_to_accuracy` portion of config. '
                    'See README for how to set up the config.')
            tag = tta_config['accuracy_tag']
            threshold = tta_config['accuracy_threshold']
            try:
                final_metrics['time_to_accuracy'] = metrics.time_to_accuracy(
                    raw_metrics, tag, threshold)
            except ValueError as e:
                raise ValueError(
                    'Error computing time to accuracy: {}'.format(e))

        return final_metrics
コード例 #3
0
    def test_aggregate_metrics_custom(self):
        raw_metrics = metrics.read_metrics_from_events_dir(self.temp_dir)
        final_metrics = metrics.aggregate_metrics(
            raw_metrics,
            default_strategies=['final', 'min', 'max'],
            metric_strategies={'accuracy': ['max']})

        # Remove wall time, since it's non-deterministic
        metric_to_value = {
            metric_name: point.metric_value
            for metric_name, point in final_metrics.items()
        }

        self.assertDictEqual(metric_to_value, {
            'foo_final': 2,
            'foo_min': 1,
            'foo_max': 2,
            'accuracy_max': .5,
        })
コード例 #4
0
    def test_aggregate_metrics_default_all(self):
        raw_metrics = metrics.read_metrics_from_events_dir(self.temp_dir)
        final_metrics = metrics.aggregate_metrics(
            raw_metrics, default_strategies=['final', 'min', 'max', 'average'])

        # Remove wall time, since it's non-deterministic
        metric_to_value = {
            metric_name: point.metric_value
            for metric_name, point in final_metrics.items()
        }

        self.assertDictEqual(
            metric_to_value, {
                'foo_final': 2,
                'foo_min': 1,
                'foo_max': 2,
                'foo_average': 1.5,
                'accuracy_final': .25,
                'accuracy_min': .125,
                'accuracy_max': .5,
                'accuracy_average': np.mean([.125, .25, .5]),
            })