def _serialize_tf_metric( metric: tf.keras.metrics.Metric) -> config.MetricConfig: """Serializes TF metric.""" cfg = metric_util.serialize_metric(metric) return config.MetricConfig(class_name=cfg['class_name'], config=json.dumps(cfg['config'], sort_keys=True))
def _metric_keys_and_configs( metrics: Dict[Text, List[_TFMetricOrLoss]], model_name: Text, sub_key: Optional[metric_types.SubKey] ) -> Tuple[_KeysBySubKey, _ConfigsBySubKey, _ConfigsBySubKey]: """Returns metric keys, metric configs, and loss configs by sub key.""" metric_keys = collections.defaultdict(list) metric_configs = collections.defaultdict(dict) loss_configs = collections.defaultdict(dict) for output_name, metrics_list in metrics.items(): for metric in metrics_list: updated_sub_key = _verify_and_update_sub_key( model_name, output_name, sub_key, metric) if output_name not in metric_configs[updated_sub_key]: metric_configs[updated_sub_key][output_name] = [] if output_name not in loss_configs[updated_sub_key]: loss_configs[updated_sub_key][output_name] = [] metric_keys[updated_sub_key].append( metric_types.MetricKey(name=metric.name, model_name=model_name, output_name=output_name, sub_key=updated_sub_key)) if isinstance(metric, tf.keras.metrics.Metric): metric_configs[updated_sub_key][output_name].append( metric_util.serialize_metric(metric)) elif isinstance(metric, tf.keras.losses.Loss): loss_configs[updated_sub_key][output_name].append( metric_util.serialize_loss(metric)) return metric_keys, metric_configs, loss_configs
def _metric_keys_and_configs( metrics: Dict[Text, List[_TFMetricOrLoss]], model_name: Text, sub_key: Optional[metric_types.SubKey] ) -> Tuple[List[metric_types.MetricKey], Dict[Text, List[Dict[Text, Any]]], Dict[Text, List[Dict[Text, Any]]]]: """Returns the metric keys, metric configs, and loss configs for metrics.""" metric_keys = [] metric_configs = {} loss_configs = {} for output_name, metrics_list in metrics.items(): metric_config_list = [] loss_config_list = [] for metric in metrics_list: metric_keys.append( metric_types.MetricKey(name=metric.name, model_name=model_name, output_name=output_name, sub_key=_verify_and_update_sub_key( model_name, output_name, sub_key, metric))) if isinstance(metric, tf.keras.metrics.Metric): metric_config_list.append(metric_util.serialize_metric(metric)) elif isinstance(metric, tf.keras.losses.Loss): loss_config_list.append(metric_util.serialize_loss(metric)) metric_configs[output_name] = metric_config_list loss_configs[output_name] = loss_config_list return metric_keys, metric_configs, loss_configs
def _private_tf_metric( metric: tf.keras.metrics.Metric) -> tf.keras.metrics.Metric: """Creates a private version of given metric.""" cfg = metric_util.serialize_metric(metric) if not cfg['config']['name'].startswith('_'): cfg['config']['name'] = '_' + cfg['config']['name'] with tf.keras.utils.custom_object_scope( {metric.__class__.__name__: metric.__class__}): return tf.keras.metrics.deserialize(cfg)