def update_schema(schema: schema_pb2.Schema,
                  statistics: statistics_pb2.DatasetFeatureStatisticsList,
                  infer_feature_shape: Optional[bool] = True,
                  max_string_domain_size: Optional[int] = 100
                 ) -> schema_pb2.Schema:
  """Updates input schema to conform to the input statistics.

  Args:
    schema: A Schema protocol buffer.
    statistics: A DatasetFeatureStatisticsList protocol buffer. Schema inference
      is currently supported only for lists with a single
      DatasetFeatureStatistics proto or lists with multiple
      DatasetFeatureStatistics protos corresponding to data slices that include
      the default slice (i.e., the slice with all examples). If a list with
      multiple DatasetFeatureStatistics protos is used, this function will
      update the schema to conform to the statistics corresponding to the
      default slice.
    infer_feature_shape: A boolean to indicate if shape of the features need to
      be inferred from the statistics.
    max_string_domain_size: Maximum size of the domain of a string feature in
      order to be interpreted as a categorical feature.

  Returns:
    A Schema protocol buffer.

  Raises:
    TypeError: If the input argument is not of the expected type.
    ValueError: If the input statistics proto contains multiple datasets, none
        of which corresponds to the default slice.
  """
  if not isinstance(schema, schema_pb2.Schema):
    raise TypeError('schema is of type %s, should be a Schema proto.' %
                    type(schema).__name__)
  if not isinstance(statistics, statistics_pb2.DatasetFeatureStatisticsList):
    raise TypeError(
        'statistics is of type %s, should be '
        'a DatasetFeatureStatisticsList proto.' % type(statistics).__name__)

  # This will raise an exception if there are multiple datasets, none of which
  # corresponds to the default slice.
  dataset_statistics = _get_default_dataset_statistics(statistics)

  _check_for_unsupported_stats_fields(dataset_statistics, 'statistics')

  schema_proto_string = pywrap_tensorflow_data_validation.UpdateSchema(
      tf.compat.as_bytes(schema.SerializeToString()),
      tf.compat.as_bytes(dataset_statistics.SerializeToString()),
      max_string_domain_size)

  # Parse the serialized Schema proto.
  result = schema_pb2.Schema()
  result.ParseFromString(schema_proto_string)

  # TODO(b/113605666): Push this shape inference logic into example validation
  # code.
  if infer_feature_shape:
    _infer_shape(result)
  return result
예제 #2
0
def update_schema(
        schema: schema_pb2.Schema,
        statistics: statistics_pb2.DatasetFeatureStatisticsList,
        infer_feature_shape: Optional[bool] = True,
        max_string_domain_size: Optional[int] = 100) -> schema_pb2.Schema:
    """Updates input schema to conform to the input statistics.

  Args:
    schema: A Schema protocol buffer.
    statistics: A DatasetFeatureStatisticsList protocol buffer. Schema
        inference is currently only supported for lists with a single
        DatasetFeatureStatistics proto.
    infer_feature_shape: A boolean to indicate if shape of the features need
        to be inferred from the statistics.
    max_string_domain_size: Maximum size of the domain of a string feature in
        order to be interpreted as a categorical feature.

  Returns:
    A Schema protocol buffer.

  Raises:
    TypeError: If the input argument is not of the expected type.
    ValueError: If the input statistics proto does not have only one dataset.
  """
    if not isinstance(schema, schema_pb2.Schema):
        raise TypeError('schema is of type %s, should be a Schema proto.' %
                        type(schema).__name__)
    if not isinstance(statistics, statistics_pb2.DatasetFeatureStatisticsList):
        raise TypeError('statistics is of type %s, should be '
                        'a DatasetFeatureStatisticsList proto.' %
                        type(statistics).__name__)
    if len(statistics.datasets) != 1:
        raise ValueError('Only statistics proto with one dataset is currently '
                         'supported for inferring schema.')

    _check_for_unsupported_stats_fields(statistics.datasets[0], 'statistics')

    schema_proto_string = pywrap_tensorflow_data_validation.UpdateSchema(
        tf.compat.as_bytes(schema.SerializeToString()),
        tf.compat.as_bytes(statistics.datasets[0].SerializeToString()),
        max_string_domain_size)

    # Parse the serialized Schema proto.
    result = schema_pb2.Schema()
    result.ParseFromString(schema_proto_string)

    # TODO(b/113605666): Push this shape inference logic into example validation
    # code.
    if infer_feature_shape:
        _infer_shape(result)
    return result