Пример #1
0
def metric_instance(
    metric_config: config.MetricConfig,
    tfma_metric_classes: Optional[Dict[Text, Type[metric_types.Metric]]] = None
) -> metric_types.Metric:
    """Creates instance of metric associated with config."""
    if tfma_metric_classes is None:
        tfma_metric_classes = metric_types.registered_metrics()
    if metric_config.class_name in tfma_metric_classes:
        return _deserialize_tfma_metric(metric_config, tfma_metric_classes)
    elif not metric_config.module:
        return _deserialize_tf_metric(metric_config, {})
    else:
        cls = getattr(importlib.import_module(metric_config.module),
                      metric_config.class_name)
        if issubclass(cls, tf.keras.metrics.Metric):
            return _deserialize_tf_metric(metric_config,
                                          {metric_config.class_name: cls})
        elif issubclass(cls, tf.keras.losses.Loss):
            return _deserialize_tf_loss(metric_config,
                                        {metric_config.class_name: cls})
        elif issubclass(cls, metric_types.Metric):
            return _deserialize_tfma_metric(metric_config,
                                            {metric_config.class_name: cls})
        else:
            raise NotImplementedError(
                'unknown metric type {}: metric={}'.format(cls, metric_config))
Пример #2
0
def has_attributions_metrics(
        metrics_specs: Iterable[config_pb2.MetricsSpec]) -> bool:
    """Returns true if any of the metrics_specs have attributions metrics."""
    tfma_metric_classes = metric_types.registered_metrics()
    for metrics_spec in metrics_specs:
        for metric_config in metrics_spec.metrics:
            instance = metric_specs.metric_instance(metric_config,
                                                    tfma_metric_classes)
            if isinstance(instance, AttributionsMetric):
                return True
    return False
Пример #3
0
def _keys_and_metrics_from_specs(
    metrics_specs: Iterable[config.MetricsSpec]
) -> Iterator[Tuple[metric_types.MetricKey, config.MetricConfig,
                    metric_types.Metric]]:
  """Yields key, config, instance tuples for each non-diff metric in specs."""
  tfma_metric_classes = metric_types.registered_metrics()
  for spec in metrics_specs:
    for aggregation_type, sub_keys in _create_sub_keys(spec).items():
      for metric_config in spec.metrics:
        instance = metric_instance(metric_config, tfma_metric_classes)
        for key in _keys_for_metric(instance.name, spec, aggregation_type,
                                    sub_keys):
          yield key, metric_config, instance
Пример #4
0
def _keys_and_metrics_from_specs(
    metrics_specs: Iterable[config.MetricsSpec]
) -> Iterator[Tuple[metric_types.MetricKey, config.MetricConfig,
                    metric_types.Metric]]:
    """Yields key, config, instance tuples for each non-diff metric in specs."""
    tfma_metric_classes = metric_types.registered_metrics()
    for spec in metrics_specs:
        sub_keys = _create_sub_keys(spec) or [None]
        if spec.aggregate.macro_average or spec.aggregate.weighted_macro_average:
            sub_keys.append(None)

        for metric_config in spec.metrics:
            if metric_config.class_name in tfma_metric_classes:
                instance = _deserialize_tfma_metric(metric_config,
                                                    tfma_metric_classes)
            elif not metric_config.module:
                instance = _deserialize_tf_metric(metric_config, {})
            else:
                cls = getattr(importlib.import_module(metric_config.module),
                              metric_config.class_name)
                if issubclass(cls, tf.keras.metrics.Metric):
                    instance = _deserialize_tf_metric(
                        metric_config, {metric_config.class_name: cls})
                elif issubclass(cls, tf.keras.losses.Loss):
                    instance = _deserialize_tf_loss(
                        metric_config, {metric_config.class_name: cls})
                elif issubclass(cls, metric_types.Metric):
                    instance = _deserialize_tfma_metric(
                        metric_config, {metric_config.class_name: cls})
                else:
                    raise NotImplementedError(
                        'unknown metric type {}: metric={}'.format(
                            cls, metric_config))

            if (hasattr(instance, 'is_model_independent')
                    and instance.is_model_independent()):
                key = metric_types.MetricKey(name=instance.name)
                yield key, metric_config, instance
            else:
                for key in _keys_for_metric(instance.name, spec, sub_keys):
                    yield key, metric_config, instance
Пример #5
0
def to_computations(
    metrics_specs: List[config.MetricsSpec],
    eval_config: Optional[config.EvalConfig] = None,
    schema: Optional[schema_pb2.Schema] = None
) -> metric_types.MetricComputations:
  """Returns computations associated with given metrics specs."""
  computations = []

  #
  # Split into TF metrics and TFMA metrics
  #

  # Dict[Text, Type[tf.keras.metrics.Metric]]
  tf_metric_classes = {}  # class_name -> class
  # Dict[Text, Type[tf.keras.losses.Loss]]
  tf_loss_classes = {}  # class_name -> class
  # List[metric_types.MetricsSpec]
  tf_metrics_specs = []
  # Dict[Text, Type[metric_types.Metric]]
  tfma_metric_classes = metric_types.registered_metrics()  # class_name -> class
  # List[metric_types.MetricsSpec]
  tfma_metrics_specs = []
  #
  # Note: Lists are used instead of Dicts for the following items because
  # protos are are no hashable.
  #
  # List[List[_TFOrTFMAMetric]] (offsets align with metrics_specs).
  per_spec_metric_instances = []
  # List[List[_TFMetricOrLoss]] (offsets align with tf_metrics_specs).
  per_tf_spec_metric_instances = []
  # List[List[metric_types.Metric]]] (offsets align with tfma_metrics_specs).
  per_tfma_spec_metric_instances = []
  for spec in metrics_specs:
    tf_spec = config.MetricsSpec()
    tf_spec.CopyFrom(spec)
    del tf_spec.metrics[:]
    tfma_spec = config.MetricsSpec()
    tfma_spec.CopyFrom(spec)
    del tfma_spec.metrics[:]
    for metric in spec.metrics:
      if metric.class_name in tfma_metric_classes:
        tfma_spec.metrics.append(metric)
      elif not metric.module:
        tf_spec.metrics.append(metric)
      else:
        cls = getattr(importlib.import_module(metric.module), metric.class_name)
        if issubclass(cls, tf.keras.metrics.Metric):
          tf_metric_classes[metric.class_name] = cls
          tf_spec.metrics.append(metric)
        elif issubclass(cls, tf.keras.losses.Loss):
          tf_loss_classes[metric.class_name] = cls
          tf_spec.metrics.append(metric)
        else:
          tfma_metric_classes[metric.class_name] = cls
          tfma_spec.metrics.append(metric)

    metric_instances = []
    if tf_spec.metrics:
      tf_metrics_specs.append(tf_spec)
      tf_metric_instances = []
      for m in tf_spec.metrics:
        # To distinguish losses from metrics, losses are required to set the
        # module name.
        if m.module == _TF_LOSSES_MODULE:
          tf_metric_instances.append(_deserialize_tf_loss(m, tf_loss_classes))
        else:
          tf_metric_instances.append(
              _deserialize_tf_metric(m, tf_metric_classes))
      per_tf_spec_metric_instances.append(tf_metric_instances)
      metric_instances.extend(tf_metric_instances)
    if tfma_spec.metrics:
      tfma_metrics_specs.append(tfma_spec)
      tfma_metric_instances = [
          _deserialize_tfma_metric(m, tfma_metric_classes)
          for m in tfma_spec.metrics
      ]
      per_tfma_spec_metric_instances.append(tfma_metric_instances)
      metric_instances.extend(tfma_metric_instances)
    per_spec_metric_instances.append(metric_instances)

  # Process TF specs
  computations.extend(
      _process_tf_metrics_specs(tf_metrics_specs, per_tf_spec_metric_instances,
                                eval_config))

  # Process TFMA specs
  computations.extend(
      _process_tfma_metrics_specs(tfma_metrics_specs,
                                  per_tfma_spec_metric_instances, eval_config,
                                  schema))

  # Process aggregation based metrics (output aggregation and macro averaging).
  # Note that processing of TF and TFMA specs were setup to create the binarized
  # metrics that macro averaging depends on.
  for i, spec in enumerate(metrics_specs):
    for aggregation_type, sub_keys in _create_sub_keys(spec).items():
      output_names = spec.output_names or ['']
      output_weights = dict(spec.output_weights)
      if not set(output_weights.keys()).issubset(output_names):
        raise ValueError(
            'one or more output_names used in output_weights does not exist: '
            'output_names={}, output_weights={}'.format(output_names,
                                                        output_weights))
      for model_name in spec.model_names or ['']:
        for sub_key in sub_keys:
          for metric in per_spec_metric_instances[i]:
            if (aggregation_type and (aggregation_type.macro_average or
                                      aggregation_type.weighted_macro_average)):
              class_weights = _class_weights(spec) or {}
              for output_name in output_names:
                sub_keys = _macro_average_sub_keys(sub_key, class_weights)
                if aggregation_type.macro_average:
                  computations.extend(
                      aggregation.macro_average(
                          metric.get_config()['name'],
                          sub_keys=sub_keys,
                          eval_config=eval_config,
                          model_name=model_name,
                          output_name=output_name,
                          sub_key=sub_key,
                          class_weights=class_weights))
                elif aggregation_type.weighted_macro_average:
                  computations.extend(
                      aggregation.weighted_macro_average(
                          metric.get_config()['name'],
                          sub_keys=sub_keys,
                          eval_config=eval_config,
                          model_name=model_name,
                          output_name=output_name,
                          sub_key=sub_key,
                          class_weights=class_weights))
            if output_weights:
              computations.extend(
                  aggregation.output_average(
                      metric.get_config()['name'],
                      output_weights=output_weights,
                      eval_config=eval_config,
                      model_name=model_name,
                      sub_key=sub_key))

  return computations
Пример #6
0
def to_computations(
    metrics_specs: List[config.MetricsSpec],
    eval_config: Optional[config.EvalConfig] = None,
    schema: Optional[schema_pb2.Schema] = None
) -> metric_types.MetricComputations:
    """Returns computations associated with given metrics specs."""
    computations = []

    #
    # Split into TF metrics and TFMA metrics
    #

    # Dict[Text, Type[tf.keras.metrics.Metric]]
    tf_metric_classes = {}  # class_name -> class
    # Dict[Text, Type[tf.keras.losses.Loss]]
    tf_loss_classes = {}  # class_name -> class
    # List[metric_types.MetricsSpec]
    tf_metrics_specs = []
    # Dict[Text, Type[metric_types.Metric]]
    tfma_metric_classes = metric_types.registered_metrics(
    )  # class_name -> class
    # List[metric_types.MetricsSpec]
    tfma_metrics_specs = []
    #
    # Note: Lists are used instead of Dicts for the following items because
    # protos are are no hashable.
    #
    # List[List[_TFOrTFMAMetric]] (offsets align with metrics_specs).
    per_spec_metric_instances = []
    # List[List[_TFMetricOrLoss]] (offsets align with tf_metrics_specs).
    per_tf_spec_metric_instances = []
    # List[List[metric_types.Metric]]] (offsets align with tfma_metrics_specs).
    per_tfma_spec_metric_instances = []
    for spec in metrics_specs:
        tf_spec = config.MetricsSpec()
        tf_spec.CopyFrom(spec)
        del tf_spec.metrics[:]
        tfma_spec = config.MetricsSpec()
        tfma_spec.CopyFrom(spec)
        del tfma_spec.metrics[:]
        for metric in spec.metrics:
            if metric.class_name in tfma_metric_classes:
                tfma_spec.metrics.append(metric)
            elif not metric.module:
                tf_spec.metrics.append(metric)
            else:
                cls = getattr(importlib.import_module(metric.module),
                              metric.class_name)
                if issubclass(cls, tf.keras.metrics.Metric):
                    tf_metric_classes[metric.class_name] = cls
                    tf_spec.metrics.append(metric)
                elif issubclass(cls, tf.keras.losses.Loss):
                    tf_loss_classes[metric.class_name] = cls
                    tf_spec.metrics.append(metric)
                else:
                    tfma_metric_classes[metric.class_name] = cls
                    tfma_spec.metrics.append(metric)

        metric_instances = []
        if tf_spec.metrics:
            tf_metrics_specs.append(tf_spec)
            tf_metric_instances = []
            for m in tf_spec.metrics:
                # To distinguish losses from metrics, losses are required to set the
                # module name.
                if m.module == _TF_LOSSES_MODULE:
                    tf_metric_instances.append(
                        _deserialize_tf_loss(m, tf_loss_classes))
                else:
                    tf_metric_instances.append(
                        _deserialize_tf_metric(m, tf_metric_classes))
            per_tf_spec_metric_instances.append(tf_metric_instances)
            metric_instances.extend(tf_metric_instances)
        if tfma_spec.metrics:
            tfma_metrics_specs.append(tfma_spec)
            tfma_metric_instances = [
                _deserialize_tfma_metric(m, tfma_metric_classes)
                for m in tfma_spec.metrics
            ]
            per_tfma_spec_metric_instances.append(tfma_metric_instances)
            metric_instances.extend(tfma_metric_instances)
        per_spec_metric_instances.append(metric_instances)

    #
    # Group TF metrics by the subkeys, models and outputs. This is done in reverse
    # because model and subkey processing is done outside of TF and so each unique
    # sub key combination needs to be run through a separate model instance. Note
    # that output_names are handled by the tf_metric_computation since all the
    # outputs are batch calculated in a single model evaluation call.
    #

    # Dict[metric_types.SubKey, Dict[Text, List[int]]
    tf_spec_indices_by_subkey = {
    }  # SubKey -> model_name -> [index(MetricSpec)]
    for i, spec in enumerate(tf_metrics_specs):
        sub_keys = _create_sub_keys(spec)
        if not sub_keys:
            sub_keys = [None]
        for sub_key in sub_keys:
            if sub_key not in tf_spec_indices_by_subkey:
                tf_spec_indices_by_subkey[sub_key] = {}
            # Dict[Text, List[config.MetricSpec]]
            tf_spec_indices_by_model = (tf_spec_indices_by_subkey[sub_key]
                                        )  # name -> [ModelSpec]
            model_names = spec.model_names
            if not model_names:
                model_names = [''
                               ]  # '' is name used when only one model is used
            for model_name in model_names:
                if model_name not in tf_spec_indices_by_model:
                    tf_spec_indices_by_model[model_name] = []
                tf_spec_indices_by_model[model_name].append(i)
    for sub_key, spec_indices_by_model in tf_spec_indices_by_subkey.items():
        for model_name, indices in spec_indices_by_model.items():
            # Class weights are a dict that is not hashable, so we store index to spec
            # containing class weights.
            metrics_by_class_weights_by_output = collections.defaultdict(dict)
            for i in indices:
                class_weights_i = None
                if tf_metrics_specs[i].HasField('aggregate'):
                    class_weights_i = i
                metrics_by_output = metrics_by_class_weights_by_output[
                    class_weights_i]
                output_names = ['']  # '' is name used when only one output
                if tf_metrics_specs[i].output_names:
                    output_names = tf_metrics_specs[i].output_names
                for output_name in output_names:
                    if output_name not in metrics_by_output:
                        metrics_by_output[output_name] = []
                    metrics_by_output[output_name].extend(
                        per_tf_spec_metric_instances[i])
            for i, metrics_by_output in metrics_by_class_weights_by_output.items(
            ):
                class_weights = None
                if i is not None:
                    class_weights = dict(
                        tf_metrics_specs[i].aggregate.class_weights)
                computations.extend(
                    tf_metric_wrapper.tf_metric_computations(
                        metrics_by_output,
                        eval_config=eval_config,
                        model_name=model_name,
                        sub_key=sub_key,
                        class_weights=class_weights))

    #
    # Group TFMA metric specs by the metric classes
    #

    # Dict[bytes, List[config.MetricSpec]]
    tfma_specs_by_metric_config = {}  # hash(MetricConfig) -> [MetricSpec]
    # Dict[bytes, metric_types.Metric]
    hashed_metrics = {}  # hash(MetricConfig) -> Metric
    for i, spec in enumerate(tfma_metrics_specs):
        for metric_config, metric in zip(spec.metrics,
                                         per_tfma_spec_metric_instances[i]):
            # Note that hashing by SerializeToString() is only safe if used within the
            # same process.
            config_hash = metric_config.SerializeToString()
            if config_hash not in tfma_specs_by_metric_config:
                hashed_metrics[config_hash] = metric
                tfma_specs_by_metric_config[config_hash] = []
            tfma_specs_by_metric_config[config_hash].append(spec)
    for config_hash, specs in tfma_specs_by_metric_config.items():
        metric = hashed_metrics[config_hash]
        for spec in specs:
            sub_keys = _create_sub_keys(spec)
            class_weights = None
            if spec.HasField('aggregate'):
                class_weights = dict(spec.aggregate.class_weights)
            computations.extend(
                metric.computations(
                    eval_config=eval_config,
                    schema=schema,
                    model_names=spec.model_names if spec.model_names else [''],
                    output_names=spec.output_names
                    if spec.output_names else [''],
                    sub_keys=sub_keys,
                    class_weights=class_weights,
                    query_key=spec.query_key))

    #
    # Create macro averaging metrics
    #

    for i, spec in enumerate(metrics_specs):
        if spec.aggregate.macro_average or spec.aggregate.weighted_macro_average:
            sub_keys = _create_sub_keys(spec)
            if sub_keys is None:
                raise ValueError(
                    'binarize settings are required when aggregate.macro_average or '
                    'aggregate.weighted_macro_average is used: spec={}'.format(
                        spec))
            for model_name in spec.model_names or ['']:
                for output_name in spec.output_names or ['']:
                    for metric in per_spec_metric_instances[i]:
                        if spec.aggregate.macro_average:
                            computations.extend(
                                aggregation.macro_average(
                                    metric.get_config()['name'],
                                    eval_config=eval_config,
                                    model_name=model_name,
                                    output_name=output_name,
                                    sub_keys=sub_keys,
                                    class_weights=dict(
                                        spec.aggregate.class_weights)))
                        elif spec.aggregate.weighted_macro_average:
                            computations.extend(
                                aggregation.weighted_macro_average(
                                    metric.get_config()['name'],
                                    eval_config=eval_config,
                                    model_name=model_name,
                                    output_name=output_name,
                                    sub_keys=sub_keys,
                                    class_weights=dict(
                                        spec.aggregate.class_weights)))

    return computations
Пример #7
0
def to_computations(
    metrics_specs: List[config.MetricsSpec],
    eval_config: Optional[config.EvalConfig] = None,
    model_loaders: Optional[Dict[Text, types.ModelLoader]] = None
) -> metric_types.MetricComputations:
    """Returns computations associated with given metrics specs."""
    computations = []

    #
    # Split into TF metrics and TFMA metrics
    #

    # Dict[Text, Type[tf.keras.metrics.Metric]]
    tf_metric_classes = {}  # class_name -> class
    # List[metric_types.MetricsSpec]
    tf_metrics_specs = []
    # Dict[Text, Type[metric_types.Metric]]
    tfma_metric_classes = metric_types.registered_metrics(
    )  # class_name -> class
    # List[metric_types.MetricsSpec]
    tfma_metrics_specs = []
    for spec in metrics_specs:
        tf_spec = config.MetricsSpec()
        tf_spec.CopyFrom(spec)
        del tf_spec.metrics[:]
        tfma_spec = config.MetricsSpec()
        tfma_spec.CopyFrom(spec)
        del tfma_spec.metrics[:]
        for metric in spec.metrics:
            if metric.class_name in tfma_metric_classes:
                tfma_spec.metrics.append(metric)
            elif not metric.module:
                tf_spec.metrics.append(metric)
            else:
                cls = getattr(importlib.import_module(metric.module_name),
                              metric.class_name)
                if isinstance(metric, tf.keras.metrics.Metric):
                    tf_metric_classes[metric.class_name] = cls
                    tf_spec.metrics.append(metric)
                else:
                    tfma_metric_classes[metric.class_name] = cls
                    tfma_spec.metrics.append(metric)
        if tf_spec.metrics:
            tf_metrics_specs.append(tf_spec)
        if tfma_spec.metrics:
            tfma_metrics_specs.append(tfma_spec)

    #
    # Group TF metrics by the subkeys, models and outputs. This is done in reverse
    # because model and subkey processing is done outside of TF and so each unique
    # sub key combination needs to be run through a separate model instance. Note
    # that output_names are handled by the tf_metric_computation since all the
    # outputs are batch calculated in a single model evaluation call.
    #

    # Dict[metric_types.SubKey, Dict[Text, List[config.MetricSpec]]
    tf_specs_by_subkey = {}  # SubKey -> model_name -> [MetricSpec]
    for spec in tf_metrics_specs:
        sub_keys = _create_sub_keys(spec)
        if not sub_keys:
            sub_keys = [None]
        for sub_key in sub_keys:
            if sub_key not in tf_specs_by_subkey:
                tf_specs_by_subkey[sub_key] = {}
            # Dict[Text, List[config.MetricSpec]]
            tf_specs_by_model = tf_specs_by_subkey[
                sub_key]  # name -> [ModelSpec]
            model_names = spec.model_names
            if not model_names:
                model_names = [''
                               ]  # '' is name used when only one model is used
            for model_name in model_names:
                if model_name not in tf_specs_by_model:
                    tf_specs_by_model[model_name] = []
                tf_specs_by_model[model_name].append(spec)
    for sub_key, specs_by_model in tf_specs_by_subkey.items():
        for model_name, specs in specs_by_model.items():
            metrics_by_output = {}
            for spec in specs:
                metrics = [
                    _deserialize_tf_metric(m, tf_metric_classes)
                    for m in spec.metrics
                ]
                if spec.output_names:
                    for output_name in spec.output_names:
                        if output_name not in metrics_by_output:
                            metrics_by_output[output_name] = []
                        metrics_by_output[output_name].extend(metrics)
                else:
                    if '' not in metrics_by_output:
                        metrics_by_output[''] = [
                        ]  # '' is name used when only one output
                    metrics_by_output[''].extend(metrics)
            model_loader = None
            if model_loaders and model_name in model_loaders:
                model_loader = model_loaders[model_name]
            computations.extend(
                tf_metric_wrapper.tf_metric_computations(
                    metrics_by_output,
                    eval_config=eval_config,
                    model_name=model_name,
                    sub_key=sub_key,
                    model_loader=model_loader))

    #
    # Group TFMA metric specs by the metric classes
    #

    # Dict[bytes, List[config.MetricSpec]]
    tfma_specs_by_metric_config = {}  # hash(MetricConfig) -> [MetricSpec]
    # Dict[bytes, config.MetricConfig]
    hashed_metric_configs = {}  # hash(MetricConfig) -> MetricConfig
    for spec in tfma_metrics_specs:
        for metric_config in spec.metrics:
            # Note that hashing by SerializeToString() is only safe if used within the
            # same process.
            config_hash = metric_config.SerializeToString()
            if config_hash not in tfma_specs_by_metric_config:
                hashed_metric_configs[config_hash] = metric_config
                tfma_specs_by_metric_config[config_hash] = []
            tfma_specs_by_metric_config[config_hash].append(spec)
    for config_hash, specs in tfma_specs_by_metric_config.items():
        metric = _deserialize_tfma_metric(hashed_metric_configs[config_hash],
                                          tfma_metric_classes)
        for spec in specs:
            sub_keys = _create_sub_keys(spec)
            computations.extend(
                metric.computations(
                    eval_config=eval_config,
                    model_names=spec.model_names if spec.model_names else [''],
                    output_names=spec.output_names
                    if spec.output_names else [''],
                    sub_keys=sub_keys,
                    query_key=spec.query_key))
    return computations
Пример #8
0
def metric_thresholds_from_metrics_specs(
    metrics_specs: List[config.MetricsSpec]
) -> Dict[metric_types.MetricKey, Union[config.GenericChangeThreshold,
                                        config.GenericValueThreshold]]:
    """Returns thresholds associated with given metrics specs."""
    result = {}

    tfma_metric_classes = metric_types.registered_metrics()

    for spec in metrics_specs:
        sub_keys = _create_sub_keys(spec) or [None]
        if spec.aggregate.macro_average or spec.aggregate.weighted_macro_average:
            sub_keys.append(None)

        # Add thresholds for metrics computed in-graph.
        for metric_name, threshold in spec.thresholds.items():
            for model_name in spec.model_names or ['']:
                for output_name in spec.output_names or ['']:
                    for sub_key in sub_keys:
                        if threshold.HasField('value_threshold'):
                            key = metric_types.MetricKey(
                                name=metric_name,
                                model_name=model_name,
                                output_name=output_name,
                                sub_key=sub_key,
                                is_diff=False)
                            result[key] = threshold.value_threshold
                        if threshold.HasField('change_threshold'):
                            key = metric_types.MetricKey(
                                name=metric_name,
                                model_name=model_name,
                                output_name=output_name,
                                sub_key=sub_key,
                                is_diff=True)
                            result[key] = threshold.change_threshold

        # Thresholds in MetricConfig override thresholds in MetricsSpec.
        for metric in spec.metrics:
            if not metric.HasField('threshold'):
                continue

            if metric.class_name in tfma_metric_classes:
                instance = _deserialize_tfma_metric(metric,
                                                    tfma_metric_classes)
            elif not metric.module:
                instance = _deserialize_tf_metric(metric, {})
            else:
                cls = getattr(importlib.import_module(metric.module),
                              metric.class_name)
                if issubclass(cls, tf.keras.metrics.Metric):
                    instance = _deserialize_tf_metric(metric,
                                                      {metric.class_name: cls})
                elif issubclass(cls, tf.keras.losses.Loss):
                    instance = _deserialize_tf_loss(metric,
                                                    {metric.class_name: cls})
                elif isinstance(metric, metric_types.Metric):
                    instance = _deserialize_tfma_metric(
                        metric, {metric.class_name: cls})
                else:
                    raise NotImplementedError(
                        'unknown metric type {}: metric={}'.format(
                            type(metric), metric))

            if (hasattr(instance, 'is_model_independent')
                    and instance.is_model_independent()):
                if metric.threshold.HasField('value_threshold'):
                    key = metric_types.MetricKey(name=instance.name,
                                                 is_diff=False)
                    result[key] = metric.threshold.value_threshold
            else:
                for model_name in spec.model_names or ['']:
                    for output_name in spec.output_names or ['']:
                        for sub_key in sub_keys:
                            if metric.threshold.HasField('value_threshold'):
                                key = metric_types.MetricKey(
                                    name=instance.name,
                                    model_name=model_name,
                                    output_name=output_name,
                                    sub_key=sub_key,
                                    is_diff=False)
                                result[key] = metric.threshold.value_threshold
                            if metric.threshold.HasField('change_threshold'):
                                key = metric_types.MetricKey(
                                    name=instance.name,
                                    model_name=model_name,
                                    output_name=output_name,
                                    sub_key=sub_key,
                                    is_diff=True)
                                result[key] = metric.threshold.change_threshold

    return result