def testDo(self): source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') eval_stats_artifact = standard_artifacts.ExampleStatistics() eval_stats_artifact.uri = os.path.join(source_data_dir, 'statistics_gen') eval_stats_artifact.split_names = artifact_utils.encode_split_names( ['train', 'eval', 'test']) schema_artifact = standard_artifacts.Schema() schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) validation_output = standard_artifacts.ExampleAnomalies() validation_output.uri = os.path.join(output_data_dir, 'output') input_dict = { STATISTICS_KEY: [eval_stats_artifact], SCHEMA_KEY: [schema_artifact], } exec_properties = { # List needs to be serialized before being passed into Do function. EXCLUDE_SPLITS_KEY: json_utils.dumps(['test']) } output_dict = { ANOMALIES_KEY: [validation_output], } example_validator_executor = executor.Executor() example_validator_executor.Do(input_dict, output_dict, exec_properties) self.assertEqual(artifact_utils.encode_split_names(['train', 'eval']), validation_output.split_names) # Check example_validator outputs. train_anomalies_path = os.path.join(validation_output.uri, 'Split-train', 'SchemaDiff.pb') eval_anomalies_path = os.path.join(validation_output.uri, 'Split-eval', 'SchemaDiff.pb') self.assertTrue(fileio.exists(train_anomalies_path)) self.assertTrue(fileio.exists(eval_anomalies_path)) train_anomalies_bytes = io_utils.read_bytes_file(train_anomalies_path) train_anomalies = anomalies_pb2.Anomalies() train_anomalies.ParseFromString(train_anomalies_bytes) eval_anomalies_bytes = io_utils.read_bytes_file(eval_anomalies_path) eval_anomalies = anomalies_pb2.Anomalies() eval_anomalies.ParseFromString(eval_anomalies_bytes) self.assertEqual(0, len(train_anomalies.anomaly_info)) self.assertEqual(0, len(eval_anomalies.anomaly_info)) # Assert 'test' split is excluded. train_file_path = os.path.join(validation_output.uri, 'Split-test', 'SchemaDiff.pb') self.assertFalse(fileio.exists(train_file_path))
def compare_anomalies(output_uri: Text, expected_uri: Text) -> bool: """Compares anomalies files in output uri and recorded uri. Args: output_uri: pipeline output artifact uri. expected_uri: recorded pipeline output artifact uri. Returns: boolean whether anomalies are same. """ for dir_name, _, leaf_files in fileio.walk(expected_uri): for leaf_file in leaf_files: expected_file_name = os.path.join(dir_name, leaf_file) file_name = os.path.join( dir_name.replace(expected_uri, output_uri, 1), leaf_file) anomalies = anomalies_pb2.Anomalies() anomalies.ParseFromString( io_utils.read_bytes_file(os.path.join(output_uri, file_name))) expected_anomalies = anomalies_pb2.Anomalies() expected_anomalies.ParseFromString( io_utils.read_bytes_file( os.path.join(expected_uri, expected_file_name))) if expected_anomalies.anomaly_info != anomalies.anomaly_info: return False return True
def test_remove_anomaly_types_removes_diff_regions(self): anomaly_types_to_remove = set([ anomalies_pb2.AnomalyInfo.ENUM_TYPE_BYTES_NOT_STRING, ]) # The anomaly_info has multiple diff regions. anomalies = text_format.Parse( """ anomaly_info { key: "feature_1" value { description: "Expected bytes but got string. Examples contain " "values missing from the schema." severity: ERROR short_description: "Multiple errors" diff_regions { removed { start: 1 contents: "Test contents" } } diff_regions { added { start: 1 contents: "Test contents" } } reason { type: ENUM_TYPE_BYTES_NOT_STRING short_description: "Bytes not string" description: "Expected bytes but got string." } reason { type: ENUM_TYPE_UNEXPECTED_STRING_VALUES short_description: "Unexpected string values" description: "Examples contain values missing from the schema." } } }""", anomalies_pb2.Anomalies()) expected_result = text_format.Parse( """ anomaly_info { key: "feature_1" value { description: "Examples contain values missing from the schema." severity: ERROR short_description: "Unexpected string values" reason { type: ENUM_TYPE_UNEXPECTED_STRING_VALUES short_description: "Unexpected string values" description: "Examples contain values missing from the schema." } } }""", anomalies_pb2.Anomalies()) anomalies_util.remove_anomaly_types(anomalies, anomaly_types_to_remove) compare.assertProtoEqual(self, anomalies, expected_result)
def test_remove_anomaly_types_does_not_change_proto( self, anomaly_types_to_remove, input_anomalies_proto_text): """Tests where remove_anomaly_types does not modify the Anomalies proto.""" input_anomalies_proto = text_format.Parse(input_anomalies_proto_text, anomalies_pb2.Anomalies()) expected_anomalies_proto = anomalies_pb2.Anomalies() expected_anomalies_proto.CopyFrom(input_anomalies_proto) anomalies_util.remove_anomaly_types(input_anomalies_proto, anomaly_types_to_remove) compare.assertProtoEqual(self, input_anomalies_proto, expected_anomalies_proto)
def test_remove_anomaly_types_changes_proto(self, anomaly_types_to_remove, input_anomalies_proto_text, expected_anomalies_proto_text): """Tests where remove_anomaly_types modifies the Anomalies proto.""" input_anomalies_proto = text_format.Parse(input_anomalies_proto_text, anomalies_pb2.Anomalies()) expected_anomalies_proto = text_format.Parse( expected_anomalies_proto_text, anomalies_pb2.Anomalies()) anomalies_util.remove_anomaly_types(input_anomalies_proto, anomaly_types_to_remove) compare.assertProtoEqual(self, input_anomalies_proto, expected_anomalies_proto)
def test_get_anomalies_dataframe(self): anomalies = text_format.Parse( """ anomaly_info { key: "feature_1" value { description: "Expected bytes but got string." severity: ERROR short_description: "Bytes not string" reason { type: ENUM_TYPE_BYTES_NOT_STRING short_description: "Bytes not string" description: "Expected bytes but got string." } } } anomaly_info { key: "feature_2" value { description: "Examples contain values missing from the schema." severity: ERROR short_description: "Unexpected string values" reason { type: ENUM_TYPE_UNEXPECTED_STRING_VALUES short_description: "Unexpected string values" description: "Examples contain values missing from the " "schema." } } } """, anomalies_pb2.Anomalies()) actual_output = display_util.get_anomalies_dataframe(anomalies) # The resulting DataFrame has a row for each feature and a column for each # of the short description and long description. self.assertEqual(actual_output.shape, (2, 2))
def test_anomalies_slicer(self, input_anomalies_proto_text, expected_slice_keys): example = pa.Table.from_arrays([]) anomalies = text_format.Parse(input_anomalies_proto_text, anomalies_pb2.Anomalies()) slice_keys = anomalies_util.anomalies_slicer(example, anomalies) self.assertCountEqual(slice_keys, expected_slice_keys)
def test_anomalies_slicer(self, input_anomalies_proto_text, expected_slice_keys): example = pa.RecordBatch.from_arrays([]) anomalies = text_format.Parse(input_anomalies_proto_text, anomalies_pb2.Anomalies()) slicer = anomalies_util.get_anomalies_slicer(anomalies) actual_slice_keys = [] for slice_key, actual_example in slicer(example): self.assertEqual(actual_example, example) actual_slice_keys.append(slice_key) self.assertCountEqual(actual_slice_keys, expected_slice_keys)
def load_anomalies_text(input_path: Text) -> anomalies_pb2.Anomalies: """Loads the Anomalies proto stored in text format in the input path. Args: input_path: File path from which to load the Anomalies proto. Returns: An Anomalies protocol buffer. """ anomalies = anomalies_pb2.Anomalies() anomalies_text = io_util.read_file_to_string(input_path) text_format.Parse(anomalies_text, anomalies) return anomalies
def display(self, artifact: types.Artifact): from IPython.core.display import display # pylint: disable=g-import-not-at-top from IPython.core.display import HTML # pylint: disable=g-import-not-at-top for split in artifact_utils.decode_split_names(artifact.split_names): display(HTML('<div><b>%r split:</b></div><br/>' % split)) anomalies_path = io_utils.get_only_uri_in_dir( artifact_utils.get_split_uri([artifact], split)) if artifact_utils.is_artifact_version_older_than( artifact, artifact_utils._ARTIFACT_VERSION_FOR_ANOMALIES_UPDATE): # pylint: disable=protected-access anomalies = tfdv.load_anomalies_text(anomalies_path) else: anomalies = anomalies_pb2.Anomalies() anomalies_bytes = io_utils.read_bytes_file(anomalies_path) anomalies.ParseFromString(anomalies_bytes) tfdv.display_anomalies(anomalies)
def load_anomalies_binary(input_path: Text) -> anomalies_pb2.Anomalies: """Loads the Anomalies proto stored in binary format in the input path. Args: input_path: File path from which to load the Anomalies proto. Returns: An Anomalies protocol buffer. """ anomalies_proto = anomalies_pb2.Anomalies() anomalies_proto.ParseFromString( io_util.read_file_to_string(input_path, binary_mode=True)) return anomalies_proto
def test_e2e(self, stats_options, expected_stats_pbtxt, expected_inferred_schema_pbtxt, schema_for_validation_pbtxt, expected_anomalies_pbtxt, expected_updated_schema_pbtxt): tfxio = tf_sequence_example_record.TFSequenceExampleRecord( self._input_file, ['tfdv', 'test']) stats_file = os.path.join(self._output_dir, 'stats') with beam.Pipeline() as p: _ = (p | 'TFXIORead' >> tfxio.BeamSource() | 'GenerateStats' >> tfdv.GenerateStatistics(stats_options) | 'WriteStats' >> tfdv.WriteStatisticsToTFRecord(stats_file)) actual_stats = tfdv.load_statistics(stats_file) test_util.make_dataset_feature_stats_list_proto_equal_fn( self, text_format.Parse( expected_stats_pbtxt, statistics_pb2.DatasetFeatureStatisticsList()))([actual_stats]) actual_inferred_schema = tfdv.infer_schema(actual_stats, infer_feature_shape=True) if hasattr(actual_inferred_schema, 'generate_legacy_feature_spec'): actual_inferred_schema.ClearField('generate_legacy_feature_spec') self._assert_schema_equal( actual_inferred_schema, text_format.Parse(expected_inferred_schema_pbtxt, schema_pb2.Schema())) schema_for_validation = text_format.Parse(schema_for_validation_pbtxt, schema_pb2.Schema()) actual_anomalies = tfdv.validate_statistics(actual_stats, schema_for_validation) actual_anomalies.ClearField('baseline') self.assertEqual( actual_anomalies, text_format.Parse(expected_anomalies_pbtxt, anomalies_pb2.Anomalies())) actual_updated_schema = tfdv.update_schema(schema_for_validation, actual_stats, infer_feature_shape=False) self._assert_schema_equal( actual_updated_schema, text_format.Parse(expected_updated_schema_pbtxt, schema_pb2.Schema()))
def testDo(self): source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') eval_stats_artifact = standard_artifacts.ExampleStatistics() eval_stats_artifact.uri = os.path.join(source_data_dir, 'statistics_gen') eval_stats_artifact.split_names = artifact_utils.encode_split_names( ['eval']) schema_artifact = standard_artifacts.Schema() schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) validation_output = standard_artifacts.ExampleAnomalies() validation_output.uri = os.path.join(output_data_dir, 'output') input_dict = { executor.STATISTICS_KEY: [eval_stats_artifact], executor.SCHEMA_KEY: [schema_artifact], } output_dict = { executor.ANOMALIES_KEY: [validation_output], } exec_properties = {} example_validator_executor = executor.Executor() example_validator_executor.Do(input_dict, output_dict, exec_properties) self.assertEqual(['anomalies.pbtxt'], tf.io.gfile.listdir(validation_output.uri)) anomalies = io_utils.parse_pbtxt_file( os.path.join(validation_output.uri, 'anomalies.pbtxt'), anomalies_pb2.Anomalies()) self.assertNotEqual(0, len(anomalies.anomaly_info))
def test_load_anomalies_binary(self): anomalies = text_format.Parse( """ anomaly_info { key: "feature_1" value { description: "Examples contain values missing from the " "schema." severity: ERROR short_description: "Unexpected string values" reason { type: ENUM_TYPE_UNEXPECTED_STRING_VALUES short_description: "Unexpected string values" description: "Examples contain values missing from the " "schema." } } }""", anomalies_pb2.Anomalies()) anomalies_path = os.path.join(FLAGS.test_tmpdir, 'anomalies.binpb') with open(anomalies_path, 'w+b') as file: file.write(anomalies.SerializeToString()) self.assertEqual( anomalies, anomalies_util.load_anomalies_binary(input_path=anomalies_path))
def test_do(self): source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') eval_stats_artifact = types.Artifact('ExampleStatsPath', split='eval') eval_stats_artifact.uri = os.path.join(source_data_dir, 'statistics_gen/eval/') schema_artifact = standard_artifacts.Schema() schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen/') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) validation_output = standard_artifacts.ExampleValidationResult() validation_output.uri = os.path.join(output_data_dir, 'output') input_dict = { 'stats': [eval_stats_artifact], 'schema': [schema_artifact], } output_dict = { 'output': [validation_output], } exec_properties = {} example_validator_executor = executor.Executor() example_validator_executor.Do(input_dict, output_dict, exec_properties) self.assertEqual(['anomalies.pbtxt'], tf.gfile.ListDirectory(validation_output.uri)) anomalies = io_utils.parse_pbtxt_file( os.path.join(validation_output.uri, 'anomalies.pbtxt'), anomalies_pb2.Anomalies()) self.assertNotEqual(0, len(anomalies.anomaly_info))
def test_write_load_anomalies_text(self): anomalies = text_format.Parse( """ anomaly_info { key: "feature_1" value { description: "Examples contain values missing from the " "schema." severity: ERROR short_description: "Unexpected string values" reason { type: ENUM_TYPE_UNEXPECTED_STRING_VALUES short_description: "Unexpected string values" description: "Examples contain values missing from the " "schema." } } }""", anomalies_pb2.Anomalies()) anomalies_path = os.path.join(FLAGS.test_tmpdir, 'anomalies.pbtxt') anomalies_util.write_anomalies_text(anomalies=anomalies, output_path=anomalies_path) loaded_anomalies = anomalies_util.load_anomalies_text( input_path=anomalies_path) self.assertEqual(anomalies, loaded_anomalies)
def validate_statistics( statistics, schema, environment=None, previous_statistics=None, serving_statistics=None, ): """Validates the input statistics against the provided input schema. This method validates the `statistics` against the `schema`. If an optional `environment` is specified, the `schema` is filtered using the `environment` and the `statistics` is validated against the filtered schema. The optional `previous_statistics` and `serving_statistics` are the statistics computed over the treatment data for drift- and skew-detection, respectively. Args: statistics: A DatasetFeatureStatisticsList protocol buffer denoting the statistics computed over the current data. Validation is currently only supported for lists with a single DatasetFeatureStatistics proto. schema: A Schema protocol buffer. environment: An optional string denoting the validation environment. Must be one of the default environments specified in the schema. By default, validation assumes that all Examples in a pipeline adhere to a single schema. In some cases introducing slight schema variations is necessary, for instance features used as labels are required during training (and should be validated), but are missing during serving. Environments can be used to express such requirements. For example, assume a feature named 'LABEL' is required for training, but is expected to be missing from serving. This can be expressed by defining two distinct environments in schema: ["SERVING", "TRAINING"] and associating 'LABEL' only with environment "TRAINING". previous_statistics: An optional DatasetFeatureStatisticsList protocol buffer denoting the statistics computed over an earlier data (for example, previous day's data). If provided, the `validate_statistics` method will detect if there exists drift between current data and previous data. Configuration for drift detection can be done by specifying a `drift_comparator` in the schema. For now drift detection is only supported for categorical features. serving_statistics: An optional DatasetFeatureStatisticsList protocol buffer denoting the statistics computed over the serving data. If provided, the `validate_statistics` method will identify if there exists distribution skew between current data and serving data. Configuration for skew detection can be done by specifying a `skew_comparator` in the schema. For now skew detection is only supported for categorical features. Returns: An Anomalies protocol buffer. Raises: TypeError: If any of the input arguments is not of the expected type. ValueError: If the input statistics proto does not have only one dataset. """ if not isinstance(statistics, statistics_pb2.DatasetFeatureStatisticsList): raise TypeError('statistics is of type %s, should be ' 'a DatasetFeatureStatisticsList proto.' % type(statistics).__name__) if len(statistics.datasets) != 1: raise ValueError('statistics proto contains multiple datasets. Only ' 'one dataset is currently supported for validation.') if not isinstance(schema, schema_pb2.Schema): raise TypeError('schema is of type %s, should be a Schema proto.' % type(schema).__name__) if environment is not None: if environment not in schema.default_environment: raise ValueError('Environment %s not found in the schema.' % environment) else: environment = '' _check_for_unsupported_stats_fields(statistics.datasets[0], 'statistics') _check_for_unsupported_schema_fields(schema) if previous_statistics is not None: if not isinstance(previous_statistics, statistics_pb2.DatasetFeatureStatisticsList): raise TypeError('previous_statistics is of type %s, should be ' 'a DatasetFeatureStatisticsList proto.' % type(previous_statistics).__name__) if len(previous_statistics.datasets) != 1: raise ValueError( 'previous_statistics proto contains multiple datasets. ' 'Only one dataset is currently supported for validation.') _check_for_unsupported_stats_fields(previous_statistics.datasets[0], 'previous_statistics') if serving_statistics is not None: if not isinstance(serving_statistics, statistics_pb2.DatasetFeatureStatisticsList): raise TypeError('serving_statistics is of type %s, should be ' 'a DatasetFeatureStatisticsList proto.' % type(serving_statistics).__name__) if len(serving_statistics.datasets) != 1: raise ValueError( 'serving_statistics proto contains multiple datasets. ' 'Only one dataset is currently supported for validation.') _check_for_unsupported_stats_fields(serving_statistics.datasets[0], 'serving_statistics') # Serialize the input protos. serialized_schema = schema.SerializeToString() serialized_stats = statistics.datasets[0].SerializeToString() serialized_previous_stats = ( previous_statistics.datasets[0].SerializeToString() if previous_statistics is not None else '') serialized_serving_stats = ( serving_statistics.datasets[0].SerializeToString() if serving_statistics is not None else '') anomalies_proto_string = ( pywrap_tensorflow_data_validation.ValidateFeatureStatistics( tf.compat.as_bytes(serialized_stats), tf.compat.as_bytes(serialized_schema), tf.compat.as_bytes(environment), tf.compat.as_bytes(serialized_previous_stats), tf.compat.as_bytes(serialized_serving_stats))) # Parse the serialized Anomalies proto. result = anomalies_pb2.Anomalies() result.ParseFromString(anomalies_proto_string) return result
def validate_statistics_internal( statistics: statistics_pb2.DatasetFeatureStatisticsList, schema: schema_pb2.Schema, environment: Optional[Text] = None, previous_span_statistics: Optional[ statistics_pb2.DatasetFeatureStatisticsList] = None, serving_statistics: Optional[ statistics_pb2.DatasetFeatureStatisticsList] = None, previous_version_statistics: Optional[ statistics_pb2.DatasetFeatureStatisticsList] = None, validation_options: Optional[vo.ValidationOptions] = None, enable_diff_regions: bool = False ) -> anomalies_pb2.Anomalies: """Validates the input statistics against the provided input schema. This method validates the `statistics` against the `schema`. If an optional `environment` is specified, the `schema` is filtered using the `environment` and the `statistics` is validated against the filtered schema. The optional `previous_span_statistics`, `serving_statistics`, and `previous_version_statistics` are the statistics computed over the control data for drift detection, skew detection, and dataset-level anomaly detection across versions, respectively. Args: statistics: A DatasetFeatureStatisticsList protocol buffer denoting the statistics computed over the current data. Validation is currently supported only for lists with a single DatasetFeatureStatistics proto or lists with multiple DatasetFeatureStatistics protos corresponding to data slices that include the default slice (i.e., the slice with all examples). If a list with multiple DatasetFeatureStatistics protos is used, this function will validate the statistics corresponding to the default slice. schema: A Schema protocol buffer. environment: An optional string denoting the validation environment. Must be one of the default environments specified in the schema. By default, validation assumes that all Examples in a pipeline adhere to a single schema. In some cases introducing slight schema variations is necessary, for instance features used as labels are required during training (and should be validated), but are missing during serving. Environments can be used to express such requirements. For example, assume a feature named 'LABEL' is required for training, but is expected to be missing from serving. This can be expressed by defining two distinct environments in schema: ["SERVING", "TRAINING"] and associating 'LABEL' only with environment "TRAINING". previous_span_statistics: An optional DatasetFeatureStatisticsList protocol buffer denoting the statistics computed over an earlier data (for example, previous day's data). If provided, the `validate_statistics_internal` method will detect if there exists drift between current data and previous data. Configuration for drift detection can be done by specifying a `drift_comparator` in the schema. For now drift detection is only supported for categorical features. serving_statistics: An optional DatasetFeatureStatisticsList protocol buffer denoting the statistics computed over the serving data. If provided, the `validate_statistics_internal` method will identify if there exists distribution skew between current data and serving data. Configuration for skew detection can be done by specifying a `skew_comparator` in the schema. For now skew detection is only supported for categorical features. previous_version_statistics: An optional DatasetFeatureStatisticsList protocol buffer denoting the statistics computed over an earlier data (typically, previous run's data within the same day). If provided, the `validate_statistics_internal` method will detect if there exists a change in the number of examples between current data and previous version data. Configuration for such dataset-wide anomaly detection can be done by specifying a `num_examples_version_comparator` in the schema. validation_options: Optional input used to specify the options of this validation. enable_diff_regions: Specifies whether to include a comparison between the existing schema and the fixed schema in the Anomalies protocol buffer output. Returns: An Anomalies protocol buffer. Raises: TypeError: If any of the input arguments is not of the expected type. ValueError: If the input statistics proto contains multiple datasets, none of which corresponds to the default slice. """ if not isinstance(statistics, statistics_pb2.DatasetFeatureStatisticsList): raise TypeError( 'statistics is of type %s, should be ' 'a DatasetFeatureStatisticsList proto.' % type(statistics).__name__) # This will raise an exception if there are multiple datasets, none of which # corresponds to the default slice. dataset_statistics = _get_default_dataset_statistics(statistics) if not isinstance(schema, schema_pb2.Schema): raise TypeError('schema is of type %s, should be a Schema proto.' % type(schema).__name__) if environment is not None: if environment not in schema.default_environment: raise ValueError('Environment %s not found in the schema.' % environment) else: environment = '' _check_for_unsupported_stats_fields(dataset_statistics, 'statistics') if previous_span_statistics is not None: if not isinstance( previous_span_statistics, statistics_pb2.DatasetFeatureStatisticsList): raise TypeError( 'previous_span_statistics is of type %s, should be ' 'a DatasetFeatureStatisticsList proto.' % type(previous_span_statistics).__name__) previous_dataset_statistics = _get_default_dataset_statistics( previous_span_statistics) _check_for_unsupported_stats_fields(previous_dataset_statistics, 'previous_statistics') if serving_statistics is not None: if not isinstance( serving_statistics, statistics_pb2.DatasetFeatureStatisticsList): raise TypeError( 'serving_statistics is of type %s, should be ' 'a DatasetFeatureStatisticsList proto.' % type(serving_statistics).__name__) serving_dataset_statistics = _get_default_dataset_statistics( serving_statistics) _check_for_unsupported_stats_fields(serving_dataset_statistics, 'serving_statistics') if previous_version_statistics is not None: if not isinstance(previous_version_statistics, statistics_pb2.DatasetFeatureStatisticsList): raise TypeError('previous_version_statistics is of type %s, should be ' 'a DatasetFeatureStatisticsList proto.' % type(previous_version_statistics).__name__) previous_version_dataset_statistics = _get_default_dataset_statistics( previous_version_statistics) _check_for_unsupported_stats_fields(previous_version_dataset_statistics, 'previous_version_statistics') # Serialize the input protos. serialized_schema = schema.SerializeToString() serialized_stats = dataset_statistics.SerializeToString() serialized_previous_span_stats = ( previous_dataset_statistics.SerializeToString() if previous_span_statistics is not None else '') serialized_serving_stats = ( serving_dataset_statistics.SerializeToString() if serving_statistics is not None else '') serialized_previous_version_stats = ( previous_version_dataset_statistics.SerializeToString() if previous_version_statistics is not None else '') features_needed_pb = validation_metadata_pb2.FeaturesNeededProto() if validation_options is not None and validation_options.features_needed: for path, reason_list in validation_options.features_needed.items(): path_and_reason_feature_need = ( features_needed_pb.path_and_reason_feature_need.add()) path_and_reason_feature_need.path.CopyFrom(path.to_proto()) for reason in reason_list: r = path_and_reason_feature_need.reason_feature_needed.add() r.comment = reason.comment serialized_features_needed = features_needed_pb.SerializeToString() validation_config = validation_config_pb2.ValidationConfig() if validation_options is not None: validation_config.new_features_are_warnings = ( validation_options.new_features_are_warnings) for override in validation_options.severity_overrides: validation_config.severity_overrides.append(override) serialized_validation_config = validation_config.SerializeToString() anomalies_proto_string = ( pywrap_tensorflow_data_validation.ValidateFeatureStatistics( tf.compat.as_bytes(serialized_stats), tf.compat.as_bytes(serialized_schema), tf.compat.as_bytes(environment), tf.compat.as_bytes(serialized_previous_span_stats), tf.compat.as_bytes(serialized_serving_stats), tf.compat.as_bytes(serialized_previous_version_stats), tf.compat.as_bytes(serialized_features_needed), tf.compat.as_bytes(serialized_validation_config), enable_diff_regions)) # Parse the serialized Anomalies proto. result = anomalies_pb2.Anomalies() result.ParseFromString(anomalies_proto_string) return result
def test_get_anomalies_dataframe_no_anomalies(self): anomalies = anomalies_pb2.Anomalies() actual_output = display_util.get_anomalies_dataframe(anomalies) self.assertEqual(actual_output.shape, (0, 2))