def testSerializeDeserializeEvalConfig(self):
     output_path = self._getTempDir()
     options = config_pb2.Options()
     options.compute_confidence_intervals.value = False
     options.min_slice_size.value = 1
     eval_config = config_pb2.EvalConfig(slicing_specs=[
         config_pb2.SlicingSpec(feature_keys=['country'],
                                feature_values={
                                    'age': '5',
                                    'gender': 'f'
                                }),
         config_pb2.SlicingSpec(feature_keys=['interest'],
                                feature_values={
                                    'age': '6',
                                    'gender': 'm'
                                })
     ],
                                         options=options)
     data_location = '/path/to/data'
     file_format = 'tfrecords'
     model_location = '/path/to/model'
     with tf.io.gfile.GFile(os.path.join(output_path, 'eval_config.json'),
                            'w') as f:
         f.write(
             eval_config_writer._serialize_eval_run(eval_config,
                                                    data_location,
                                                    file_format,
                                                    {'': model_location}))
     got_eval_config, got_data_location, got_file_format, got_model_locations = (
         eval_config_writer.load_eval_run(output_path))
     self.assertEqual(eval_config, got_eval_config)
     self.assertEqual(data_location, got_data_location)
     self.assertEqual(file_format, got_file_format)
     self.assertEqual({'': model_location}, got_model_locations)
Пример #2
0
 def testDeserializeSliceSpec_hashable(self):
   single_slice_spec = slicer.deserialize_slice_spec(
       config_pb2.SlicingSpec(feature_values={'a': '1'}))
   cross_slice_spec = slicer.deserialize_slice_spec(
       slicer.config_pb2.CrossSlicingSpec(
           baseline_spec=config_pb2.SlicingSpec(),
           slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})]))
   # Check either of them can be hashed and used as keys.
   slice_map = {single_slice_spec: 1, cross_slice_spec: 2}
   self.assertEqual(slice_map[single_slice_spec], 1)
   self.assertEqual(slice_map[cross_slice_spec], 2)
Пример #3
0
def get_missing_slices(
    slicing_details: Iterable[validation_result_pb2.SlicingDetails],
    eval_config: config_pb2.EvalConfig
) -> List[Union[config_pb2.SlicingSpec, config_pb2.CrossSlicingSpec]]:
    """Returns specs that are defined in the EvalConfig but not found in details.

  Args:
    slicing_details: Slicing details.
    eval_config: Eval config.

  Returns:
    List of missing slices or empty list if none are missing.
  """
    hashed_details = _hashed_slicing_details(slicing_details)
    thresholds = metric_specs.metric_thresholds_from_metrics_specs(
        eval_config.metrics_specs)
    baseline_spec = model_util.get_baseline_model_spec(eval_config)
    baseline_model_name = baseline_spec.name if baseline_spec else None
    missing_slices = []
    for metric_key, sliced_thresholds in thresholds.items():
        # Skip baseline.
        if metric_key.model_name == baseline_model_name:
            continue
        for slice_spec, _ in sliced_thresholds:
            if not slice_spec:
                slice_spec = config_pb2.SlicingSpec()
            slice_hash = slice_spec.SerializeToString()
            if slice_hash not in hashed_details:
                missing_slices.append(slice_spec)
                # Same slice may be used by other metrics/thresholds, only add once
                hashed_details[
                    slice_hash] = validation_result_pb2.SlicingDetails()
    return missing_slices
Пример #4
0
    def _slicing_spec(
            self,
            me_slicing_spec: me_proto.SlicingSpec) -> config_pb2.SlicingSpec:
        """Convert ME SlicingSpec into TFMA.

    Args:
      me_slicing_spec: Input ME SlicingSpec.

    Returns:
      TFMA SlicingSpec.
    """
        if not me_slicing_spec:
            return None
        tfma_slicing_spec = config_pb2.SlicingSpec()
        if me_slicing_spec.feature_key_specs:
            tfma_slicing_spec.feature_keys.extend([
                ColumnSpec(spec).as_string()
                for spec in me_slicing_spec.feature_key_specs
            ])
        for feature_value in me_slicing_spec.feature_values:
            tfma_slicing_spec.feature_values.update({
                ColumnSpec(feature_value.name_spec).as_string():
                feature_value.value
            })
        return tfma_slicing_spec
 def testSerializeDeserializeLegacyEvalConfig(self):
     output_path = self._getTempDir()
     old_config = LegacyConfig(
         model_location='/path/to/model',
         data_location='/path/to/data',
         slice_spec=[
             slicer.SingleSliceSpec(columns=['country'],
                                    features=[('age', 5), ('gender', 'f')]),
             slicer.SingleSliceSpec(columns=['interest'],
                                    features=[('age', 6), ('gender', 'm')])
         ],
         example_count_metric_key=None,
         example_weight_metric_key='key',
         compute_confidence_intervals=False,
         k_anonymization_count=1)
     final_dict = {}
     final_dict['tfma_version'] = tfma_version.VERSION
     final_dict['eval_config'] = old_config
     with tf.io.TFRecordWriter(os.path.join(output_path,
                                            'eval_config')) as w:
         w.write(pickle.dumps(final_dict))
     got_eval_config, got_data_location, _, got_model_locations = (
         eval_config_writer.load_eval_run(output_path))
     options = config_pb2.Options()
     options.compute_confidence_intervals.value = (
         old_config.compute_confidence_intervals)
     options.min_slice_size.value = old_config.k_anonymization_count
     eval_config = config_pb2.EvalConfig(slicing_specs=[
         config_pb2.SlicingSpec(feature_keys=['country'],
                                feature_values={
                                    'age': '5',
                                    'gender': 'f'
                                }),
         config_pb2.SlicingSpec(feature_keys=['interest'],
                                feature_values={
                                    'age': '6',
                                    'gender': 'm'
                                })
     ],
                                         options=options)
     self.assertEqual(eval_config, got_eval_config)
     self.assertEqual(old_config.data_location, got_data_location)
     self.assertLen(got_model_locations, 1)
     self.assertEqual(old_config.model_location,
                      list(got_model_locations.values())[0])
Пример #6
0
def _get_tfma_slicing_specs(
    slice_features: List[List[ColumnSpec]]) -> List[config_pb2.SlicingSpec]:
  """Build TFMA Slicing Spec values from feature lists.

  Args:
    slice_features: List of slices, each being a list of feature keys.

  Returns:
    A list of SlicingSpec values, including the overall slice.
  """
  # Overall Slice:
  slicing_specs = [config_pb2.SlicingSpec()]
  # Per specified slice_features:
  for feature_list in slice_features:
    if feature_list:
      slicing_spec = config_pb2.SlicingSpec(
          feature_keys=[spec.as_string() for spec in feature_list])
      slicing_specs.append(slicing_spec)
  return slicing_specs
Пример #7
0
  def testSqlSliceKeyExtractorWithMultipleSchema(self):
    eval_config = config_pb2.EvalConfig(slicing_specs=[
        config_pb2.SlicingSpec(slice_keys_sql="""
        SELECT
          STRUCT(fixed_string)
        FROM
          example.fixed_string,
          example.fixed_int
        WHERE fixed_int = 1
        """)
    ])
    slice_key_extractor = sql_slice_key_extractor.SqlSliceKeyExtractor(
        eval_config)

    record_batch_1 = pa.RecordBatch.from_arrays([
        pa.array([[1], [1], [2]], type=pa.list_(pa.int64())),
        pa.array([[1.0], [1.0], [2.0]], type=pa.list_(pa.float64())),
        pa.array([['fixed_string1'], ['fixed_string2'], ['fixed_string3']],
                 type=pa.list_(pa.string())),
    ], ['fixed_int', 'fixed_float', 'fixed_string'])
    record_batch_2 = pa.RecordBatch.from_arrays([
        pa.array([[1], [1], [2]], type=pa.list_(pa.int64())),
        pa.array([[1.0], [1.0], [2.0]], type=pa.list_(pa.float64())),
        pa.array([['fixed_string1'], ['fixed_string2'], ['fixed_string3']],
                 type=pa.list_(pa.string())),
        pa.array([['extra_field1'], ['extra_field2'], ['extra_field3']],
                 type=pa.list_(pa.string())),
    ], ['fixed_int', 'fixed_float', 'fixed_string', 'extra_field'])

    with beam.Pipeline() as pipeline:
      # pylint: disable=no-value-for-parameter
      result = (
          pipeline
          | 'Create' >> beam.Create([record_batch_1, record_batch_2],
                                    reshuffle=False)
          | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts()
          | slice_key_extractor.stage_name >> slice_key_extractor.ptransform)

      # pylint: enable=no-value-for-parameter

      def check_result(got):
        try:
          self.assertLen(got, 2)
          self.assertEqual(got[0][constants.SLICE_KEY_TYPES_KEY],
                           [[(('fixed_string', 'fixed_string1'),)],
                            [(('fixed_string', 'fixed_string2'),)], []])
          self.assertEqual(got[1][constants.SLICE_KEY_TYPES_KEY],
                           [[(('fixed_string', 'fixed_string1'),)],
                            [(('fixed_string', 'fixed_string2'),)], []])

        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(result, check_result)
Пример #8
0
  def testSqlSliceKeyExtractor(self):
    eval_config = config_pb2.EvalConfig(slicing_specs=[
        config_pb2.SlicingSpec(slice_keys_sql="""
        SELECT
          STRUCT(fixed_string)
        FROM
          example.fixed_string,
          example.fixed_int
        WHERE fixed_int = 1
        """)
    ])
    slice_key_extractor = sql_slice_key_extractor.SqlSliceKeyExtractor(
        eval_config)

    tfx_io = tf_example_record.TFExampleBeamRecord(
        physical_format='inmem',
        telemetry_descriptors=['test', 'component'],
        schema=_SCHEMA,
        raw_record_column_name=constants.ARROW_INPUT_COLUMN)
    examples = [
        self._makeExample(
            fixed_int=1, fixed_float=1.0, fixed_string='fixed_string1'),
        self._makeExample(
            fixed_int=1, fixed_float=1.0, fixed_string='fixed_string2'),
        self._makeExample(
            fixed_int=2, fixed_float=0.0, fixed_string='fixed_string3')
    ]

    with beam.Pipeline() as pipeline:
      # pylint: disable=no-value-for-parameter
      result = (
          pipeline
          | 'Create' >> beam.Create([e.SerializeToString() for e in examples],
                                    reshuffle=False)
          | 'BatchExamples' >> tfx_io.BeamSource(batch_size=3)
          | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts()
          | slice_key_extractor.stage_name >> slice_key_extractor.ptransform)

      # pylint: enable=no-value-for-parameter

      def check_result(got):
        try:
          self.assertLen(got, 1)
          self.assertEqual(got[0][constants.SLICE_KEY_TYPES_KEY],
                           [[(('fixed_string', 'fixed_string1'),)],
                            [(('fixed_string', 'fixed_string2'),)], []])

        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(result, check_result)
Пример #9
0
def validate_metrics(
    sliced_metrics: Tuple[Union[slicer.SliceKeyType, slicer.CrossSliceKeyType],
                          Dict['metric_types.MetricKey',
                               Any]], eval_config: config_pb2.EvalConfig
) -> validation_result_pb2.ValidationResult:
    """Check the metrics and check whether they should be validated."""
    # Find out which model is baseline.
    baseline_spec = model_util.get_baseline_model_spec(eval_config)
    baseline_model_name = baseline_spec.name if baseline_spec else None

    sliced_key, metrics = sliced_metrics
    thresholds = metric_specs.metric_thresholds_from_metrics_specs(
        eval_config.metrics_specs)
    is_cross_slice = slicer.is_cross_slice_key(sliced_key)

    def _check_threshold(key: metric_types.MetricKey,
                         threshold: _ThresholdType, metric: Any) -> bool:
        """Verify a metric given its metric key and metric value."""
        metric = float(metric)
        if isinstance(threshold, config_pb2.GenericValueThreshold):
            lower_bound, upper_bound = -np.inf, np.inf
            if threshold.HasField('lower_bound'):
                lower_bound = threshold.lower_bound.value
            if threshold.HasField('upper_bound'):
                upper_bound = threshold.upper_bound.value
            return metric >= lower_bound and metric <= upper_bound
        elif isinstance(threshold, config_pb2.GenericChangeThreshold):
            diff = metric
            metric_baseline = float(
                metrics[key.make_baseline_key(baseline_model_name)])
            if math.isclose(metric_baseline, 0.0):
                ratio = float('nan')
            else:
                ratio = diff / metric_baseline
            if threshold.direction == config_pb2.MetricDirection.LOWER_IS_BETTER:
                absolute, relative = np.inf, np.inf
            elif threshold.direction == config_pb2.MetricDirection.HIGHER_IS_BETTER:
                absolute, relative = -np.inf, -np.inf
            else:
                raise ValueError(
                    '"UNKNOWN" direction for change threshold: {}.'.format(
                        threshold))
            if threshold.HasField('absolute'):
                absolute = threshold.absolute.value
            if threshold.HasField('relative'):
                relative = threshold.relative.value
            if threshold.direction == config_pb2.MetricDirection.LOWER_IS_BETTER:
                return diff <= absolute and ratio <= relative
            elif threshold.direction == config_pb2.MetricDirection.HIGHER_IS_BETTER:
                return diff >= absolute and ratio >= relative
        else:
            raise ValueError('Unknown threshold: {}'.format(threshold))

    def _copy_metric(metric, to):
        # Will add more types when more MetricValue are supported.
        to.double_value.value = float(metric)

    def _copy_threshold(threshold, to):
        if isinstance(threshold, config_pb2.GenericValueThreshold):
            to.value_threshold.CopyFrom(threshold)
        if isinstance(threshold, config_pb2.GenericChangeThreshold):
            to.change_threshold.CopyFrom(threshold)

    def _add_to_set(s, v):
        """Adds value to set. Returns true if didn't exist."""
        if v in s:
            return False
        else:
            s.add(v)
            return True

    # Empty metrics per slice is considered validated.
    result = validation_result_pb2.ValidationResult(validation_ok=True)
    validation_for_slice = validation_result_pb2.MetricsValidationForSlice()
    unchecked_thresholds = dict(thresholds)
    for metric_key, metric in metrics.items():
        if metric_key not in thresholds:
            continue
        del unchecked_thresholds[metric_key]
        # Not meaningful to check threshold for baseline model, thus always return
        # True if such threshold is configured. We also do not compare Message type
        # metrics.
        if metric_key.model_name == baseline_model_name:
            continue
        msg = ''
        existing_failures = set()
        for slice_spec, threshold in thresholds[metric_key]:
            if slice_spec is not None:
                if (isinstance(slice_spec, config_pb2.SlicingSpec)
                        and (is_cross_slice or not slicer.SingleSliceSpec(
                            spec=slice_spec).is_slice_applicable(sliced_key))):
                    continue
                if (isinstance(slice_spec, config_pb2.CrossSlicingSpec)
                        and (not is_cross_slice
                             or not slicer.is_cross_slice_applicable(
                                 cross_slice_key=sliced_key,
                                 cross_slicing_spec=slice_spec))):
                    continue
            elif is_cross_slice:
                continue
            try:
                check_result = _check_threshold(metric_key, threshold, metric)
            except ValueError:
                msg = """
          Invalid metrics or threshold for comparison: The type of the metric
          is: {}, the metric value is: {}, and the threshold is: {}.
          """.format(type(metric), metric, threshold)
                check_result = False
            else:
                msg = ''
            if not check_result:
                # The same threshold values could be set for multiple matching slice
                # specs. Only store the first match.
                #
                # Note that hashing by SerializeToString() is only safe if used within
                # the same process.
                if not _add_to_set(existing_failures,
                                   threshold.SerializeToString()):
                    continue
                failure = validation_for_slice.failures.add()
                failure.metric_key.CopyFrom(metric_key.to_proto())
                _copy_metric(metric, failure.metric_value)
                _copy_threshold(threshold, failure.metric_threshold)
                failure.message = msg
            # Track we have completed a validation check for slice spec and metric
            slicing_details = result.validation_details.slicing_details.add()
            if slice_spec is not None:
                if isinstance(slice_spec, config_pb2.SlicingSpec):
                    slicing_details.slicing_spec.CopyFrom(slice_spec)
                else:
                    slicing_details.cross_slicing_spec.CopyFrom(slice_spec)
            else:
                slicing_details.slicing_spec.CopyFrom(config_pb2.SlicingSpec())
            slicing_details.num_matching_slices = 1
    # All unchecked thresholds are considered failures.
    for metric_key, thresholds in unchecked_thresholds.items():
        if metric_key.model_name == baseline_model_name:
            continue
        existing_failures = set()
        for slice_spec, threshold in thresholds:
            if slice_spec is not None:
                if is_cross_slice != isinstance(slice_spec,
                                                config_pb2.CrossSlicingSpec):
                    continue
                if (is_cross_slice
                        and not slicer.is_cross_slice_applicable(
                            cross_slice_key=sliced_key,
                            cross_slicing_spec=slice_spec)):
                    continue
            elif is_cross_slice:
                continue
            # The same threshold values could be set for multiple matching slice
            # specs. Only store the first match.
            #
            # Note that hashing by SerializeToString() is only safe if used within
            # the same process.
            if not _add_to_set(existing_failures,
                               threshold.SerializeToString()):
                continue
            failure = validation_for_slice.failures.add()
            failure.metric_key.CopyFrom(metric_key.to_proto())
            _copy_threshold(threshold, failure.metric_threshold)
            failure.message = 'Metric not found.'
    # Any failure leads to overall failure.
    if validation_for_slice.failures:
        if not is_cross_slice:
            validation_for_slice.slice_key.CopyFrom(
                slicer.serialize_slice_key(sliced_key))
        else:
            validation_for_slice.cross_slice_key.CopyFrom(
                slicer.serialize_cross_slice_key(sliced_key))
        result.validation_ok = False
        result.metric_validations_per_slice.append(validation_for_slice)
    return result
Пример #10
0
 def to_proto(self) -> config_pb2.SlicingSpec:
     feature_values = {k: str(v) for (k, v) in self._features}
     return config_pb2.SlicingSpec(feature_keys=self._columns,
                                   feature_values=feature_values)
Пример #11
0
class SlicerTest(testutil.TensorflowModelAnalysisTest, parameterized.TestCase):

  def setUp(self):
    super().setUp()
    self.longMessage = True  # pylint: disable=invalid-name

  def _makeFeaturesDict(self, features_dict):
    result = {}
    for key, value in features_dict.items():
      result[key] = {'node': np.array(value)}
    return result

  def assertSliceResult(self, name, features_dict, columns, features, expected):
    spec = slicer.SingleSliceSpec(columns=columns, features=features)
    msg = 'Test case %s: slice on columns %s, features %s' % (name, columns,
                                                              features)
    six.assertCountEqual(
        self, expected,
        slicer.get_slices_for_features_dicts([features_dict], None, [spec]),
        msg)

  def testDeserializeSliceKey(self):
    slice_metrics = text_format.Parse(
        """
          single_slice_keys {
            column: 'age'
            int64_value: 5
          }
          single_slice_keys {
            column: 'language'
            bytes_value: 'english'
          }
          single_slice_keys {
            column: 'price'
            float_value: 1.0
          }
        """, metrics_for_slice_pb2.SliceKey())

    got_slice_key = slicer.deserialize_slice_key(slice_metrics)
    self.assertCountEqual([('age', 5), ('language', 'english'), ('price', 1.0)],
                          got_slice_key)

  def testDeserializeCrossSliceKey(self):
    slice_metrics = text_format.Parse(
        """
          baseline_slice_key {
            single_slice_keys {
              column: 'age'
              int64_value: 5
            }
            single_slice_keys {
              column: 'language'
              bytes_value: 'english'
            }
            single_slice_keys {
              column: 'price'
              float_value: 1.0
            }
          }
          comparison_slice_key {
            single_slice_keys {
              column: 'age'
              int64_value: 8
            }
            single_slice_keys {
              column: 'language'
              bytes_value: 'hindi'
            }
          }
        """, metrics_for_slice_pb2.CrossSliceKey())

    got_slice_key = slicer.deserialize_cross_slice_key(slice_metrics)
    self.assertCountEqual(
        ((('age', 5), ('language', 'english'), ('price', 1.0)),
         (('age', 8), ('language', 'hindi'))), got_slice_key)

  def testSliceEquality(self):
    overall = slicer.SingleSliceSpec()
    age_column = slicer.SingleSliceSpec(columns=['age'])
    age_feature = slicer.SingleSliceSpec(features=[('age', 5)])
    age_and_gender = slicer.SingleSliceSpec(
        columns=['age'], features=[('gender', 'f')])

    # Note that we construct new instances of the slices to ensure that we
    # aren't just checking object identity.
    def check_equality_and_hash_equality(left, right):
      self.assertEqual(left, right)
      self.assertEqual(hash(left), hash(right))

    check_equality_and_hash_equality(overall, slicer.SingleSliceSpec())
    check_equality_and_hash_equality(age_column,
                                     slicer.SingleSliceSpec(columns=['age']))
    check_equality_and_hash_equality(
        age_feature, slicer.SingleSliceSpec(features=[('age', 5)]))
    check_equality_and_hash_equality(
        age_and_gender,
        slicer.SingleSliceSpec(columns=['age'], features=[('gender', 'f')]))

    self.assertNotEqual(overall, age_column)
    self.assertNotEqual(age_column, age_feature)
    self.assertNotEqual(age_column, age_and_gender)
    self.assertNotEqual(age_feature, age_and_gender)

    self.assertCountEqual([slicer.SingleSliceSpec()], [overall])
    self.assertCountEqual([
        slicer.SingleSliceSpec(columns=['age']),
        slicer.SingleSliceSpec(),
        slicer.SingleSliceSpec(features=[('age', 5)]),
        slicer.SingleSliceSpec(columns=['age'], features=[('gender', 'f')])
    ], [age_and_gender, age_feature, overall, age_column])

  def testNoOverlappingColumns(self):
    self.assertRaises(ValueError, slicer.SingleSliceSpec, ['age'], [('age', 5)])

  def testNonUTF8ValueRaisesValueError(self):
    column_name = 'column_name'
    invalid_value = b'\x8a'
    spec = slicer.SingleSliceSpec(columns=[column_name])
    features_dict = self._makeFeaturesDict({
        column_name: [invalid_value],
    })
    with self.assertRaisesRegex(ValueError, column_name):
      list(slicer.get_slices_for_features_dicts([features_dict], None, [spec]))

  def testGetSlicesForFeaturesDictUnivalent(self):
    test_cases = [
        ('Overall', [], [], [()]),
        ('Feature does not match', [], [('age', 99)], []),
        ('No such column', ['no_such_column'], [], []),
        ('Single column', ['age'], [], [(('age', 5),)]),
        ('Single feature', [], [('age', 5)], [(('age', 5),)]),
        ('Single feature type mismatch', [], [('age', '5')], [(('age', 5),)]),
        ('One column, one feature',
         ['gender'], [('age', 5)], [(('age', 5), ('gender', 'f'))]),
        ('Two features', ['interest', 'gender'], [('age', 5)],
         [(('age', 5), ('gender', 'f'), ('interest', 'cars'))]),
    ]  # pyformat: disable
    features_dict = self._makeFeaturesDict({
        'gender': ['f'],
        'age': [5],
        'interest': ['cars']
    })
    for (name, columns, features, expected) in test_cases:
      self.assertSliceResult(name, features_dict, columns, features, expected)

  def testGetSlicesForFeaturesDictMultivalent(self):
    test_cases = [
        (
            'One column',
            ['fruits'],
            [],
            [
                (('fruits', 'apples'),),
                (('fruits', 'pears'),)
            ],
        ),
        (
            'Two columns',
            ['fruits', 'interests'],
            [],
            [
                (('fruits', 'apples'), ('interests', 'cars')),
                (('fruits', 'apples'), ('interests', 'dogs')),
                (('fruits', 'pears'), ('interests', 'cars')),
                (('fruits', 'pears'), ('interests', 'dogs'))
            ],
        ),
        (
            'One feature',
            [],
            [('interests', 'cars')],
            [
                (('interests', 'cars'),)
            ],
        ),
        (
            'Two features',
            [],
            [('gender', 'f'), ('interests', 'cars')],
            [
                (('gender', 'f'), ('interests', 'cars'))
            ],
        ),
        (
            'One column, one feature',
            ['fruits'],
            [('interests', 'cars')],
            [
                (('fruits', 'apples'), ('interests', 'cars')),
                (('fruits', 'pears'), ('interests', 'cars'))
            ],
        ),
        (
            'One column, two features',
            ['fruits'],
            [('gender', 'f'), ('interests', 'cars')],
            [
                (('fruits', 'apples'), ('gender', 'f'), ('interests', 'cars')),
                (('fruits', 'pears'), ('gender', 'f'), ('interests', 'cars')),
            ],
        ),
        (
            'Two columns, one feature',
            ['interests', 'fruits'], [('gender', 'f')],
            [
                (('fruits', 'apples'), ('gender', 'f'), ('interests', 'cars')),
                (('fruits', 'pears'), ('gender', 'f'), ('interests', 'cars')),
                (('fruits', 'apples'), ('gender', 'f'), ('interests', 'dogs')),
                (('fruits', 'pears'), ('gender', 'f'), ('interests', 'dogs'))
            ],
        ),
        (
            'Two columns, two features',
            ['interests', 'fruits'],
            [('gender', 'f'), ('age', 5)],
            [
                (('age', 5), ('fruits', 'apples'), ('gender', 'f'),
                 ('interests', 'cars')),
                (('age', 5), ('fruits', 'pears'), ('gender', 'f'),
                 ('interests', 'cars')),
                (('age', 5), ('fruits', 'apples'), ('gender', 'f'),
                 ('interests', 'dogs')),
                (('age', 5), ('fruits', 'pears'), ('gender', 'f'),
                 ('interests', 'dogs'))
            ],
        )
    ]  # pyformat: disable

    features_dict = self._makeFeaturesDict({
        'gender': ['f'],
        'age': [5],
        'interests': ['cars', 'dogs'],
        'fruits': ['apples', 'pears']
    })

    for (name, columns, features, expected) in test_cases:
      self.assertSliceResult(name, features_dict, columns, features, expected)

  def testGetSlicesForFeaturesDictMultipleSingleSliceSpecs(self):
    features_dict = self._makeFeaturesDict({
        'gender': ['f'],
        'age': [5],
        'interest': ['cars']
    })

    spec_overall = slicer.SingleSliceSpec()
    spec_age = slicer.SingleSliceSpec(columns=['age'])
    spec_age4 = slicer.SingleSliceSpec(features=[('age', 4)])
    spec_age5_gender = slicer.SingleSliceSpec(
        columns=['gender'], features=[('age', 5)])

    slice_spec = [spec_overall, spec_age, spec_age4, spec_age5_gender]
    expected = [(), (('age', 5),), (('age', 5), ('gender', 'f'))]
    self.assertCountEqual(
        expected,
        slicer.get_slices_for_features_dicts([features_dict], None, slice_spec))

  def testStringifySliceKey(self):
    test_cases = [
        ('overall', (), 'Overall'),
        ('one bytes feature', (('age_str', '5'),), 'age_str:5'),
        ('one int64 feature', (('age', 1),), 'age:1'),
        ('mixed', (('age', 1), ('gender', 'f')), 'age_X_gender:1_X_f'),
        ('more', (('age', 1), ('gender', 'f'), ('interest', 'cars')),
         'age_X_gender_X_interest:1_X_f_X_cars'),
        ('unicode', (('text', b'\xe4\xb8\xad\xe6\x96\x87'),), u'text:\u4e2d\u6587'),
    ]  # pyformat: disable
    for (name, slice_key, stringified_key) in test_cases:
      self.assertEqual(
          stringified_key, slicer.stringify_slice_key(slice_key), msg=name)

  @parameterized.named_parameters(('empty_slice_keys', [], np.array([])),
                                  ('specific_and_overall_slice_key', [
                                      ('f', 1), ()
                                  ], np.array([('f', 1), ()], dtype=object)))
  def testSliceKeysToNumpy(self, slice_keys_tuples, expected_slice_keys_array):
    np.testing.assert_array_equal(
        slicer.slice_keys_to_numpy_array(slice_keys_tuples),
        expected_slice_keys_array)

  def testSliceKeysToNumpyOverall(self):
    actual = slicer.slice_keys_to_numpy_array([()])
    self.assertIsInstance(actual, np.ndarray)
    self.assertEqual(actual.dtype, object)
    self.assertEqual(actual.shape, (1,))
    self.assertEqual(actual[0], ())

  def testIsCrossSliceApplicable(self):
    test_cases = [
        (True, 'overall pass', ((), (('b', 2),)), config_pb2.CrossSlicingSpec(
            baseline_spec=config_pb2.SlicingSpec(),
            slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])),
        (True, 'value pass', ((('a', 1),), (('b', 2),)),
         config_pb2.CrossSlicingSpec(
             baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}),
             slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])),
        (True, 'baseline key pass', ((('a', 1),), (('b', 2),)),
         config_pb2.CrossSlicingSpec(
             baseline_spec=config_pb2.SlicingSpec(feature_keys=['a']),
             slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])),
        (True, 'comparison key pass', ((('a', 1),), (('b', 2),)),
         config_pb2.CrossSlicingSpec(
             baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}),
             slicing_specs=[config_pb2.SlicingSpec(feature_keys=['b'])])),
        (True, 'comparison multiple key pass', ((('a', 1),), (('c', 3),)),
         config_pb2.CrossSlicingSpec(
             baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}),
             slicing_specs=[config_pb2.SlicingSpec(feature_keys=['b']),
                            config_pb2.SlicingSpec(feature_keys=['c'])])),
        (False, 'overall fail', ((('a', 1),), (('b', 2),)),
         config_pb2.CrossSlicingSpec(
             baseline_spec=config_pb2.SlicingSpec(),
             slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])),
        (False, 'value fail', ((('a', 1),), (('b', 3),)),
         config_pb2.CrossSlicingSpec(
             baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}),
             slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])),
        (False, 'baseline key fail', ((('c', 1),), (('b', 2),)),
         config_pb2.CrossSlicingSpec(
             baseline_spec=config_pb2.SlicingSpec(feature_keys=['a']),
             slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])),
        (False, 'comparison key fail', ((('a', 1),), (('c', 3),)),
         config_pb2.CrossSlicingSpec(
             baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}),
             slicing_specs=[config_pb2.SlicingSpec(feature_keys=['b'])])),
        (False, 'comparison multiple key fail', ((('a', 1),), (('d', 3),)),
         config_pb2.CrossSlicingSpec(
             baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}),
             slicing_specs=[config_pb2.SlicingSpec(feature_keys=['b']),
                            config_pb2.SlicingSpec(feature_keys=['c'])])),
    ]  # pyformat: disable
    for (expected_result, name, sliced_key, slicing_spec) in test_cases:
      self.assertEqual(
          expected_result,
          slicer.is_cross_slice_applicable(
              cross_slice_key=sliced_key, cross_slicing_spec=slicing_spec),
          msg=name)

  def testGetSliceKeyType(self):
    test_cases = [
        (slicer.SliceKeyType, 'overall', ()),
        (slicer.SliceKeyType, 'one bytes feature', (('a', '5'),)),
        (slicer.SliceKeyType, 'one int64 feature', (('a', 1),)),
        (slicer.SliceKeyType, 'mixed', (('a', 1), ('b', 'f'))),
        (slicer.SliceKeyType, 'more', (('a', 1), ('b', 'f'), ('c', 'cars'))),
        (slicer.SliceKeyType, 'unicode',
         (('a', b'\xe4\xb8\xad\xe6\x96\x87'),)),
        (slicer.CrossSliceKeyType, 'CrossSlice overall', ((), ())),
        (slicer.CrossSliceKeyType, 'CrossSlice one slice key baseline',
         ((('a', '5'),), ())),
        (slicer.CrossSliceKeyType, 'CrossSlice one slice key comparison',
         ((), (('a', 1),))),
        (slicer.CrossSliceKeyType, 'CrossSlice two simple slice key',
         ((('a', 1),), (('b', 'f'),))),
        (slicer.CrossSliceKeyType, 'CrossSlice two multiple slice key',
         ((('a', 1), ('b', 'f'), ('c', '11')),
          (('a2', 1), ('b', 'm'), ('c', '11')))),
    ]  # pyformat: disable
    for (expected_result, name, slice_key) in test_cases:
      self.assertEqual(
          expected_result, slicer.get_slice_key_type(slice_key), msg=name)

    unrecognized_test_cases = [
        ('Unrecognized 1: ', ('a')),
        ('Unrecognized 2: ', ('a',)),
        ('Unrecognized 3: ', ('a', 1)),
        ('Unrecognized 4: ', (('a'))),
        ('Unrecognized 5: ', (('a',))),
        ('Unrecognized 6: ', ((), (), ())),
        ('Unrecognized 7: ', ((('a', 1),), (('b', 1),), (('c', 1),))),
        ('Unrecognized 8: ', ((('a', 1),), ('b', 1))),
        ('Unrecognized 9: ', (('a', 1), (('b', 1),))),
    ]  # pyformat: disable
    for (name, slice_key) in unrecognized_test_cases:
      with self.assertRaises(TypeError, msg=name + str(slice_key)):
        slicer.get_slice_key_type(slice_key)

  @parameterized.named_parameters(
      {
          'testcase_name': '_single_slice_spec',
          'slice_type': slicer.SingleSliceSpec,
          'slicing_spec': config_pb2.SlicingSpec(feature_values={'a': '1'}),
      }, {
          'testcase_name':
              '_cross_slice_spec',
          'slice_type':
              slicer.CrossSliceSpec,
          'slicing_spec':
              config_pb2.CrossSlicingSpec(
                  baseline_spec=config_pb2.SlicingSpec(),
                  slicing_specs=[
                      config_pb2.SlicingSpec(feature_values={'b': '2'})
                  ]),
      })
  def testDeserializeSliceSpec(self, slice_type, slicing_spec):
    slice_spec = slicer.deserialize_slice_spec(slicing_spec)
    self.assertIsInstance(slice_spec, slice_type)

  def testDeserializeSliceSpec_hashable(self):
    single_slice_spec = slicer.deserialize_slice_spec(
        config_pb2.SlicingSpec(feature_values={'a': '1'}))
    cross_slice_spec = slicer.deserialize_slice_spec(
        slicer.config_pb2.CrossSlicingSpec(
            baseline_spec=config_pb2.SlicingSpec(),
            slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})]))
    # Check either of them can be hashed and used as keys.
    slice_map = {single_slice_spec: 1, cross_slice_spec: 2}
    self.assertEqual(slice_map[single_slice_spec], 1)
    self.assertEqual(slice_map[cross_slice_spec], 2)

  def testIsSliceApplicable(self):
    test_cases = [
        ('applicable', ['column1'],
         [('column3', 'value3'), ('column4', 'value4')],
         (('column1', 'value1'), ('column3', 'value3'), ('column4', 'value4')),
         True),
        ('wrongcolumns', ['column1', 'column2'],
         [('column3', 'value3'), ('column4', 'value4')],
         (('column1', 'value1'), ('column3', 'value3'), ('column4', 'value4')),
         False),
        ('wrongfeatures', ['column1'], [('column3', 'value3')],
         (('column1', 'value1'), ('column3', 'value3'), ('column4', 'value4')),
         False),
        ('nocolumns', [], [('column3', 'value3')],
         (('column1', 'value1'), ('column3', 'value3'), ('column4', 'value4')),
         False),
        ('nofeatures', ['column1'], [], (('column1', 'value1'),), True),
        ('empty slice key', ['column1'], [('column2', 'value1')], (), False),
        ('overall', [], [], (), True)
    ]  # pyformat: disable

    for (name, columns, features, slice_key, result) in test_cases:
      slice_spec = slicer.SingleSliceSpec(columns=columns, features=features)
      self.assertEqual(
          slice_spec.is_slice_applicable(slice_key), result, msg=name)

  def testSliceDefaultSlice(self):
    with beam.Pipeline() as pipeline:
      fpls = create_fpls()

      metrics = (
          pipeline
          | 'CreateTestInput' >> beam.Create(fpls)
          | 'WrapFpls' >> beam.Map(wrap_fpl)
          | 'ExtractSlices' >> slice_key_extractor.ExtractSliceKeys(
              [slicer.SingleSliceSpec()])
          | 'FanoutSlices' >> slicer.FanoutSlices())

      def check_result(got):
        try:
          self.assertLen(got, 2)
          expected_result = [
              ((), wrap_fpl(fpls[0])),
              ((), wrap_fpl(fpls[1])),
          ]
          self.assertEqual(len(got), len(expected_result))
          self.assertTrue(
              got[0] == expected_result[0] and got[1] == expected_result[1] or
              got[1] == expected_result[0] and got[0] == expected_result[1])
        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(metrics, check_result)

  def testSliceOneSlice(self):
    with beam.Pipeline() as pipeline:
      fpls = create_fpls()
      metrics = (
          pipeline
          | 'CreateTestInput' >> beam.Create(fpls, reshuffle=False)
          | 'WrapFpls' >> beam.Map(wrap_fpl)
          | 'ExtractSlices' >> slice_key_extractor.ExtractSliceKeys([
              slicer.SingleSliceSpec(),
              slicer.SingleSliceSpec(columns=['gender'])
          ])
          | 'FanoutSlices' >> slicer.FanoutSlices())

      def check_result(got):
        try:
          self.assertLen(got, 4)
          expected_result = [
              ((), wrap_fpl(fpls[0])),
              ((), wrap_fpl(fpls[1])),
              ((('gender', 'f'),), wrap_fpl(fpls[0])),
              ((('gender', 'm'),), wrap_fpl(fpls[1])),
          ]
          self.assertCountEqual(got, expected_result)
        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(metrics, check_result)

  def testMultidimSlices(self):
    data = [{
        'features': {
            'gender': [['f'], ['f']],
            'age': [[13], [13]],
            'interest': [['cars'], ['cars']]
        },
        'predictions': [[1], [1]],
        'labels': [[0], [0]],
        constants.SLICE_KEY_TYPES_KEY:
            np.array([
                slicer.slice_keys_to_numpy_array([(), (('gender', 'f'),)]),
                slicer.slice_keys_to_numpy_array([(), (('gender', 'f'),)])
            ])
    }, {
        'features': {
            'gender': [['f'], ['m']],
            'age': [[13], [10]],
            'interest': [['cars'], ['cars']]
        },
        'predictions': [[1], [1]],
        'labels': [[0], [0]],
        constants.SLICE_KEY_TYPES_KEY:
            np.array([
                slicer.slice_keys_to_numpy_array([(), (('gender', 'f'),)]),
                slicer.slice_keys_to_numpy_array([(), (('gender', 'm'),)])
            ])
    }]

    with beam.Pipeline() as pipeline:
      result = (
          pipeline
          | 'CreateTestInput' >> beam.Create(data, reshuffle=False)
          | 'FanoutSlices' >> slicer.FanoutSlices())

      def check_result(got):
        try:
          self.assertLen(got, 5)
          del data[0][constants.SLICE_KEY_TYPES_KEY]
          del data[1][constants.SLICE_KEY_TYPES_KEY]
          expected_result = [
              ((), data[0]),
              ((), data[1]),
              ((('gender', 'f'),), data[0]),
              ((('gender', 'f'),), data[1]),
              ((('gender', 'm'),), data[1]),
          ]
          self.assertCountEqual(got, expected_result)
        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(result, check_result)

  def testMultidimOverallSlices(self):
    data = [
        {
            constants.SLICE_KEY_TYPES_KEY:  # variable length batch case
                types.VarLenTensorValue.from_dense_rows([
                    slicer.slice_keys_to_numpy_array([(('gender', 'f'),), ()]),
                    slicer.slice_keys_to_numpy_array([()])
                ])
        },
        {
            constants.SLICE_KEY_TYPES_KEY:  # fixed length batch case
                np.array([
                    slicer.slice_keys_to_numpy_array([()]),
                    slicer.slice_keys_to_numpy_array([()])
                ])
        }
    ]

    with beam.Pipeline() as pipeline:
      result = (
          pipeline
          | 'CreateTestInput' >> beam.Create(data, reshuffle=False)
          | 'FanoutSlices' >> slicer.FanoutSlices())

      def check_result(got):
        try:
          del data[0][constants.SLICE_KEY_TYPES_KEY]
          del data[1][constants.SLICE_KEY_TYPES_KEY]
          expected_result = [
              ((('gender', 'f'),), data[0]),
              ((), data[0]),
              ((), data[1]),
          ]
          self.assertCountEqual(got, expected_result)
        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(result, check_result)

  def testFilterOutSlices(self):
    slice_key_1 = (('slice_key', 'slice1'),)
    slice_key_2 = (('slice_key', 'slice2'),)
    slice_key_3 = (('slice_key', 'slice3'),)

    values_list = [(slice_key_1, {
        'val11': 'val12'
    }), (slice_key_2, {
        'val21': 'val22'
    })]
    slice_counts_list = [(slice_key_1, 2), (slice_key_2, 1), (slice_key_3, 0)]

    def check_output(got):
      try:
        self.assertLen(got, 2)
        slices = {}
        for (k, v) in got:
          slices[k] = v

        self.assertEqual(slices[slice_key_1], {'val11': 'val12'})
        self.assertIn(metric_keys.ERROR_METRIC, slices[slice_key_2])
      except AssertionError as err:
        raise util.BeamAssertException(err)

    with beam.Pipeline() as pipeline:
      slice_counts_pcoll = (
          pipeline | 'CreateSliceCountsPColl' >> beam.Create(slice_counts_list))
      output_dict = (
          pipeline
          | 'CreateValuesPColl' >> beam.Create(values_list)
          | 'FilterOutSlices' >> slicer.FilterOutSlices(
              slice_counts_pcoll,
              min_slice_size=2,
              error_metric_key=metric_keys.ERROR_METRIC))
      util.assert_that(output_dict, check_output)

  @parameterized.named_parameters(
      {
          'testcase_name': 'matching_single_spec',
          'slice_key': (('f1', 1),),
          'slice_specs': [slicer.SingleSliceSpec(features=[('f1', 1)])],
          'expected_result': True
      },
      {
          'testcase_name': 'matching_single_spec_with_float',
          'slice_key': (('f1', '1.0'),),
          'slice_specs': [slicer.SingleSliceSpec(features=[('f1', '1.0')])],
          'expected_result': True
      },
      {
          'testcase_name': 'non_matching_single_spec',
          'slice_key': (('f1', 1),),
          'slice_specs': [slicer.SingleSliceSpec(columns=['f2'])],
          'expected_result': False
      },
      {
          'testcase_name': 'matching_multiple_specs',
          'slice_key': (('f1', 1),),
          'slice_specs': [
              slicer.SingleSliceSpec(columns=['f1']),
              slicer.SingleSliceSpec(columns=['f2'])
          ],
          'expected_result': True
      },
      {
          'testcase_name': 'empty_specs',
          'slice_key': (('f1', 1),),
          'slice_specs': [],
          'expected_result': False
      },
  )
  def testSliceKeyMatchesSliceSpecs(self, slice_key, slice_specs,
                                    expected_result):
    self.assertEqual(
        expected_result,
        slicer.slice_key_matches_slice_specs(slice_key, slice_specs))
Пример #12
0
 def testIsCrossSliceApplicable(self):
   test_cases = [
       (True, 'overall pass', ((), (('b', 2),)), config_pb2.CrossSlicingSpec(
           baseline_spec=config_pb2.SlicingSpec(),
           slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])),
       (True, 'value pass', ((('a', 1),), (('b', 2),)),
        config_pb2.CrossSlicingSpec(
            baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}),
            slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])),
       (True, 'baseline key pass', ((('a', 1),), (('b', 2),)),
        config_pb2.CrossSlicingSpec(
            baseline_spec=config_pb2.SlicingSpec(feature_keys=['a']),
            slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])),
       (True, 'comparison key pass', ((('a', 1),), (('b', 2),)),
        config_pb2.CrossSlicingSpec(
            baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}),
            slicing_specs=[config_pb2.SlicingSpec(feature_keys=['b'])])),
       (True, 'comparison multiple key pass', ((('a', 1),), (('c', 3),)),
        config_pb2.CrossSlicingSpec(
            baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}),
            slicing_specs=[config_pb2.SlicingSpec(feature_keys=['b']),
                           config_pb2.SlicingSpec(feature_keys=['c'])])),
       (False, 'overall fail', ((('a', 1),), (('b', 2),)),
        config_pb2.CrossSlicingSpec(
            baseline_spec=config_pb2.SlicingSpec(),
            slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])),
       (False, 'value fail', ((('a', 1),), (('b', 3),)),
        config_pb2.CrossSlicingSpec(
            baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}),
            slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])),
       (False, 'baseline key fail', ((('c', 1),), (('b', 2),)),
        config_pb2.CrossSlicingSpec(
            baseline_spec=config_pb2.SlicingSpec(feature_keys=['a']),
            slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])),
       (False, 'comparison key fail', ((('a', 1),), (('c', 3),)),
        config_pb2.CrossSlicingSpec(
            baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}),
            slicing_specs=[config_pb2.SlicingSpec(feature_keys=['b'])])),
       (False, 'comparison multiple key fail', ((('a', 1),), (('d', 3),)),
        config_pb2.CrossSlicingSpec(
            baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}),
            slicing_specs=[config_pb2.SlicingSpec(feature_keys=['b']),
                           config_pb2.SlicingSpec(feature_keys=['c'])])),
   ]  # pyformat: disable
   for (expected_result, name, sliced_key, slicing_spec) in test_cases:
     self.assertEqual(
         expected_result,
         slicer.is_cross_slice_applicable(
             cross_slice_key=sliced_key, cross_slicing_spec=slicing_spec),
         msg=name)
Пример #13
0
    def testMetricThresholdsFromMetricsSpecs(self):
        slice_specs = [
            config_pb2.SlicingSpec(feature_keys=['feature1']),
            config_pb2.SlicingSpec(feature_values={'feature2': 'value1'})
        ]

        # For cross slice tests.
        baseline_slice_spec = config_pb2.SlicingSpec(feature_keys=['feature3'])

        metrics_specs = [
            config_pb2.MetricsSpec(
                thresholds={
                    'auc':
                    config_pb2.MetricThreshold(
                        value_threshold=config_pb2.GenericValueThreshold()),
                    'mean/label':
                    config_pb2.MetricThreshold(
                        value_threshold=config_pb2.GenericValueThreshold(),
                        change_threshold=config_pb2.GenericChangeThreshold()),
                    'mse':
                    config_pb2.MetricThreshold(
                        change_threshold=config_pb2.GenericChangeThreshold())
                },
                per_slice_thresholds={
                    'auc':
                    config_pb2.PerSliceMetricThresholds(thresholds=[
                        config_pb2.PerSliceMetricThreshold(
                            slicing_specs=slice_specs,
                            threshold=config_pb2.MetricThreshold(
                                value_threshold=config_pb2.
                                GenericValueThreshold()))
                    ]),
                    'mean/label':
                    config_pb2.PerSliceMetricThresholds(thresholds=[
                        config_pb2.PerSliceMetricThreshold(
                            slicing_specs=slice_specs,
                            threshold=config_pb2.MetricThreshold(
                                value_threshold=config_pb2.
                                GenericValueThreshold(),
                                change_threshold=config_pb2.
                                GenericChangeThreshold()))
                    ])
                },
                cross_slice_thresholds={
                    'auc':
                    config_pb2.CrossSliceMetricThresholds(thresholds=[
                        config_pb2.CrossSliceMetricThreshold(
                            cross_slicing_specs=[
                                config_pb2.CrossSlicingSpec(
                                    baseline_spec=baseline_slice_spec,
                                    slicing_specs=slice_specs)
                            ],
                            threshold=config_pb2.MetricThreshold(
                                value_threshold=config_pb2.
                                GenericValueThreshold(),
                                change_threshold=config_pb2.
                                GenericChangeThreshold()))
                    ]),
                    'mse':
                    config_pb2.CrossSliceMetricThresholds(thresholds=[
                        config_pb2.CrossSliceMetricThreshold(
                            cross_slicing_specs=[
                                config_pb2.CrossSlicingSpec(
                                    baseline_spec=baseline_slice_spec,
                                    slicing_specs=slice_specs)
                            ],
                            threshold=config_pb2.MetricThreshold(
                                change_threshold=config_pb2.
                                GenericChangeThreshold())),
                        # Test for duplicate cross_slicing_spec.
                        config_pb2.CrossSliceMetricThreshold(
                            cross_slicing_specs=[
                                config_pb2.CrossSlicingSpec(
                                    baseline_spec=baseline_slice_spec,
                                    slicing_specs=slice_specs)
                            ],
                            threshold=config_pb2.MetricThreshold(
                                value_threshold=config_pb2.
                                GenericValueThreshold()))
                    ])
                },
                model_names=['model_name'],
                output_names=['output_name']),
            config_pb2.MetricsSpec(metrics=[
                config_pb2.MetricConfig(
                    class_name='ExampleCount',
                    config=json.dumps({'name': 'example_count'}),
                    threshold=config_pb2.MetricThreshold(
                        value_threshold=config_pb2.GenericValueThreshold()))
            ],
                                   model_names=['model_name1', 'model_name2'],
                                   example_weights=config_pb2.
                                   ExampleWeightOptions(unweighted=True)),
            config_pb2.MetricsSpec(metrics=[
                config_pb2.MetricConfig(
                    class_name='WeightedExampleCount',
                    config=json.dumps({'name': 'weighted_example_count'}),
                    threshold=config_pb2.MetricThreshold(
                        value_threshold=config_pb2.GenericValueThreshold()))
            ],
                                   model_names=['model_name1', 'model_name2'],
                                   output_names=[
                                       'output_name1', 'output_name2'
                                   ],
                                   example_weights=config_pb2.
                                   ExampleWeightOptions(weighted=True)),
            config_pb2.MetricsSpec(metrics=[
                config_pb2.MetricConfig(
                    class_name='MeanSquaredError',
                    config=json.dumps({'name': 'mse'}),
                    threshold=config_pb2.MetricThreshold(
                        change_threshold=config_pb2.GenericChangeThreshold())),
                config_pb2.MetricConfig(
                    class_name='MeanLabel',
                    config=json.dumps({'name': 'mean_label'}),
                    threshold=config_pb2.MetricThreshold(
                        change_threshold=config_pb2.GenericChangeThreshold()),
                    per_slice_thresholds=[
                        config_pb2.PerSliceMetricThreshold(
                            slicing_specs=slice_specs,
                            threshold=config_pb2.MetricThreshold(
                                change_threshold=config_pb2.
                                GenericChangeThreshold())),
                    ],
                    cross_slice_thresholds=[
                        config_pb2.CrossSliceMetricThreshold(
                            cross_slicing_specs=[
                                config_pb2.CrossSlicingSpec(
                                    baseline_spec=baseline_slice_spec,
                                    slicing_specs=slice_specs)
                            ],
                            threshold=config_pb2.MetricThreshold(
                                change_threshold=config_pb2.
                                GenericChangeThreshold()))
                    ]),
            ],
                                   model_names=['model_name'],
                                   output_names=['output_name'],
                                   binarize=config_pb2.BinarizationOptions(
                                       class_ids={'values': [0, 1]}),
                                   aggregate=config_pb2.AggregationOptions(
                                       macro_average=True,
                                       class_weights={
                                           0: 1.0,
                                           1: 1.0
                                       }))
        ]

        thresholds = metric_specs.metric_thresholds_from_metrics_specs(
            metrics_specs, eval_config=config_pb2.EvalConfig())

        expected_keys_and_threshold_counts = {
            metric_types.MetricKey(name='auc',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=False,
                                   example_weighted=None):
            4,
            metric_types.MetricKey(name='auc',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=True,
                                   example_weighted=None):
            1,
            metric_types.MetricKey(name='mean/label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=True,
                                   example_weighted=None):
            3,
            metric_types.MetricKey(name='mean/label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=False,
                                   example_weighted=None):
            3,
            metric_types.MetricKey(name='example_count',
                                   model_name='model_name1'):
            1,
            metric_types.MetricKey(name='example_count',
                                   model_name='model_name2'):
            1,
            metric_types.MetricKey(name='weighted_example_count',
                                   model_name='model_name1',
                                   output_name='output_name1',
                                   example_weighted=True):
            1,
            metric_types.MetricKey(name='weighted_example_count',
                                   model_name='model_name1',
                                   output_name='output_name2',
                                   example_weighted=True):
            1,
            metric_types.MetricKey(name='weighted_example_count',
                                   model_name='model_name2',
                                   output_name='output_name1',
                                   example_weighted=True):
            1,
            metric_types.MetricKey(name='weighted_example_count',
                                   model_name='model_name2',
                                   output_name='output_name2',
                                   example_weighted=True):
            1,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=0),
                                   is_diff=True):
            1,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=1),
                                   is_diff=True):
            1,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=False,
                                   example_weighted=None):
            1,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=True,
                                   example_weighted=None):
            2,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   aggregation_type=metric_types.AggregationType(macro_average=True),
                                   is_diff=True):
            1,
            metric_types.MetricKey(name='mean_label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=0),
                                   is_diff=True):
            4,
            metric_types.MetricKey(name='mean_label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=1),
                                   is_diff=True):
            4,
            metric_types.MetricKey(name='mean_label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   aggregation_type=metric_types.AggregationType(macro_average=True),
                                   is_diff=True):
            4
        }
        self.assertLen(thresholds, len(expected_keys_and_threshold_counts))
        for key, count in expected_keys_and_threshold_counts.items():
            self.assertIn(key, thresholds)
            self.assertLen(thresholds[key], count,
                           'failed for key {}'.format(key))
Пример #14
0
    def testSqlSliceKeyExtractorWithTransformedFeatures(self):
        eval_config = config_pb2.EvalConfig(
            model_specs=[
                config_pb2.ModelSpec(name='model1'),
                config_pb2.ModelSpec(name='model2')
            ],
            slicing_specs=[
                config_pb2.SlicingSpec(slice_keys_sql="""
            SELECT
              STRUCT(fixed_string)
            FROM
              example.fixed_string,
              example.fixed_int
            WHERE fixed_int = 1
            """)
            ])
        slice_key_extractor = sql_slice_key_extractor.SqlSliceKeyExtractor(
            eval_config)

        extracts = {
            constants.FEATURES_KEY: {
                'fixed_int': np.array([1, 1, 2]),
            },
            constants.TRANSFORMED_FEATURES_KEY: {
                'model1': {
                    'fixed_int':
                    np.array([1, 1, 2]),
                    'fixed_float':
                    np.array([1.0, 1.0, 0.0]),
                    'fixed_string':
                    np.array(
                        ['fixed_string1', 'fixed_string2', 'fixed_string3'])
                },
                'model2': {
                    'fixed_int':
                    np.array([1, 1, 2]),
                    'fixed_string':
                    np.array(
                        ['fixed_string1', 'fixed_string2', 'fixed_string3'])
                },
            }
        }

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (pipeline
                      | 'CreateTestInput' >> beam.Create([extracts])
                      | slice_key_extractor.stage_name >>
                      slice_key_extractor.ptransform)

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    np.testing.assert_equal(
                        got[0][constants.SLICE_KEY_TYPES_KEY],
                        types.VarLenTensorValue.from_dense_rows([
                            slicer_lib.slice_keys_to_numpy_array([
                                (('fixed_string', 'fixed_string1'), )
                            ]),
                            slicer_lib.slice_keys_to_numpy_array([
                                (('fixed_string', 'fixed_string2'), )
                            ]),
                            np.array([])
                        ]))

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result)