예제 #1
0
def _serialize_metrics(metrics: Tuple[slicer.SliceKeyType, Dict[Text, Any]],
                       post_export_metrics: List[types.AddMetricsCallbackType]
                      ) -> bytes:
  """Converts the given slice metrics into serialized proto MetricsForSlice.

  Args:
    metrics: The slice metrics.
    post_export_metrics: A list of metric callbacks. This should be the same
      list as the one passed to tfma.Evaluate().

  Returns:
    The serialized proto MetricsForSlice.

  Raises:
    TypeError: If the type of the feature value in slice key cannot be
      recognized.
  """
  result = metrics_for_slice_pb2.MetricsForSlice()
  slice_key, slice_metrics = metrics

  if metric_keys.ERROR_METRIC in slice_metrics:
    tf.logging.warning('Error for slice: %s with error message: %s ', slice_key,
                       slice_metrics[metric_keys.ERROR_METRIC])
    metrics = metrics_for_slice_pb2.MetricsForSlice()
    metrics.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))
    metrics.metrics[metric_keys.ERROR_METRIC].debug_message = slice_metrics[
        metric_keys.ERROR_METRIC]
    return metrics.SerializeToString()

  # Convert the slice key.
  result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))

  # Convert the slice metrics.
  convert_slice_metrics(slice_metrics, post_export_metrics, result)
  return result.SerializeToString()
예제 #2
0
        def check_result(got):  # pylint: disable=invalid-name
            try:
                self.assertEqual(1, len(got), 'got: %s' % got)
                (slice_key, value) = got[0]
                self.assertEqual((), slice_key)
                self.assertDictElementsAlmostEqual(value, expected_values_dict)

                # Check serialization too.
                # Note that we can't just make this a dict, since proto maps
                # allow uninitialized key access, i.e. they act like defaultdicts.
                output_metrics = metrics_for_slice_pb2.MetricsForSlice(
                ).metrics
                auc_metric.populate_stats_and_pop(value, output_metrics)
                self.assertProtoEquals(
                    """
            bounded_value {
              lower_bound {
                value: 0.6999999
              }
              upper_bound {
                value: 0.7777776
              }
              value {
                value: 0.7407472
              }
            }
            """, output_metrics[metric_keys.AUPRC])
            except AssertionError as err:
                raise util.BeamAssertException(err)
def revert_slice_keys_for_transformed_features(
        metrics: List[metrics_for_slice_pb2.MetricsForSlice],
        statistics: statistics_pb2.DatasetFeatureStatisticsList):
    """Revert the slice keys for the transformed features.

  Args:
    metrics: List of slice metrics protos.
    statistics: Data statistics used to configure AutoSliceKeyExtractor.

  Returns:
    List of slice metrics protos where transformed features are mapped back to
    raw features in the slice keys.
  """
    result = []
    boundaries = auto_slice_key_extractor._get_quantile_boundaries(statistics)  # pylint: disable=protected-access
    for slice_metrics in metrics:
        transformed_metrics = metrics_for_slice_pb2.MetricsForSlice()
        transformed_metrics.CopyFrom(slice_metrics)
        for single_slice_key in transformed_metrics.slice_key.single_slice_keys:
            if single_slice_key.column.startswith(
                    auto_slice_key_extractor.TRANSFORMED_FEATURE_PREFIX):
                raw_feature = single_slice_key.column[
                    len(auto_slice_key_extractor.TRANSFORMED_FEATURE_PREFIX):]
                single_slice_key.column = raw_feature
                (start, end) = auto_slice_key_extractor._get_bucket_boundary(  # pylint: disable=protected-access
                    getattr(single_slice_key,
                            single_slice_key.WhichOneof('kind')),
                    boundaries[raw_feature])
                single_slice_key.bytes_value = _format_boundary(start, end)
        result.append(transformed_metrics)
    return result
예제 #4
0
def revert_slice_keys_for_transformed_features(
    metrics: List[metrics_for_slice_pb2.MetricsForSlice],
    statistics: statistics_pb2.DatasetFeatureStatisticsList):
  """Revert the slice keys for the transformed features.

  Args:
    metrics: List of slice metrics protos.
    statistics: Data statistics used to configure AutoSliceKeyExtractor.

  Returns:
    List of slice metrics protos where transformed features are mapped back to
    raw features in the slice keys.
  """
  result = []
  boundaries = auto_slice_key_extractor.get_quantile_boundaries(statistics)
  for slice_metrics in metrics:
    transformed_metrics = metrics_for_slice_pb2.MetricsForSlice()
    transformed_metrics.CopyFrom(slice_metrics)
    for single_slice_key in transformed_metrics.slice_key.single_slice_keys:
      raw_feature_name, raw_feature_value = get_raw_feature(
          single_slice_key.column,
          getattr(single_slice_key, single_slice_key.WhichOneof('kind')),
          boundaries)
      single_slice_key.column = raw_feature_name
      single_slice_key.bytes_value = raw_feature_value
    result.append(transformed_metrics)
  return result
예제 #5
0
def _serialize_metrics(
    metrics,
    post_export_metrics):
  """Converts the given slice metrics into serialized proto MetricsForSlice.

  Args:
    metrics: The slice metrics.
    post_export_metrics: A list of metric callbacks. This should be the same
      list as the one passed to tfma.Evaluate().

  Returns:
    The serialized proto MetricsForSlice.

  Raises:
    TypeError: If the type of the feature value in slice key cannot be
      recognized.
  """
  result = metrics_for_slice_pb2.MetricsForSlice()
  slice_key, slice_metrics = metrics

  # Convert the slice key.
  result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))

  # Convert the slice metrics.
  _convert_slice_metrics(slice_metrics, post_export_metrics, result)

  return result.SerializeToString()
예제 #6
0
    def testConvertMetricsProto(self):
        metrics_for_slice = text_format.Parse(
            """
      slice_key {}
      metric_keys_and_values {
        key {
          name: "metric_name"
        }
        value: {
          double_value { value: 1.0 }
        }
        confidence_interval {
          lower_bound: { double_value: { value: 0.5 } }
          upper_bound: { double_value: { value: 1.5 } }
        }
      }""", metrics_for_slice_pb2.MetricsForSlice())

        got = util.convert_metrics_proto_to_dict(metrics_for_slice)
        expected = ((), {
            '': {
                '': {
                    'metric_name': {
                        'boundedValue': {
                            'lowerBound': 0.5,
                            'upperBound': 1.5,
                            'value': 1.0
                        }
                    }
                }
            }
        })
        self.assertEqual(got, expected)
        def check_result(got):
            try:
                self.assertEqual(3, len(got), 'got: %s' % got)
                for _, value in got:
                    expected_value = {
                        # Subgroup
                        'post_export_metrics/fairness/auc/subgroup_auc/fixed_int':
                        0.5,
                        'post_export_metrics/fairness/auc/subgroup_auc/fixed_int/lower_bound':
                        0.25,
                        'post_export_metrics/fairness/auc/subgroup_auc/fixed_int/upper_bound':
                        0.75,
                        # BNSP
                        'post_export_metrics/fairness/auc/bnsp_auc/fixed_int':
                        0.5,
                        'post_export_metrics/fairness/auc/bnsp_auc/fixed_int/lower_bound':
                        0.25,
                        'post_export_metrics/fairness/auc/bnsp_auc/fixed_int/upper_bound':
                        0.75,
                        # BPSN
                        'post_export_metrics/fairness/auc/bpsn_auc/fixed_int':
                        0.5,
                        'post_export_metrics/fairness/auc/bpsn_auc/fixed_int/lower_bound':
                        0.25,
                        'post_export_metrics/fairness/auc/bpsn_auc/fixed_int/upper_bound':
                        0.75,
                        'average_loss':
                        0.5,
                    }
                    self.assertDictElementsAlmostEqual(value, expected_value)

                # Check serialization too.
                output_metrics = metrics_for_slice_pb2.MetricsForSlice(
                ).metrics
                for slice_key, value in got:
                    fairness_auc.populate_stats_and_pop(
                        slice_key, value, output_metrics)
                for key in (
                        metric_keys.FAIRNESS_AUC + '/subgroup_auc/fixed_int',
                        metric_keys.FAIRNESS_AUC + '/bpsn_auc/fixed_int',
                        metric_keys.FAIRNESS_AUC + '/bnsp_auc/fixed_int',
                ):
                    self.assertProtoEquals(
                        """
              bounded_value {
                lower_bound {
                  value: 0.2500001
                }
                upper_bound {
                  value: 0.7499999
                }
                value {
                  value: 0.5
                }
                methodology: RIEMANN_SUM
              }
              """, output_metrics[key])
            except AssertionError as err:
                raise util.BeamAssertException(err)
 def testUncertaintyValuedMetrics(self):
     slice_key = _make_slice_key()
     slice_metrics = {
         'one_dim':
         types.ValueWithConfidenceInterval(2.0, 1.0, 3.0),
         'nans':
         types.ValueWithConfidenceInterval(float('nan'), float('nan'),
                                           float('nan')),
     }
     expected_metrics_for_slice = text_format.Parse(
         """
     slice_key {}
     metrics {
       key: "one_dim"
       value {
         bounded_value {
           value {
             value: 2.0
           }
           lower_bound {
             value: 1.0
           }
           upper_bound {
             value: 3.0
           }
           methodology: POISSON_BOOTSTRAP
         }
       }
     }
     metrics {
       key: "nans"
       value {
         bounded_value {
           value {
             value: nan
           }
           lower_bound {
             value: nan
           }
           upper_bound {
             value: nan
           }
           methodology: POISSON_BOOTSTRAP
         }
       }
     }
     """, metrics_for_slice_pb2.MetricsForSlice())
     got = metrics_and_plots_evaluator._serialize_metrics(
         (slice_key, slice_metrics), [])
     self.assertProtoEquals(
         expected_metrics_for_slice,
         metrics_for_slice_pb2.MetricsForSlice.FromString(got))
 def testTensorValuedMetrics(self):
     slice_key = _make_slice_key()
     slice_metrics = {
         'one_dim':
         np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32),
         'two_dims':
         np.array([['two', 'dims', 'test'], ['TWO', 'DIMS', 'TEST']]),
         'three_dims':
         np.array([[[100, 200, 300]], [[500, 600, 700]]], dtype=np.int64),
     }
     expected_metrics_for_slice = text_format.Parse(
         """
     slice_key {}
     metrics {
       key: "one_dim"
       value {
         array_value {
           data_type: FLOAT32
           shape: 4
           float32_values: [1.0, 2.0, 3.0, 4.0]
         }
       }
     }
     metrics {
       key: "two_dims"
       value {
         array_value {
           data_type: BYTES
           shape: [2, 3]
           bytes_values: ["two", "dims", "test", "TWO", "DIMS", "TEST"]
         }
       }
     }
     metrics {
       key: "three_dims"
       value {
         array_value {
           data_type: INT64
           shape: [2, 1, 3]
           int64_values: [100, 200, 300, 500, 600, 700]
         }
       }
     }
     """, metrics_for_slice_pb2.MetricsForSlice())
     got = metrics_and_plots_evaluator._serialize_metrics(
         (slice_key, slice_metrics), [])
     self.assertProtoEquals(
         expected_metrics_for_slice,
         metrics_for_slice_pb2.MetricsForSlice.FromString(got))
예제 #10
0
  def testConvertSliceMetricsToProtoEmptyMetrics(self):
    slice_key = _make_slice_key('age', 5, 'language', 'english', 'price', 0.3)
    slice_metrics = {metric_keys.ERROR_METRIC: 'error_message'}

    actual_metrics = (
        metrics_plots_and_validations_writer.convert_slice_metrics_to_proto(
            (slice_key, slice_metrics),
            [post_export_metrics.auc(),
             post_export_metrics.auc(curve='PR')]))

    expected_metrics = metrics_for_slice_pb2.MetricsForSlice()
    expected_metrics.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))
    expected_metrics.metrics[
        metric_keys.ERROR_METRIC].debug_message = 'error_message'
    self.assertProtoEquals(expected_metrics, actual_metrics)
  def testSerializeMetrics_emptyMetrics(self):
    slice_key = _make_slice_key('age', 5, 'language', 'english', 'price', 0.3)
    slice_metrics = {metric_keys.ERROR_METRIC: 'error_message'}

    actual_metrics = metrics_and_plots_serialization._serialize_metrics(
        (slice_key, slice_metrics),
        [post_export_metrics.auc(),
         post_export_metrics.auc(curve='PR')])

    expected_metrics = metrics_for_slice_pb2.MetricsForSlice()
    expected_metrics.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))
    expected_metrics.metrics[
        metric_keys.ERROR_METRIC].debug_message = 'error_message'
    self.assertProtoEquals(
        expected_metrics,
        metrics_for_slice_pb2.MetricsForSlice.FromString(actual_metrics))
예제 #12
0
  def testConvertSliceMetricsToProtoStringMetrics(self):
    slice_key = _make_slice_key()
    slice_metrics = {
        'valid_ascii': b'test string',
        'valid_unicode': b'\xF0\x9F\x90\x84',  # U+1F404, Cow
        'invalid_unicode': b'\xE2\x28\xA1',
    }
    expected_metrics_for_slice = metrics_for_slice_pb2.MetricsForSlice()
    expected_metrics_for_slice.slice_key.SetInParent()
    expected_metrics_for_slice.metrics[
        'valid_ascii'].bytes_value = slice_metrics['valid_ascii']
    expected_metrics_for_slice.metrics[
        'valid_unicode'].bytes_value = slice_metrics['valid_unicode']
    expected_metrics_for_slice.metrics[
        'invalid_unicode'].bytes_value = slice_metrics['invalid_unicode']

    got = metrics_plots_and_validations_writer.convert_slice_metrics_to_proto(
        (slice_key, slice_metrics), [])
    self.assertProtoEquals(expected_metrics_for_slice, got)
예제 #13
0
  def testStringMetrics(self):
    slice_key = _make_slice_key()
    slice_metrics = {
        'valid_ascii': b'test string',
        'valid_unicode': b'\xF0\x9F\x90\x84',  # U+1F404, Cow
        'invalid_unicode': b'\xE2\x28\xA1',
    }
    expected_metrics_for_slice = metrics_for_slice_pb2.MetricsForSlice()
    expected_metrics_for_slice.slice_key.SetInParent()
    expected_metrics_for_slice.metrics[
        'valid_ascii'].bytes_value = slice_metrics['valid_ascii']
    expected_metrics_for_slice.metrics[
        'valid_unicode'].bytes_value = slice_metrics['valid_unicode']
    expected_metrics_for_slice.metrics[
        'invalid_unicode'].bytes_value = slice_metrics['invalid_unicode']

    got = serialization._serialize_metrics((slice_key, slice_metrics), [])
    self.assertProtoEquals(
        expected_metrics_for_slice,
        metrics_for_slice_pb2.MetricsForSlice.FromString(got))
예제 #14
0
    def testSerializeMetrics(self):
        slice_key = _make_slice_key('age', 5, 'language', 'english', 'price',
                                    0.3)
        slice_metrics = {
            metric_types.MetricKey(name='accuracy', output_name='output_name'):
            0.8
        }
        expected_metrics_for_slice = text_format.Parse(
            """
        slice_key {
          single_slice_keys {
            column: 'age'
            int64_value: 5
          }
          single_slice_keys {
            column: 'language'
            bytes_value: 'english'
          }
          single_slice_keys {
            column: 'price'
            float_value: 0.3
          }
        }
        metric_keys_and_values {
          key {
            name: "accuracy"
            output_name: "output_name"
          }
          value {
            double_value {
              value: 0.8
            }
          }
        }""", metrics_for_slice_pb2.MetricsForSlice())

        got = metrics_and_plots_serialization._serialize_metrics(
            (slice_key, slice_metrics), None)
        self.assertProtoEquals(
            expected_metrics_for_slice,
            metrics_for_slice_pb2.MetricsForSlice.FromString(got))
예제 #15
0
        def check_result(got):  # pylint: disable=invalid-name
            try:
                self.assertEqual(1, len(got), 'got: %s' % got)
                (slice_key, value) = got[0]
                self.assertEqual((), slice_key)
                self.assertIn(metric_keys.PRECISION_RECALL_AT_K, value)
                table = value[metric_keys.PRECISION_RECALL_AT_K]
                cutoffs = table[:, 0].tolist()
                precision = table[:, 1].tolist()
                recall = table[:, 2].tolist()

                self.assertEqual(cutoffs, [0, 1, 2, 3, 5])
                self.assertSequenceAlmostEqual(
                    precision,
                    [4.0 / 9.0, 2.0 / 3.0, 2.0 / 6.0, 4.0 / 9.0, 4.0 / 9.0])
                self.assertSequenceAlmostEqual(
                    recall,
                    [4.0 / 4.0, 2.0 / 4.0, 2.0 / 4.0, 4.0 / 4.0, 4.0 / 4.0])

                # Check serialization too.
                # Note that we can't just make this a dict, since proto maps
                # allow uninitialized key access, i.e. they act like defaultdicts.
                output_metrics = metrics_for_slice_pb2.MetricsForSlice(
                ).metrics
                precision_recall_metric.populate_stats_and_pop(
                    value, output_metrics)
                self.assertProtoEquals(
                    """
            value_at_cutoffs {
              values {
                cutoff: 0
                value: 0.44444444
              }
            }
            value_at_cutoffs {
              values {
                cutoff: 1
                value: 0.66666666
              }
            }
            value_at_cutoffs {
              values {
                cutoff: 2
                value: 0.33333333
              }
            }
            value_at_cutoffs {
              values {
                cutoff: 3
                value: 0.44444444
              }
            }
            value_at_cutoffs {
              values {
                cutoff: 5
                value: 0.44444444
              }
            }
            """, output_metrics[metric_keys.PRECISION_AT_K])
                self.assertProtoEquals(
                    """
            value_at_cutoffs {
              values {
                cutoff: 0
                value: 1.0
              }
            }
            value_at_cutoffs {
              values {
                cutoff: 1
                value: 0.5
              }
            }
            value_at_cutoffs {
              values {
                cutoff: 2
                value: 0.5
              }
            }
            value_at_cutoffs {
              values {
                cutoff: 3
                value: 1.0
              }
            }
            value_at_cutoffs {
              values {
                cutoff: 5
                value: 1.0
              }
            }
            """, output_metrics[metric_keys.RECALL_AT_K])
            except AssertionError as err:
                raise util.BeamAssertException(err)
    def testSerializeMetrics(self):
        slice_key = _make_slice_key('age', 5, 'language', 'english', 'price',
                                    0.3)
        slice_metrics = {
            'accuracy': 0.8,
            _full_key(metric_keys.AUPRC): 0.1,
            _full_key(metric_keys.lower_bound(metric_keys.AUPRC)): 0.05,
            _full_key(metric_keys.upper_bound(metric_keys.AUPRC)): 0.17,
            _full_key(metric_keys.AUC): 0.2,
            _full_key(metric_keys.lower_bound(metric_keys.AUC)): 0.1,
            _full_key(metric_keys.upper_bound(metric_keys.AUC)): 0.3
        }
        expected_metrics_for_slice = text_format.Parse(
            string.Template("""
        slice_key {
          single_slice_keys {
            column: 'age'
            int64_value: 5
          }
          single_slice_keys {
            column: 'language'
            bytes_value: 'english'
          }
          single_slice_keys {
            column: 'price'
            float_value: 0.3
          }
        }
        metrics {
          key: "accuracy"
          value {
            double_value {
              value: 0.8
            }
          }
        }
        metrics {
          key: "$auc"
          value {
            bounded_value {
              lower_bound {
                value: 0.1
              }
              upper_bound {
                value: 0.3
              }
              value {
                value: 0.2
              }
              methodology: RIEMANN_SUM
            }
          }
        }
        metrics {
          key: "$auprc"
          value {
            bounded_value {
              lower_bound {
                value: 0.05
              }
              upper_bound {
                value: 0.17
              }
              value {
                value: 0.1
              }
              methodology: RIEMANN_SUM
            }
          }
        }""").substitute(auc=_full_key(metric_keys.AUC),
                         auprc=_full_key(metric_keys.AUPRC)),
            metrics_for_slice_pb2.MetricsForSlice())

        got = metrics_and_plots_evaluator._serialize_metrics(
            (slice_key, slice_metrics),
            [post_export_metrics.auc(),
             post_export_metrics.auc(curve='PR')])
        self.assertProtoEquals(
            expected_metrics_for_slice,
            metrics_for_slice_pb2.MetricsForSlice.FromString(got))
    def testSerializeConfusionMatrices(self):
        slice_key = _make_slice_key()

        thresholds = [0.25, 0.75, 1.00]
        matrices = [[0.0, 1.0, 0.0, 2.0, 1.0, 1.0],
                    [1.0, 1.0, 0.0, 1.0, 1.0, 0.5],
                    [2.0, 1.0, 0.0, 0.0, float('nan'), 0.0]]

        slice_metrics = {
            _full_key(metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_MATRICES):
            matrices,
            _full_key(metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_THRESHOLDS):
            thresholds,
        }
        expected_metrics_for_slice = text_format.Parse(
            """
        slice_key {}
        metrics {
          key: "post_export_metrics/confusion_matrix_at_thresholds"
          value {
            confusion_matrix_at_thresholds {
              matrices {
                threshold: 0.25
                false_negatives: 0.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 2.0
                precision: 1.0
                recall: 1.0
                bounded_false_negatives {
                  value {
                    value: 0.0
                  }
                }
                bounded_true_negatives {
                  value {
                    value: 1.0
                  }
                }
                bounded_true_positives {
                  value {
                    value: 2.0
                  }
                }
                bounded_false_positives {
                  value {
                    value: 0.0
                  }
                }
                bounded_precision {
                  value {
                    value: 1.0
                  }
                }
                bounded_recall {
                  value {
                    value: 1.0
                  }
                }
              }
              matrices {
                threshold: 0.75
                false_negatives: 1.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 1.0
                precision: 1.0
                recall: 0.5
                bounded_false_negatives {
                  value {
                    value: 1.0
                  }
                }
                bounded_true_negatives {
                  value {
                    value: 1.0
                  }
                }
                bounded_true_positives {
                  value {
                    value: 1.0
                  }
                }
                bounded_false_positives {
                  value {
                    value: 0.0
                  }
                }
                bounded_precision {
                  value {
                    value: 1.0
                  }
                }
                bounded_recall {
                  value {
                    value: 0.5
                  }
                }
              }
              matrices {
                threshold: 1.00
                false_negatives: 2.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 0.0
                precision: nan
                recall: 0.0
                bounded_false_negatives {
                  value {
                    value: 2.0
                  }
                }
                bounded_true_negatives {
                  value {
                    value: 1.0
                  }
                }
                bounded_true_positives {
                  value {
                    value: 0.0
                  }
                }
                bounded_false_positives {
                  value {
                    value: 0.0
                  }
                }
                bounded_precision {
                  value {
                    value: nan
                  }
                }
                bounded_recall {
                  value {
                    value: 0.0
                  }
                }
              }
            }
          }
        }
        """, metrics_for_slice_pb2.MetricsForSlice())

        got = metrics_and_plots_evaluator._serialize_metrics(
            (slice_key, slice_metrics),
            [post_export_metrics.confusion_matrix_at_thresholds(thresholds)])
        self.assertProtoEquals(
            expected_metrics_for_slice,
            metrics_for_slice_pb2.MetricsForSlice.FromString(got))
예제 #18
0
    def testLoadMetricsAsDataframe_DoubleValueOnly(self):
        metrics_for_slice = text_format.Parse(
            """
        slice_key {
           single_slice_keys {
             column: "age"
             float_value: 38.0
           }
           single_slice_keys {
             column: "sex"
             bytes_value: "Female"
           }
         }
         metric_keys_and_values {
           key {
             name: "mean_absolute_error"
             example_weighted {
             }
           }
           value {
             double_value {
               value: 0.1
             }
           }
         }
         metric_keys_and_values {
           key {
             name: "mean_squared_logarithmic_error"
             example_weighted {
             }
           }
           value {
             double_value {
               value: 0.02
             }
           }
         }
         """, metrics_for_slice_pb2.MetricsForSlice())
        path = os.path.join(absltest.get_default_test_tmpdir(),
                            'metrics.tfrecord')
        with tf.io.TFRecordWriter(path) as writer:
            writer.write(metrics_for_slice.SerializeToString())
        df = experimental.load_metrics_as_dataframe(path)

        expected = pd.DataFrame({
            'slice':
            ['age = 38.0; sex = b\'Female\'', 'age = 38.0; sex = b\'Female\''],
            'name': ['mean_absolute_error', 'mean_squared_logarithmic_error'],
            'model_name': ['', ''],
            'output_name': ['', ''],
            'example_weighted': [False, False],
            'is_diff': [False, False],
            'display_value': [str(0.1), str(0.02)],
            'metric_value': [
                metrics_for_slice_pb2.MetricValue(double_value={'value': 0.1}),
                metrics_for_slice_pb2.MetricValue(double_value={'value': 0.02})
            ],
        })
        pd.testing.assert_frame_equal(expected, df)

        # Include empty column.
        df = experimental.load_metrics_as_dataframe(path,
                                                    include_empty_columns=True)
        expected = pd.DataFrame({
            'slice':
            ['age = 38.0; sex = b\'Female\'', 'age = 38.0; sex = b\'Female\''],
            'name': ['mean_absolute_error', 'mean_squared_logarithmic_error'],
            'model_name': ['', ''],
            'output_name': ['', ''],
            'sub_key': [None, None],
            'aggregation_type': [None, None],
            'example_weighted': [False, False],
            'is_diff': [False, False],
            'display_value': [str(0.1), str(0.02)],
            'metric_value': [
                metrics_for_slice_pb2.MetricValue(double_value={'value': 0.1}),
                metrics_for_slice_pb2.MetricValue(double_value={'value': 0.02})
            ],
            'confidence_interval': [None, None],
        })
        pd.testing.assert_frame_equal(expected, df)
  def testWriteMetricsAndPlots(self):
    metrics_file = os.path.join(self._getTempDir(), 'metrics')
    plots_file = os.path.join(self._getTempDir(), 'plots')
    temp_eval_export_dir = os.path.join(self._getTempDir(), 'eval_export_dir')

    _, eval_export_dir = (
        fixed_prediction_estimator.simple_fixed_prediction_estimator(
            None, temp_eval_export_dir))
    eval_config = config.EvalConfig(
        model_specs=[config.ModelSpec()],
        options=config.Options(
            disabled_outputs={'values': ['eval_config.json']}))
    eval_shared_model = self.createTestEvalSharedModel(
        eval_saved_model_path=eval_export_dir,
        add_metrics_callbacks=[
            post_export_metrics.example_count(),
            post_export_metrics.calibration_plot_and_prediction_histogram(
                num_buckets=2)
        ])
    extractors = [
        predict_extractor.PredictExtractor(eval_shared_model),
        slice_key_extractor.SliceKeyExtractor()
    ]
    evaluators = [
        metrics_and_plots_evaluator.MetricsAndPlotsEvaluator(eval_shared_model)
    ]
    output_paths = {
        constants.METRICS_KEY: metrics_file,
        constants.PLOTS_KEY: plots_file
    }
    writers = [
        metrics_plots_and_validations_writer.MetricsPlotsAndValidationsWriter(
            output_paths, eval_shared_model.add_metrics_callbacks)
    ]

    with beam.Pipeline() as pipeline:
      example1 = self._makeExample(prediction=0.0, label=1.0)
      example2 = self._makeExample(prediction=1.0, label=1.0)

      # pylint: disable=no-value-for-parameter
      _ = (
          pipeline
          | 'Create' >> beam.Create([
              example1.SerializeToString(),
              example2.SerializeToString(),
          ])
          | 'ExtractEvaluateAndWriteResults' >>
          model_eval_lib.ExtractEvaluateAndWriteResults(
              eval_config=eval_config,
              eval_shared_model=eval_shared_model,
              extractors=extractors,
              evaluators=evaluators,
              writers=writers))
      # pylint: enable=no-value-for-parameter

    expected_metrics_for_slice = text_format.Parse(
        """
        slice_key {}
        metrics {
          key: "average_loss"
          value {
            double_value {
              value: 0.5
            }
          }
        }
        metrics {
          key: "post_export_metrics/example_count"
          value {
            double_value {
              value: 2.0
            }
          }
        }
        """, metrics_for_slice_pb2.MetricsForSlice())

    metric_records = []
    for record in tf.compat.v1.python_io.tf_record_iterator(metrics_file):
      metric_records.append(
          metrics_for_slice_pb2.MetricsForSlice.FromString(record))
    self.assertEqual(1, len(metric_records), 'metrics: %s' % metric_records)
    self.assertProtoEquals(expected_metrics_for_slice, metric_records[0])

    expected_plots_for_slice = text_format.Parse(
        """
      slice_key {}
      plots {
        key: "post_export_metrics"
        value {
          calibration_histogram_buckets {
            buckets {
              lower_threshold_inclusive: -inf
              num_weighted_examples {}
              total_weighted_label {}
              total_weighted_refined_prediction {}
            }
            buckets {
              upper_threshold_exclusive: 0.5
              num_weighted_examples {
                value: 1.0
              }
              total_weighted_label {
                value: 1.0
              }
              total_weighted_refined_prediction {}
            }
            buckets {
              lower_threshold_inclusive: 0.5
              upper_threshold_exclusive: 1.0
              num_weighted_examples {
              }
              total_weighted_label {}
              total_weighted_refined_prediction {}
            }
            buckets {
              lower_threshold_inclusive: 1.0
              upper_threshold_exclusive: inf
              num_weighted_examples {
                value: 1.0
              }
              total_weighted_label {
                value: 1.0
              }
              total_weighted_refined_prediction {
                value: 1.0
              }
            }
         }
        }
      }
    """, metrics_for_slice_pb2.PlotsForSlice())

    plot_records = []
    for record in tf.compat.v1.python_io.tf_record_iterator(plots_file):
      plot_records.append(
          metrics_for_slice_pb2.PlotsForSlice.FromString(record))
    self.assertEqual(1, len(plot_records), 'plots: %s' % plot_records)
    self.assertProtoEquals(expected_plots_for_slice, plot_records[0])
 def test_find_top_slices(self):
     statistics = text_format.Parse(
         """
     datasets{
       num_examples: 1500
       features {
         path { step: 'country' }
         type: STRING
         string_stats {
           unique: 10
         }
       }
       features {
         path { step: 'age' }
         type: INT
         num_stats {
           common_stats {
             num_non_missing: 1500
             min_num_values: 1
             max_num_values: 1
           }
           min: 1
           max: 18
           histograms {
             buckets {
               low_value: 1
               high_value: 6.0
               sample_count: 500
             }
             buckets {
               low_value: 6.0
               high_value: 12.0
               sample_count: 500
             }
             buckets {
               low_value: 12.0
               high_value: 18.0
               sample_count: 500
             }
             type: QUANTILES
           }
         }
       }
     }
     """, statistics_pb2.DatasetFeatureStatisticsList())
     metrics = [
         text_format.Parse(
             """
     slice_key {
     }
     metric_keys_and_values {
       key { name: "accuracy" }
       value {
         bounded_value {
           value { value: 0.8 }
           lower_bound { value: 0.5737843 }
           upper_bound { value: 1.0262157 }
           methodology: POISSON_BOOTSTRAP
         }
         confidence_interval {
           lower_bound { value: 0.5737843 }
           upper_bound { value: 1.0262157 }
           t_distribution_value {
             sample_mean { value: 0.8 }
             sample_standard_deviation { value: 0.1 }
             sample_degrees_of_freedom { value: 9 }
             unsampled_value { value: 0.8 }
           }
         }
       }
     }
     metric_keys_and_values {
       key { name: "example_count" }
       value {
         bounded_value {
           value { value: 1500 }
           lower_bound { value: 1500 }
           upper_bound { value: 1500 }
           methodology: POISSON_BOOTSTRAP
         }
         confidence_interval {
           lower_bound { value: 1500 }
           upper_bound { value: 1500 }
           t_distribution_value {
             sample_mean { value: 1500 }
             sample_standard_deviation { value: 0 }
             sample_degrees_of_freedom { value: 9 }
             unsampled_value { value: 1500 }
           }
         }
       }
     }
     """, metrics_for_slice_pb2.MetricsForSlice()),
         text_format.Parse(
             """
     slice_key {
       single_slice_keys {
         column: 'transformed_age'
         int64_value: 1
       }
     }
     metric_keys_and_values {
       key { name: "accuracy" }
       value {
         bounded_value {
           value { value: 0.4 }
           lower_bound { value: 0.3737843 }
           upper_bound { value: 0.6262157 }
           methodology: POISSON_BOOTSTRAP
         }
         confidence_interval {
           lower_bound { value: 0.3737843 }
           upper_bound { value: 0.6262157 }
           t_distribution_value {
             sample_mean { value: 0.4 }
             sample_standard_deviation { value: 0.1 }
             sample_degrees_of_freedom { value: 9 }
             unsampled_value { value: 0.4 }
           }
         }
       }
     }
     metric_keys_and_values {
       key { name: "example_count" }
       value {
         bounded_value {
           value { value: 500 }
           lower_bound { value: 500 }
           upper_bound { value: 500 }
           methodology: POISSON_BOOTSTRAP
         }
         confidence_interval {
           lower_bound { value: 500 }
           upper_bound { value: 500 }
           t_distribution_value {
             sample_mean { value: 500 }
             sample_standard_deviation { value: 0 }
             sample_degrees_of_freedom { value: 9 }
             unsampled_value { value: 500 }
           }
         }
       }
     }
     """, metrics_for_slice_pb2.MetricsForSlice()),
         text_format.Parse(
             """
     slice_key {
       single_slice_keys {
         column: 'transformed_age'
         int64_value: 2
       }
     }
     metric_keys_and_values {
       key { name: "accuracy" }
       value {
         bounded_value {
           value { value: 0.79 }
           lower_bound { value: 0.5737843 }
           upper_bound { value: 1.0262157 }
           methodology: POISSON_BOOTSTRAP
         }
         confidence_interval {
           lower_bound { value: 0.5737843 }
           upper_bound { value: 1.0262157 }
           t_distribution_value {
             sample_mean { value: 0.79 }
             sample_standard_deviation { value: 0.1 }
             sample_degrees_of_freedom { value: 9 }
             unsampled_value { value: 0.79 }
           }
         }
       }
     }
     metric_keys_and_values {
       key { name: "example_count" }
       value {
         bounded_value {
           value { value: 500 }
           lower_bound { value: 500 }
           upper_bound { value: 500 }
           methodology: POISSON_BOOTSTRAP
         }
         confidence_interval {
           lower_bound { value: 500 }
           upper_bound { value: 500 }
           t_distribution_value {
             sample_mean { value: 500 }
             sample_standard_deviation { value: 0 }
             sample_degrees_of_freedom { value: 9 }
             unsampled_value { value: 500}
           }
         }
       }
     }
     """, metrics_for_slice_pb2.MetricsForSlice()),
         text_format.Parse(
             """
     slice_key {
       single_slice_keys {
         column: 'transformed_age'
         int64_value: 3
       }
     }
     metric_keys_and_values {
       key { name: "accuracy" }
       value {
         bounded_value {
           value { value: 0.9 }
           lower_bound { value: 0.5737843 }
           upper_bound { value: 1.0262157 }
           methodology: POISSON_BOOTSTRAP
         }
         confidence_interval {
           lower_bound { value: 0.5737843 }
           upper_bound { value: 1.0262157 }
           t_distribution_value {
             sample_mean { value: 0.9 }
             sample_standard_deviation { value: 0.1 }
             sample_degrees_of_freedom { value: 9 }
             unsampled_value { value: 0.9 }
           }
         }
       }
     }
     metric_keys_and_values {
       key { name: "example_count" }
       value {
         bounded_value {
           value { value: 500 }
           lower_bound { value: 500 }
           upper_bound { value: 500 }
           methodology: POISSON_BOOTSTRAP
         }
         confidence_interval {
           lower_bound { value: 500 }
           upper_bound { value: 500 }
           t_distribution_value {
             sample_mean { value: 500 }
             sample_standard_deviation { value: 0 }
             sample_degrees_of_freedom { value: 9 }
             unsampled_value { value: 500}
           }
         }
       }
     }
     """, metrics_for_slice_pb2.MetricsForSlice()),
         text_format.Parse(
             """
     slice_key {
       single_slice_keys {
         column: 'country'
         bytes_value: 'USA'
       }
     }
     metric_keys_and_values {
       key { name: "accuracy" }
       value {
         bounded_value {
           value { value: 0.9 }
           lower_bound { value: 0.5737843 }
           upper_bound { value: 1.0262157 }
           methodology: POISSON_BOOTSTRAP
         }
         confidence_interval {
           lower_bound { value: 0.5737843 }
           upper_bound { value: 1.0262157 }
           t_distribution_value {
             sample_mean { value: 0.9 }
             sample_standard_deviation { value: 0.1 }
             sample_degrees_of_freedom { value: 9 }
             unsampled_value { value: 0.9 }
           }
         }
       }
     }
     metric_keys_and_values {
       key { name: "example_count" }
       value {
         bounded_value {
           value { value: 500 }
           lower_bound { value: 500 }
           upper_bound { value: 500 }
           methodology: POISSON_BOOTSTRAP
         }
         confidence_interval {
           lower_bound { value: 500 }
           upper_bound { value: 500 }
           t_distribution_value {
             sample_mean { value: 500 }
             sample_standard_deviation { value: 0 }
             sample_degrees_of_freedom { value: 9 }
             unsampled_value { value: 500}
           }
         }
       }
     }
     """, metrics_for_slice_pb2.MetricsForSlice())
     ]
     self.assertCountEqual(
         auto_slicing_util.find_top_slices(metrics,
                                           metric_key='accuracy',
                                           statistics=statistics,
                                           comparison_type='LOWER'),
         [
             auto_slicing_util.SliceComparisonResult(
                 slice_key=u'age:[1.0, 6.0]',
                 num_examples=500.0,
                 slice_metric=0.4,
                 base_metric=0.8,
                 pvalue=0.0,
                 effect_size=4.0)
         ])
     self.assertCountEqual(
         auto_slicing_util.find_top_slices(metrics,
                                           metric_key='accuracy',
                                           statistics=statistics,
                                           comparison_type='HIGHER'),
         [
             auto_slicing_util.SliceComparisonResult(
                 slice_key=u'age:[12.0, 18.0]',
                 num_examples=500.0,
                 slice_metric=0.9,
                 base_metric=0.8,
                 pvalue=7.356017854191938e-70,
                 effect_size=0.9999999999999996),
             auto_slicing_util.SliceComparisonResult(
                 slice_key=u'country:USA',
                 num_examples=500.0,
                 slice_metric=0.9,
                 base_metric=0.8,
                 pvalue=7.356017854191938e-70,
                 effect_size=0.9999999999999996)
         ])
예제 #21
0
  def testConvertSliceMetricsToProtoFromLegacyStrings(self):
    slice_key = _make_slice_key('age', 5, 'language', 'english', 'price', 0.3)
    slice_metrics = {
        'accuracy': 0.8,
        metric_keys.AUPRC: 0.1,
        metric_keys.lower_bound_key(metric_keys.AUPRC): 0.05,
        metric_keys.upper_bound_key(metric_keys.AUPRC): 0.17,
        metric_keys.AUC: 0.2,
        metric_keys.lower_bound_key(metric_keys.AUC): 0.1,
        metric_keys.upper_bound_key(metric_keys.AUC): 0.3
    }
    expected_metrics_for_slice = text_format.Parse(
        string.Template("""
        slice_key {
          single_slice_keys {
            column: 'age'
            int64_value: 5
          }
          single_slice_keys {
            column: 'language'
            bytes_value: 'english'
          }
          single_slice_keys {
            column: 'price'
            float_value: 0.3
          }
        }
        metrics {
          key: "accuracy"
          value {
            double_value {
              value: 0.8
            }
          }
        }
        metrics {
          key: "$auc"
          value {
            bounded_value {
              lower_bound {
                value: 0.1
              }
              upper_bound {
                value: 0.3
              }
              value {
                value: 0.2
              }
              methodology: RIEMANN_SUM
            }
          }
        }
        metrics {
          key: "$auprc"
          value {
            bounded_value {
              lower_bound {
                value: 0.05
              }
              upper_bound {
                value: 0.17
              }
              value {
                value: 0.1
              }
              methodology: RIEMANN_SUM
            }
          }
        }""").substitute(auc=metric_keys.AUC, auprc=metric_keys.AUPRC),
        metrics_for_slice_pb2.MetricsForSlice())

    got = metrics_plots_and_validations_writer.convert_slice_metrics_to_proto(
        (slice_key, slice_metrics),
        [post_export_metrics.auc(),
         post_export_metrics.auc(curve='PR')])
    self.assertProtoEquals(expected_metrics_for_slice, got)
    def testSerializeDeserializeToFile(self):
        metrics_slice_key = _make_slice_key(b'fruit', b'pear', b'animal',
                                            b'duck')
        metrics_for_slice = text_format.Parse(
            """
        slice_key {
          single_slice_keys {
            column: "fruit"
            bytes_value: "pear"
          }
          single_slice_keys {
            column: "animal"
            bytes_value: "duck"
          }
        }
        metrics {
          key: "accuracy"
          value {
            double_value {
              value: 0.8
            }
          }
        }
        metrics {
          key: "example_weight"
          value {
            double_value {
              value: 10.0
            }
          }
        }
        metrics {
          key: "auc"
          value {
            bounded_value {
              lower_bound {
                value: 0.1
              }
              upper_bound {
                value: 0.3
              }
              value {
                value: 0.2
              }
            }
          }
        }
        metrics {
          key: "auprc"
          value {
            bounded_value {
              lower_bound {
                value: 0.05
              }
              upper_bound {
                value: 0.17
              }
              value {
                value: 0.1
              }
            }
          }
        }""", metrics_for_slice_pb2.MetricsForSlice())
        plots_for_slice = text_format.Parse(
            """
        slice_key {
          single_slice_keys {
            column: "fruit"
            bytes_value: "peach"
          }
          single_slice_keys {
            column: "animal"
            bytes_value: "cow"
          }
        }
        plots {
          key: ''
          value {
            calibration_histogram_buckets {
              buckets {
                lower_threshold_inclusive: -inf
                upper_threshold_exclusive: 0.0
                num_weighted_examples { value: 0.0 }
                total_weighted_label { value: 0.0 }
                total_weighted_refined_prediction { value: 0.0 }
              }
              buckets {
                lower_threshold_inclusive: 0.0
                upper_threshold_exclusive: 0.5
                num_weighted_examples { value: 1.0 }
                total_weighted_label { value: 1.0 }
                total_weighted_refined_prediction { value: 0.3 }
              }
              buckets {
                lower_threshold_inclusive: 0.5
                upper_threshold_exclusive: 1.0
                num_weighted_examples { value: 1.0 }
                total_weighted_label { value: 0.0 }
                total_weighted_refined_prediction { value: 0.7 }
              }
              buckets {
                lower_threshold_inclusive: 1.0
                upper_threshold_exclusive: inf
                num_weighted_examples { value: 0.0 }
                total_weighted_label { value: 0.0 }
                total_weighted_refined_prediction { value: 0.0 }
              }
            }
          }
        }""", metrics_for_slice_pb2.PlotsForSlice())
        plots_slice_key = _make_slice_key(b'fruit', b'peach', b'animal',
                                          b'cow')
        eval_config = model_eval_lib.EvalConfig(
            model_location='/path/to/model',
            data_location='/path/to/data',
            slice_spec=[
                slicer.SingleSliceSpec(features=[('age', 5), ('gender', 'f')],
                                       columns=['country']),
                slicer.SingleSliceSpec(features=[('age', 6), ('gender', 'm')],
                                       columns=['interest'])
            ],
            example_weight_metric_key='key')

        output_path = self._getTempDir()
        with beam.Pipeline() as pipeline:
            metrics = (pipeline
                       | 'CreateMetrics' >> beam.Create(
                           [metrics_for_slice.SerializeToString()]))
            plots = (pipeline
                     | 'CreatePlots' >> beam.Create(
                         [plots_for_slice.SerializeToString()]))
            evaluation = {
                constants.METRICS_KEY: metrics,
                constants.PLOTS_KEY: plots
            }
            _ = (evaluation
                 | 'WriteResults' >> model_eval_lib.WriteResults(
                     writers=model_eval_lib.default_writers(
                         output_path=output_path)))
            _ = pipeline | model_eval_lib.WriteEvalConfig(
                eval_config, output_path)

        metrics = metrics_and_plots_evaluator.load_and_deserialize_metrics(
            path=os.path.join(output_path,
                              model_eval_lib._METRICS_OUTPUT_FILE))
        plots = metrics_and_plots_evaluator.load_and_deserialize_plots(
            path=os.path.join(output_path, model_eval_lib._PLOTS_OUTPUT_FILE))
        self.assertSliceMetricsListEqual(
            [(metrics_slice_key, metrics_for_slice.metrics)], metrics)
        self.assertSlicePlotsListEqual(
            [(plots_slice_key, plots_for_slice.plots)], plots)
        got_eval_config = model_eval_lib.load_eval_config(output_path)
        self.assertEqual(eval_config, got_eval_config)
예제 #23
0
    def testSerializeMetricsRanges(self):
        slice_key = _make_slice_key('age', 5, 'language', 'english', 'price',
                                    0.3)
        slice_metrics = {
            'accuracy': types.ValueWithTDistribution(0.8, 0.1, 9, 0.8),
            metric_keys.AUPRC: 0.1,
            metric_keys.lower_bound_key(metric_keys.AUPRC): 0.05,
            metric_keys.upper_bound_key(metric_keys.AUPRC): 0.17,
            metric_keys.AUC: 0.2,
            metric_keys.lower_bound_key(metric_keys.AUC): 0.1,
            metric_keys.upper_bound_key(metric_keys.AUC): 0.3
        }
        expected_metrics_for_slice = text_format.Parse(
            string.Template("""
        slice_key {
          single_slice_keys {
            column: 'age'
            int64_value: 5
          }
          single_slice_keys {
            column: 'language'
            bytes_value: 'english'
          }
          single_slice_keys {
            column: 'price'
            float_value: 0.3
          }
        }
        metrics {
          key: "accuracy"
          value {
            bounded_value {
              value {
                value: 0.8
              }
              lower_bound {
                value: 0.5737843
              }
              upper_bound {
                value: 1.0262157
              }
              methodology: POISSON_BOOTSTRAP
            }
          }
        }
        metrics {
          key: "$auc"
          value {
            bounded_value {
              lower_bound {
                value: 0.1
              }
              upper_bound {
                value: 0.3
              }
              value {
                value: 0.2
              }
              methodology: RIEMANN_SUM
            }
          }
        }
        metrics {
          key: "$auprc"
          value {
            bounded_value {
              lower_bound {
                value: 0.05
              }
              upper_bound {
                value: 0.17
              }
              value {
                value: 0.1
              }
              methodology: RIEMANN_SUM
            }
          }
        }""").substitute(auc=metric_keys.AUC, auprc=metric_keys.AUPRC),
            metrics_for_slice_pb2.MetricsForSlice())

        got = metrics_and_plots_serialization._serialize_metrics(
            (slice_key, slice_metrics),
            [post_export_metrics.auc(),
             post_export_metrics.auc(curve='PR')])
        self.assertProtoEquals(
            expected_metrics_for_slice,
            metrics_for_slice_pb2.MetricsForSlice.FromString(got))
예제 #24
0
 def testUncertaintyValuedMetrics(self):
   slice_key = _make_slice_key()
   slice_metrics = {
       'one_dim':
           types.ValueWithTDistribution(2.0, 1.0, 3, 2.0),
       'nans':
           types.ValueWithTDistribution(
               float('nan'), float('nan'), -1, float('nan')),
   }
   expected_metrics_for_slice = text_format.Parse(
       """
       slice_key {}
       metrics {
         key: "one_dim"
         value {
           bounded_value {
             value {
               value: 2.0
             }
             lower_bound {
               value: -1.1824463
             }
             upper_bound {
               value: 5.1824463
             }
             methodology: POISSON_BOOTSTRAP
           }
           confidence_interval {
             lower_bound {
               value: -1.1824463
             }
             upper_bound {
               value: 5.1824463
             }
             t_distribution_value {
               sample_mean {
                 value: 2.0
               }
               sample_standard_deviation {
                 value: 1.0
               }
               sample_degrees_of_freedom {
                 value: 3
               }
               unsampled_value {
                 value: 2.0
               }
             }
           }
         }
       }
       metrics {
         key: "nans"
         value {
           bounded_value {
             value {
               value: nan
             }
             lower_bound {
               value: nan
             }
             upper_bound {
               value: nan
             }
             methodology: POISSON_BOOTSTRAP
           }
           confidence_interval {
             lower_bound {
               value: nan
             }
             upper_bound {
               value: nan
             }
             t_distribution_value {
               sample_mean {
                 value: nan
               }
               sample_standard_deviation {
                 value: nan
               }
               sample_degrees_of_freedom {
                 value: -1
               }
               unsampled_value {
                 value: nan
               }
             }
           }
         }
       }
       """, metrics_for_slice_pb2.MetricsForSlice())
   got = metrics_plots_and_validations_writer.convert_slice_metrics_to_proto(
       (slice_key, slice_metrics), [])
   self.assertProtoEquals(expected_metrics_for_slice, got)
    def test_find_significant_slices(self):
        metrics = [
            text_format.Parse(
                """
        slice_key {
        }
        metric_keys_and_values {
          key { name: "accuracy" }
          value {
            bounded_value {
              value { value: 0.8 }
              lower_bound { value: 0.5737843 }
              upper_bound { value: 1.0262157 }
              methodology: POISSON_BOOTSTRAP
            }
            confidence_interval {
              lower_bound { value: 0.5737843 }
              upper_bound { value: 1.0262157 }
              t_distribution_value {
                sample_mean { value: 0.8 }
                sample_standard_deviation { value: 0.1 }
                sample_degrees_of_freedom { value: 9 }
                unsampled_value { value: 0.8 }
              }
            }
          }
        }
        metric_keys_and_values {
          key { name: "example_count" }
          value {
            bounded_value {
              value { value: 1500 }
              lower_bound { value: 1500 }
              upper_bound { value: 1500 }
              methodology: POISSON_BOOTSTRAP
            }
            confidence_interval {
              lower_bound { value: 1500 }
              upper_bound { value: 1500 }
              t_distribution_value {
                sample_mean { value: 1500 }
                sample_standard_deviation { value: 0 }
                sample_degrees_of_freedom { value: 9 }
                unsampled_value { value: 1500 }
              }
            }
          }
        }
        """, metrics_for_slice_pb2.MetricsForSlice()),
            text_format.Parse(
                """
        slice_key {
          single_slice_keys {
            column: 'age'
            bytes_value: '[1.0, 6.0)'
          }
        }
        metric_keys_and_values {
          key { name: "accuracy" }
          value {
            bounded_value {
              value { value: 0.4 }
              lower_bound { value: 0.3737843 }
              upper_bound { value: 0.6262157 }
              methodology: POISSON_BOOTSTRAP
            }
            confidence_interval {
              lower_bound { value: 0.3737843 }
              upper_bound { value: 0.6262157 }
              t_distribution_value {
                sample_mean { value: 0.4 }
                sample_standard_deviation { value: 0.1 }
                sample_degrees_of_freedom { value: 9 }
                unsampled_value { value: 0.4 }
              }
            }
          }
        }
        metric_keys_and_values {
          key { name: "example_count" }
          value {
            bounded_value {
              value { value: 500 }
              lower_bound { value: 500 }
              upper_bound { value: 500 }
              methodology: POISSON_BOOTSTRAP
            }
            confidence_interval {
              lower_bound { value: 500 }
              upper_bound { value: 500 }
              t_distribution_value {
                sample_mean { value: 500 }
                sample_standard_deviation { value: 0 }
                sample_degrees_of_freedom { value: 9 }
                unsampled_value { value: 500 }
              }
            }
          }
        }
        """, metrics_for_slice_pb2.MetricsForSlice()),
            text_format.Parse(
                """
        slice_key {
          single_slice_keys {
            column: 'age'
            bytes_value: '[6.0, 12.0)'
          }
        }
        metric_keys_and_values {
          key { name: "accuracy" }
          value {
            bounded_value {
              value { value: 0.79 }
              lower_bound { value: 0.5737843 }
              upper_bound { value: 1.0262157 }
              methodology: POISSON_BOOTSTRAP
            }
            confidence_interval {
              lower_bound { value: 0.5737843 }
              upper_bound { value: 1.0262157 }
              t_distribution_value {
                sample_mean { value: 0.79 }
                sample_standard_deviation { value: 0.1 }
                sample_degrees_of_freedom { value: 9 }
                unsampled_value { value: 0.79 }
              }
            }
          }
        }
        metric_keys_and_values {
          key { name: "example_count" }
          value {
            bounded_value {
              value { value: 500 }
              lower_bound { value: 500 }
              upper_bound { value: 500 }
              methodology: POISSON_BOOTSTRAP
            }
            confidence_interval {
              lower_bound { value: 500 }
              upper_bound { value: 500 }
              t_distribution_value {
                sample_mean { value: 500 }
                sample_standard_deviation { value: 0 }
                sample_degrees_of_freedom { value: 9 }
                unsampled_value { value: 500}
              }
            }
          }
        }
        """, metrics_for_slice_pb2.MetricsForSlice()),
            text_format.Parse(
                """
        slice_key {
          single_slice_keys {
            column: 'age'
            bytes_value: '[12.0, 18.0)'
          }
        }
        metric_keys_and_values {
          key { name: "accuracy" }
          value {
            bounded_value {
              value { value: 0.9 }
              lower_bound { value: 0.5737843 }
              upper_bound { value: 1.0262157 }
              methodology: POISSON_BOOTSTRAP
            }
            confidence_interval {
              lower_bound { value: 0.5737843 }
              upper_bound { value: 1.0262157 }
              t_distribution_value {
                sample_mean { value: 0.9 }
                sample_standard_deviation { value: 0.1 }
                sample_degrees_of_freedom { value: 9 }
                unsampled_value { value: 0.9 }
              }
            }
          }
        }
        metric_keys_and_values {
          key { name: "example_count" }
          value {
            bounded_value {
              value { value: 500 }
              lower_bound { value: 500 }
              upper_bound { value: 500 }
              methodology: POISSON_BOOTSTRAP
            }
            confidence_interval {
              lower_bound { value: 500 }
              upper_bound { value: 500 }
              t_distribution_value {
                sample_mean { value: 500 }
                sample_standard_deviation { value: 0 }
                sample_degrees_of_freedom { value: 9 }
                unsampled_value { value: 500}
              }
            }
          }
        }
        """, metrics_for_slice_pb2.MetricsForSlice()),
            text_format.Parse(
                """
        slice_key {
          single_slice_keys {
            column: 'country'
            bytes_value: 'USA'
          }
        }
        metric_keys_and_values {
          key { name: "accuracy" }
          value {
            bounded_value {
              value { value: 0.9 }
              lower_bound { value: 0.5737843 }
              upper_bound { value: 1.0262157 }
              methodology: POISSON_BOOTSTRAP
            }
            confidence_interval {
              lower_bound { value: 0.5737843 }
              upper_bound { value: 1.0262157 }
              t_distribution_value {
                sample_mean { value: 0.9 }
                sample_standard_deviation { value: 0.1 }
                sample_degrees_of_freedom { value: 9 }
                unsampled_value { value: 0.9 }
              }
            }
          }
        }
        metric_keys_and_values {
          key { name: "example_count" }
          value {
            bounded_value {
              value { value: 500 }
              lower_bound { value: 500 }
              upper_bound { value: 500 }
              methodology: POISSON_BOOTSTRAP
            }
            confidence_interval {
              lower_bound { value: 500 }
              upper_bound { value: 500 }
              t_distribution_value {
                sample_mean { value: 500 }
                sample_standard_deviation { value: 0 }
                sample_degrees_of_freedom { value: 9 }
                unsampled_value { value: 500}
              }
            }
          }
        }
        """, metrics_for_slice_pb2.MetricsForSlice()),
            text_format.Parse(
                """
        slice_key {
          single_slice_keys {
            column: 'country'
            bytes_value: 'USA'
          }
          single_slice_keys {
            column: 'age'
            bytes_value: '[12.0, 18.0)'
          }
        }
        metric_keys_and_values {
          key { name: "accuracy" }
          value {
            bounded_value {
              value { value: 0.9 }
              lower_bound { value: 0.5737843 }
              upper_bound { value: 1.0262157 }
              methodology: POISSON_BOOTSTRAP
            }
            confidence_interval {
              lower_bound { value: 0.5737843 }
              upper_bound { value: 1.0262157 }
              t_distribution_value {
                sample_mean { value: 0.9 }
                sample_standard_deviation { value: 0.1 }
                sample_degrees_of_freedom { value: 9 }
                unsampled_value { value: 0.9 }
              }
            }
          }
        }
        metric_keys_and_values {
          key { name: "example_count" }
          value {
            bounded_value {
              value { value: 500 }
              lower_bound { value: 500 }
              upper_bound { value: 500 }
              methodology: POISSON_BOOTSTRAP
            }
            confidence_interval {
              lower_bound { value: 500 }
              upper_bound { value: 500 }
              t_distribution_value {
                sample_mean { value: 500 }
                sample_standard_deviation { value: 0 }
                sample_degrees_of_freedom { value: 9 }
                unsampled_value { value: 500}
              }
            }
          }
        }
        """, metrics_for_slice_pb2.MetricsForSlice())
        ]
        result = auto_slicing_util.partition_slices(metrics,
                                                    metric_key='accuracy',
                                                    comparison_type='LOWER')
        self.assertCountEqual([s.slice_key for s in result[0]],
                              [(('age', '[1.0, 6.0)'), )])
        self.assertCountEqual([s.slice_key for s in result[1]],
                              [(('age', '[6.0, 12.0)'), ),
                               (('age', '[12.0, 18.0)'), ),
                               (('country', 'USA'), ),
                               (('country', 'USA'), ('age', '[12.0, 18.0)'))])

        result = auto_slicing_util.partition_slices(metrics,
                                                    metric_key='accuracy',
                                                    comparison_type='HIGHER')
        self.assertCountEqual([s.slice_key for s in result[0]],
                              [(('age', '[12.0, 18.0)'), ),
                               (('country', 'USA'), ),
                               (('country', 'USA'), ('age', '[12.0, 18.0)'))])
        self.assertCountEqual([s.slice_key for s in result[1]],
                              [(('age', '[1.0, 6.0)'), ),
                               (('age', '[6.0, 12.0)'), )])
예제 #26
0
  def testConvertSliceMetricsToProtoConfusionMatrices(self):
    slice_key = _make_slice_key()

    thresholds = [0.25, 0.75, 1.00]
    matrices = [[0.0, 1.0, 0.0, 2.0, 1.0, 1.0], [1.0, 1.0, 0.0, 1.0, 1.0, 0.5],
                [2.0, 1.0, 0.0, 0.0, float('nan'), 0.0]]

    slice_metrics = {
        metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_MATRICES: matrices,
        metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_THRESHOLDS: thresholds,
    }
    expected_metrics_for_slice = text_format.Parse(
        """
        slice_key {}
        metrics {
          key: "post_export_metrics/confusion_matrix_at_thresholds"
          value {
            confusion_matrix_at_thresholds {
              matrices {
                threshold: 0.25
                false_negatives: 0.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 2.0
                precision: 1.0
                recall: 1.0
                bounded_false_negatives {
                  value {
                    value: 0.0
                  }
                }
                bounded_true_negatives {
                  value {
                    value: 1.0
                  }
                }
                bounded_true_positives {
                  value {
                    value: 2.0
                  }
                }
                bounded_false_positives {
                  value {
                    value: 0.0
                  }
                }
                bounded_precision {
                  value {
                    value: 1.0
                  }
                }
                bounded_recall {
                  value {
                    value: 1.0
                  }
                }
                t_distribution_false_negatives {
                  unsampled_value {
                    value: 0.0
                  }
                }
                t_distribution_true_negatives {
                  unsampled_value {
                    value: 1.0
                  }
                }
                t_distribution_true_positives {
                  unsampled_value {
                    value: 2.0
                  }
                }
                t_distribution_false_positives {
                  unsampled_value {
                    value: 0.0
                  }
                }
                t_distribution_precision {
                  unsampled_value {
                    value: 1.0
                  }
                }
                t_distribution_recall {
                  unsampled_value {
                    value: 1.0
                  }
                }
              }
              matrices {
                threshold: 0.75
                false_negatives: 1.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 1.0
                precision: 1.0
                recall: 0.5
                bounded_false_negatives {
                  value {
                    value: 1.0
                  }
                }
                bounded_true_negatives {
                  value {
                    value: 1.0
                  }
                }
                bounded_true_positives {
                  value {
                    value: 1.0
                  }
                }
                bounded_false_positives {
                  value {
                    value: 0.0
                  }
                }
                bounded_precision {
                  value {
                    value: 1.0
                  }
                }
                bounded_recall {
                  value {
                    value: 0.5
                  }
                }
                t_distribution_false_negatives {
                  unsampled_value {
                    value: 1.0
                  }
                }
                t_distribution_true_negatives {
                  unsampled_value {
                    value: 1.0
                  }
                }
                t_distribution_true_positives {
                  unsampled_value {
                    value: 1.0
                  }
                }
                t_distribution_false_positives {
                  unsampled_value {
                    value: 0.0
                  }
                }
                t_distribution_precision {
                  unsampled_value {
                    value: 1.0
                  }
                }
                t_distribution_recall {
                  unsampled_value {
                    value: 0.5
                  }
                }
              }
              matrices {
                threshold: 1.00
                false_negatives: 2.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 0.0
                precision: nan
                recall: 0.0
                bounded_false_negatives {
                  value {
                    value: 2.0
                  }
                }
                bounded_true_negatives {
                  value {
                    value: 1.0
                  }
                }
                bounded_true_positives {
                  value {
                    value: 0.0
                  }
                }
                bounded_false_positives {
                  value {
                    value: 0.0
                  }
                }
                bounded_precision {
                  value {
                    value: nan
                  }
                }
                bounded_recall {
                  value {
                    value: 0.0
                  }
                }
                t_distribution_false_negatives {
                  unsampled_value {
                    value: 2.0
                  }
                }
                t_distribution_true_negatives {
                  unsampled_value {
                    value: 1.0
                  }
                }
                t_distribution_true_positives {
                  unsampled_value {
                    value: 0.0
                  }
                }
                t_distribution_false_positives {
                  unsampled_value {
                    value: 0.0
                  }
                }
                t_distribution_precision {
                  unsampled_value {
                    value: nan
                  }
                }
                t_distribution_recall {
                  unsampled_value {
                    value: 0.0
                  }
                }
              }
            }
          }
        }
        """, metrics_for_slice_pb2.MetricsForSlice())

    got = metrics_plots_and_validations_writer.convert_slice_metrics_to_proto(
        (slice_key, slice_metrics),
        [post_export_metrics.confusion_matrix_at_thresholds(thresholds)])
    self.assertProtoEquals(expected_metrics_for_slice, got)
예제 #27
0
def convert_slice_metrics_to_proto(
    metrics: Tuple[slicer.SliceKeyOrCrossSliceKeyType, Dict[Any, Any]],
    add_metrics_callbacks: List[types.AddMetricsCallbackType]
) -> metrics_for_slice_pb2.MetricsForSlice:
    """Converts the given slice metrics into serialized proto MetricsForSlice.

  Args:
    metrics: The slice metrics.
    add_metrics_callbacks: A list of metric callbacks. This should be the same
      list as the one passed to tfma.Evaluate().

  Returns:
    The MetricsForSlice proto.

  Raises:
    TypeError: If the type of the feature value in slice key cannot be
      recognized.
  """
    result = metrics_for_slice_pb2.MetricsForSlice()
    slice_key, slice_metrics = metrics

    if slicer.is_cross_slice_key(slice_key):
        result.cross_slice_key.CopyFrom(
            slicer.serialize_cross_slice_key(slice_key))
    else:
        result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))

    slice_metrics = slice_metrics.copy()

    if metric_keys.ERROR_METRIC in slice_metrics:
        logging.warning('Error for slice: %s with error message: %s ',
                        slice_key, slice_metrics[metric_keys.ERROR_METRIC])
        result.metrics[metric_keys.ERROR_METRIC].debug_message = slice_metrics[
            metric_keys.ERROR_METRIC]
        return result

    # Convert the metrics from add_metrics_callbacks to the structured output if
    # defined.
    if add_metrics_callbacks and (not any(
            isinstance(k, metric_types.MetricKey)
            for k in slice_metrics.keys())):
        for add_metrics_callback in add_metrics_callbacks:
            if hasattr(add_metrics_callback, 'populate_stats_and_pop'):
                add_metrics_callback.populate_stats_and_pop(
                    slice_key, slice_metrics, result.metrics)
    for key in sorted(slice_metrics.keys()):
        value = slice_metrics[key]
        if isinstance(value, types.ValueWithTDistribution):
            unsampled_value = value.unsampled_value
            _, lower_bound, upper_bound = (
                math_util.calculate_confidence_interval(value))
            confidence_interval = metrics_for_slice_pb2.ConfidenceInterval(
                lower_bound=convert_metric_value_to_proto(lower_bound),
                upper_bound=convert_metric_value_to_proto(upper_bound),
                standard_error=convert_metric_value_to_proto(
                    value.sample_standard_deviation),
                degrees_of_freedom={'value': value.sample_degrees_of_freedom})
            metric_value = convert_metric_value_to_proto(unsampled_value)

            # If metric can be stored to double_value metrics, replace it with a
            # bounded_value for backwards compatibility.
            # TODO(b/188575688): remove this logic to stop populating bounded_value
            if metric_value.WhichOneof('type') == 'double_value':
                # setting bounded_value clears double_value in the same oneof scope.
                metric_value.bounded_value.value.value = unsampled_value
                metric_value.bounded_value.lower_bound.value = lower_bound
                metric_value.bounded_value.upper_bound.value = upper_bound
                metric_value.bounded_value.methodology = (
                    metrics_for_slice_pb2.BoundedValue.POISSON_BOOTSTRAP)
        else:
            metric_value = convert_metric_value_to_proto(value)
            confidence_interval = None

        if isinstance(key, metric_types.MetricKey):
            result.metric_keys_and_values.add(
                key=key.to_proto(),
                value=metric_value,
                confidence_interval=confidence_interval)
        else:
            result.metrics[key].CopyFrom(metric_value)

    return result
예제 #28
0
        def check_result(got):  # pylint: disable=invalid-name
            try:
                self.assertEqual(1, len(got), 'got: %s' % got)
                (slice_key, value) = got[0]
                self.assertEqual((), slice_key)
                self.assertIn(
                    metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_MATRICES, value)
                matrices = value[
                    metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_MATRICES]
                #            |      | ---- Threshold ----
                # true label | pred | 0.25 | 0.75 | 1.00
                #     -      | 0.0  | TN   | TN   | TN
                #     +      | 0.5  | TP   | FN   | FN
                #     +      | 1.0  | TP   | TP   | FN
                self.assertSequenceAlmostEqual(matrices[0],
                                               [0.0, 1.0, 0.0, 2.0, 1.0, 1.0])
                self.assertSequenceAlmostEqual(matrices[1],
                                               [1.0, 1.0, 0.0, 1.0, 1.0, 0.5])
                self.assertSequenceAlmostEqual(
                    matrices[2],
                    [2.0, 1.0, 0.0, 0.0, float('nan'), 0.0])
                self.assertIn(
                    metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_THRESHOLDS,
                    value)
                thresholds = value[
                    metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_THRESHOLDS]
                self.assertAlmostEqual(0.25, thresholds[0])
                self.assertAlmostEqual(0.75, thresholds[1])
                self.assertAlmostEqual(1.00, thresholds[2])

                # Check serialization too.
                # Note that we can't just make this a dict, since proto maps
                # allow uninitialized key access, i.e. they act like defaultdicts.
                output_metrics = metrics_for_slice_pb2.MetricsForSlice(
                ).metrics
                confusion_matrix_at_thresholds_metric.populate_stats_and_pop(
                    value, output_metrics)
                self.assertProtoEquals(
                    """
            confusion_matrix_at_thresholds {
              matrices {
                threshold: 0.25
                false_negatives: 0.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 2.0
                precision: 1.0
                recall: 1.0
              }
              matrices {
                threshold: 0.75
                false_negatives: 1.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 1.0
                precision: 1.0
                recall: 0.5
              }
              matrices {
                threshold: 1.00
                false_negatives: 2.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 0.0
                precision: nan
                recall: 0.0
              }
            }
            """, output_metrics[metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS])
            except AssertionError as err:
                raise util.BeamAssertException(err)
def convert_slice_metrics_to_proto(
    metrics: Tuple[slicer.SliceKeyType, Dict[Any, Any]],
    add_metrics_callbacks: List[types.AddMetricsCallbackType]
) -> metrics_for_slice_pb2.MetricsForSlice:
    """Converts the given slice metrics into serialized proto MetricsForSlice.

  Args:
    metrics: The slice metrics.
    add_metrics_callbacks: A list of metric callbacks. This should be the same
      list as the one passed to tfma.Evaluate().

  Returns:
    The MetricsForSlice proto.

  Raises:
    TypeError: If the type of the feature value in slice key cannot be
      recognized.
  """
    result = metrics_for_slice_pb2.MetricsForSlice()
    slice_key, slice_metrics = metrics

    result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))

    slice_metrics = slice_metrics.copy()

    if metric_keys.ERROR_METRIC in slice_metrics:
        logging.warning('Error for slice: %s with error message: %s ',
                        slice_key, slice_metrics[metric_keys.ERROR_METRIC])
        result.metrics[metric_keys.ERROR_METRIC].debug_message = slice_metrics[
            metric_keys.ERROR_METRIC]
        return result

    # Convert the metrics from add_metrics_callbacks to the structured output if
    # defined.
    if add_metrics_callbacks and (not any(
            isinstance(k, metric_types.MetricKey)
            for k in slice_metrics.keys())):
        for add_metrics_callback in add_metrics_callbacks:
            if hasattr(add_metrics_callback, 'populate_stats_and_pop'):
                add_metrics_callback.populate_stats_and_pop(
                    slice_key, slice_metrics, result.metrics)
    for key in sorted(slice_metrics.keys()):
        value = slice_metrics[key]
        metric_value = metrics_for_slice_pb2.MetricValue()
        if isinstance(value,
                      metrics_for_slice_pb2.ConfusionMatrixAtThresholds):
            metric_value.confusion_matrix_at_thresholds.CopyFrom(value)
        elif isinstance(
                value,
                metrics_for_slice_pb2.MultiClassConfusionMatrixAtThresholds):
            metric_value.multi_class_confusion_matrix_at_thresholds.CopyFrom(
                value)
        elif isinstance(value, types.ValueWithTDistribution):
            # Currently we populate both bounded_value and confidence_interval.
            # Avoid populating bounded_value once the UI handles confidence_interval.
            # Convert to a bounded value. 95% confidence level is computed here.
            _, lower_bound, upper_bound = (
                math_util.calculate_confidence_interval(value))
            metric_value.bounded_value.value.value = value.unsampled_value
            metric_value.bounded_value.lower_bound.value = lower_bound
            metric_value.bounded_value.upper_bound.value = upper_bound
            metric_value.bounded_value.methodology = (
                metrics_for_slice_pb2.BoundedValue.POISSON_BOOTSTRAP)
            # Populate confidence_interval
            metric_value.confidence_interval.lower_bound.value = lower_bound
            metric_value.confidence_interval.upper_bound.value = upper_bound
            t_dist_value = metrics_for_slice_pb2.TDistributionValue()
            t_dist_value.sample_mean.value = value.sample_mean
            t_dist_value.sample_standard_deviation.value = (
                value.sample_standard_deviation)
            t_dist_value.sample_degrees_of_freedom.value = (
                value.sample_degrees_of_freedom)
            # Once the UI handles confidence interval, we will avoid setting this and
            # instead use the double_value.
            t_dist_value.unsampled_value.value = value.unsampled_value
            metric_value.confidence_interval.t_distribution_value.CopyFrom(
                t_dist_value)
        elif isinstance(value, six.binary_type):
            # Convert textual types to string metrics.
            metric_value.bytes_value = value
        elif isinstance(value, six.text_type):
            # Convert textual types to string metrics.
            metric_value.bytes_value = value.encode('utf8')
        elif isinstance(value, np.ndarray):
            # Convert NumPy arrays to ArrayValue.
            metric_value.array_value.CopyFrom(_convert_to_array_value(value))
        else:
            # We try to convert to float values.
            try:
                metric_value.double_value.value = float(value)
            except (TypeError, ValueError) as e:
                metric_value.unknown_type.value = str(value)
                metric_value.unknown_type.error = e.message  # pytype: disable=attribute-error

        if isinstance(key, metric_types.MetricKey):
            key_and_value = result.metric_keys_and_values.add()
            key_and_value.key.CopyFrom(key.to_proto())
            key_and_value.value.CopyFrom(metric_value)
        else:
            result.metrics[key].CopyFrom(metric_value)

    return result
예제 #30
0
  def testConvertSliceMetricsToProtoMetricsRanges(self):
    slice_key = _make_slice_key('age', 5, 'language', 'english', 'price', 0.3)
    slice_metrics = {
        'accuracy': types.ValueWithTDistribution(0.8, 0.1, 9, 0.8),
        metric_keys.AUPRC: 0.1,
        metric_keys.lower_bound_key(metric_keys.AUPRC): 0.05,
        metric_keys.upper_bound_key(metric_keys.AUPRC): 0.17,
        metric_keys.AUC: 0.2,
        metric_keys.lower_bound_key(metric_keys.AUC): 0.1,
        metric_keys.upper_bound_key(metric_keys.AUC): 0.3
    }
    expected_metrics_for_slice = text_format.Parse(
        string.Template("""
        slice_key {
          single_slice_keys {
            column: 'age'
            int64_value: 5
          }
          single_slice_keys {
            column: 'language'
            bytes_value: 'english'
          }
          single_slice_keys {
            column: 'price'
            float_value: 0.3
          }
        }
        metrics {
          key: "accuracy"
          value {
            bounded_value {
              value {
                value: 0.8
              }
              lower_bound {
                value: 0.5737843
              }
              upper_bound {
                value: 1.0262157
              }
              methodology: POISSON_BOOTSTRAP
            }
            confidence_interval {
              lower_bound {
                value: 0.5737843
              }
              upper_bound {
                value: 1.0262157
              }
              t_distribution_value {
                sample_mean {
                  value: 0.8
                }
                sample_standard_deviation {
                  value: 0.1
                }
                sample_degrees_of_freedom {
                  value: 9
                }
                unsampled_value {
                  value: 0.8
                }
              }
            }
          }
        }
        metrics {
          key: "$auc"
          value {
            bounded_value {
              lower_bound {
                value: 0.1
              }
              upper_bound {
                value: 0.3
              }
              value {
                value: 0.2
              }
              methodology: RIEMANN_SUM
            }
          }
        }
        metrics {
          key: "$auprc"
          value {
            bounded_value {
              lower_bound {
                value: 0.05
              }
              upper_bound {
                value: 0.17
              }
              value {
                value: 0.1
              }
              methodology: RIEMANN_SUM
            }
          }
        }""").substitute(auc=metric_keys.AUC, auprc=metric_keys.AUPRC),
        metrics_for_slice_pb2.MetricsForSlice())

    got = metrics_plots_and_validations_writer.convert_slice_metrics_to_proto(
        (slice_key, slice_metrics),
        [post_export_metrics.auc(),
         post_export_metrics.auc(curve='PR')])
    self.assertProtoEquals(expected_metrics_for_slice, got)