Exemplo n.º 1
0
        def finalize_metric(metric: tf.keras.metrics.Metric, values):
            # Note: the following call requires that `type(metric)` have a no argument
            # __init__ method, which will restrict the types of metrics that can be
            # used. This is somewhat limiting, but the pattern to use default
            # arguments and export the values in `get_config()` (see
            # `tf.keras.metrics.TopKCategoricalAccuracy`) works well.
            keras_metric = None
            try:
                # This is some trickery to reconstruct a metric object in the current
                # scope, so that the `tf.Variable`s get created when we desire.
                keras_metric = type(metric).from_config(metric.get_config())
            except TypeError as e:
                # Re-raise the error with a more helpful message, but the previous stack
                # trace.
                raise TypeError(
                    'Caught exception trying to call `{t}.from_config()` with '
                    'config {c}. Confirm that {t}.__init__() has an argument for '
                    'each member of the config.\nException: {e}'.format(
                        t=type(metric), c=metric.config(), e=e))

            assignments = []
            for v, a in zip(keras_metric.variables, values):
                assignments.append(v.assign(a))
            with tf.control_dependencies(assignments):
                return keras_metric.result()
Exemplo n.º 2
0
def get_naive_forecasting_performance(dataset: tf.data.Dataset,
                                      loss: tf.keras.losses.Loss,
                                      metric: tf.keras.metrics.Metric) -> Tuple:
  metric.reset_states()
  losses = []
  for sequence, y_true in dataset:
    y_pred = sequence[..., -1]
    metric.update_state(y_true, y_pred)
    losses.append(loss(y_true, y_pred).numpy())
  return sum(losses) / len(losses), metric.result().numpy()
Exemplo n.º 3
0
def _check_keras_metric_config_constructable(metric: tf.keras.metrics.Metric):
  """Checks that a Keras metric is constructable from the `get_config()` method.

  Args:
    metric: A single `tf.keras.metrics.Metric`.

  Raises:
    TypeError: If the metric is not an instance of `tf.keras.metrics.Metric`, if
    the metric is not constructable from the `get_config()` method.
  """
  if not isinstance(metric, tf.keras.metrics.Metric):
    raise TypeError(f'Metric {type(metric)} is not a `tf.keras.metrics.Metric` '
                    'to be constructable from the `get_config()` method.')

  metric_type_str = type(metric).__name__

  if not hasattr(tf.keras.metrics, metric_type_str):
    _, init_fn = tf.__internal__.decorator.unwrap(metric.__init__)
    init_args = inspect.getfullargspec(init_fn).args
    init_args.remove('self')
    get_config_args = metric.get_config().keys()
    extra_args = [arg for arg in init_args if arg not in get_config_args]
    if extra_args:
      # TODO(b/197746608): Remove the suggestion of updating `get_config` if
      # that code path is removed.
      raise TypeError(
          f'Metric {metric_type_str} is not constructable from the '
          '`get_config()` method, because `__init__` takes extra arguments '
          f'that are not included in the `get_config()`: {extra_args}. '
          'Pass the metric constructor instead, or update the `get_config()` '
          'in the metric class to include these extra arguments.\n'
          'Example:\n'
          'class CustomMetric(tf.keras.metrics.Metric):\n'
          '  def __init__(self, arg1):\n'
          '    self._arg1 = arg1\n\n'
          '  def get_config(self)\n'
          '    config = super().get_config()\n'
          '    config[\'arg1\'] = self._arg1\n'
          '    return config')