Exemple #1
0
 def parse_schema_file(schema_path):  # type: (str) -> Schema
     """
     Read a schema file and return the proto object.
     """
     assert file_io.file_exists(schema_path), "File not found: {}".format(schema_path)
     schema = Schema()
     with file_io.FileIO(schema_path, "rb") as f:
         schema.ParseFromString(f.read())
     return schema
def _replicate_schema_for_sliced_validation(
        schema: schema_pb2.Schema, slice_names: Set[str]) -> schema_pb2.Schema:
    """Replicates features in a schema with prepended slice names."""
    if schema.HasField('dataset_constraints') is not None:
        logging.error('DatasetConstraints will not be validated per-slice.')
    result = schema_pb2.Schema()
    result.string_domain.extend(schema.string_domain)
    result.float_domain.extend(schema.float_domain)
    result.int_domain.extend(schema.int_domain)
    for slice_name in slice_names:
        for feature in schema.feature:
            new_feature = result.feature.add()
            new_feature.CopyFrom(feature)
            new_feature.name = _prepend_slice_name(slice_name, feature.name)
        for sparse_feature in schema.sparse_feature:
            new_sparse_feature = result.sparse_feature.add()
            new_sparse_feature.CopyFrom(sparse_feature)
            new_sparse_feature.name = _prepend_slice_name(
                slice_name, sparse_feature.name)
        for weighted_feature in schema.weighted_feature:
            new_weighted_feature = result.weighted_feature.add()
            new_weighted_feature.CopyFrom(weighted_feature)
            new_weighted_feature.name = _prepend_slice_name(
                slice_name, weighted_feature.name)
    return result
def update_schema(schema: schema_pb2.Schema,
                  statistics: statistics_pb2.DatasetFeatureStatisticsList,
                  infer_feature_shape: Optional[bool] = True,
                  max_string_domain_size: Optional[int] = 100
                 ) -> schema_pb2.Schema:
  """Updates input schema to conform to the input statistics.

  Args:
    schema: A Schema protocol buffer.
    statistics: A DatasetFeatureStatisticsList protocol buffer. Schema inference
      is currently supported only for lists with a single
      DatasetFeatureStatistics proto or lists with multiple
      DatasetFeatureStatistics protos corresponding to data slices that include
      the default slice (i.e., the slice with all examples). If a list with
      multiple DatasetFeatureStatistics protos is used, this function will
      update the schema to conform to the statistics corresponding to the
      default slice.
    infer_feature_shape: A boolean to indicate if shape of the features need to
      be inferred from the statistics.
    max_string_domain_size: Maximum size of the domain of a string feature in
      order to be interpreted as a categorical feature.

  Returns:
    A Schema protocol buffer.

  Raises:
    TypeError: If the input argument is not of the expected type.
    ValueError: If the input statistics proto contains multiple datasets, none
        of which corresponds to the default slice.
  """
  if not isinstance(schema, schema_pb2.Schema):
    raise TypeError('schema is of type %s, should be a Schema proto.' %
                    type(schema).__name__)
  if not isinstance(statistics, statistics_pb2.DatasetFeatureStatisticsList):
    raise TypeError(
        'statistics is of type %s, should be '
        'a DatasetFeatureStatisticsList proto.' % type(statistics).__name__)

  # This will raise an exception if there are multiple datasets, none of which
  # corresponds to the default slice.
  dataset_statistics = _get_default_dataset_statistics(statistics)

  _check_for_unsupported_stats_fields(dataset_statistics, 'statistics')

  schema_proto_string = pywrap_tensorflow_data_validation.UpdateSchema(
      tf.compat.as_bytes(schema.SerializeToString()),
      tf.compat.as_bytes(dataset_statistics.SerializeToString()),
      max_string_domain_size)

  # Parse the serialized Schema proto.
  result = schema_pb2.Schema()
  result.ParseFromString(schema_proto_string)

  # TODO(b/113605666): Push this shape inference logic into example validation
  # code.
  if infer_feature_shape:
    _infer_shape(result)
  return result
def _validate_sql(sql_query: Text, schema: schema_pb2.Schema):
  arrow_schema = example_coder.ExamplesToRecordBatchDecoder(
      schema.SerializeToString()).ArrowSchema()
  formatted_query = slicing_util.format_slice_sql_query(sql_query)
  try:
    sql_util.RecordBatchSQLSliceQuery(formatted_query, arrow_schema)
  except Exception as e:  # pylint: disable=broad-except
    raise ValueError('One of the slice SQL query %s raised an exception: %s.'
                     % (sql_query, repr(e)))
Exemple #5
0
def parse_schema_txt_file(schema_path):  # type: (str) -> Schema
    """
    Parse a tf.metadata Schema txt file into its in-memory representation.
    """
    assert file_io.file_exists(schema_path), "File not found: {}".format(
        schema_path)
    schema = Schema()
    schema_text = file_io.read_file_to_string(schema_path)
    google.protobuf.text_format.Parse(schema_text, schema)
    return schema
Exemple #6
0
    def apply(cls, feature_spec):
        # type: (Dict[str, Union[tf.FixedLenFeature, tf.VarLenFeature, tf.SparseFeature]]) -> Schema  # noqa: E501
        """
        Main entry point.
        """
        schema_proto = Schema()
        for k, v in feature_spec.items():
            if isinstance(v, tf.SparseFeature):
                cls._add_sparse_feature_to_proto(schema_proto, k, v)
            else:
                cls._add_feature_to_proto(schema_proto, k, v)

        return schema_proto
def update_schema(
        schema: schema_pb2.Schema,
        statistics: statistics_pb2.DatasetFeatureStatisticsList,
        infer_feature_shape: Optional[bool] = True,
        max_string_domain_size: Optional[int] = 100) -> schema_pb2.Schema:
    """Updates input schema to conform to the input statistics.

  Args:
    schema: A Schema protocol buffer.
    statistics: A DatasetFeatureStatisticsList protocol buffer. Schema
        inference is currently only supported for lists with a single
        DatasetFeatureStatistics proto.
    infer_feature_shape: A boolean to indicate if shape of the features need
        to be inferred from the statistics.
    max_string_domain_size: Maximum size of the domain of a string feature in
        order to be interpreted as a categorical feature.

  Returns:
    A Schema protocol buffer.

  Raises:
    TypeError: If the input argument is not of the expected type.
    ValueError: If the input statistics proto does not have only one dataset.
  """
    if not isinstance(schema, schema_pb2.Schema):
        raise TypeError('schema is of type %s, should be a Schema proto.' %
                        type(schema).__name__)
    if not isinstance(statistics, statistics_pb2.DatasetFeatureStatisticsList):
        raise TypeError('statistics is of type %s, should be '
                        'a DatasetFeatureStatisticsList proto.' %
                        type(statistics).__name__)
    if len(statistics.datasets) != 1:
        raise ValueError('Only statistics proto with one dataset is currently '
                         'supported for inferring schema.')

    _check_for_unsupported_stats_fields(statistics.datasets[0], 'statistics')

    schema_proto_string = pywrap_tensorflow_data_validation.UpdateSchema(
        tf.compat.as_bytes(schema.SerializeToString()),
        tf.compat.as_bytes(statistics.datasets[0].SerializeToString()),
        max_string_domain_size)

    # Parse the serialized Schema proto.
    result = schema_pb2.Schema()
    result.ParseFromString(schema_proto_string)

    # TODO(b/113605666): Push this shape inference logic into example validation
    # code.
    if infer_feature_shape:
        _infer_shape(result)
    return result
Exemple #8
0
def lists_to_partitions(
    datasets: DatasetFeatureStatisticsList,
    schema: Schema,
    examples: types.Artifact,
    partitions: List[List[Text]],
) -> List[Partition]:
    result: List[Partition] = []
    for p in partitions:
        name = '__'.join(p)
        partition = Partition(
            name=name,
            statistics=DatasetFeatureStatisticsList(datasets=[
                DatasetFeatureStatistics(
                    name=name,
                    num_examples=min([
                        getattr(feature, feature.WhichOneof(
                            'stats')).common_stats.num_non_missing
                        for feature in filter(lambda f: feature_name(f) in p,
                                              dataset.features)
                    ]),
                    weighted_num_examples=0,
                    features=list(
                        filter(lambda f: feature_name(f) in p,
                               dataset.features)),
                ) for dataset in datasets.datasets
            ]),
            example_splits=[
                ExampleSplit(split=split,
                             uri=os.path.join(examples.uri, split)) for split
                in artifact_utils.decode_split_names(examples.split_names)
            ],
            schema=Schema(feature=filter(lambda f: f.name in p,
                                         schema.feature),
                          sparse_feature=filter(lambda f: f.name in p,
                                                schema.sparse_feature),
                          weighted_feature=filter(lambda f: f.name in p,
                                                  schema.weighted_feature),
                          string_domain=filter(lambda f: f.name in p,
                                               schema.string_domain),
                          float_domain=filter(lambda f: f.name in p,
                                              schema.float_domain),
                          int_domain=filter(lambda f: f.name in p,
                                            schema.int_domain),
                          default_environment=schema.default_environment))
        result.append(partition)
    return result
def validate_statistics_internal(
    statistics: statistics_pb2.DatasetFeatureStatisticsList,
    schema: schema_pb2.Schema,
    environment: Optional[Text] = None,
    previous_span_statistics: Optional[
        statistics_pb2.DatasetFeatureStatisticsList] = None,
    serving_statistics: Optional[
        statistics_pb2.DatasetFeatureStatisticsList] = None,
    previous_version_statistics: Optional[
        statistics_pb2.DatasetFeatureStatisticsList] = None,
    validation_options: Optional[vo.ValidationOptions] = None,
    enable_diff_regions: bool = False
) -> anomalies_pb2.Anomalies:
  """Validates the input statistics against the provided input schema.

  This method validates the `statistics` against the `schema`. If an optional
  `environment` is specified, the `schema` is filtered using the
  `environment` and the `statistics` is validated against the filtered schema.
  The optional `previous_span_statistics`, `serving_statistics`, and
  `previous_version_statistics` are the statistics computed over the control
  data for drift detection, skew detection, and dataset-level anomaly detection
  across versions, respectively.

  Args:
    statistics: A DatasetFeatureStatisticsList protocol buffer denoting the
       statistics computed over the current data. Validation is currently
       supported only for lists with a single DatasetFeatureStatistics proto or
       lists with multiple DatasetFeatureStatistics protos corresponding to data
       slices that include the default slice (i.e., the slice with all
       examples). If a list with multiple DatasetFeatureStatistics protos is
       used, this function will validate the statistics corresponding to the
       default slice.
    schema: A Schema protocol buffer.
    environment: An optional string denoting the validation environment.
        Must be one of the default environments specified in the schema.
        By default, validation assumes that all Examples in a pipeline adhere
        to a single schema. In some cases introducing slight schema variations
        is necessary, for instance features used as labels are required during
        training (and should be validated), but are missing during serving.
        Environments can be used to express such requirements. For example,
        assume a feature named 'LABEL' is required for training, but is expected
        to be missing from serving. This can be expressed by defining two
        distinct environments in schema: ["SERVING", "TRAINING"] and
        associating 'LABEL' only with environment "TRAINING".
    previous_span_statistics: An optional DatasetFeatureStatisticsList protocol
        buffer denoting the statistics computed over an earlier data (for
        example, previous day's data). If provided, the
        `validate_statistics_internal` method will detect if there exists drift
        between current data and previous data. Configuration for drift
        detection can be done by specifying a `drift_comparator` in the schema.
        For now drift detection is only supported for categorical features.
    serving_statistics: An optional DatasetFeatureStatisticsList protocol
        buffer denoting the statistics computed over the serving data. If
        provided, the `validate_statistics_internal` method will identify if
        there exists distribution skew between current data and serving data.
        Configuration for skew detection can be done by specifying a
        `skew_comparator` in the schema. For now skew detection is only
        supported for categorical features.
    previous_version_statistics: An optional DatasetFeatureStatisticsList
        protocol buffer denoting the statistics computed over an earlier data
        (typically, previous run's data within the same day). If provided,
        the `validate_statistics_internal` method will detect if there exists a
        change in the number of examples between current data and previous
        version data. Configuration for such dataset-wide anomaly detection can
        be done by specifying a `num_examples_version_comparator` in the schema.
    validation_options: Optional input used to specify the options of this
        validation.
    enable_diff_regions: Specifies whether to include a comparison between the
        existing schema and the fixed schema in the Anomalies protocol buffer
        output.

  Returns:
    An Anomalies protocol buffer.

  Raises:
    TypeError: If any of the input arguments is not of the expected type.
    ValueError: If the input statistics proto contains multiple datasets, none
        of which corresponds to the default slice.
  """
  if not isinstance(statistics, statistics_pb2.DatasetFeatureStatisticsList):
    raise TypeError(
        'statistics is of type %s, should be '
        'a DatasetFeatureStatisticsList proto.' % type(statistics).__name__)

  # This will raise an exception if there are multiple datasets, none of which
  # corresponds to the default slice.
  dataset_statistics = _get_default_dataset_statistics(statistics)

  if not isinstance(schema, schema_pb2.Schema):
    raise TypeError('schema is of type %s, should be a Schema proto.' %
                    type(schema).__name__)

  if environment is not None:
    if environment not in schema.default_environment:
      raise ValueError('Environment %s not found in the schema.' % environment)
  else:
    environment = ''

  _check_for_unsupported_stats_fields(dataset_statistics, 'statistics')

  if previous_span_statistics is not None:
    if not isinstance(
        previous_span_statistics, statistics_pb2.DatasetFeatureStatisticsList):
      raise TypeError(
          'previous_span_statistics is of type %s, should be '
          'a DatasetFeatureStatisticsList proto.'
          % type(previous_span_statistics).__name__)

    previous_dataset_statistics = _get_default_dataset_statistics(
        previous_span_statistics)
    _check_for_unsupported_stats_fields(previous_dataset_statistics,
                                        'previous_statistics')

  if serving_statistics is not None:
    if not isinstance(
        serving_statistics, statistics_pb2.DatasetFeatureStatisticsList):
      raise TypeError(
          'serving_statistics is of type %s, should be '
          'a DatasetFeatureStatisticsList proto.'
          % type(serving_statistics).__name__)

    serving_dataset_statistics = _get_default_dataset_statistics(
        serving_statistics)
    _check_for_unsupported_stats_fields(serving_dataset_statistics,
                                        'serving_statistics')

  if previous_version_statistics is not None:
    if not isinstance(previous_version_statistics,
                      statistics_pb2.DatasetFeatureStatisticsList):
      raise TypeError('previous_version_statistics is of type %s, should be '
                      'a DatasetFeatureStatisticsList proto.' %
                      type(previous_version_statistics).__name__)

    previous_version_dataset_statistics = _get_default_dataset_statistics(
        previous_version_statistics)
    _check_for_unsupported_stats_fields(previous_version_dataset_statistics,
                                        'previous_version_statistics')

  # Serialize the input protos.
  serialized_schema = schema.SerializeToString()
  serialized_stats = dataset_statistics.SerializeToString()
  serialized_previous_span_stats = (
      previous_dataset_statistics.SerializeToString()
      if previous_span_statistics is not None else '')
  serialized_serving_stats = (
      serving_dataset_statistics.SerializeToString()
      if serving_statistics is not None else '')
  serialized_previous_version_stats = (
      previous_version_dataset_statistics.SerializeToString()
      if previous_version_statistics is not None else '')

  features_needed_pb = validation_metadata_pb2.FeaturesNeededProto()
  if validation_options is not None and validation_options.features_needed:
    for path, reason_list in validation_options.features_needed.items():
      path_and_reason_feature_need = (
          features_needed_pb.path_and_reason_feature_need.add())
      path_and_reason_feature_need.path.CopyFrom(path.to_proto())
      for reason in reason_list:
        r = path_and_reason_feature_need.reason_feature_needed.add()
        r.comment = reason.comment

  serialized_features_needed = features_needed_pb.SerializeToString()

  validation_config = validation_config_pb2.ValidationConfig()
  if validation_options is not None:
    validation_config.new_features_are_warnings = (
        validation_options.new_features_are_warnings)
    for override in validation_options.severity_overrides:
      validation_config.severity_overrides.append(override)
  serialized_validation_config = validation_config.SerializeToString()

  anomalies_proto_string = (
      pywrap_tensorflow_data_validation.ValidateFeatureStatistics(
          tf.compat.as_bytes(serialized_stats),
          tf.compat.as_bytes(serialized_schema),
          tf.compat.as_bytes(environment),
          tf.compat.as_bytes(serialized_previous_span_stats),
          tf.compat.as_bytes(serialized_serving_stats),
          tf.compat.as_bytes(serialized_previous_version_stats),
          tf.compat.as_bytes(serialized_features_needed),
          tf.compat.as_bytes(serialized_validation_config),
          enable_diff_regions))

  # Parse the serialized Anomalies proto.
  result = anomalies_pb2.Anomalies()
  result.ParseFromString(anomalies_proto_string)
  return result
def _may_be_set_legacy_flag(schema: schema_pb2.Schema):
  """Sets legacy flag to False if it exists."""
  if getattr(schema, 'generate_legacy_feature_spec', None) is not None:
    schema.generate_legacy_feature_spec = False
def validate_statistics(
    statistics: statistics_pb2.DatasetFeatureStatisticsList,
    schema: schema_pb2.Schema,
    environment: Optional[str] = None,
    previous_statistics: Optional[
        statistics_pb2.DatasetFeatureStatisticsList] = None,
    serving_statistics: Optional[
        statistics_pb2.DatasetFeatureStatisticsList] = None,
) -> anomalies_pb2.Anomalies:
    """Validates the input statistics against the provided input schema.

  This method validates the `statistics` against the `schema`. If an optional
  `environment` is specified, the `schema` is filtered using the
  `environment` and the `statistics` is validated against the filtered schema.
  The optional `previous_statistics` and `serving_statistics` are the statistics
  computed over the treatment data for drift- and skew-detection, respectively.

  Args:
    statistics: A DatasetFeatureStatisticsList protocol buffer denoting the
        statistics computed over the current data. Validation is currently only
        supported for lists with a single DatasetFeatureStatistics proto.
    schema: A Schema protocol buffer.
    environment: An optional string denoting the validation environment.
        Must be one of the default environments specified in the schema.
        By default, validation assumes that all Examples in a pipeline adhere
        to a single schema. In some cases introducing slight schema variations
        is necessary, for instance features used as labels are required during
        training (and should be validated), but are missing during serving.
        Environments can be used to express such requirements. For example,
        assume a feature named 'LABEL' is required for training, but is expected
        to be missing from serving. This can be expressed by defining two
        distinct environments in schema: ["SERVING", "TRAINING"] and
        associating 'LABEL' only with environment "TRAINING".
    previous_statistics: An optional DatasetFeatureStatisticsList protocol
        buffer denoting the statistics computed over an earlier data (for
        example, previous day's data). If provided, the `validate_statistics`
        method will detect if there exists drift between current data and
        previous data. Configuration for drift detection can be done by
        specifying a `drift_comparator` in the schema. For now drift detection
        is only supported for categorical features.
    serving_statistics: An optional DatasetFeatureStatisticsList protocol
        buffer denoting the statistics computed over the serving data. If
        provided, the `validate_statistics` method will identify if there exists
        distribution skew between current data and serving data. Configuration
        for skew detection can be done by specifying a `skew_comparator` in the
        schema. For now skew detection is only supported for categorical
        features.

  Returns:
    An Anomalies protocol buffer.

  Raises:
    TypeError: If any of the input arguments is not of the expected type.
    ValueError: If the input statistics proto does not have only one dataset.
  """
    if not isinstance(statistics, statistics_pb2.DatasetFeatureStatisticsList):
        raise TypeError('statistics is of type %s, should be '
                        'a DatasetFeatureStatisticsList proto.' %
                        type(statistics).__name__)

    if len(statistics.datasets) != 1:
        raise ValueError('statistics proto contains multiple datasets. Only '
                         'one dataset is currently supported for validation.')

    if not isinstance(schema, schema_pb2.Schema):
        raise TypeError('schema is of type %s, should be a Schema proto.' %
                        type(schema).__name__)

    if environment is not None:
        if environment not in schema.default_environment:
            raise ValueError('Environment %s not found in the schema.' %
                             environment)
    else:
        environment = ''

    _check_for_unsupported_stats_fields(statistics.datasets[0], 'statistics')
    _check_for_unsupported_schema_fields(schema)

    if previous_statistics is not None:
        if not isinstance(previous_statistics,
                          statistics_pb2.DatasetFeatureStatisticsList):
            raise TypeError('previous_statistics is of type %s, should be '
                            'a DatasetFeatureStatisticsList proto.' %
                            type(previous_statistics).__name__)

        if len(previous_statistics.datasets) != 1:
            raise ValueError(
                'previous_statistics proto contains multiple datasets. '
                'Only one dataset is currently supported for validation.')

        _check_for_unsupported_stats_fields(previous_statistics.datasets[0],
                                            'previous_statistics')

    if serving_statistics is not None:
        if not isinstance(serving_statistics,
                          statistics_pb2.DatasetFeatureStatisticsList):
            raise TypeError('serving_statistics is of type %s, should be '
                            'a DatasetFeatureStatisticsList proto.' %
                            type(serving_statistics).__name__)

        if len(serving_statistics.datasets) != 1:
            raise ValueError(
                'serving_statistics proto contains multiple datasets. '
                'Only one dataset is currently supported for validation.')

        _check_for_unsupported_stats_fields(serving_statistics.datasets[0],
                                            'serving_statistics')

    # Serialize the input protos.
    serialized_schema = schema.SerializeToString()
    serialized_stats = statistics.datasets[0].SerializeToString()
    serialized_previous_stats = (
        previous_statistics.datasets[0].SerializeToString()
        if previous_statistics is not None else '')
    serialized_serving_stats = (
        serving_statistics.datasets[0].SerializeToString()
        if serving_statistics is not None else '')
    # TODO(b/138589321): Update API to support validation against previous version
    # stats.
    serialized_previous_version_stats = ''

    anomalies_proto_string = (
        pywrap_tensorflow_data_validation.ValidateFeatureStatistics(
            tf.compat.as_bytes(serialized_stats),
            tf.compat.as_bytes(serialized_schema),
            tf.compat.as_bytes(environment),
            tf.compat.as_bytes(serialized_previous_stats),
            tf.compat.as_bytes(serialized_serving_stats),
            tf.compat.as_bytes(serialized_previous_version_stats)))

    # Parse the serialized Anomalies proto.
    result = anomalies_pb2.Anomalies()
    result.ParseFromString(anomalies_proto_string)
    return result