예제 #1
0
 def TensorRepresentations(self) -> tensor_adapter.TensorRepresentations:
     result = tensor_rep_util.GetTensorRepresentationsFromSchema(
         self._schema)
     if result is None:
         result = (tensor_rep_util.InferTensorRepresentationsFromSchema(
             self._schema))
     return result
예제 #2
0
 def TensorRepresentations(self) -> tensor_adapter.TensorRepresentations:
   result = (
       tensor_representation_util.GetTensorRepresentationsFromSchema(
           self._schema))
   if result is None:
     raise ValueError("For SequenceExample, TensorRepresentations must be "
                      "specified in the schema.")
   return result
예제 #3
0
 def testGetTensorRepresentationsFromSchema(self):
   self.assertIsNone(
       tensor_representation_util.GetTensorRepresentationsFromSchema(
           schema_pb2.Schema()))
   schema = text_format.Parse("""
     tensor_representation_group {
       key: ""
       value {
         tensor_representation {
           key: "a"
           value { }
         }
       }
     }
   """, schema_pb2.Schema())
   result = tensor_representation_util.GetTensorRepresentationsFromSchema(
       schema)
   self.assertTrue(result)
   self.assertIn('a', result)
예제 #4
0
def schema_as_feature_spec(
        schema_proto: schema_pb2.Schema) -> SchemaAsFeatureSpecResult:
    """Generates a feature spec from a Schema proto.

  For a Feature with a FixedShape we generate a FixedLenFeature with no default.
  For a Feature without a FixedShape we generate a VarLenFeature.  For a
  SparseFeature we generate a SparseFeature.

  Args:
    schema_proto: A Schema proto.

  Returns:
    A pair (feature spec, domains) where feature spec is a dict whose keys are
        feature names and values are instances of FixedLenFeature, VarLenFeature
        or SparseFeature, and `domains` is a dict whose keys are feature names
        and values are one of the `domain_info` oneof, e.g. IntDomain.

  Raises:
    ValueError: If the schema proto is invalid.
  """
    for feature in schema_proto.feature:
        if RAGGED_TENSOR_TAG in feature.annotation.tag:
            raise ValueError(
                'Feature "{}" had tag "{}".  Features represented by a '
                'RaggedTensor cannot be serialized/deserialized to Example proto or '
                'other formats, and cannot have a feature spec generated for '
                'them.'.format(feature.name, RAGGED_TENSOR_TAG))

    if schema_utils_legacy.get_generate_legacy_feature_spec(schema_proto):
        return _legacy_schema_as_feature_spec(schema_proto)
    feature_spec = {}
    # Will hold the domain_info (IntDomain, FloatDomain etc.) of the feature.  For
    # sparse features, will hold the domain_info of the values feature.  Features
    # that do not have a domain set will not be present in `domains`.
    domains = {}
    feature_by_name = {
        feature.name: feature
        for feature in schema_proto.feature
    }
    string_domains = _get_string_domains(schema_proto)

    # Generate a `tf.SparseFeature` for each element of
    # `schema_proto.sparse_feature`.  This also removed the features from
    # feature_by_name.
    # TODO(KesterTong): Allow sparse features to share index features.
    for feature in schema_proto.sparse_feature:
        if _include_in_parsing_spec(feature):
            feature_spec[feature.name], domains[feature.name] = (
                _sparse_feature_as_feature_spec(feature, feature_by_name,
                                                string_domains))

    # Handle ragged `TensorRepresentation`s.
    tensor_representations = (
        tensor_representation_util.GetTensorRepresentationsFromSchema(
            schema_proto, TENSOR_REPRESENTATION_GROUP))
    if tensor_representations is not None:
        for name, tensor_representation in tensor_representations.items():
            if name in feature_by_name:
                raise ValueError(
                    'Ragged TensorRepresentation name "{}" conflicts with a different '
                    'feature in the same schema.'.format(name))
            (feature_spec[name],
             domains[name]) = (_ragged_tensor_representation_as_feature_spec(
                 name, tensor_representation, feature_by_name, string_domains))

    # Generate a `tf.FixedLenFeature` or `tf.VarLenFeature` for each element of
    # `schema_proto.feature` that was not referenced by a `SparseFeature` or a
    # ragged `TensorRepresentation`.
    for name, feature in feature_by_name.items():
        if _include_in_parsing_spec(feature):
            feature_spec[name], domains[name] = _feature_as_feature_spec(
                feature, string_domains)

    schema_utils_legacy.check_for_unsupported_features(schema_proto)

    domains = {
        name: domain
        for name, domain in domains.items() if domain is not None
    }
    return SchemaAsFeatureSpecResult(feature_spec, domains)
예제 #5
0
def _infer_feature_schema_common(
        features: Mapping[str, common_types.TensorType],
        tensor_ranges: Mapping[str, Tuple[int, int]],
        feature_annotations: Mapping[str, List[any_pb2.Any]],
        global_annotations: List[any_pb2.Any],
        is_evaluation_complete: bool) -> schema_pb2.Schema:
    """Given a dict of tensors, creates a `Schema`.

  Args:
    features: A dict mapping column names to `Tensor`, `SparseTensor` or
      `RaggedTensor`. The `Tensor`, `SparseTensor` or `RaggedTensor` should have
      a 0'th dimension which is interpreted as the batch dimension.
    tensor_ranges: A dict mapping a tensor to a tuple containing its min and max
      value.
    feature_annotations: dictionary from feature name to list of any_pb2.Any
      protos to be added as an annotation for that feature in the schema.
    global_annotations: list of any_pb2.Any protos to be added at the global
      schema level.
    is_evaluation_complete: A boolean indicating whether all analyzers have been
      evaluated or not.

  Returns:
    A `Schema` proto.
  """
    domains = {}
    feature_tags = collections.defaultdict(list)
    for name, tensor in features.items():
        if (isinstance(tensor, tf.RaggedTensor)
                and not common_types.is_ragged_feature_available()):
            # Add the 'ragged_tensor' tag which will cause coder and
            # schema_as_feature_spec to raise an error, as there is no feature spec
            # for ragged tensors in TF 1.x.
            feature_tags[name].append(schema_utils.RAGGED_TENSOR_TAG)
        if name in tensor_ranges:
            min_value, max_value = tensor_ranges[name]
            domains[name] = schema_pb2.IntDomain(min=min_value,
                                                 max=max_value,
                                                 is_categorical=True)
    feature_spec = _feature_spec_from_batched_tensors(features,
                                                      is_evaluation_complete)

    schema_proto = schema_utils.schema_from_feature_spec(feature_spec, domains)

    # Add the annotations to the schema.
    for annotation in global_annotations:
        schema_proto.annotation.extra_metadata.add().CopyFrom(annotation)
    # Build a map from logical feature names to Feature protos
    feature_protos_by_name = {}
    for feature in schema_proto.feature:
        feature_protos_by_name[feature.name] = feature
    for sparse_feature in schema_proto.sparse_feature:
        for index_feature in sparse_feature.index_feature:
            feature_protos_by_name.pop(index_feature.name)
        value_feature = feature_protos_by_name.pop(
            sparse_feature.value_feature.name)
        feature_protos_by_name[sparse_feature.name] = value_feature

    # Handle ragged tensor representations.
    tensor_representations = (
        tensor_representation_util.GetTensorRepresentationsFromSchema(
            schema_proto, schema_utils.TENSOR_REPRESENTATION_GROUP))
    if tensor_representations is not None:
        for name, tensor_representation in tensor_representations.items():
            feature_protos_by_name[
                name] = schema_utils.pop_ragged_source_columns(
                    name, tensor_representation, feature_protos_by_name)

    # Update annotations
    for feature_name, annotations in feature_annotations.items():
        feature_proto = feature_protos_by_name[feature_name]
        for annotation in annotations:
            feature_proto.annotation.extra_metadata.add().CopyFrom(annotation)
    for feature_name, tags in feature_tags.items():
        feature_proto = feature_protos_by_name[feature_name]
        for tag in tags:
            feature_proto.annotation.tag.append(tag)
    return schema_proto