Exemple #1
0
    def testDo(self):
        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')

        eval_stats_artifact = standard_artifacts.ExampleStatistics()
        eval_stats_artifact.uri = os.path.join(source_data_dir,
                                               'statistics_gen')
        eval_stats_artifact.split_names = artifact_utils.encode_split_names(
            ['train', 'eval', 'test'])

        schema_artifact = standard_artifacts.Schema()
        schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen')

        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        validation_output = standard_artifacts.ExampleAnomalies()
        validation_output.uri = os.path.join(output_data_dir, 'output')

        input_dict = {
            STATISTICS_KEY: [eval_stats_artifact],
            SCHEMA_KEY: [schema_artifact],
        }

        exec_properties = {
            # List needs to be serialized before being passed into Do function.
            EXCLUDE_SPLITS_KEY: json_utils.dumps(['test'])
        }

        output_dict = {
            ANOMALIES_KEY: [validation_output],
        }

        example_validator_executor = executor.Executor()
        example_validator_executor.Do(input_dict, output_dict, exec_properties)

        self.assertEqual(artifact_utils.encode_split_names(['train', 'eval']),
                         validation_output.split_names)

        # Check example_validator outputs.
        train_anomalies_path = os.path.join(validation_output.uri,
                                            'Split-train', 'SchemaDiff.pb')
        eval_anomalies_path = os.path.join(validation_output.uri, 'Split-eval',
                                           'SchemaDiff.pb')
        self.assertTrue(fileio.exists(train_anomalies_path))
        self.assertTrue(fileio.exists(eval_anomalies_path))
        train_anomalies_bytes = io_utils.read_bytes_file(train_anomalies_path)
        train_anomalies = anomalies_pb2.Anomalies()
        train_anomalies.ParseFromString(train_anomalies_bytes)
        eval_anomalies_bytes = io_utils.read_bytes_file(eval_anomalies_path)
        eval_anomalies = anomalies_pb2.Anomalies()
        eval_anomalies.ParseFromString(eval_anomalies_bytes)
        self.assertEqual(0, len(train_anomalies.anomaly_info))
        self.assertEqual(0, len(eval_anomalies.anomaly_info))

        # Assert 'test' split is excluded.
        train_file_path = os.path.join(validation_output.uri, 'Split-test',
                                       'SchemaDiff.pb')
        self.assertFalse(fileio.exists(train_file_path))
Exemple #2
0
def compare_anomalies(output_uri: Text, expected_uri: Text) -> bool:
    """Compares anomalies files in output uri and recorded uri.

  Args:
    output_uri: pipeline output artifact uri.
    expected_uri: recorded pipeline output artifact uri.

  Returns:
     boolean whether anomalies are same.
  """
    for dir_name, _, leaf_files in fileio.walk(expected_uri):
        for leaf_file in leaf_files:
            expected_file_name = os.path.join(dir_name, leaf_file)
            file_name = os.path.join(
                dir_name.replace(expected_uri, output_uri, 1), leaf_file)
            anomalies = anomalies_pb2.Anomalies()
            anomalies.ParseFromString(
                io_utils.read_bytes_file(os.path.join(output_uri, file_name)))
            expected_anomalies = anomalies_pb2.Anomalies()
            expected_anomalies.ParseFromString(
                io_utils.read_bytes_file(
                    os.path.join(expected_uri, expected_file_name)))
            if expected_anomalies.anomaly_info != anomalies.anomaly_info:
                return False
    return True
Exemple #3
0
 def test_remove_anomaly_types_removes_diff_regions(self):
     anomaly_types_to_remove = set([
         anomalies_pb2.AnomalyInfo.ENUM_TYPE_BYTES_NOT_STRING,
     ])
     # The anomaly_info has multiple diff regions.
     anomalies = text_format.Parse(
         """
    anomaly_info {
      key: "feature_1"
      value {
           description: "Expected bytes but got string. Examples contain "
             "values missing from the schema."
           severity: ERROR
           short_description: "Multiple errors"
           diff_regions {
             removed {
               start: 1
               contents: "Test contents"
             }
           }
           diff_regions {
             added {
               start: 1
               contents: "Test contents"
             }
           }
           reason {
             type: ENUM_TYPE_BYTES_NOT_STRING
             short_description: "Bytes not string"
             description: "Expected bytes but got string."
           }
           reason {
             type: ENUM_TYPE_UNEXPECTED_STRING_VALUES
             short_description: "Unexpected string values"
             description: "Examples contain values missing from the schema."
           }
       }
     }""", anomalies_pb2.Anomalies())
     expected_result = text_format.Parse(
         """
    anomaly_info {
      key: "feature_1"
      value {
           description: "Examples contain values missing from the schema."
           severity: ERROR
           short_description: "Unexpected string values"
           reason {
             type: ENUM_TYPE_UNEXPECTED_STRING_VALUES
             short_description: "Unexpected string values"
             description: "Examples contain values missing from the schema."
           }
       }
     }""", anomalies_pb2.Anomalies())
     anomalies_util.remove_anomaly_types(anomalies, anomaly_types_to_remove)
     compare.assertProtoEqual(self, anomalies, expected_result)
Exemple #4
0
 def test_remove_anomaly_types_does_not_change_proto(
         self, anomaly_types_to_remove, input_anomalies_proto_text):
     """Tests where remove_anomaly_types does not modify the Anomalies proto."""
     input_anomalies_proto = text_format.Parse(input_anomalies_proto_text,
                                               anomalies_pb2.Anomalies())
     expected_anomalies_proto = anomalies_pb2.Anomalies()
     expected_anomalies_proto.CopyFrom(input_anomalies_proto)
     anomalies_util.remove_anomaly_types(input_anomalies_proto,
                                         anomaly_types_to_remove)
     compare.assertProtoEqual(self, input_anomalies_proto,
                              expected_anomalies_proto)
Exemple #5
0
 def test_remove_anomaly_types_changes_proto(self, anomaly_types_to_remove,
                                             input_anomalies_proto_text,
                                             expected_anomalies_proto_text):
     """Tests where remove_anomaly_types modifies the Anomalies proto."""
     input_anomalies_proto = text_format.Parse(input_anomalies_proto_text,
                                               anomalies_pb2.Anomalies())
     expected_anomalies_proto = text_format.Parse(
         expected_anomalies_proto_text, anomalies_pb2.Anomalies())
     anomalies_util.remove_anomaly_types(input_anomalies_proto,
                                         anomaly_types_to_remove)
     compare.assertProtoEqual(self, input_anomalies_proto,
                              expected_anomalies_proto)
 def test_get_anomalies_dataframe(self):
     anomalies = text_format.Parse(
         """
 anomaly_info {
  key: "feature_1"
  value {
     description: "Expected bytes but got string."
     severity: ERROR
     short_description: "Bytes not string"
     reason {
       type: ENUM_TYPE_BYTES_NOT_STRING
       short_description: "Bytes not string"
       description: "Expected bytes but got string."
     }
   }
 }
 anomaly_info {
   key: "feature_2"
   value {
     description: "Examples contain values missing from the schema."
     severity: ERROR
     short_description: "Unexpected string values"
     reason {
       type: ENUM_TYPE_UNEXPECTED_STRING_VALUES
       short_description: "Unexpected string values"
       description: "Examples contain values missing from the "
         "schema."
     }
   }
 }
 """, anomalies_pb2.Anomalies())
     actual_output = display_util.get_anomalies_dataframe(anomalies)
     # The resulting DataFrame has a row for each feature and a column for each
     # of the short description and long description.
     self.assertEqual(actual_output.shape, (2, 2))
Exemple #7
0
 def test_anomalies_slicer(self, input_anomalies_proto_text,
                           expected_slice_keys):
     example = pa.Table.from_arrays([])
     anomalies = text_format.Parse(input_anomalies_proto_text,
                                   anomalies_pb2.Anomalies())
     slice_keys = anomalies_util.anomalies_slicer(example, anomalies)
     self.assertCountEqual(slice_keys, expected_slice_keys)
 def test_anomalies_slicer(self, input_anomalies_proto_text,
                           expected_slice_keys):
     example = pa.RecordBatch.from_arrays([])
     anomalies = text_format.Parse(input_anomalies_proto_text,
                                   anomalies_pb2.Anomalies())
     slicer = anomalies_util.get_anomalies_slicer(anomalies)
     actual_slice_keys = []
     for slice_key, actual_example in slicer(example):
         self.assertEqual(actual_example, example)
         actual_slice_keys.append(slice_key)
     self.assertCountEqual(actual_slice_keys, expected_slice_keys)
def load_anomalies_text(input_path: Text) -> anomalies_pb2.Anomalies:
    """Loads the Anomalies proto stored in text format in the input path.

  Args:
    input_path: File path from which to load the Anomalies proto.

  Returns:
    An Anomalies protocol buffer.
  """
    anomalies = anomalies_pb2.Anomalies()
    anomalies_text = io_util.read_file_to_string(input_path)
    text_format.Parse(anomalies_text, anomalies)
    return anomalies
Exemple #10
0
 def display(self, artifact: types.Artifact):
   from IPython.core.display import display  # pylint: disable=g-import-not-at-top
   from IPython.core.display import HTML  # pylint: disable=g-import-not-at-top
   for split in artifact_utils.decode_split_names(artifact.split_names):
     display(HTML('<div><b>%r split:</b></div><br/>' % split))
     anomalies_path = io_utils.get_only_uri_in_dir(
         artifact_utils.get_split_uri([artifact], split))
     if artifact_utils.is_artifact_version_older_than(
         artifact, artifact_utils._ARTIFACT_VERSION_FOR_ANOMALIES_UPDATE):  # pylint: disable=protected-access
       anomalies = tfdv.load_anomalies_text(anomalies_path)
     else:
       anomalies = anomalies_pb2.Anomalies()
       anomalies_bytes = io_utils.read_bytes_file(anomalies_path)
       anomalies.ParseFromString(anomalies_bytes)
     tfdv.display_anomalies(anomalies)
def load_anomalies_binary(input_path: Text) -> anomalies_pb2.Anomalies:
    """Loads the Anomalies proto stored in binary format in the input path.

  Args:
    input_path: File path from which to load the Anomalies proto.

  Returns:
    An Anomalies protocol buffer.
  """
    anomalies_proto = anomalies_pb2.Anomalies()

    anomalies_proto.ParseFromString(
        io_util.read_file_to_string(input_path, binary_mode=True))

    return anomalies_proto
Exemple #12
0
    def test_e2e(self, stats_options, expected_stats_pbtxt,
                 expected_inferred_schema_pbtxt, schema_for_validation_pbtxt,
                 expected_anomalies_pbtxt, expected_updated_schema_pbtxt):
        tfxio = tf_sequence_example_record.TFSequenceExampleRecord(
            self._input_file, ['tfdv', 'test'])
        stats_file = os.path.join(self._output_dir, 'stats')
        with beam.Pipeline() as p:
            _ = (p
                 | 'TFXIORead' >> tfxio.BeamSource()
                 | 'GenerateStats' >> tfdv.GenerateStatistics(stats_options)
                 | 'WriteStats' >> tfdv.WriteStatisticsToTFRecord(stats_file))

        actual_stats = tfdv.load_statistics(stats_file)
        test_util.make_dataset_feature_stats_list_proto_equal_fn(
            self,
            text_format.Parse(
                expected_stats_pbtxt,
                statistics_pb2.DatasetFeatureStatisticsList()))([actual_stats])
        actual_inferred_schema = tfdv.infer_schema(actual_stats,
                                                   infer_feature_shape=True)

        if hasattr(actual_inferred_schema, 'generate_legacy_feature_spec'):
            actual_inferred_schema.ClearField('generate_legacy_feature_spec')
        self._assert_schema_equal(
            actual_inferred_schema,
            text_format.Parse(expected_inferred_schema_pbtxt,
                              schema_pb2.Schema()))

        schema_for_validation = text_format.Parse(schema_for_validation_pbtxt,
                                                  schema_pb2.Schema())
        actual_anomalies = tfdv.validate_statistics(actual_stats,
                                                    schema_for_validation)
        actual_anomalies.ClearField('baseline')
        self.assertEqual(
            actual_anomalies,
            text_format.Parse(expected_anomalies_pbtxt,
                              anomalies_pb2.Anomalies()))

        actual_updated_schema = tfdv.update_schema(schema_for_validation,
                                                   actual_stats,
                                                   infer_feature_shape=False)
        self._assert_schema_equal(
            actual_updated_schema,
            text_format.Parse(expected_updated_schema_pbtxt,
                              schema_pb2.Schema()))
Exemple #13
0
    def testDo(self):
        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')

        eval_stats_artifact = standard_artifacts.ExampleStatistics()
        eval_stats_artifact.uri = os.path.join(source_data_dir,
                                               'statistics_gen')
        eval_stats_artifact.split_names = artifact_utils.encode_split_names(
            ['eval'])

        schema_artifact = standard_artifacts.Schema()
        schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen')

        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        validation_output = standard_artifacts.ExampleAnomalies()
        validation_output.uri = os.path.join(output_data_dir, 'output')

        input_dict = {
            executor.STATISTICS_KEY: [eval_stats_artifact],
            executor.SCHEMA_KEY: [schema_artifact],
        }
        output_dict = {
            executor.ANOMALIES_KEY: [validation_output],
        }

        exec_properties = {}

        example_validator_executor = executor.Executor()
        example_validator_executor.Do(input_dict, output_dict, exec_properties)
        self.assertEqual(['anomalies.pbtxt'],
                         tf.io.gfile.listdir(validation_output.uri))
        anomalies = io_utils.parse_pbtxt_file(
            os.path.join(validation_output.uri, 'anomalies.pbtxt'),
            anomalies_pb2.Anomalies())
        self.assertNotEqual(0, len(anomalies.anomaly_info))
 def test_load_anomalies_binary(self):
     anomalies = text_format.Parse(
         """
      anomaly_info {
        key: "feature_1"
        value {
           description: "Examples contain values missing from the "
             "schema."
           severity: ERROR
           short_description: "Unexpected string values"
           reason {
             type: ENUM_TYPE_UNEXPECTED_STRING_VALUES
             short_description: "Unexpected string values"
             description: "Examples contain values missing from the "
               "schema."
           }
         }
       }""", anomalies_pb2.Anomalies())
     anomalies_path = os.path.join(FLAGS.test_tmpdir, 'anomalies.binpb')
     with open(anomalies_path, 'w+b') as file:
         file.write(anomalies.SerializeToString())
     self.assertEqual(
         anomalies,
         anomalies_util.load_anomalies_binary(input_path=anomalies_path))
Exemple #15
0
    def test_do(self):
        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')

        eval_stats_artifact = types.Artifact('ExampleStatsPath', split='eval')
        eval_stats_artifact.uri = os.path.join(source_data_dir,
                                               'statistics_gen/eval/')

        schema_artifact = standard_artifacts.Schema()
        schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen/')

        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        validation_output = standard_artifacts.ExampleValidationResult()
        validation_output.uri = os.path.join(output_data_dir, 'output')

        input_dict = {
            'stats': [eval_stats_artifact],
            'schema': [schema_artifact],
        }
        output_dict = {
            'output': [validation_output],
        }

        exec_properties = {}

        example_validator_executor = executor.Executor()
        example_validator_executor.Do(input_dict, output_dict, exec_properties)
        self.assertEqual(['anomalies.pbtxt'],
                         tf.gfile.ListDirectory(validation_output.uri))
        anomalies = io_utils.parse_pbtxt_file(
            os.path.join(validation_output.uri, 'anomalies.pbtxt'),
            anomalies_pb2.Anomalies())
        self.assertNotEqual(0, len(anomalies.anomaly_info))
Exemple #16
0
 def test_write_load_anomalies_text(self):
     anomalies = text_format.Parse(
         """
          anomaly_info {
            key: "feature_1"
            value {
               description: "Examples contain values missing from the "
                 "schema."
               severity: ERROR
               short_description: "Unexpected string values"
               reason {
                 type: ENUM_TYPE_UNEXPECTED_STRING_VALUES
                 short_description: "Unexpected string values"
                 description: "Examples contain values missing from the "
                   "schema."
               }
             }
           }""", anomalies_pb2.Anomalies())
     anomalies_path = os.path.join(FLAGS.test_tmpdir, 'anomalies.pbtxt')
     anomalies_util.write_anomalies_text(anomalies=anomalies,
                                         output_path=anomalies_path)
     loaded_anomalies = anomalies_util.load_anomalies_text(
         input_path=anomalies_path)
     self.assertEqual(anomalies, loaded_anomalies)
Exemple #17
0
def validate_statistics(
    statistics,
    schema,
    environment=None,
    previous_statistics=None,
    serving_statistics=None,
):
    """Validates the input statistics against the provided input schema.

  This method validates the `statistics` against the `schema`. If an optional
  `environment` is specified, the `schema` is filtered using the
  `environment` and the `statistics` is validated against the filtered schema.
  The optional `previous_statistics` and `serving_statistics` are the statistics
  computed over the treatment data for drift- and skew-detection, respectively.

  Args:
    statistics: A DatasetFeatureStatisticsList protocol buffer denoting the
        statistics computed over the current data. Validation is currently only
        supported for lists with a single DatasetFeatureStatistics proto.
    schema: A Schema protocol buffer.
    environment: An optional string denoting the validation environment.
        Must be one of the default environments specified in the schema.
        By default, validation assumes that all Examples in a pipeline adhere
        to a single schema. In some cases introducing slight schema variations
        is necessary, for instance features used as labels are required during
        training (and should be validated), but are missing during serving.
        Environments can be used to express such requirements. For example,
        assume a feature named 'LABEL' is required for training, but is expected
        to be missing from serving. This can be expressed by defining two
        distinct environments in schema: ["SERVING", "TRAINING"] and
        associating 'LABEL' only with environment "TRAINING".
    previous_statistics: An optional DatasetFeatureStatisticsList protocol
        buffer denoting the statistics computed over an earlier data (for
        example, previous day's data). If provided, the `validate_statistics`
        method will detect if there exists drift between current data and
        previous data. Configuration for drift detection can be done by
        specifying a `drift_comparator` in the schema. For now drift detection
        is only supported for categorical features.
    serving_statistics: An optional DatasetFeatureStatisticsList protocol
        buffer denoting the statistics computed over the serving data. If
        provided, the `validate_statistics` method will identify if there exists
        distribution skew between current data and serving data. Configuration
        for skew detection can be done by specifying a `skew_comparator` in the
        schema. For now skew detection is only supported for categorical
        features.

  Returns:
    An Anomalies protocol buffer.

  Raises:
    TypeError: If any of the input arguments is not of the expected type.
    ValueError: If the input statistics proto does not have only one dataset.
  """
    if not isinstance(statistics, statistics_pb2.DatasetFeatureStatisticsList):
        raise TypeError('statistics is of type %s, should be '
                        'a DatasetFeatureStatisticsList proto.' %
                        type(statistics).__name__)

    if len(statistics.datasets) != 1:
        raise ValueError('statistics proto contains multiple datasets. Only '
                         'one dataset is currently supported for validation.')

    if not isinstance(schema, schema_pb2.Schema):
        raise TypeError('schema is of type %s, should be a Schema proto.' %
                        type(schema).__name__)

    if environment is not None:
        if environment not in schema.default_environment:
            raise ValueError('Environment %s not found in the schema.' %
                             environment)
    else:
        environment = ''

    _check_for_unsupported_stats_fields(statistics.datasets[0], 'statistics')
    _check_for_unsupported_schema_fields(schema)

    if previous_statistics is not None:
        if not isinstance(previous_statistics,
                          statistics_pb2.DatasetFeatureStatisticsList):
            raise TypeError('previous_statistics is of type %s, should be '
                            'a DatasetFeatureStatisticsList proto.' %
                            type(previous_statistics).__name__)

        if len(previous_statistics.datasets) != 1:
            raise ValueError(
                'previous_statistics proto contains multiple datasets. '
                'Only one dataset is currently supported for validation.')

        _check_for_unsupported_stats_fields(previous_statistics.datasets[0],
                                            'previous_statistics')

    if serving_statistics is not None:
        if not isinstance(serving_statistics,
                          statistics_pb2.DatasetFeatureStatisticsList):
            raise TypeError('serving_statistics is of type %s, should be '
                            'a DatasetFeatureStatisticsList proto.' %
                            type(serving_statistics).__name__)

        if len(serving_statistics.datasets) != 1:
            raise ValueError(
                'serving_statistics proto contains multiple datasets. '
                'Only one dataset is currently supported for validation.')

        _check_for_unsupported_stats_fields(serving_statistics.datasets[0],
                                            'serving_statistics')

    # Serialize the input protos.
    serialized_schema = schema.SerializeToString()
    serialized_stats = statistics.datasets[0].SerializeToString()
    serialized_previous_stats = (
        previous_statistics.datasets[0].SerializeToString()
        if previous_statistics is not None else '')
    serialized_serving_stats = (
        serving_statistics.datasets[0].SerializeToString()
        if serving_statistics is not None else '')

    anomalies_proto_string = (
        pywrap_tensorflow_data_validation.ValidateFeatureStatistics(
            tf.compat.as_bytes(serialized_stats),
            tf.compat.as_bytes(serialized_schema),
            tf.compat.as_bytes(environment),
            tf.compat.as_bytes(serialized_previous_stats),
            tf.compat.as_bytes(serialized_serving_stats)))

    # Parse the serialized Anomalies proto.
    result = anomalies_pb2.Anomalies()
    result.ParseFromString(anomalies_proto_string)
    return result
def validate_statistics_internal(
    statistics: statistics_pb2.DatasetFeatureStatisticsList,
    schema: schema_pb2.Schema,
    environment: Optional[Text] = None,
    previous_span_statistics: Optional[
        statistics_pb2.DatasetFeatureStatisticsList] = None,
    serving_statistics: Optional[
        statistics_pb2.DatasetFeatureStatisticsList] = None,
    previous_version_statistics: Optional[
        statistics_pb2.DatasetFeatureStatisticsList] = None,
    validation_options: Optional[vo.ValidationOptions] = None,
    enable_diff_regions: bool = False
) -> anomalies_pb2.Anomalies:
  """Validates the input statistics against the provided input schema.

  This method validates the `statistics` against the `schema`. If an optional
  `environment` is specified, the `schema` is filtered using the
  `environment` and the `statistics` is validated against the filtered schema.
  The optional `previous_span_statistics`, `serving_statistics`, and
  `previous_version_statistics` are the statistics computed over the control
  data for drift detection, skew detection, and dataset-level anomaly detection
  across versions, respectively.

  Args:
    statistics: A DatasetFeatureStatisticsList protocol buffer denoting the
       statistics computed over the current data. Validation is currently
       supported only for lists with a single DatasetFeatureStatistics proto or
       lists with multiple DatasetFeatureStatistics protos corresponding to data
       slices that include the default slice (i.e., the slice with all
       examples). If a list with multiple DatasetFeatureStatistics protos is
       used, this function will validate the statistics corresponding to the
       default slice.
    schema: A Schema protocol buffer.
    environment: An optional string denoting the validation environment.
        Must be one of the default environments specified in the schema.
        By default, validation assumes that all Examples in a pipeline adhere
        to a single schema. In some cases introducing slight schema variations
        is necessary, for instance features used as labels are required during
        training (and should be validated), but are missing during serving.
        Environments can be used to express such requirements. For example,
        assume a feature named 'LABEL' is required for training, but is expected
        to be missing from serving. This can be expressed by defining two
        distinct environments in schema: ["SERVING", "TRAINING"] and
        associating 'LABEL' only with environment "TRAINING".
    previous_span_statistics: An optional DatasetFeatureStatisticsList protocol
        buffer denoting the statistics computed over an earlier data (for
        example, previous day's data). If provided, the
        `validate_statistics_internal` method will detect if there exists drift
        between current data and previous data. Configuration for drift
        detection can be done by specifying a `drift_comparator` in the schema.
        For now drift detection is only supported for categorical features.
    serving_statistics: An optional DatasetFeatureStatisticsList protocol
        buffer denoting the statistics computed over the serving data. If
        provided, the `validate_statistics_internal` method will identify if
        there exists distribution skew between current data and serving data.
        Configuration for skew detection can be done by specifying a
        `skew_comparator` in the schema. For now skew detection is only
        supported for categorical features.
    previous_version_statistics: An optional DatasetFeatureStatisticsList
        protocol buffer denoting the statistics computed over an earlier data
        (typically, previous run's data within the same day). If provided,
        the `validate_statistics_internal` method will detect if there exists a
        change in the number of examples between current data and previous
        version data. Configuration for such dataset-wide anomaly detection can
        be done by specifying a `num_examples_version_comparator` in the schema.
    validation_options: Optional input used to specify the options of this
        validation.
    enable_diff_regions: Specifies whether to include a comparison between the
        existing schema and the fixed schema in the Anomalies protocol buffer
        output.

  Returns:
    An Anomalies protocol buffer.

  Raises:
    TypeError: If any of the input arguments is not of the expected type.
    ValueError: If the input statistics proto contains multiple datasets, none
        of which corresponds to the default slice.
  """
  if not isinstance(statistics, statistics_pb2.DatasetFeatureStatisticsList):
    raise TypeError(
        'statistics is of type %s, should be '
        'a DatasetFeatureStatisticsList proto.' % type(statistics).__name__)

  # This will raise an exception if there are multiple datasets, none of which
  # corresponds to the default slice.
  dataset_statistics = _get_default_dataset_statistics(statistics)

  if not isinstance(schema, schema_pb2.Schema):
    raise TypeError('schema is of type %s, should be a Schema proto.' %
                    type(schema).__name__)

  if environment is not None:
    if environment not in schema.default_environment:
      raise ValueError('Environment %s not found in the schema.' % environment)
  else:
    environment = ''

  _check_for_unsupported_stats_fields(dataset_statistics, 'statistics')

  if previous_span_statistics is not None:
    if not isinstance(
        previous_span_statistics, statistics_pb2.DatasetFeatureStatisticsList):
      raise TypeError(
          'previous_span_statistics is of type %s, should be '
          'a DatasetFeatureStatisticsList proto.'
          % type(previous_span_statistics).__name__)

    previous_dataset_statistics = _get_default_dataset_statistics(
        previous_span_statistics)
    _check_for_unsupported_stats_fields(previous_dataset_statistics,
                                        'previous_statistics')

  if serving_statistics is not None:
    if not isinstance(
        serving_statistics, statistics_pb2.DatasetFeatureStatisticsList):
      raise TypeError(
          'serving_statistics is of type %s, should be '
          'a DatasetFeatureStatisticsList proto.'
          % type(serving_statistics).__name__)

    serving_dataset_statistics = _get_default_dataset_statistics(
        serving_statistics)
    _check_for_unsupported_stats_fields(serving_dataset_statistics,
                                        'serving_statistics')

  if previous_version_statistics is not None:
    if not isinstance(previous_version_statistics,
                      statistics_pb2.DatasetFeatureStatisticsList):
      raise TypeError('previous_version_statistics is of type %s, should be '
                      'a DatasetFeatureStatisticsList proto.' %
                      type(previous_version_statistics).__name__)

    previous_version_dataset_statistics = _get_default_dataset_statistics(
        previous_version_statistics)
    _check_for_unsupported_stats_fields(previous_version_dataset_statistics,
                                        'previous_version_statistics')

  # Serialize the input protos.
  serialized_schema = schema.SerializeToString()
  serialized_stats = dataset_statistics.SerializeToString()
  serialized_previous_span_stats = (
      previous_dataset_statistics.SerializeToString()
      if previous_span_statistics is not None else '')
  serialized_serving_stats = (
      serving_dataset_statistics.SerializeToString()
      if serving_statistics is not None else '')
  serialized_previous_version_stats = (
      previous_version_dataset_statistics.SerializeToString()
      if previous_version_statistics is not None else '')

  features_needed_pb = validation_metadata_pb2.FeaturesNeededProto()
  if validation_options is not None and validation_options.features_needed:
    for path, reason_list in validation_options.features_needed.items():
      path_and_reason_feature_need = (
          features_needed_pb.path_and_reason_feature_need.add())
      path_and_reason_feature_need.path.CopyFrom(path.to_proto())
      for reason in reason_list:
        r = path_and_reason_feature_need.reason_feature_needed.add()
        r.comment = reason.comment

  serialized_features_needed = features_needed_pb.SerializeToString()

  validation_config = validation_config_pb2.ValidationConfig()
  if validation_options is not None:
    validation_config.new_features_are_warnings = (
        validation_options.new_features_are_warnings)
    for override in validation_options.severity_overrides:
      validation_config.severity_overrides.append(override)
  serialized_validation_config = validation_config.SerializeToString()

  anomalies_proto_string = (
      pywrap_tensorflow_data_validation.ValidateFeatureStatistics(
          tf.compat.as_bytes(serialized_stats),
          tf.compat.as_bytes(serialized_schema),
          tf.compat.as_bytes(environment),
          tf.compat.as_bytes(serialized_previous_span_stats),
          tf.compat.as_bytes(serialized_serving_stats),
          tf.compat.as_bytes(serialized_previous_version_stats),
          tf.compat.as_bytes(serialized_features_needed),
          tf.compat.as_bytes(serialized_validation_config),
          enable_diff_regions))

  # Parse the serialized Anomalies proto.
  result = anomalies_pb2.Anomalies()
  result.ParseFromString(anomalies_proto_string)
  return result
 def test_get_anomalies_dataframe_no_anomalies(self):
     anomalies = anomalies_pb2.Anomalies()
     actual_output = display_util.get_anomalies_dataframe(anomalies)
     self.assertEqual(actual_output.shape, (0, 2))