def get_metrics_from_events_dir(self): """Retrieves and aggregates metrics from Tensorboard Summary file. Returns: raw_metrics (dict): Key is Tensorboard Tag and value is a list of MetricPoint namedtuples. aggregated_metrics (dict): Key is metric name and value is a MetricPoint containing the aggregated value for that metric. """ tags_to_ignore = set( self.metric_collection_config.get('tags_to_ignore', [])) raw_metrics = metrics.read_metrics_from_events_dir( self.events_dir, tags_to_ignore) if not raw_metrics: self.logger.warning("No metrics found in {}".format(self.events_dir)) return {}, {} default_aggregation_strategies = self.metric_collection_config.get( 'default_aggregation_strategies') metric_to_aggregation_strategies = self.metric_collection_config.get( 'metric_to_aggregation_strategies') try: final_metrics = metrics.aggregate_metrics( raw_metrics, default_aggregation_strategies, metric_to_aggregation_strategies ) except ValueError as e: raise ValueError("Error during metric aggregation: {}".format(e)) return raw_metrics, final_metrics
def get_metrics_from_events_dir(self, job_status_dict): """Retrieves and aggregates metrics from Tensorboard Summary file. Args: job_status_dict (dict): Should contain `job_status`, `start_time`, and `stop_time` as keys. Returns: final_metrics (dict): Key is metric name and value is a MetricPoint containing the aggregated value for that metric. """ tags_to_ignore = set( self.metric_collection_config.get('tags_to_ignore', [])) raw_metrics = metrics.read_metrics_from_events_dir( self.events_dir, tags_to_ignore) if not raw_metrics: self.logger.warning("No metrics found in {}".format( self.events_dir)) return {} default_aggregation_strategies = self.metric_collection_config.get( 'default_aggregation_strategies') metric_to_aggregation_strategies = self.metric_collection_config.get( 'metric_to_aggregation_strategies') try: final_metrics = metrics.aggregate_metrics( raw_metrics, default_aggregation_strategies, metric_to_aggregation_strategies) except ValueError as e: raise ValueError("Error during metric aggregation: {}".format(e)) start_time = job_status_dict['start_time'] stop_time = job_status_dict['stop_time'] final_metrics['total_wall_time'] = metrics.MetricPoint( stop_time - start_time, stop_time) tta_config = self.metric_collection_config.get('time_to_accuracy') # Compute time_to_accuracy if requested in the config. if tta_config: if 'accuracy_tag' not in tta_config or \ 'accuracy_threshold' not in tta_config: raise ValueError( 'Invalid `time_to_accuracy` portion of config. ' 'See README for how to set up the config.') tag = tta_config['accuracy_tag'] threshold = tta_config['accuracy_threshold'] try: final_metrics['time_to_accuracy'] = metrics.time_to_accuracy( raw_metrics, tag, threshold) except ValueError as e: raise ValueError( 'Error computing time to accuracy: {}'.format(e)) return final_metrics
def test_aggregate_metrics_custom(self): raw_metrics = metrics.read_metrics_from_events_dir(self.temp_dir) final_metrics = metrics.aggregate_metrics( raw_metrics, default_strategies=['final', 'min', 'max'], metric_strategies={'accuracy': ['max']}) # Remove wall time, since it's non-deterministic metric_to_value = { metric_name: point.metric_value for metric_name, point in final_metrics.items() } self.assertDictEqual(metric_to_value, { 'foo_final': 2, 'foo_min': 1, 'foo_max': 2, 'accuracy_max': .5, })
def test_aggregate_metrics_default_all(self): raw_metrics = metrics.read_metrics_from_events_dir(self.temp_dir) final_metrics = metrics.aggregate_metrics( raw_metrics, default_strategies=['final', 'min', 'max', 'average']) # Remove wall time, since it's non-deterministic metric_to_value = { metric_name: point.metric_value for metric_name, point in final_metrics.items() } self.assertDictEqual( metric_to_value, { 'foo_final': 2, 'foo_min': 1, 'foo_max': 2, 'foo_average': 1.5, 'accuracy_final': .25, 'accuracy_min': .125, 'accuracy_max': .5, 'accuracy_average': np.mean([.125, .25, .5]), })