예제 #1
0
    def _ProjectTfmdSchema(self,
                           tensor_names: List[Text]) -> schema_pb2.Schema:
        """Projects self._schema by the given tensor names."""
        tensor_representations = self.TensorRepresentations()
        tensor_names = set(tensor_names)
        if not tensor_names.issubset(tensor_representations):
            raise ValueError(
                "Unable to project {} because they were not in the original "
                "TensorRepresentations.".format(tensor_names -
                                                tensor_representations))
        paths = set()
        for tensor_name in tensor_names:
            paths.update(
                tensor_rep_util.GetSourceColumnsFromTensorRepresentation(
                    tensor_representations[tensor_name]))
        result = schema_pb2.Schema()
        # Note: We only copy projected features into the new schema because the
        # coder, and ArrowSchema() only care about Schema.feature. If they start
        # depending on other Schema fields then those fields must also be projected.
        for f in self._schema.feature:
            if path.ColumnPath(f.name) in paths:
                result.feature.add().CopyFrom(f)

        tensor_rep_util.SetTensorRepresentationsInSchema(
            result, {
                k: v
                for k, v in tensor_representations.items() if k in tensor_names
            })

        return result
예제 #2
0
def _get_ragged_column_names(
        tensor_representation: schema_pb2.TensorRepresentation) -> List[str]:
    """Extracts source column names from a ragged tensor representation."""
    source_columns = (
        tensor_representation_util.GetSourceColumnsFromTensorRepresentation(
            tensor_representation))
    result = []
    for column in source_columns:
        if len(column.steps()) != 1:
            raise NotImplementedError(
                "Support of RaggedFeatures with multiple steps in feature_path is "
                "not implemented, got {}".format(len(column.steps())))
        result.append(column.steps()[0])
    return result
예제 #3
0
def pop_ragged_source_columns(
        name: str, tensor_representation: schema_pb2.TensorRepresentation,
        feature_by_name: Dict[str, schema_pb2.Feature]) -> schema_pb2.Feature:
    """Removes source columns of a ragged tensor from the given features dict.

  Args:
    name: Name of the ragged tensor.
    tensor_representation: Ragged TensorRepresentation.
    feature_by_name: Dict of features that contains source columns of the ragged
      TensorRepresentation.

  Returns:
    Value feature of the ragged tensor.

  Raises:
    ValueError: If any of the source columns are missing in the features dict.
  """
    source_columns = (
        tensor_representation_util.GetSourceColumnsFromTensorRepresentation(
            tensor_representation))
    missing_column_error_format = (
        'Ragged feature "{}" referred to value feature "{}" which did not exist '
        'in the schema or was referred to as an index or value multiple times.'
    )

    assert source_columns
    assert len(source_columns[0].steps()) == 1, (name,
                                                 source_columns[0].steps())
    try:
        value_feature = feature_by_name.pop(source_columns[0].steps()[0])
    except KeyError:
        raise ValueError(
            missing_column_error_format.format(name,
                                               source_columns[0].steps()[0]))
    for column_path in source_columns[1:]:
        assert len(column_path.steps()) == 1, (name, column_path.steps())
        try:
            row_length_feature = feature_by_name.pop(column_path.steps()[0])
        except KeyError:
            raise ValueError(
                missing_column_error_format.format(name,
                                                   column_path.steps()[0]))
        if row_length_feature.type != schema_pb2.FeatureType.INT:
            raise ValueError(
                'Row length feature "{}" is not an integer feature.'.format(
                    row_length_feature.name))
    return value_feature
예제 #4
0
  def _ProjectTfmdSchema(self, tensor_names: List[Text]) -> schema_pb2.Schema:
    """Projects self._schema by the given tensor names."""
    tensor_representations = self.TensorRepresentations()
    tensor_names = set(tensor_names)
    if not tensor_names.issubset(tensor_representations):
      raise ValueError(
          "Unable to project {} because they were not in the original "
          "TensorRepresentations.".format(tensor_names -
                                          tensor_representations))
    used_paths = set()
    for tensor_name in tensor_names:
      used_paths.update(
          tensor_representation_util.GetSourceColumnsFromTensorRepresentation(
              tensor_representations[tensor_name]))
    result = schema_pb2.Schema()
    # Note: We only copy projected features into the new schema because the
    # coder, and ArrowSchema() only care about Schema.feature. If they start
    # depending on other Schema fields then those fields must also be projected.
    for f in self._schema.feature:
      p = path.ColumnPath(f.name)
      if f.name == _SEQUENCE_COLUMN_NAME:
        if f.type != schema_pb2.STRUCT:
          raise ValueError(
              "Feature {} was expected to be of type STRUCT, but got {}"
              .format(f.name, f))
        result_sequence_struct = schema_pb2.Feature()
        result_sequence_struct.CopyFrom(f)
        result_sequence_struct.ClearField("struct_domain")
        any_sequence_feature_projected = False
        for sf in f.struct_domain.feature:
          sequence_feature_path = p.child(sf.name)
          if sequence_feature_path in used_paths:
            any_sequence_feature_projected = True
            result_sequence_struct.struct_domain.feature.add().CopyFrom(sf)
        if any_sequence_feature_projected:
          result.feature.add().CopyFrom(result_sequence_struct)
      elif p in used_paths:
        result.feature.add().CopyFrom(f)

    tensor_representation_util.SetTensorRepresentationsInSchema(
        result,
        {k: v for k, v in tensor_representations.items() if k in tensor_names})

    return result
예제 #5
0
    def _GetTfExampleParserConfig(
            self) -> Tuple[Dict[Text, Any], Dict[Text, Text]]:
        """Creates a dict feature spec that can be used in tf.io.parse_example().

    To reduce confusion: 'tensor name' are the keys of TensorRepresentations.
    'feature name' are the keys to the tf.Example parser config.
    'column name' are the features in the schema.

    Returns:
      Two maps. The first is the parser config that maps from feature
      name to a tf.io Feature. The second is a mapping from feature names to
      tensor names.

    Raises:
      ValueError: if the tf.Example parser config is invalid.
    """
        if self._schema is None:
            raise ValueError(
                "Unable to create a parsing config because no schema is provided."
            )

        column_name_to_type = {f.name: f.type for f in self._schema.feature}
        features = {}
        feature_name_to_tensor_name = {}
        for tensor_name, tensor_rep in self.TensorRepresentations().items():
            paths = tensor_rep_util.GetSourceColumnsFromTensorRepresentation(
                tensor_rep)
            if len(paths) == 1:
                # The parser config refers to a single tf.Example feature. In this case,
                # the key to the parser config needs to be the name of the feature.
                column_name = paths[0].initial_step()
                value_type = column_name_to_type[column_name]
            else:
                # The parser config needs to refer to multiple tf.Example features. In
                # this case the key to the parser config does not matter. We preserve
                # the tensor representation key.
                column_name = tensor_name
                value_type = column_name_to_type[
                    tensor_rep_util.
                    GetSourceValueColumnFromTensorRepresentation(
                        tensor_rep).initial_step()]
            parse_config = tensor_rep_util.CreateTfExampleParserConfig(
                tensor_rep, value_type)

            if _is_multi_column_parser_config(parse_config):
                # Create internal naming, to prevent possible naming collisions between
                # tensor_name and column_name.
                feature_name = _FEATURE_NAME_PREFIX + tensor_name + "_" + column_name
            else:
                feature_name = column_name
            if feature_name in feature_name_to_tensor_name:
                clashing_tensor_rep = self.TensorRepresentations()[
                    feature_name_to_tensor_name[feature_name]]
                raise ValueError(
                    f"Unable to create a valid parsing config. Feature "
                    f"name: {feature_name} is a duplicate of "
                    f"tensor representation: {clashing_tensor_rep}")
            feature_name_to_tensor_name[feature_name] = tensor_name
            features[feature_name] = parse_config

        _validate_tf_example_parser_config(features, self._schema)

        return features, feature_name_to_tensor_name
 def testGetSourceColumnsFromTensorRepresentation(self, pbtxt, expected):
   self.assertEqual(
       [path.ColumnPath(e) for e in expected],
       tensor_representation_util.GetSourceColumnsFromTensorRepresentation(
           text_format.Parse(pbtxt, schema_pb2.TensorRepresentation())))