Example #1
0
def _serialize_tf_loss(loss: tf.keras.losses.Loss) -> config.MetricConfig:
  """Serializes TF loss."""
  cfg = metric_util.serialize_loss(loss)
  return config.MetricConfig(
      class_name=cfg['class_name'],
      module=loss.__class__.__module__,
      config=json.dumps(cfg['config'], sort_keys=True))
Example #2
0
def _metric_keys_and_configs(
    metrics: Dict[Text, List[_TFMetricOrLoss]],
    model_name: Text,
    sub_key: Optional[metric_types.SubKey],
    aggregation_type: Optional[metric_types.AggregationType],
) -> Tuple[_KeysBySubKey, _ConfigsBySubKey, _ConfigsBySubKey]:
  """Returns metric keys, metric configs, and loss configs by sub key."""
  metric_keys = collections.defaultdict(list)
  metric_configs = collections.defaultdict(dict)
  loss_configs = collections.defaultdict(dict)
  for output_name, metrics_list in metrics.items():
    for metric in metrics_list:
      updated_sub_key = _verify_and_update_sub_key(model_name, output_name,
                                                   sub_key, metric)
      if output_name not in metric_configs[updated_sub_key]:
        metric_configs[updated_sub_key][output_name] = []
      if output_name not in loss_configs[updated_sub_key]:
        loss_configs[updated_sub_key][output_name] = []
      metric_keys[updated_sub_key].append(
          metric_types.MetricKey(
              name=metric.name,
              model_name=model_name,
              output_name=output_name,
              sub_key=updated_sub_key,
              aggregation_type=aggregation_type))
      if isinstance(metric, tf.keras.metrics.Metric):
        metric_configs[updated_sub_key][output_name].append(
            metric_util.serialize_metric(metric))
      elif isinstance(metric, tf.keras.losses.Loss):
        loss_configs[updated_sub_key][output_name].append(
            metric_util.serialize_loss(metric))
  return metric_keys, metric_configs, loss_configs
def _metric_keys_and_configs(
    metrics: Dict[Text, List[_TFMetricOrLoss]], model_name: Text,
    sub_key: Optional[metric_types.SubKey]
) -> Tuple[List[metric_types.MetricKey], Dict[Text, List[Dict[Text, Any]]],
           Dict[Text, List[Dict[Text, Any]]]]:
  """Returns the metric keys, metric configs, and loss configs for metrics."""
  metric_keys = []
  metric_configs = {}
  loss_configs = {}
  for output_name, metrics_list in metrics.items():
    metric_config_list = []
    loss_config_list = []
    for metric in metrics_list:
      metric_keys.append(
          metric_types.MetricKey(
              name=metric.name,
              model_name=model_name,
              output_name=output_name,
              sub_key=_verify_and_update_sub_key(model_name, output_name,
                                                 sub_key, metric)))
      if isinstance(metric, tf.keras.metrics.Metric):
        metric_config_list.append(metric_util.serialize_metric(metric))
      elif isinstance(metric, tf.keras.losses.Loss):
        loss_config_list.append(metric_util.serialize_loss(metric))

    metric_configs[output_name] = metric_config_list
    loss_configs[output_name] = loss_config_list
  return metric_keys, metric_configs, loss_configs
Example #4
0
def _private_tf_loss(loss: tf.keras.losses.Loss) -> tf.keras.losses.Loss:
  """Creates a private version of given loss."""
  cfg = metric_util.serialize_loss(loss)
  if not cfg['config']['name'].startswith('_'):
    cfg['config']['name'] = '_' + cfg['config']['name']
  with tf.keras.utils.custom_object_scope(
      {loss.__class__.__name__: loss.__class__}):
    return tf.keras.losses.deserialize(cfg)