Ejemplo n.º 1
0
def get_base_metric_name(metric, weighted=False):
  """Returns the metric name given the metric function.

  Arguments:
      metric: Metric function name or reference.
      weighted: Boolean indicating if the metric for which we are adding
          names is weighted.

  Returns:
      a metric name.
  """
  metric_name_prefix = 'weighted_' if weighted else ''
  if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
    if metric in ('accuracy', 'acc'):
      suffix = 'acc'
    elif metric in ('crossentropy', 'ce'):
      suffix = 'ce'
    metric_name = metric_name_prefix + suffix
  else:
    metric_fn = metrics_module.get(metric)
    # Get metric name as string
    if hasattr(metric_fn, 'name'):
      metric_name = metric_fn.name
    else:
      metric_name = metric_fn.__name__
    metric_name = metric_name_prefix + metric_name

  return metric_name
Ejemplo n.º 2
0
def get_base_metric_name(metric, weighted=False):
    """Returns the metric name given the metric function.

  Arguments:
      metric: Metric function name or reference.
      weighted: Boolean indicating if the metric for which we are adding
          names is weighted.

  Returns:
      a metric name.
  """
    metric_name_prefix = 'weighted_' if weighted else ''
    if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
        if metric in ('accuracy', 'acc'):
            suffix = 'acc'
        elif metric in ('crossentropy', 'ce'):
            suffix = 'ce'
        metric_name = metric_name_prefix + suffix
    else:
        metric_fn = metrics_module.get(metric)
        # Get metric name as string
        if hasattr(metric_fn, 'name'):
            metric_name = metric_fn.name
        else:
            metric_name = metric_fn.__name__
        metric_name = metric_name_prefix + metric_name

    return metric_name
Ejemplo n.º 3
0
def _get_metrics_info(metric, internal_output_shapes=None, loss_func=None):
  if metric == 'accuracy' or metric == 'acc':
    # custom handling of accuracy
    # (because of class mode duality)
    output_shape = internal_output_shapes
    if output_shape[-1] == 1 or loss_func == losses.binary_crossentropy:
      # case: binary accuracy
      acc_fn = metrics_module.binary_accuracy
    elif loss_func == losses.sparse_categorical_crossentropy:
      # case: categorical accuracy with sparse targets
      acc_fn = metrics_module.sparse_categorical_accuracy
    else:
      acc_fn = metrics_module.categorical_accuracy

    metric_name = 'acc'
    return metric_name, acc_fn
  else:
    metric_fn = metrics_module.get(metric)
    metric_name = metric_fn.__name__
    return metric_name, metric_fn
Ejemplo n.º 4
0
def _get_metrics_info(metric, internal_output_shapes=None, loss_func=None):
    if metric == 'accuracy' or metric == 'acc':
        # custom handling of accuracy
        # (because of class mode duality)
        output_shape = internal_output_shapes
        if output_shape[-1] == 1 or loss_func == losses.binary_crossentropy:
            # case: binary accuracy
            acc_fn = metrics_module.binary_accuracy
        elif loss_func == losses.sparse_categorical_crossentropy:
            # case: categorical accuracy with sparse targets
            acc_fn = metrics_module.sparse_categorical_accuracy
        else:
            acc_fn = metrics_module.categorical_accuracy

        metric_name = 'acc'
        return metric_name, acc_fn
    else:
        metric_fn = metrics_module.get(metric)
        metric_name = metric_fn.__name__
        return metric_name, metric_fn