def parse_schema_file(schema_path): # type: (str) -> Schema """ Read a schema file and return the proto object. """ assert file_io.file_exists(schema_path), "File not found: {}".format(schema_path) schema = Schema() with file_io.FileIO(schema_path, "rb") as f: schema.ParseFromString(f.read()) return schema
def _replicate_schema_for_sliced_validation( schema: schema_pb2.Schema, slice_names: Set[str]) -> schema_pb2.Schema: """Replicates features in a schema with prepended slice names.""" if schema.HasField('dataset_constraints') is not None: logging.error('DatasetConstraints will not be validated per-slice.') result = schema_pb2.Schema() result.string_domain.extend(schema.string_domain) result.float_domain.extend(schema.float_domain) result.int_domain.extend(schema.int_domain) for slice_name in slice_names: for feature in schema.feature: new_feature = result.feature.add() new_feature.CopyFrom(feature) new_feature.name = _prepend_slice_name(slice_name, feature.name) for sparse_feature in schema.sparse_feature: new_sparse_feature = result.sparse_feature.add() new_sparse_feature.CopyFrom(sparse_feature) new_sparse_feature.name = _prepend_slice_name( slice_name, sparse_feature.name) for weighted_feature in schema.weighted_feature: new_weighted_feature = result.weighted_feature.add() new_weighted_feature.CopyFrom(weighted_feature) new_weighted_feature.name = _prepend_slice_name( slice_name, weighted_feature.name) return result
def update_schema(schema: schema_pb2.Schema, statistics: statistics_pb2.DatasetFeatureStatisticsList, infer_feature_shape: Optional[bool] = True, max_string_domain_size: Optional[int] = 100 ) -> schema_pb2.Schema: """Updates input schema to conform to the input statistics. Args: schema: A Schema protocol buffer. statistics: A DatasetFeatureStatisticsList protocol buffer. Schema inference is currently supported only for lists with a single DatasetFeatureStatistics proto or lists with multiple DatasetFeatureStatistics protos corresponding to data slices that include the default slice (i.e., the slice with all examples). If a list with multiple DatasetFeatureStatistics protos is used, this function will update the schema to conform to the statistics corresponding to the default slice. infer_feature_shape: A boolean to indicate if shape of the features need to be inferred from the statistics. max_string_domain_size: Maximum size of the domain of a string feature in order to be interpreted as a categorical feature. Returns: A Schema protocol buffer. Raises: TypeError: If the input argument is not of the expected type. ValueError: If the input statistics proto contains multiple datasets, none of which corresponds to the default slice. """ if not isinstance(schema, schema_pb2.Schema): raise TypeError('schema is of type %s, should be a Schema proto.' % type(schema).__name__) if not isinstance(statistics, statistics_pb2.DatasetFeatureStatisticsList): raise TypeError( 'statistics is of type %s, should be ' 'a DatasetFeatureStatisticsList proto.' % type(statistics).__name__) # This will raise an exception if there are multiple datasets, none of which # corresponds to the default slice. dataset_statistics = _get_default_dataset_statistics(statistics) _check_for_unsupported_stats_fields(dataset_statistics, 'statistics') schema_proto_string = pywrap_tensorflow_data_validation.UpdateSchema( tf.compat.as_bytes(schema.SerializeToString()), tf.compat.as_bytes(dataset_statistics.SerializeToString()), max_string_domain_size) # Parse the serialized Schema proto. result = schema_pb2.Schema() result.ParseFromString(schema_proto_string) # TODO(b/113605666): Push this shape inference logic into example validation # code. if infer_feature_shape: _infer_shape(result) return result
def _validate_sql(sql_query: Text, schema: schema_pb2.Schema): arrow_schema = example_coder.ExamplesToRecordBatchDecoder( schema.SerializeToString()).ArrowSchema() formatted_query = slicing_util.format_slice_sql_query(sql_query) try: sql_util.RecordBatchSQLSliceQuery(formatted_query, arrow_schema) except Exception as e: # pylint: disable=broad-except raise ValueError('One of the slice SQL query %s raised an exception: %s.' % (sql_query, repr(e)))
def parse_schema_txt_file(schema_path): # type: (str) -> Schema """ Parse a tf.metadata Schema txt file into its in-memory representation. """ assert file_io.file_exists(schema_path), "File not found: {}".format( schema_path) schema = Schema() schema_text = file_io.read_file_to_string(schema_path) google.protobuf.text_format.Parse(schema_text, schema) return schema
def apply(cls, feature_spec): # type: (Dict[str, Union[tf.FixedLenFeature, tf.VarLenFeature, tf.SparseFeature]]) -> Schema # noqa: E501 """ Main entry point. """ schema_proto = Schema() for k, v in feature_spec.items(): if isinstance(v, tf.SparseFeature): cls._add_sparse_feature_to_proto(schema_proto, k, v) else: cls._add_feature_to_proto(schema_proto, k, v) return schema_proto
def update_schema( schema: schema_pb2.Schema, statistics: statistics_pb2.DatasetFeatureStatisticsList, infer_feature_shape: Optional[bool] = True, max_string_domain_size: Optional[int] = 100) -> schema_pb2.Schema: """Updates input schema to conform to the input statistics. Args: schema: A Schema protocol buffer. statistics: A DatasetFeatureStatisticsList protocol buffer. Schema inference is currently only supported for lists with a single DatasetFeatureStatistics proto. infer_feature_shape: A boolean to indicate if shape of the features need to be inferred from the statistics. max_string_domain_size: Maximum size of the domain of a string feature in order to be interpreted as a categorical feature. Returns: A Schema protocol buffer. Raises: TypeError: If the input argument is not of the expected type. ValueError: If the input statistics proto does not have only one dataset. """ if not isinstance(schema, schema_pb2.Schema): raise TypeError('schema is of type %s, should be a Schema proto.' % type(schema).__name__) if not isinstance(statistics, statistics_pb2.DatasetFeatureStatisticsList): raise TypeError('statistics is of type %s, should be ' 'a DatasetFeatureStatisticsList proto.' % type(statistics).__name__) if len(statistics.datasets) != 1: raise ValueError('Only statistics proto with one dataset is currently ' 'supported for inferring schema.') _check_for_unsupported_stats_fields(statistics.datasets[0], 'statistics') schema_proto_string = pywrap_tensorflow_data_validation.UpdateSchema( tf.compat.as_bytes(schema.SerializeToString()), tf.compat.as_bytes(statistics.datasets[0].SerializeToString()), max_string_domain_size) # Parse the serialized Schema proto. result = schema_pb2.Schema() result.ParseFromString(schema_proto_string) # TODO(b/113605666): Push this shape inference logic into example validation # code. if infer_feature_shape: _infer_shape(result) return result
def lists_to_partitions( datasets: DatasetFeatureStatisticsList, schema: Schema, examples: types.Artifact, partitions: List[List[Text]], ) -> List[Partition]: result: List[Partition] = [] for p in partitions: name = '__'.join(p) partition = Partition( name=name, statistics=DatasetFeatureStatisticsList(datasets=[ DatasetFeatureStatistics( name=name, num_examples=min([ getattr(feature, feature.WhichOneof( 'stats')).common_stats.num_non_missing for feature in filter(lambda f: feature_name(f) in p, dataset.features) ]), weighted_num_examples=0, features=list( filter(lambda f: feature_name(f) in p, dataset.features)), ) for dataset in datasets.datasets ]), example_splits=[ ExampleSplit(split=split, uri=os.path.join(examples.uri, split)) for split in artifact_utils.decode_split_names(examples.split_names) ], schema=Schema(feature=filter(lambda f: f.name in p, schema.feature), sparse_feature=filter(lambda f: f.name in p, schema.sparse_feature), weighted_feature=filter(lambda f: f.name in p, schema.weighted_feature), string_domain=filter(lambda f: f.name in p, schema.string_domain), float_domain=filter(lambda f: f.name in p, schema.float_domain), int_domain=filter(lambda f: f.name in p, schema.int_domain), default_environment=schema.default_environment)) result.append(partition) return result
def validate_statistics_internal( statistics: statistics_pb2.DatasetFeatureStatisticsList, schema: schema_pb2.Schema, environment: Optional[Text] = None, previous_span_statistics: Optional[ statistics_pb2.DatasetFeatureStatisticsList] = None, serving_statistics: Optional[ statistics_pb2.DatasetFeatureStatisticsList] = None, previous_version_statistics: Optional[ statistics_pb2.DatasetFeatureStatisticsList] = None, validation_options: Optional[vo.ValidationOptions] = None, enable_diff_regions: bool = False ) -> anomalies_pb2.Anomalies: """Validates the input statistics against the provided input schema. This method validates the `statistics` against the `schema`. If an optional `environment` is specified, the `schema` is filtered using the `environment` and the `statistics` is validated against the filtered schema. The optional `previous_span_statistics`, `serving_statistics`, and `previous_version_statistics` are the statistics computed over the control data for drift detection, skew detection, and dataset-level anomaly detection across versions, respectively. Args: statistics: A DatasetFeatureStatisticsList protocol buffer denoting the statistics computed over the current data. Validation is currently supported only for lists with a single DatasetFeatureStatistics proto or lists with multiple DatasetFeatureStatistics protos corresponding to data slices that include the default slice (i.e., the slice with all examples). If a list with multiple DatasetFeatureStatistics protos is used, this function will validate the statistics corresponding to the default slice. schema: A Schema protocol buffer. environment: An optional string denoting the validation environment. Must be one of the default environments specified in the schema. By default, validation assumes that all Examples in a pipeline adhere to a single schema. In some cases introducing slight schema variations is necessary, for instance features used as labels are required during training (and should be validated), but are missing during serving. Environments can be used to express such requirements. For example, assume a feature named 'LABEL' is required for training, but is expected to be missing from serving. This can be expressed by defining two distinct environments in schema: ["SERVING", "TRAINING"] and associating 'LABEL' only with environment "TRAINING". previous_span_statistics: An optional DatasetFeatureStatisticsList protocol buffer denoting the statistics computed over an earlier data (for example, previous day's data). If provided, the `validate_statistics_internal` method will detect if there exists drift between current data and previous data. Configuration for drift detection can be done by specifying a `drift_comparator` in the schema. For now drift detection is only supported for categorical features. serving_statistics: An optional DatasetFeatureStatisticsList protocol buffer denoting the statistics computed over the serving data. If provided, the `validate_statistics_internal` method will identify if there exists distribution skew between current data and serving data. Configuration for skew detection can be done by specifying a `skew_comparator` in the schema. For now skew detection is only supported for categorical features. previous_version_statistics: An optional DatasetFeatureStatisticsList protocol buffer denoting the statistics computed over an earlier data (typically, previous run's data within the same day). If provided, the `validate_statistics_internal` method will detect if there exists a change in the number of examples between current data and previous version data. Configuration for such dataset-wide anomaly detection can be done by specifying a `num_examples_version_comparator` in the schema. validation_options: Optional input used to specify the options of this validation. enable_diff_regions: Specifies whether to include a comparison between the existing schema and the fixed schema in the Anomalies protocol buffer output. Returns: An Anomalies protocol buffer. Raises: TypeError: If any of the input arguments is not of the expected type. ValueError: If the input statistics proto contains multiple datasets, none of which corresponds to the default slice. """ if not isinstance(statistics, statistics_pb2.DatasetFeatureStatisticsList): raise TypeError( 'statistics is of type %s, should be ' 'a DatasetFeatureStatisticsList proto.' % type(statistics).__name__) # This will raise an exception if there are multiple datasets, none of which # corresponds to the default slice. dataset_statistics = _get_default_dataset_statistics(statistics) if not isinstance(schema, schema_pb2.Schema): raise TypeError('schema is of type %s, should be a Schema proto.' % type(schema).__name__) if environment is not None: if environment not in schema.default_environment: raise ValueError('Environment %s not found in the schema.' % environment) else: environment = '' _check_for_unsupported_stats_fields(dataset_statistics, 'statistics') if previous_span_statistics is not None: if not isinstance( previous_span_statistics, statistics_pb2.DatasetFeatureStatisticsList): raise TypeError( 'previous_span_statistics is of type %s, should be ' 'a DatasetFeatureStatisticsList proto.' % type(previous_span_statistics).__name__) previous_dataset_statistics = _get_default_dataset_statistics( previous_span_statistics) _check_for_unsupported_stats_fields(previous_dataset_statistics, 'previous_statistics') if serving_statistics is not None: if not isinstance( serving_statistics, statistics_pb2.DatasetFeatureStatisticsList): raise TypeError( 'serving_statistics is of type %s, should be ' 'a DatasetFeatureStatisticsList proto.' % type(serving_statistics).__name__) serving_dataset_statistics = _get_default_dataset_statistics( serving_statistics) _check_for_unsupported_stats_fields(serving_dataset_statistics, 'serving_statistics') if previous_version_statistics is not None: if not isinstance(previous_version_statistics, statistics_pb2.DatasetFeatureStatisticsList): raise TypeError('previous_version_statistics is of type %s, should be ' 'a DatasetFeatureStatisticsList proto.' % type(previous_version_statistics).__name__) previous_version_dataset_statistics = _get_default_dataset_statistics( previous_version_statistics) _check_for_unsupported_stats_fields(previous_version_dataset_statistics, 'previous_version_statistics') # Serialize the input protos. serialized_schema = schema.SerializeToString() serialized_stats = dataset_statistics.SerializeToString() serialized_previous_span_stats = ( previous_dataset_statistics.SerializeToString() if previous_span_statistics is not None else '') serialized_serving_stats = ( serving_dataset_statistics.SerializeToString() if serving_statistics is not None else '') serialized_previous_version_stats = ( previous_version_dataset_statistics.SerializeToString() if previous_version_statistics is not None else '') features_needed_pb = validation_metadata_pb2.FeaturesNeededProto() if validation_options is not None and validation_options.features_needed: for path, reason_list in validation_options.features_needed.items(): path_and_reason_feature_need = ( features_needed_pb.path_and_reason_feature_need.add()) path_and_reason_feature_need.path.CopyFrom(path.to_proto()) for reason in reason_list: r = path_and_reason_feature_need.reason_feature_needed.add() r.comment = reason.comment serialized_features_needed = features_needed_pb.SerializeToString() validation_config = validation_config_pb2.ValidationConfig() if validation_options is not None: validation_config.new_features_are_warnings = ( validation_options.new_features_are_warnings) for override in validation_options.severity_overrides: validation_config.severity_overrides.append(override) serialized_validation_config = validation_config.SerializeToString() anomalies_proto_string = ( pywrap_tensorflow_data_validation.ValidateFeatureStatistics( tf.compat.as_bytes(serialized_stats), tf.compat.as_bytes(serialized_schema), tf.compat.as_bytes(environment), tf.compat.as_bytes(serialized_previous_span_stats), tf.compat.as_bytes(serialized_serving_stats), tf.compat.as_bytes(serialized_previous_version_stats), tf.compat.as_bytes(serialized_features_needed), tf.compat.as_bytes(serialized_validation_config), enable_diff_regions)) # Parse the serialized Anomalies proto. result = anomalies_pb2.Anomalies() result.ParseFromString(anomalies_proto_string) return result
def _may_be_set_legacy_flag(schema: schema_pb2.Schema): """Sets legacy flag to False if it exists.""" if getattr(schema, 'generate_legacy_feature_spec', None) is not None: schema.generate_legacy_feature_spec = False
def validate_statistics( statistics: statistics_pb2.DatasetFeatureStatisticsList, schema: schema_pb2.Schema, environment: Optional[str] = None, previous_statistics: Optional[ statistics_pb2.DatasetFeatureStatisticsList] = None, serving_statistics: Optional[ statistics_pb2.DatasetFeatureStatisticsList] = None, ) -> anomalies_pb2.Anomalies: """Validates the input statistics against the provided input schema. This method validates the `statistics` against the `schema`. If an optional `environment` is specified, the `schema` is filtered using the `environment` and the `statistics` is validated against the filtered schema. The optional `previous_statistics` and `serving_statistics` are the statistics computed over the treatment data for drift- and skew-detection, respectively. Args: statistics: A DatasetFeatureStatisticsList protocol buffer denoting the statistics computed over the current data. Validation is currently only supported for lists with a single DatasetFeatureStatistics proto. schema: A Schema protocol buffer. environment: An optional string denoting the validation environment. Must be one of the default environments specified in the schema. By default, validation assumes that all Examples in a pipeline adhere to a single schema. In some cases introducing slight schema variations is necessary, for instance features used as labels are required during training (and should be validated), but are missing during serving. Environments can be used to express such requirements. For example, assume a feature named 'LABEL' is required for training, but is expected to be missing from serving. This can be expressed by defining two distinct environments in schema: ["SERVING", "TRAINING"] and associating 'LABEL' only with environment "TRAINING". previous_statistics: An optional DatasetFeatureStatisticsList protocol buffer denoting the statistics computed over an earlier data (for example, previous day's data). If provided, the `validate_statistics` method will detect if there exists drift between current data and previous data. Configuration for drift detection can be done by specifying a `drift_comparator` in the schema. For now drift detection is only supported for categorical features. serving_statistics: An optional DatasetFeatureStatisticsList protocol buffer denoting the statistics computed over the serving data. If provided, the `validate_statistics` method will identify if there exists distribution skew between current data and serving data. Configuration for skew detection can be done by specifying a `skew_comparator` in the schema. For now skew detection is only supported for categorical features. Returns: An Anomalies protocol buffer. Raises: TypeError: If any of the input arguments is not of the expected type. ValueError: If the input statistics proto does not have only one dataset. """ if not isinstance(statistics, statistics_pb2.DatasetFeatureStatisticsList): raise TypeError('statistics is of type %s, should be ' 'a DatasetFeatureStatisticsList proto.' % type(statistics).__name__) if len(statistics.datasets) != 1: raise ValueError('statistics proto contains multiple datasets. Only ' 'one dataset is currently supported for validation.') if not isinstance(schema, schema_pb2.Schema): raise TypeError('schema is of type %s, should be a Schema proto.' % type(schema).__name__) if environment is not None: if environment not in schema.default_environment: raise ValueError('Environment %s not found in the schema.' % environment) else: environment = '' _check_for_unsupported_stats_fields(statistics.datasets[0], 'statistics') _check_for_unsupported_schema_fields(schema) if previous_statistics is not None: if not isinstance(previous_statistics, statistics_pb2.DatasetFeatureStatisticsList): raise TypeError('previous_statistics is of type %s, should be ' 'a DatasetFeatureStatisticsList proto.' % type(previous_statistics).__name__) if len(previous_statistics.datasets) != 1: raise ValueError( 'previous_statistics proto contains multiple datasets. ' 'Only one dataset is currently supported for validation.') _check_for_unsupported_stats_fields(previous_statistics.datasets[0], 'previous_statistics') if serving_statistics is not None: if not isinstance(serving_statistics, statistics_pb2.DatasetFeatureStatisticsList): raise TypeError('serving_statistics is of type %s, should be ' 'a DatasetFeatureStatisticsList proto.' % type(serving_statistics).__name__) if len(serving_statistics.datasets) != 1: raise ValueError( 'serving_statistics proto contains multiple datasets. ' 'Only one dataset is currently supported for validation.') _check_for_unsupported_stats_fields(serving_statistics.datasets[0], 'serving_statistics') # Serialize the input protos. serialized_schema = schema.SerializeToString() serialized_stats = statistics.datasets[0].SerializeToString() serialized_previous_stats = ( previous_statistics.datasets[0].SerializeToString() if previous_statistics is not None else '') serialized_serving_stats = ( serving_statistics.datasets[0].SerializeToString() if serving_statistics is not None else '') # TODO(b/138589321): Update API to support validation against previous version # stats. serialized_previous_version_stats = '' anomalies_proto_string = ( pywrap_tensorflow_data_validation.ValidateFeatureStatistics( tf.compat.as_bytes(serialized_stats), tf.compat.as_bytes(serialized_schema), tf.compat.as_bytes(environment), tf.compat.as_bytes(serialized_previous_stats), tf.compat.as_bytes(serialized_serving_stats), tf.compat.as_bytes(serialized_previous_version_stats))) # Parse the serialized Anomalies proto. result = anomalies_pb2.Anomalies() result.ParseFromString(anomalies_proto_string) return result