Пример #1
0
def infer_schema(statistics,
                 infer_feature_shape = True,
                 max_string_domain_size = 100
                ):
  """Infers schema from the input statistics.

  Args:
    statistics: A DatasetFeatureStatisticsList protocol buffer. Schema
        inference is currently only supported for lists with a single
        DatasetFeatureStatistics proto.
    infer_feature_shape: A boolean to indicate if shape of the features need
        to be inferred from the statistics.
    max_string_domain_size: Maximum size of the domain of a string feature in
        order to be interpreted as a categorical feature.

  Returns:
    A Schema protocol buffer.

  Raises:
    TypeError: If the input argument is not of the expected type.
    ValueError: If the input statistics proto does not have only one dataset.
  """
  if not isinstance(statistics, statistics_pb2.DatasetFeatureStatisticsList):
    raise TypeError(
        'statistics is of type %s, should be '
        'a DatasetFeatureStatisticsList proto.' % type(statistics).__name__)

  if len(statistics.datasets) != 1:
    raise ValueError('Only statistics proto with one dataset is currently '
                     'supported for inferring schema.')

  _check_for_unsupported_stats_fields(statistics.datasets[0], 'statistics')

  schema_proto_string = pywrap_tensorflow_data_validation.InferSchema(
      tf.compat.as_bytes(statistics.datasets[0].SerializeToString()),
      max_string_domain_size)

  # Parse the serialized Schema proto.
  result = schema_pb2.Schema()
  result.ParseFromString(schema_proto_string)

  _may_be_set_legacy_flag(result)

  # TODO(b/113605666): Push this shape inference logic into example validation
  # code.
  if infer_feature_shape:
    _infer_shape(result)

  return result
def infer_schema(
    statistics: statistics_pb2.DatasetFeatureStatisticsList,
    infer_feature_shape: bool = True,
    max_string_domain_size: int = 100,
    schema_transformations: Optional[List[
        Callable[[schema_pb2.Schema, statistics_pb2.DatasetFeatureStatistics],
                 schema_pb2.Schema]]] = None
) -> schema_pb2.Schema:
  """Infers schema from the input statistics.

  Args:
    statistics: A DatasetFeatureStatisticsList protocol buffer. Schema inference
      is currently supported only for lists with a single
      DatasetFeatureStatistics proto or lists with multiple
      DatasetFeatureStatistics protos corresponding to data slices that include
      the default slice (i.e., the slice with all examples). If a list with
      multiple DatasetFeatureStatistics protos is used, this function will infer
      the schema from the statistics corresponding to the default slice.
    infer_feature_shape: A boolean to indicate if shape of the features need to
      be inferred from the statistics.
    max_string_domain_size: Maximum size of the domain of a string feature in
        order to be interpreted as a categorical feature.
    schema_transformations: List of transformation functions to apply to the
        auto-inferred schema. Each transformation function should take the
        schema and statistics as input and should return the transformed schema.
        The transformations are applied in the order provided in the list.

  Returns:
    A Schema protocol buffer.

  Raises:
    TypeError: If the input argument is not of the expected type.
    ValueError: If the input statistics proto contains multiple datasets, none
        of which corresponds to the default slice.
  """
  if not isinstance(statistics, statistics_pb2.DatasetFeatureStatisticsList):
    raise TypeError(
        'statistics is of type %s, should be '
        'a DatasetFeatureStatisticsList proto.' % type(statistics).__name__)

  # This will raise an exception if there are multiple datasets, none of which
  # corresponds to the default slice.
  dataset_statistics = _get_default_dataset_statistics(statistics)

  # dataset_statistics may include stats for composite features like
  # SparseFeatures and WeightedFeatures. We cannot infer a useful schema from
  # these stats, so we remove them at the start.
  dataset_statistics = _remove_features_missing_common_stats(dataset_statistics)

  _check_for_unsupported_stats_fields(dataset_statistics, 'statistics')

  schema_proto_string = pywrap_tensorflow_data_validation.InferSchema(
      tf.compat.as_bytes(dataset_statistics.SerializeToString()),
      max_string_domain_size)

  # Parse the serialized Schema proto.
  result = schema_pb2.Schema()
  result.ParseFromString(schema_proto_string)

  _may_be_set_legacy_flag(result)

  # TODO(b/113605666): Push this shape inference logic into example validation
  # code.
  if infer_feature_shape:
    _infer_shape(result)

  if schema_transformations is not None:
    for transformation_fn in schema_transformations:
      result = transformation_fn(result, statistics.datasets[0])
  return result