class SchemaInferenceTest(test_case.TransformTestCase):

    # pylint: disable=g-long-lambda
    @test_case.named_parameters(
        dict(testcase_name='fixed_len_int',
             make_tensors_fn=lambda:
             {'x': tf.compat.v1.placeholder(tf.int64, (None, ))},
             feature_spec={'x': tf.io.FixedLenFeature([], tf.int64)}),
        dict(testcase_name='fixed_len_string',
             make_tensors_fn=lambda:
             {'x': tf.compat.v1.placeholder(tf.string, (None, ))},
             feature_spec={'x': tf.io.FixedLenFeature([], tf.string)}),
        dict(testcase_name='fixed_len_float',
             make_tensors_fn=lambda:
             {'x': tf.compat.v1.placeholder(tf.float32, (None, ))},
             feature_spec={'x': tf.io.FixedLenFeature([], tf.float32)}),
        dict(testcase_name='override',
             make_tensors_fn=_make_tensors_with_override,
             feature_spec={'x': tf.io.FixedLenFeature([], tf.int64)},
             domains={'x': schema_pb2.IntDomain(is_categorical=True)}),
        dict(testcase_name='override_with_session',
             make_tensors_fn=_make_tensors_with_override,
             feature_spec={'x': tf.io.FixedLenFeature([], tf.int64)},
             domains={
                 'x': schema_pb2.IntDomain(min=5, max=6, is_categorical=True)
             },
             create_session=True))
    # pylint: enable=g-long-lambda
    def test_infer_feature_schema(self,
                                  make_tensors_fn,
                                  feature_spec,
                                  domains=None,
                                  create_session=False):
        with tf.Graph().as_default() as graph:
            tensors = make_tensors_fn()

        if create_session:
            with tf.compat.v1.Session(graph=graph) as session:
                schema = schema_inference.infer_feature_schema(
                    tensors, graph, session)
        else:
            schema = schema_inference.infer_feature_schema(tensors, graph)

        expected_schema = dataset_schema.from_feature_spec(
            feature_spec, domains)
        self.assertEqual(schema, expected_schema)

    def test_infer_feature_schema_bad_rank(self):
        with tf.Graph().as_default() as graph:
            tensors = {
                'a': tf.compat.v1.placeholder(tf.float32, ()),
            }
        with self.assertRaises(ValueError):
            schema_inference.infer_feature_schema(tensors, graph)
Beispiel #2
0
def _sparse_feature_from_feature_spec(spec, name, domains):
    """Returns a representation of a SparseFeature from a feature spec."""
    if isinstance(spec.index_key, list):
        assert isinstance(spec.size,
                          (list, tuple, tf.TensorShape)), type(spec.size)
        assert len(spec.index_key) == len(spec.size), (spec.index_key,
                                                       spec.size)
        spec_size = [
            s.value if isinstance(s, tf.compat.v1.Dimension) else s
            for s in spec.size
        ]
        int_domains = [
            schema_pb2.IntDomain(min=0, max=size -
                                 1) if size is not None else None
            for size in spec_size
        ]
        index_feature = [
            schema_pb2.Feature(name=key,
                               type=schema_pb2.INT,
                               int_domain=int_domain)
            for (key, int_domain) in zip(spec.index_key, int_domains)
        ]
        index_feature_ref = [
            schema_pb2.SparseFeature.IndexFeature(name=key)
            for key in spec.index_key
        ]
    else:
        # Create a index feature.
        index_feature = [
            schema_pb2.Feature(name=spec.index_key,
                               type=schema_pb2.INT,
                               int_domain=schema_pb2.IntDomain(min=0,
                                                               max=spec.size -
                                                               1))
        ]
        index_feature_ref = [
            schema_pb2.SparseFeature.IndexFeature(name=spec.index_key)
        ]

    # Create a value feature.
    value_feature = schema_pb2.Feature(name=spec.value_key)
    _set_type(name, value_feature, spec.dtype)
    _set_domain(name, value_feature, domains.get(name))

    # Create a sparse feature which refers to the index and value features.
    value_feature_ref = schema_pb2.SparseFeature.ValueFeature(
        name=spec.value_key)
    sparse_feature = schema_pb2.SparseFeature(
        name=name,
        is_sorted=True if spec.already_sorted else None,
        index_feature=index_feature_ref,
        value_feature=value_feature_ref)

    return (index_feature, value_feature, sparse_feature)
Beispiel #3
0
def _sparse_feature_from_feature_spec(spec, name, domains):
    """Returns a representation of a SparseFeature from a feature spec."""
    if isinstance(spec.index_key, list):
        raise ValueError(
            'SparseFeature "{}" had index_key {}, but size and index_key '
            'fields should be single values'.format(name, spec.index_key))
    if isinstance(spec.size, list):
        raise ValueError(
            'SparseFeature "{}" had size {}, but size and index_key fields '
            'should be single values'.format(name, spec.size))

    # Create a index feature.
    index_feature = schema_pb2.Feature(name=spec.index_key,
                                       type=schema_pb2.INT,
                                       int_domain=schema_pb2.IntDomain(
                                           min=0, max=spec.size - 1))

    # Create a value feature.
    value_feature = schema_pb2.Feature(name=spec.value_key)
    _set_type(name, value_feature, spec.dtype)
    _set_domain(name, value_feature, domains.get(name))

    # Create a sparse feature which refers to the index and value features.
    index_feature_ref = schema_pb2.SparseFeature.IndexFeature(
        name=spec.index_key)
    value_feature_ref = schema_pb2.SparseFeature.ValueFeature(
        name=spec.value_key)
    sparse_feature = schema_pb2.SparseFeature(
        name=name,
        is_sorted=True if spec.already_sorted else None,
        index_feature=[index_feature_ref],
        value_feature=value_feature_ref)

    return (index_feature, value_feature, sparse_feature)
Beispiel #4
0
def IntDomain(dtype, min_value=None, max_value=None, is_categorical=None):  # pylint: disable=invalid-name
    """Legacy constructor for an IntDomain."""
    if dtype != tf.int64:
        raise ValueError('IntDomain must be called with dtype=tf.int64')
    return schema_pb2.IntDomain(min=min_value,
                                max=max_value,
                                is_categorical=is_categorical)
def _infer_feature_schema_common(features, tensor_ranges, feature_annotations,
                                 global_annotations):
    """Given a dict of tensors, creates a `Schema`.

  Args:
    features: A dict mapping column names to `Tensor` or `SparseTensor`s. The
      `Tensor` or `SparseTensor`s should have a 0'th dimension which is
      interpreted as the batch dimension.
    tensor_ranges: A dict mapping a tensor to a tuple containing its min and max
      value.
    feature_annotations: dictionary from feature name to list of any_pb2.Any
      protos to be added as an annotation for that feature in the schema.
    global_annotations: list of any_pb2.Any protos to be added at the global
      schema level.

  Returns:
    A `Schema` proto.
  """
    domains = {}
    feature_tags = collections.defaultdict(list)
    for name, tensor in six.iteritems(features):
        if isinstance(tensor, tf.RaggedTensor):
            # Add the 'ragged_tensor' tag which will cause coder and
            # schema_as_feature_spec to raise an error, as currently there is no
            # feature spec for ragged tensors.
            feature_tags[name].append(schema_utils.RAGGED_TENSOR_TAG)
        if name in tensor_ranges:
            min_value, max_value = tensor_ranges[name]
            domains[name] = schema_pb2.IntDomain(min=min_value,
                                                 max=max_value,
                                                 is_categorical=True)
    feature_spec = _feature_spec_from_batched_tensors(features)

    schema_proto = schema_utils.schema_from_feature_spec(feature_spec, domains)

    # Add the annotations to the schema.
    for annotation in global_annotations:
        schema_proto.annotation.extra_metadata.add().CopyFrom(annotation)
    # Build a map from logical feature names to Feature protos
    feature_protos_by_name = {}
    for feature in schema_proto.feature:
        feature_protos_by_name[feature.name] = feature
    for sparse_feature in schema_proto.sparse_feature:
        for index_feature in sparse_feature.index_feature:
            feature_protos_by_name.pop(index_feature.name)
        value_feature = feature_protos_by_name.pop(
            sparse_feature.value_feature.name)
        feature_protos_by_name[sparse_feature.name] = value_feature
    # Update annotations
    for feature_name, annotations in feature_annotations.items():
        feature_proto = feature_protos_by_name[feature_name]
        for annotation in annotations:
            feature_proto.annotation.extra_metadata.add().CopyFrom(annotation)
    for feature_name, tags in feature_tags.items():
        feature_proto = feature_protos_by_name[feature_name]
        for tag in tags:
            feature_proto.annotation.tag.append(tag)
    return schema_proto
Beispiel #6
0
def _domain_from_json(domain):
  """Translate a JSON domain dict into an IntDomain or None."""
  if domain.get('ints') is not None:
    def maybe_to_int(s):
      return int(s) if s is not None else None
    return schema_pb2.IntDomain(
        min=maybe_to_int(domain['ints'].get('min')),
        max=maybe_to_int(domain['ints'].get('max')),
        is_categorical=domain['ints'].get('isCategorical'))
  return None
    def testBucketizePerKey(self):
        def preprocessing_fn(inputs):
            x_bucketized = tft.bucketize_per_key(inputs['x'],
                                                 inputs['key'],
                                                 num_buckets=3,
                                                 epsilon=0.00001)
            return {'x': inputs['x'], 'x_bucketized': x_bucketized}

        # NOTE: We force 10 batches: data has 100 elements and we request a batch
        # size of 10.
        input_data = [{
            'x': x,
            'key': 'a' if x < 50 else 'b'
        } for x in range(1, 100)]
        input_metadata = tft_unit.metadata_from_feature_spec({
            'x':
            tf.io.FixedLenFeature([], tf.float32),
            'key':
            tf.io.FixedLenFeature([], tf.string)
        })

        def compute_quantile(instance):
            if instance['key'] == 'a':
                if instance['x'] < 17:
                    return 0
                elif instance['x'] < 33:
                    return 1
                else:
                    return 2
            else:
                if instance['x'] < 66:
                    return 0
                elif instance['x'] < 83:
                    return 1
                else:
                    return 2

        expected_data = [{
            'x_bucketized': compute_quantile(instance),
            'x': instance['x']
        } for instance in input_data]
        expected_metadata = tft_unit.metadata_from_feature_spec(
            {
                'x': tf.io.FixedLenFeature([], tf.float32),
                'x_bucketized': tf.io.FixedLenFeature([], tf.int64),
            }, {
                'x_bucketized':
                schema_pb2.IntDomain(min=0, max=2, is_categorical=True),
            })
        self.assertAnalyzeAndTransformResults(input_data,
                                              input_metadata,
                                              preprocessing_fn,
                                              expected_data,
                                              expected_metadata,
                                              desired_batch_size=10)
def set_protobuf_int(column_schema, feature):
    domain = column_schema.properties.get("domain", {})
    feature.int_domain.CopyFrom(
        schema_pb2.IntDomain(
            name=column_schema.name,
            min=domain.get("min", None),
            max=domain.get("max", None),
            is_categorical=(Tags.CATEGORICAL in column_schema.tags
                            or Tags.CATEGORICAL.value in column_schema.tags),
        ))
    feature.type = schema_pb2.FeatureType.INT
    return feature
def infer_feature_schema(features, graph, session=None):
    """Given a dict of tensors, creates a `Schema`.

  Infers a schema, in the format of a tf.Transform `Schema`, for the given
  dictionary of tensors.

  If there is an override specified, we override the inferred schema for the
  given feature's tensor.  An override has the meaning that we should set
  is_categorical=True.  If session is not provided then we just set
  is_categorical=True, and if the session is provided then was also compute
  values of the tensors representing the min and max values and set them in the
  schema.

  Args:
    features: A dict mapping column names to `Tensor` or `SparseTensor`s. The
        `Tensor` or `SparseTensor`s should have a 0'th dimension which is
        interpreted as the batch dimension.
    graph: A `tf.Graph` used to determine schema overrides.
    session: (optional) A `tf.Session` used to compute schema overrides.  If
        None, schema overrides will not be computed.

  Returns:
    A `Schema` object.
  """
    tensor_ranges = _get_tensor_schema_overrides(graph)
    if session is None:
        tensor_ranges = {
            tensor: (None, None)
            for tensor in tensor_ranges.keys()
        }
    else:
        tensor_ranges = session.run(tensor_ranges)

    domains = {}
    for name, tensor in six.iteritems(features):
        values = tensor.values if isinstance(tensor,
                                             tf.SparseTensor) else tensor
        if values in tensor_ranges:
            assert values.dtype == tf.int64
            min_value, max_value = tensor_ranges[values]
            domains[name] = schema_pb2.IntDomain(min=min_value,
                                                 max=max_value,
                                                 is_categorical=True)

    feature_spec = _feature_spec_from_batched_tensors(features)

    return dataset_schema.from_feature_spec(feature_spec, domains)
Beispiel #10
0
 def _validate_column_schemas(self):
   """Validate that this Schema can be represented as a schema_pb2.Schema."""
   feature_spec = self.as_feature_spec()
   int_domains = {}
   for name, column_schema in self._column_schemas.items():
     domain = column_schema.domain
     if isinstance(domain, IntDomain):
       int_domains[name] = schema_pb2.IntDomain(
           min=domain.min_value, max=domain.max_value,
           is_categorical=domain.is_categorical)
   try:
     schema_utils.schema_from_feature_spec(feature_spec, int_domains)
   except Exception as e:
     raise ValueError(
         'The values of column_schemas were invalid, as detected when '
         'converting them to a schema_pb2.Schema proto.  Original error: '
         '{}'.format(e.message))
class SchemaInferenceTest(test_case.TransformTestCase):

    # pylint: disable=g-long-lambda
    @test_case.named_parameters(
        dict(testcase_name='fixed_len_int',
             make_tensors_fn=lambda:
             {'x': tf.compat.v1.placeholder(tf.int64, (None, ))},
             feature_spec={'x': tf.io.FixedLenFeature([], tf.int64)}),
        dict(testcase_name='fixed_len_string',
             make_tensors_fn=lambda:
             {'x': tf.compat.v1.placeholder(tf.string, (None, ))},
             feature_spec={'x': tf.io.FixedLenFeature([], tf.string)}),
        dict(testcase_name='fixed_len_float',
             make_tensors_fn=lambda:
             {'x': tf.compat.v1.placeholder(tf.float32, (None, ))},
             feature_spec={'x': tf.io.FixedLenFeature([], tf.float32)}),
        dict(testcase_name='override',
             make_tensors_fn=_make_tensors_with_override,
             feature_spec={'x': tf.io.FixedLenFeature([], tf.int64)},
             domains={'x': schema_pb2.IntDomain(is_categorical=True)}),
        dict(testcase_name='override_with_session',
             make_tensors_fn=_make_tensors_with_override,
             feature_spec={'x': tf.io.FixedLenFeature([], tf.int64)},
             domains={
                 'x': schema_pb2.IntDomain(min=5, max=6, is_categorical=True)
             },
             create_session=True))
    # pylint: enable=g-long-lambda
    def test_infer_feature_schema(self,
                                  make_tensors_fn,
                                  feature_spec,
                                  domains=None,
                                  create_session=False):
        with tf.Graph().as_default() as graph:
            tensors = make_tensors_fn()

        if create_session:
            with tf.compat.v1.Session(graph=graph) as session:
                schema = schema_inference.infer_feature_schema(
                    tensors, graph, session)
        else:
            schema = schema_inference.infer_feature_schema(tensors, graph)

        expected_schema = schema_utils.schema_from_feature_spec(
            feature_spec, domains)
        self.assertEqual(schema, expected_schema)

    def test_infer_feature_schema_bad_rank(self):
        with tf.Graph().as_default() as graph:
            tensors = {
                'a': tf.compat.v1.placeholder(tf.float32, ()),
            }
        with self.assertRaises(ValueError):
            schema_inference.infer_feature_schema(tensors, graph)

    def test_bucketization_annotation(self):
        # TODO(b/132098015): Schema annotations aren't yet supported in OSS builds.
        # pylint: disable=g-import-not-at-top
        try:
            from tensorflow_transform import annotations_pb2
        except ImportError:
            return
        # pylint: enable=g-import-not-at-top
        with tf.Graph().as_default() as graph:
            inputs = {
                'foo': tf.convert_to_tensor([0, 1, 2, 3]),
                'bar': tf.convert_to_tensor([0, 2, 0, 2]),
            }
            boundaries_foo = tf.expand_dims(tf.convert_to_tensor([.5, 1.5]),
                                            axis=0)
            boundaries_bar = tf.expand_dims(tf.convert_to_tensor([.1, .2]),
                                            axis=0)
            outputs = {}

            # tft.apply_buckets will annotate the feature in the output schema to
            # indicate the bucket boundaries that were applied.
            outputs['Bucketized_foo'] = mappers.apply_buckets(
                inputs['foo'], boundaries_foo)
            outputs['Bucketized_bar'] = mappers.apply_buckets(
                inputs['bar'], boundaries_bar)
            # Create a session to actually evaluate the annotations and extract the
            # the output schema with annotations applied.
            with tf.compat.v1.Session(graph=graph) as session:
                schema = schema_inference.infer_feature_schema(
                    outputs, graph, session)
                self.assertLen(schema.feature, 2)
                for feature in schema.feature:
                    self.assertLen(feature.annotation.extra_metadata, 1)
                    for annotation in feature.annotation.extra_metadata:

                        # Extract the annotated message and validate its contents
                        message = annotations_pb2.BucketBoundaries()
                        annotation.Unpack(message)
                        if feature.name == 'Bucketized_foo':
                            self.assertAllClose(list(message.boundaries),
                                                [.5, 1.5])
                        elif feature.name == 'Bucketized_bar':
                            self.assertAllClose(list(message.boundaries),
                                                [.1, .2])
                        else:
                            raise RuntimeError('Unexpected features in schema')

    def test_global_annotation(self):
        # TODO(b/132098015): Schema annotations aren't yet supported in OSS builds.
        # pylint: disable=g-import-not-at-top
        try:
            from tensorflow_transform import annotations_pb2
        except ImportError:
            return
        # pylint: enable=g-import-not-at-top
        with tf.Graph().as_default() as graph:
            outputs = {
                'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64),
                'bar': tf.convert_to_tensor([0, 2, 0, 2], dtype=tf.int64),
            }

            # Annotate an arbitrary proto at the schema level (not sure what global
            # schema boundaries would mean, but hey I'm just a test).
            boundaries = tf.constant([[1.0]])
            message_type = annotations_pb2.BucketBoundaries.DESCRIPTOR.full_name
            sizes = tf.expand_dims([tf.size(boundaries)], axis=0)
            message_proto = tf.raw_ops.EncodeProto(
                sizes=sizes,
                values=[tf.cast(boundaries, tf.float32)],
                field_names=['boundaries'],
                message_type=message_type)[0]
            type_url = os.path.join('type.googleapis.com', message_type)
            schema_inference.annotate(type_url, message_proto)

            with tf.compat.v1.Session(graph=graph) as session:
                schema = schema_inference.infer_feature_schema(
                    outputs, graph, session)
                self.assertLen(schema.annotation.extra_metadata, 1)
                for annotation in schema.annotation.extra_metadata:
                    # Extract the annotated message and validate its contents
                    message = annotations_pb2.BucketBoundaries()
                    annotation.Unpack(message)
                    self.assertAllClose(list(message.boundaries), [1])
def _infer_feature_schema_common(
        features: Mapping[str, common_types.TensorType],
        tensor_ranges: Mapping[str, Tuple[int, int]],
        feature_annotations: Mapping[str, List[any_pb2.Any]],
        global_annotations: List[any_pb2.Any],
        is_evaluation_complete: bool) -> schema_pb2.Schema:
    """Given a dict of tensors, creates a `Schema`.

  Args:
    features: A dict mapping column names to `Tensor`, `SparseTensor` or
      `RaggedTensor`. The `Tensor`, `SparseTensor` or `RaggedTensor` should have
      a 0'th dimension which is interpreted as the batch dimension.
    tensor_ranges: A dict mapping a tensor to a tuple containing its min and max
      value.
    feature_annotations: dictionary from feature name to list of any_pb2.Any
      protos to be added as an annotation for that feature in the schema.
    global_annotations: list of any_pb2.Any protos to be added at the global
      schema level.
    is_evaluation_complete: A boolean indicating whether all analyzers have been
      evaluated or not.

  Returns:
    A `Schema` proto.
  """
    domains = {}
    feature_tags = collections.defaultdict(list)
    for name, tensor in features.items():
        if (isinstance(tensor, tf.RaggedTensor)
                and not common_types.is_ragged_feature_available()):
            # Add the 'ragged_tensor' tag which will cause coder and
            # schema_as_feature_spec to raise an error, as there is no feature spec
            # for ragged tensors in TF 1.x.
            feature_tags[name].append(schema_utils.RAGGED_TENSOR_TAG)
        if name in tensor_ranges:
            min_value, max_value = tensor_ranges[name]
            domains[name] = schema_pb2.IntDomain(min=min_value,
                                                 max=max_value,
                                                 is_categorical=True)
    feature_spec = _feature_spec_from_batched_tensors(features,
                                                      is_evaluation_complete)

    schema_proto = schema_utils.schema_from_feature_spec(feature_spec, domains)

    # Add the annotations to the schema.
    for annotation in global_annotations:
        schema_proto.annotation.extra_metadata.add().CopyFrom(annotation)
    # Build a map from logical feature names to Feature protos
    feature_protos_by_name = {}
    for feature in schema_proto.feature:
        feature_protos_by_name[feature.name] = feature
    for sparse_feature in schema_proto.sparse_feature:
        for index_feature in sparse_feature.index_feature:
            feature_protos_by_name.pop(index_feature.name)
        value_feature = feature_protos_by_name.pop(
            sparse_feature.value_feature.name)
        feature_protos_by_name[sparse_feature.name] = value_feature

    # Handle ragged tensor representations.
    tensor_representations = (
        tensor_representation_util.GetTensorRepresentationsFromSchema(
            schema_proto, schema_utils.TENSOR_REPRESENTATION_GROUP))
    if tensor_representations is not None:
        for name, tensor_representation in tensor_representations.items():
            feature_protos_by_name[
                name] = schema_utils.pop_ragged_source_columns(
                    name, tensor_representation, feature_protos_by_name)

    # Update annotations
    for feature_name, annotations in feature_annotations.items():
        feature_proto = feature_protos_by_name[feature_name]
        for annotation in annotations:
            feature_proto.annotation.extra_metadata.add().CopyFrom(annotation)
    for feature_name, tags in feature_tags.items():
        feature_proto = feature_protos_by_name[feature_name]
        for tag in tags:
            feature_proto.annotation.tag.append(tag)
    return schema_proto
def infer_feature_schema(features, graph, session=None):
    """Given a dict of tensors, creates a `Schema`.

  Infers a schema, in the format of a tf.Transform `Schema`, for the given
  dictionary of tensors.

  If there is an override specified, we override the inferred schema for the
  given feature's tensor.  An override has the meaning that we should set
  is_categorical=True.  If session is not provided then we just set
  is_categorical=True, and if the session is provided then was also compute
  values of the tensors representing the min and max values and set them in the
  schema.

  If annotations have been specified, they are added to the output schema.

  Args:
    features: A dict mapping column names to `Tensor` or `SparseTensor`s. The
      `Tensor` or `SparseTensor`s should have a 0'th dimension which is
      interpreted as the batch dimension.
    graph: A `tf.Graph` used to determine schema overrides.
    session: (optional) A `tf.Session` used to compute schema overrides.  If
      None, schema overrides will not be computed.

  Returns:
    A `Schema` proto.
  """
    tensor_ranges = _get_tensor_schema_overrides(graph)
    if session is None:
        tensor_ranges = {hashable: (None, None) for hashable in tensor_ranges}
        tensor_annotations = {}
        global_annotations = []
    else:
        tensor_ranges = session.run(tensor_ranges)
        tensor_annotations, global_annotations = _get_schema_annotations(
            graph, session)

    domains = {}
    feature_annotations = {}
    feature_tags = collections.defaultdict(list)
    for name, tensor in six.iteritems(features):
        if isinstance(tensor, tf.SparseTensor):
            values = tensor.values
        elif isinstance(tensor, tf.RaggedTensor):
            values = tensor.flat_values
            # Add the 'ragged_tensor' tag which will cause coder and
            # schema_as_feature_spec to raise an error, as currently there is no
            # feature spec for ragged tensors.
            feature_tags[name].append(schema_utils.RAGGED_TENSOR_TAG)
        else:
            values = tensor
        values = tensor.values if isinstance(tensor,
                                             tf.SparseTensor) else tensor
        hashable_values = tf_utils.hashable_tensor_or_op(values)
        if hashable_values in tensor_ranges:
            assert values.dtype == tf.int64
            min_value, max_value = tensor_ranges[hashable_values]
            domains[name] = schema_pb2.IntDomain(min=min_value,
                                                 max=max_value,
                                                 is_categorical=True)
        # tensor_annotations is a defaultdict(list) so always returns a list.
        feature_annotations[name] = tensor_annotations.get(hashable_values, [])
    feature_spec = _feature_spec_from_batched_tensors(features)

    schema_proto = schema_utils.schema_from_feature_spec(feature_spec, domains)

    # Add the annotations to the schema.
    for annotation in global_annotations:
        schema_proto.annotation.extra_metadata.add().CopyFrom(annotation)
    # Build a map from logical feature names to Feature protos
    feature_protos_by_name = {}
    for feature in schema_proto.feature:
        feature_protos_by_name[feature.name] = feature
    for sparse_feature in schema_proto.sparse_feature:
        for index_feature in sparse_feature.index_feature:
            feature_protos_by_name.pop(index_feature.name)
        value_feature = feature_protos_by_name.pop(
            sparse_feature.value_feature.name)
        feature_protos_by_name[sparse_feature.name] = value_feature
    # Update annotations
    for feature_name, annotations in feature_annotations.items():
        feature_proto = feature_protos_by_name[feature_name]
        for annotation in annotations:
            feature_proto.annotation.extra_metadata.add().CopyFrom(annotation)
    for feature_name, tags in feature_tags.items():
        feature_proto = feature_protos_by_name[feature_name]
        for tag in tags:
            feature_proto.annotation.tag.append(tag)
    return schema_proto
      input_metadata=tft.DatasetMetadata.from_feature_spec({
          'x':
          tf.io.FixedLenFeature([], tf.float32),
          'key':
          tf.io.FixedLenFeature([], tf.string)
      }),
      expected_data=[{
          'x_bucketized':
          _compute_simple_per_key_bucket(x, 'a' if x < 50 else 'b')
      } for x in range(1, 100)],
      expected_metadata=tft.DatasetMetadata.from_feature_spec(
          {
              'x_bucketized': tf.io.FixedLenFeature([], tf.int64),
          }, {
              'x_bucketized':
              schema_pb2.IntDomain(min=0, max=2, is_categorical=True),
          })),
 dict(testcase_name='sparse',
      input_data=[{
          'x': [x],
          'idx0': [0],
          'idx1': [0],
          'key': ['a'] if x < 50 else ['b']
      } for x in range(1, 100)],
      input_metadata=tft.DatasetMetadata.from_feature_spec({
          'x':
          tf.io.SparseFeature(['idx0', 'idx1'], 'x', tf.float32, (2, 2)),
          'key':
          tf.io.VarLenFeature(tf.string)
      }),
      expected_data=[{
class SchemaInferenceTest(test_case.TransformTestCase):
    def _get_schema(self,
                    preprocessing_fn,
                    use_compat_v1,
                    inputs=None,
                    input_signature=None,
                    create_session=False):
        if inputs is None:
            inputs = {}
        if input_signature is None:
            input_signature = {}
        if use_compat_v1:
            with tf.compat.v1.Graph().as_default() as graph:
                # Convert eager tensors to graph tensors.
                inputs_copy = {
                    k: tf.constant(v, input_signature[k].dtype)
                    for k, v in inputs.items()
                }
                tensors = preprocessing_fn(inputs_copy)
                if create_session:
                    # Create a session to actually evaluate the annotations and extract
                    # the output schema with annotations applied.
                    with tf.compat.v1.Session(graph=graph) as session:
                        schema = schema_inference.infer_feature_schema(
                            tensors, graph, session)
                else:
                    schema = schema_inference.infer_feature_schema(
                        tensors, graph)
        else:
            tf_func = tf.function(preprocessing_fn,
                                  input_signature=[input_signature
                                                   ]).get_concrete_function()
            tensors = tf.nest.pack_sequence_as(
                structure=tf_func.structured_outputs,
                flat_sequence=tf_func.outputs,
                expand_composites=True)
            metadata_fn = schema_inference.get_traced_metadata_fn(
                tensor_replacement_map={},
                preprocessing_fn=preprocessing_fn,
                input_signature=input_signature,
                base_temp_dir=os.path.join(self.get_temp_dir(),
                                           self._testMethodName),
                evaluate_schema_overrides=create_session)
            schema = schema_inference.infer_feature_schema_v2(
                tensors,
                metadata_fn.get_concrete_function(),
                evaluate_schema_overrides=create_session)
        return schema

    # pylint: disable=g-long-lambda
    @test_case.named_parameters(*test_case.cross_named_parameters([
        dict(testcase_name='fixed_len_int',
             make_tensors_fn=_make_tensors,
             feature_spec={'x': tf.io.FixedLenFeature([], tf.int64)}),
        dict(testcase_name='fixed_len_string',
             make_tensors_fn=_make_tensors,
             feature_spec={'x': tf.io.FixedLenFeature([], tf.string)}),
        dict(testcase_name='fixed_len_float',
             make_tensors_fn=_make_tensors,
             feature_spec={'x': tf.io.FixedLenFeature([], tf.float32)}),
        dict(testcase_name='override',
             make_tensors_fn=_make_tensors_with_override,
             feature_spec={'x': tf.io.FixedLenFeature([], tf.int64)},
             domains={'x': schema_pb2.IntDomain(is_categorical=True)}),
        dict(testcase_name='override_with_session',
             make_tensors_fn=_make_tensors_with_override,
             feature_spec={'x': tf.io.FixedLenFeature([], tf.int64)},
             domains={
                 'x': schema_pb2.IntDomain(min=5, max=6, is_categorical=True)
             },
             create_session=True)
    ], [
        dict(testcase_name='compat_v1', use_compat_v1=True),
        dict(testcase_name='v2', use_compat_v1=False)
    ]))
    # pylint: enable=g-long-lambda
    def test_infer_feature_schema(self,
                                  make_tensors_fn,
                                  feature_spec,
                                  use_compat_v1,
                                  domains=None,
                                  create_session=False):
        if not use_compat_v1:
            test_case.skip_if_not_tf2('Tensorflow 2.x required')
        x_val = '0' if feature_spec['x'].dtype == tf.string else 0
        inputs = {'x': [x_val]}
        input_signature = {
            'x': tf.TensorSpec([None], dtype=feature_spec['x'].dtype)
        }
        schema = self._get_schema(make_tensors_fn,
                                  use_compat_v1,
                                  inputs=inputs,
                                  input_signature=input_signature,
                                  create_session=create_session)
        expected_schema = schema_utils.schema_from_feature_spec(
            feature_spec, domains)
        self.assertEqual(schema, expected_schema)

    @test_case.named_parameters(
        dict(testcase_name='compat_v1', use_compat_v1=True),
        dict(testcase_name='v2', use_compat_v1=False))
    def test_infer_feature_schema_bad_rank(self, use_compat_v1):
        if not use_compat_v1:
            test_case.skip_if_not_tf2('Tensorflow 2.x required')
        inputs = {'x': 0}
        input_signature = {'x': tf.TensorSpec([], dtype=tf.float32)}
        with self.assertRaises(ValueError):
            self._get_schema(_make_tensors,
                             use_compat_v1,
                             inputs=inputs,
                             input_signature=input_signature)

    @unittest.skipIf(not common.IS_ANNOTATIONS_PB_AVAILABLE,
                     'Schema annotations are not available')
    @test_case.named_parameters(
        dict(testcase_name='compat_v1', use_compat_v1=True),
        dict(testcase_name='v2', use_compat_v1=False))
    def test_vocab_annotation(self, use_compat_v1):
        if not use_compat_v1:
            test_case.skip_if_not_tf2('Tensorflow 2.x required')

        def preprocessing_fn(_):
            analyzers._maybe_annotate_vocab_metadata(
                'file1', tf.constant(100, dtype=tf.int64))
            analyzers._maybe_annotate_vocab_metadata(
                'file2', tf.constant(200, dtype=tf.int64))
            return {
                'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64),
            }

        schema = self._get_schema(preprocessing_fn,
                                  use_compat_v1,
                                  create_session=True)
        self.assertLen(schema.annotation.extra_metadata, 2)
        sizes = {}
        for annotation in schema.annotation.extra_metadata:
            message = annotations_pb2.VocabularyMetadata()
            annotation.Unpack(message)
            sizes[message.file_name] = message.unfiltered_vocabulary_size
        self.assertDictEqual(sizes, {'file1': 100, 'file2': 200})

    @unittest.skipIf(not common.IS_ANNOTATIONS_PB_AVAILABLE,
                     'Schema annotations are not available')
    @test_case.named_parameters(
        dict(testcase_name='compat_v1', use_compat_v1=True),
        dict(testcase_name='v2', use_compat_v1=False))
    def test_bucketization_annotation(self, use_compat_v1):
        if not use_compat_v1:
            test_case.skip_if_not_tf2('Tensorflow 2.x required')

        def preprocessing_fn(_):
            inputs = {
                'foo': tf.convert_to_tensor([0, 1, 2, 3]),
                'bar': tf.convert_to_tensor([0, 2, 0, 2]),
            }
            boundaries_foo = tf.expand_dims(tf.convert_to_tensor([.5, 1.5]),
                                            axis=0)
            boundaries_bar = tf.expand_dims(tf.convert_to_tensor([.1, .2]),
                                            axis=0)
            outputs = {}
            # tft.apply_buckets will annotate the feature in the output schema to
            # indicate the bucket boundaries that were applied.
            outputs['Bucketized_foo'] = mappers.apply_buckets(
                inputs['foo'], boundaries_foo)
            outputs['Bucketized_bar'] = mappers.apply_buckets(
                inputs['bar'], boundaries_bar)
            return outputs

        schema = self._get_schema(preprocessing_fn,
                                  use_compat_v1,
                                  create_session=True)
        self.assertLen(schema.feature, 2)
        for feature in schema.feature:
            self.assertLen(feature.annotation.extra_metadata, 1)
            for annotation in feature.annotation.extra_metadata:

                # Extract the annotated message and validate its contents
                message = annotations_pb2.BucketBoundaries()
                annotation.Unpack(message)
                if feature.name == 'Bucketized_foo':
                    self.assertAllClose(list(message.boundaries), [.5, 1.5])
                elif feature.name == 'Bucketized_bar':
                    self.assertAllClose(list(message.boundaries), [.1, .2])
                else:
                    raise RuntimeError('Unexpected features in schema')

    @unittest.skipIf(not common.IS_ANNOTATIONS_PB_AVAILABLE,
                     'Schema annotations are not available')
    @test_case.named_parameters(
        dict(testcase_name='compat_v1', use_compat_v1=True),
        dict(testcase_name='v2', use_compat_v1=False))
    def test_global_annotation(self, use_compat_v1):
        # pylint: enable=g-import-not-at-top
        if not use_compat_v1:
            test_case.skip_if_not_tf2('Tensorflow 2.x required')

        def preprocessing_fn(_):
            # Annotate an arbitrary proto at the schema level (not sure what global
            # schema boundaries would mean, but hey I'm just a test).
            boundaries = tf.constant([[1.0]])
            message_type = annotations_pb2.BucketBoundaries.DESCRIPTOR.full_name
            sizes = tf.expand_dims([tf.size(boundaries)], axis=0)
            message_proto = tf.raw_ops.EncodeProto(
                sizes=sizes,
                values=[tf.cast(boundaries, tf.float32)],
                field_names=['boundaries'],
                message_type=message_type)[0]
            type_url = os.path.join('type.googleapis.com', message_type)
            schema_inference.annotate(type_url, message_proto)
            return {
                'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64),
                'bar': tf.convert_to_tensor([0, 2, 0, 2], dtype=tf.int64),
            }

        schema = self._get_schema(preprocessing_fn,
                                  use_compat_v1,
                                  create_session=True)
        self.assertLen(schema.annotation.extra_metadata, 1)
        for annotation in schema.annotation.extra_metadata:
            # Extract the annotated message and validate its contents
            message = annotations_pb2.BucketBoundaries()
            annotation.Unpack(message)
            self.assertAllClose(list(message.boundaries), [1])

    @test_case.named_parameters(
        dict(testcase_name='compat_v1', use_compat_v1=True),
        dict(testcase_name='v2', use_compat_v1=False))
    def test_infer_feature_schema_with_ragged_tensor(self, use_compat_v1):
        if not use_compat_v1:
            test_case.skip_if_not_tf2('Tensorflow 2.x required')

        def preprocessing_fn(_):
            return {
                'foo':
                tf.RaggedTensor.from_row_splits(values=tf.constant(
                    [3, 1, 4, 1, 5, 9, 2, 6], tf.int64),
                                                row_splits=[0, 4, 4, 7, 8, 8]),
            }

        schema = self._get_schema(preprocessing_fn,
                                  use_compat_v1,
                                  create_session=True)
        expected_schema_ascii = """feature {
name: "foo"
type: INT
annotation {
tag: "ragged_tensor"
}
}
"""
        expected_schema = text_format.Parse(expected_schema_ascii,
                                            schema_pb2.Schema())
        schema_utils_legacy.set_generate_legacy_feature_spec(
            expected_schema, False)
        self.assertProtoEquals(expected_schema, schema)
        with self.assertRaisesRegexp(ValueError,
                                     'Feature "foo" had tag "ragged_tensor"'):
            schema_utils.schema_as_feature_spec(schema)
    def testBucketizePerKeyWithInfrequentKeys(self):
        def preprocessing_fn(inputs):
            x_bucketized = tft.bucketize_per_key(inputs['x'],
                                                 inputs['key'],
                                                 num_buckets=4,
                                                 epsilon=0.00001)
            return {'x': inputs['x'], 'x_bucketized': x_bucketized}

        input_data = [{
            'x': [],
            'key': []
        }, {
            'x': [5, 6],
            'key': ['a', 'a']
        }, {
            'x': [7],
            'key': ['a']
        }, {
            'x': [12],
            'key': ['b']
        }, {
            'x': [13],
            'key': ['b']
        }, {
            'x': [15],
            'key': ['c']
        }, {
            'x': [2],
            'key': ['d']
        }, {
            'x': [4],
            'key': ['d']
        }, {
            'x': [6],
            'key': ['d']
        }, {
            'x': [8],
            'key': ['d']
        }, {
            'x': [2],
            'key': ['e']
        }, {
            'x': [4],
            'key': ['e']
        }, {
            'x': [6],
            'key': ['e']
        }, {
            'x': [8],
            'key': ['e']
        }, {
            'x': [10],
            'key': ['e']
        }, {
            'x': [11],
            'key': ['e']
        }, {
            'x': [12],
            'key': ['e']
        }, {
            'x': [13],
            'key': ['e']
        }]  # pyformat: disable
        input_metadata = tft_unit.metadata_from_feature_spec({
            'x':
            tf.io.VarLenFeature(tf.float32),
            'key':
            tf.io.VarLenFeature(tf.string)
        })
        expected_data = [{
            'x': [],
            'x_bucketized': []
        }, {
            'x': [5, 6],
            'x_bucketized': [1, 2]
        }, {
            'x': [7],
            'x_bucketized': [3]
        }, {
            'x': [12],
            'x_bucketized': [1]
        }, {
            'x': [13],
            'x_bucketized': [3]
        }, {
            'x': [15],
            'x_bucketized': [1]
        }, {
            'x': [2],
            'x_bucketized': [0]
        }, {
            'x': [4],
            'x_bucketized': [1]
        }, {
            'x': [6],
            'x_bucketized': [2]
        }, {
            'x': [8],
            'x_bucketized': [3]
        }, {
            'x': [2],
            'x_bucketized': [0]
        }, {
            'x': [4],
            'x_bucketized': [0]
        }, {
            'x': [6],
            'x_bucketized': [1]
        }, {
            'x': [8],
            'x_bucketized': [1]
        }, {
            'x': [10],
            'x_bucketized': [2]
        }, {
            'x': [11],
            'x_bucketized': [2]
        }, {
            'x': [12],
            'x_bucketized': [3]
        }, {
            'x': [13],
            'x_bucketized': [2]
        }]  # pyformat: disable
        expected_metadata = tft_unit.metadata_from_feature_spec(
            {
                'x': tf.io.VarLenFeature(tf.float32),
                'x_bucketized': tf.io.VarLenFeature(tf.int64),
            }, {
                'x_bucketized':
                schema_pb2.IntDomain(min=0, max=3, is_categorical=True),
            })
        self.assertAnalyzeAndTransformResults(input_data,
                                              input_metadata,
                                              preprocessing_fn,
                                              expected_data,
                                              expected_metadata,
                                              desired_batch_size=10)
    def testBucketization(self, test_inputs, expected_boundaries, do_shuffle,
                          epsilon, should_apply, is_manual_boundaries,
                          input_dtype):
        test_inputs = list(test_inputs)

        # Shuffle the input to add randomness to input generated with
        # simple range().
        if do_shuffle:
            random.shuffle(test_inputs)

        def preprocessing_fn(inputs):
            x = tf.cast(inputs['x'], input_dtype)
            num_buckets = len(expected_boundaries) + 1
            if should_apply:
                if is_manual_boundaries:
                    bucket_boundaries = expected_boundaries
                else:
                    bucket_boundaries = tft.quantiles(inputs['x'], num_buckets,
                                                      epsilon)
                result = tft.apply_buckets(x, bucket_boundaries)
            else:
                result = tft.bucketize(x,
                                       num_buckets=num_buckets,
                                       epsilon=epsilon)
            return {'q_b': result}

        input_data = [{'x': [x]} for x in test_inputs]

        input_metadata = tft_unit.metadata_from_feature_spec({
            'x':
            tf.io.FixedLenFeature(
                [1], tft_unit.canonical_numeric_dtype(input_dtype))
        })

        # Sort the input based on value, index is used to create expected_data.
        indexed_input = enumerate(test_inputs)

        sorted_list = sorted(indexed_input, key=lambda p: p[1])

        # Expected data has the same size as input, one bucket per input value.
        expected_data = [None] * len(test_inputs)
        bucket = 0
        for (index, x) in sorted_list:
            # Increment the bucket number when crossing the boundary
            if (bucket < len(expected_boundaries)
                    and x >= expected_boundaries[bucket]):
                bucket += 1
            expected_data[index] = {'q_b': [bucket]}

        expected_metadata = tft_unit.metadata_from_feature_spec(
            {
                'q_b': tf.io.FixedLenFeature([1], tf.int64),
            }, {
                'q_b':
                schema_pb2.IntDomain(
                    min=0, max=len(expected_boundaries), is_categorical=True),
            })

        @contextlib.contextmanager
        def no_assert():
            yield None

        assertion = no_assert()
        if input_dtype == tf.float16:
            assertion = self.assertRaisesRegexp(
                TypeError,
                '.*DataType float16 not in list of allowed values.*')

        with assertion:
            self.assertAnalyzeAndTransformResults(
                input_data,
                input_metadata,
                preprocessing_fn,
                expected_data,
                expected_metadata,
                desired_batch_size=1000,
                # TODO(b/110855155): Remove this explicit use of DirectRunner.
                beam_pipeline=beam.Pipeline())
Beispiel #18
0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test metadata for tft_beam_io tests."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow_transform.tf_metadata import dataset_metadata
from tensorflow_transform.tf_metadata import schema_utils

from tensorflow_metadata.proto.v0 import schema_pb2

_FEATURE_SPEC = {
    'fixed_column': tf.io.FixedLenFeature([3], tf.string),
    'list_columm': tf.io.VarLenFeature(tf.int64),
}

COMPLETE_METADATA = dataset_metadata.DatasetMetadata(
    schema_utils.schema_from_feature_spec(
        _FEATURE_SPEC,
        domains={'list_columm': schema_pb2.IntDomain(min=-1, max=5)}))

INCOMPLETE_METADATA = dataset_metadata.DatasetMetadata(
    schema_utils.schema_from_feature_spec(
        _FEATURE_SPEC,
        # Values will be overridden by those in COMPLETE_METADATA
        domains={'list_columm': schema_pb2.IntDomain(min=0, max=0)}))
Beispiel #19
0
    def testBucketizePerKeySparse(self):
        def preprocessing_fn(inputs):
            x_bucketized = tft.bucketize_per_key(inputs['x'],
                                                 inputs['key'],
                                                 num_buckets=3,
                                                 epsilon=0.00001)
            return {'x_bucketized': x_bucketized}

        # NOTE: We force 10 batches: data has 100 elements and we request a batch
        # size of 10.
        input_data = [{
            'x': [x],
            'idx0': [0],
            'idx1': [0],
            'key': ['a'] if x < 50 else ['b']
        } for x in range(1, 100)]
        input_metadata = tft_unit.metadata_from_feature_spec({
            'x':
            tf.io.SparseFeature(['idx0', 'idx1'], 'x', tf.float32, (2, 2)),
            'key':
            tf.io.VarLenFeature(tf.string)
        })

        def compute_bucket(instance):
            if instance['key'][0] == 'a':
                if instance['x'][0] < 17:
                    return 0
                elif instance['x'][0] < 33:
                    return 1
                else:
                    return 2
            else:
                if instance['x'][0] < 66:
                    return 0
                elif instance['x'][0] < 83:
                    return 1
                else:
                    return 2

        expected_data = [{
            'x_bucketized$sparse_values': [compute_bucket(instance)],
            'x_bucketized$sparse_indices_0': [0],
            'x_bucketized$sparse_indices_1': [0],
        } for instance in input_data]
        expected_metadata = tft_unit.metadata_from_feature_spec(
            {
                'x_bucketized':
                tf.io.SparseFeature([
                    'x_bucketized$sparse_indices_0',
                    'x_bucketized$sparse_indices_1'
                ],
                                    'x_bucketized$sparse_values',
                                    tf.int64, (None, None),
                                    already_sorted=True),
            }, {
                'x_bucketized':
                schema_pb2.IntDomain(min=0, max=2, is_categorical=True),
            })
        self.assertAnalyzeAndTransformResults(input_data,
                                              input_metadata,
                                              preprocessing_fn,
                                              expected_data,
                                              expected_metadata,
                                              desired_batch_size=10)
Beispiel #20
0
from absl.testing import parameterized
from tensorflow_data_validation.utils import schema_util
from google.protobuf import text_format
from tensorflow_metadata.proto.v0 import schema_pb2

FLAGS = flags.FLAGS

SET_DOMAIN_VALID_TESTS = [{
    'testcase_name':
    'int_domain',
    'input_schema_proto_text':
    '''feature { name: 'x' }''',
    'feature_name':
    'x',
    'domain':
    schema_pb2.IntDomain(min=1, max=5),
    'output_schema_proto_text':
    '''
          feature { name: 'x' int_domain { min: 1 max: 5 } }'''
}, {
    'testcase_name':
    'float_domain',
    'input_schema_proto_text':
    '''feature { name: 'x' }''',
    'feature_name':
    'x',
    'domain':
    schema_pb2.FloatDomain(min=1.1, max=5.1),
    'output_schema_proto_text':
    '''
          feature { name: 'x' float_domain { min: 1.1 max: 5.1 } }'''
Beispiel #21
0
 def test_set_domain_invalid_schema(self):
     with self.assertRaisesRegexp(TypeError, 'should be a Schema proto'):
         schema_util.set_domain({}, 'feature', schema_pb2.IntDomain())
     }
 },
 # Test domains
 {
     'testcase_name': 'int_domain',
     'ascii_proto': """
       feature: {
         name: "x" type: INT
         int_domain {min: 0 max: 5 is_categorical: true}
       }
     """,
     'feature_spec': {
         'x': tf.io.VarLenFeature(tf.int64)
     },
     'domains': {
         'x': schema_pb2.IntDomain(min=0, max=5, is_categorical=True)
     }
 },
 {
     'testcase_name': 'string_domain',
     'ascii_proto': """
       feature: {
         name: "x" type: BYTES
         string_domain {value: "a" value: "b"}
       }
     """,
     'feature_spec': {
         'x': tf.io.VarLenFeature(tf.string)
     },
     'domains': {
         'x': schema_pb2.StringDomain(value=['a', 'b'])
class SchemaInferenceTest(test_case.TransformTestCase):

    # pylint: disable=g-long-lambda
    @test_case.named_parameters(
        dict(testcase_name='fixed_len_int',
             make_tensors_fn=lambda:
             {'x': tf.compat.v1.placeholder(tf.int64, (None, ))},
             feature_spec={'x': tf.io.FixedLenFeature([], tf.int64)}),
        dict(testcase_name='fixed_len_string',
             make_tensors_fn=lambda:
             {'x': tf.compat.v1.placeholder(tf.string, (None, ))},
             feature_spec={'x': tf.io.FixedLenFeature([], tf.string)}),
        dict(testcase_name='fixed_len_float',
             make_tensors_fn=lambda:
             {'x': tf.compat.v1.placeholder(tf.float32, (None, ))},
             feature_spec={'x': tf.io.FixedLenFeature([], tf.float32)}),
        dict(testcase_name='override',
             make_tensors_fn=_make_tensors_with_override,
             feature_spec={'x': tf.io.FixedLenFeature([], tf.int64)},
             domains={'x': schema_pb2.IntDomain(is_categorical=True)}),
        dict(testcase_name='override_with_session',
             make_tensors_fn=_make_tensors_with_override,
             feature_spec={'x': tf.io.FixedLenFeature([], tf.int64)},
             domains={
                 'x': schema_pb2.IntDomain(min=5, max=6, is_categorical=True)
             },
             create_session=True))
    # pylint: enable=g-long-lambda
    def test_infer_feature_schema(self,
                                  make_tensors_fn,
                                  feature_spec,
                                  domains=None,
                                  create_session=False):
        with tf.compat.v1.Graph().as_default() as graph:
            tensors = make_tensors_fn()

        if create_session:
            with tf.compat.v1.Session(graph=graph) as session:
                schema = schema_inference.infer_feature_schema(
                    tensors, graph, session)
        else:
            schema = schema_inference.infer_feature_schema(tensors, graph)

        expected_schema = schema_utils.schema_from_feature_spec(
            feature_spec, domains)
        self.assertEqual(schema, expected_schema)

    def test_infer_feature_schema_bad_rank(self):
        with tf.compat.v1.Graph().as_default() as graph:
            tensors = {
                'a': tf.compat.v1.placeholder(tf.float32, ()),
            }
        with self.assertRaises(ValueError):
            schema_inference.infer_feature_schema(tensors, graph)

    @unittest.skipIf(not common.IS_ANNOTATIONS_PB_AVAILABLE,
                     'Schema annotations are not available')
    def test_vocab_annotation(self):
        with tf.compat.v1.Graph().as_default() as graph:
            tensors = {
                'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64),
            }
            analyzers._maybe_annotate_vocab_metadata(
                'file1', tf.constant(100, dtype=tf.int64))
            analyzers._maybe_annotate_vocab_metadata(
                'file2', tf.constant(200, dtype=tf.int64))
            # Create a session to actually evaluate the annotations and extract the
            # the output schema with annotations applied.
            with tf.compat.v1.Session(graph=graph) as session:
                schema = schema_inference.infer_feature_schema(
                    tensors, graph, session)
                self.assertLen(schema.annotation.extra_metadata, 2)
                sizes = {}
                for annotation in schema.annotation.extra_metadata:
                    message = annotations_pb2.VocabularyMetadata()
                    annotation.Unpack(message)
                    sizes[
                        message.file_name] = message.unfiltered_vocabulary_size
                self.assertDictEqual(sizes, {'file1': 100, 'file2': 200})

    @unittest.skipIf(not common.IS_ANNOTATIONS_PB_AVAILABLE,
                     'Schema annotations are not available')
    def test_bucketization_annotation(self):
        with tf.compat.v1.Graph().as_default() as graph:
            inputs = {
                'foo': tf.convert_to_tensor([0, 1, 2, 3]),
                'bar': tf.convert_to_tensor([0, 2, 0, 2]),
            }
            boundaries_foo = tf.expand_dims(tf.convert_to_tensor([.5, 1.5]),
                                            axis=0)
            boundaries_bar = tf.expand_dims(tf.convert_to_tensor([.1, .2]),
                                            axis=0)
            outputs = {}

            # tft.apply_buckets will annotate the feature in the output schema to
            # indicate the bucket boundaries that were applied.
            outputs['Bucketized_foo'] = mappers.apply_buckets(
                inputs['foo'], boundaries_foo)
            outputs['Bucketized_bar'] = mappers.apply_buckets(
                inputs['bar'], boundaries_bar)
            # Create a session to actually evaluate the annotations and extract the
            # the output schema with annotations applied.
            with tf.compat.v1.Session(graph=graph) as session:
                schema = schema_inference.infer_feature_schema(
                    outputs, graph, session)
                self.assertLen(schema.feature, 2)
                for feature in schema.feature:
                    self.assertLen(feature.annotation.extra_metadata, 1)
                    for annotation in feature.annotation.extra_metadata:

                        # Extract the annotated message and validate its contents
                        message = annotations_pb2.BucketBoundaries()
                        annotation.Unpack(message)
                        if feature.name == 'Bucketized_foo':
                            self.assertAllClose(list(message.boundaries),
                                                [.5, 1.5])
                        elif feature.name == 'Bucketized_bar':
                            self.assertAllClose(list(message.boundaries),
                                                [.1, .2])
                        else:
                            raise RuntimeError('Unexpected features in schema')

    @unittest.skipIf(not common.IS_ANNOTATIONS_PB_AVAILABLE,
                     'Schema annotations are not available')
    def test_global_annotation(self):
        # pylint: enable=g-import-not-at-top
        with tf.compat.v1.Graph().as_default() as graph:
            outputs = {
                'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64),
                'bar': tf.convert_to_tensor([0, 2, 0, 2], dtype=tf.int64),
            }

            # Annotate an arbitrary proto at the schema level (not sure what global
            # schema boundaries would mean, but hey I'm just a test).
            boundaries = tf.constant([[1.0]])
            message_type = annotations_pb2.BucketBoundaries.DESCRIPTOR.full_name
            sizes = tf.expand_dims([tf.size(boundaries)], axis=0)
            message_proto = tf.raw_ops.EncodeProto(
                sizes=sizes,
                values=[tf.cast(boundaries, tf.float32)],
                field_names=['boundaries'],
                message_type=message_type)[0]
            type_url = os.path.join('type.googleapis.com', message_type)
            schema_inference.annotate(type_url, message_proto)

            with tf.compat.v1.Session(graph=graph) as session:
                schema = schema_inference.infer_feature_schema(
                    outputs, graph, session)
                self.assertLen(schema.annotation.extra_metadata, 1)
                for annotation in schema.annotation.extra_metadata:
                    # Extract the annotated message and validate its contents
                    message = annotations_pb2.BucketBoundaries()
                    annotation.Unpack(message)
                    self.assertAllClose(list(message.boundaries), [1])

    def test_infer_feature_schema_with_ragged_tensor(self):
        with tf.compat.v1.Graph().as_default() as graph:
            outputs = {
                'foo':
                tf.RaggedTensor.from_row_splits(values=tf.constant(
                    [3, 1, 4, 1, 5, 9, 2, 6], tf.int64),
                                                row_splits=[0, 4, 4, 7, 8, 8]),
            }
            with tf.compat.v1.Session(graph=graph) as session:
                schema = schema_inference.infer_feature_schema(
                    outputs, graph, session)
                expected_schema_ascii = """feature {
  name: "foo"
  type: INT
  annotation {
    tag: "ragged_tensor"
  }
}
"""
                expected_schema = text_format.Parse(expected_schema_ascii,
                                                    schema_pb2.Schema())
                schema_utils_legacy.set_generate_legacy_feature_spec(
                    expected_schema, False)
                self.assertProtoEquals(expected_schema, schema)
                with self.assertRaisesRegexp(
                        ValueError, 'Feature "foo" had tag "ragged_tensor"'):
                    schema_utils.schema_as_feature_spec(schema)