def finalize_metric(metric: tf.keras.metrics.Metric, values): # Note: the following call requires that `type(metric)` have a no argument # __init__ method, which will restrict the types of metrics that can be # used. This is somewhat limiting, but the pattern to use default # arguments and export the values in `get_config()` (see # `tf.keras.metrics.TopKCategoricalAccuracy`) works well. keras_metric = None try: # This is some trickery to reconstruct a metric object in the current # scope, so that the `tf.Variable`s get created when we desire. keras_metric = type(metric).from_config(metric.get_config()) except TypeError as e: # Re-raise the error with a more helpful message, but the previous stack # trace. raise TypeError( 'Caught exception trying to call `{t}.from_config()` with ' 'config {c}. Confirm that {t}.__init__() has an argument for ' 'each member of the config.\nException: {e}'.format( t=type(metric), c=metric.config(), e=e)) assignments = [] for v, a in zip(keras_metric.variables, values): assignments.append(v.assign(a)) with tf.control_dependencies(assignments): return keras_metric.result()
def get_naive_forecasting_performance(dataset: tf.data.Dataset, loss: tf.keras.losses.Loss, metric: tf.keras.metrics.Metric) -> Tuple: metric.reset_states() losses = [] for sequence, y_true in dataset: y_pred = sequence[..., -1] metric.update_state(y_true, y_pred) losses.append(loss(y_true, y_pred).numpy()) return sum(losses) / len(losses), metric.result().numpy()
def _check_keras_metric_config_constructable(metric: tf.keras.metrics.Metric): """Checks that a Keras metric is constructable from the `get_config()` method. Args: metric: A single `tf.keras.metrics.Metric`. Raises: TypeError: If the metric is not an instance of `tf.keras.metrics.Metric`, if the metric is not constructable from the `get_config()` method. """ if not isinstance(metric, tf.keras.metrics.Metric): raise TypeError(f'Metric {type(metric)} is not a `tf.keras.metrics.Metric` ' 'to be constructable from the `get_config()` method.') metric_type_str = type(metric).__name__ if not hasattr(tf.keras.metrics, metric_type_str): _, init_fn = tf.__internal__.decorator.unwrap(metric.__init__) init_args = inspect.getfullargspec(init_fn).args init_args.remove('self') get_config_args = metric.get_config().keys() extra_args = [arg for arg in init_args if arg not in get_config_args] if extra_args: # TODO(b/197746608): Remove the suggestion of updating `get_config` if # that code path is removed. raise TypeError( f'Metric {metric_type_str} is not constructable from the ' '`get_config()` method, because `__init__` takes extra arguments ' f'that are not included in the `get_config()`: {extra_args}. ' 'Pass the metric constructor instead, or update the `get_config()` ' 'in the metric class to include these extra arguments.\n' 'Example:\n' 'class CustomMetric(tf.keras.metrics.Metric):\n' ' def __init__(self, arg1):\n' ' self._arg1 = arg1\n\n' ' def get_config(self)\n' ' config = super().get_config()\n' ' config[\'arg1\'] = self._arg1\n' ' return config')