def testInferFeatureSchemaWithSession(self): with tf.Graph().as_default() as graph: tensors = { 'a': tf.placeholder(tf.float32, (None, )), 'b': tf.placeholder(tf.string, (1, 2, 3)), 'c': tf.placeholder(tf.int64, (None, )) } schema_inference.set_tensor_schema_override( tensors['c'], tf.constant(5), tf.constant(6)) with tf.Session(graph=graph) as session: schema = schema_inference.infer_feature_schema( tensors, graph, session) expected_schema = dataset_schema.Schema( column_schemas={ 'a': dataset_schema.ColumnSchema( tf.float32, [], dataset_schema.FixedColumnRepresentation()), 'b': dataset_schema.ColumnSchema( tf.string, [2, 3], dataset_schema.FixedColumnRepresentation()), 'c': dataset_schema.ColumnSchema( dataset_schema.IntDomain( tf.int64, 5, 6, is_categorical=True), [], dataset_schema.FixedColumnRepresentation()) }) self.assertEqual(schema, expected_schema)
def apply_buckets(x, bucket_boundaries, name=None): """Returns a bucketized column, with a bucket index assigned to each input. Args: x: A numeric input `Tensor` whose values should be mapped to buckets. bucket_boundaries: The bucket boundaries represented as a list. name: (Optional) A name for this operation. Returns: A `Tensor` of the same shape as `x`, with each element in the returned tensor representing the bucketized value. Bucketized value is in the range [0, len(bucket_boundaries)]. """ with tf.name_scope(name, 'apply_buckets'): buckets = quantile_ops.bucketize_with_input_boundaries( x, boundaries=bucket_boundaries, name='assign_buckets') # Convert to int64 because int32 is not compatible with tf.Example parser. # See _TF_EXAMPLE_ALLOWED_TYPES in FixedColumnRepresentation() # in tf_metadata/dataset_schema.py result = tf.to_int64(buckets) # Attach the relevant metadata to result, so that the corresponding # output feature will have this metadata set. max_value = tf.shape(bucket_boundaries)[1] column_schema = dataset_schema.infer_column_schema_from_tensor(result) column_schema.domain = dataset_schema.IntDomain( result.dtype, min_value=0, max_value=futures.Future(max_value.name), is_categorical=True) api.set_column_schema(result, column_schema) return result
def _augment_metadata(saved_model_dir, metadata): """Augments the metadata with min/max values stored in the SavedModel. Takes the min/max values of tensors stored in the SavedModel, and uses these to augment the metadata. For each feature in the metadata, the min/max of the corresponding `Tensor` are used to augment the schema. For a feature represented by a `SparseTensor` we use the min/max for the `values` field of the `SparseTensor`. Args: saved_model_dir: Location of a SavedModel metadata: A `DatasetMetadata` Returns: An augmented DatasetMetadata. The original DatasetMetadata is unchanged. """ with tf.Graph().as_default() as graph: with tf.Session(graph=graph) as session: _, output_tensor_by_name = ( saved_transform_io.partially_apply_saved_transform_internal( saved_model_dir, {})) # Get overrides for the min/max of tensors from the graph, and use these # determine overrides for the min/max of the outputs of the graph. tensor_schema_overrides = tft_api.get_tensor_schema_overrides() column_schema_overrides = {} for name, tensor in six.iteritems(output_tensor_by_name): if isinstance(tensor, tf.SparseTensor): tensor = tensor.values if tensor in tensor_schema_overrides: column_schema_overrides[name] = tensor_schema_overrides[ tensor] session.run(tf.global_variables_initializer()) session.run(tf.tables_initializer()) column_schema_override_values = session.run( column_schema_overrides) new_column_schemas = {} for key, column_schema in six.iteritems(metadata.schema.column_schemas): if key in column_schema_override_values: min_value, max_value = column_schema_override_values[key] assert column_schema.domain.dtype == tf.int64 assert isinstance(column_schema.domain, dataset_schema.IntDomain) # Create a new column schema. An override always results in a # categorical column. new_column_schemas[key] = dataset_schema.ColumnSchema( dataset_schema.IntDomain(tf.int64, min_value, max_value, is_categorical=True), column_schema.axes, column_schema.representation) else: new_column_schemas[key] = column_schema return dataset_metadata.DatasetMetadata( dataset_schema.Schema(new_column_schemas))
def test_schema_with_futures(self): schema = sch.Schema() schema.column_schemas['fixed_bool_without_default'] = ( sch.ColumnSchema( tf.bool, [5, futures.Future('foo_dim_1'), 7, futures.Future('foo_dim_3')], sch.FixedColumnRepresentation())) schema.column_schemas['fixed_int_with_default'] = ( sch.ColumnSchema(tf.int64, [1], sch.FixedColumnRepresentation( default_value=futures.Future('bar_int_default')))) schema.column_schemas['fixed_categorical_int_with_range'] = ( sch.ColumnSchema(sch.IntDomain(tf.int64, futures.Future('baz_int_min'), futures.Future('baz_int_max'), is_categorical=True), [1], sch.FixedColumnRepresentation(default_value=0))) self.assertFalse(schema.all_futures_resolved()) schema.substitute_futures({'foo_dim_1': 6, 'foo_dim_3': 8, 'bar_int_default': 12, 'baz_int_min': 3, 'baz_int_max': 4}) self.assertTrue(schema.all_futures_resolved()) expected_schema = sch.Schema() expected_schema.column_schemas['fixed_bool_without_default'] = ( sch.ColumnSchema(tf.bool, [5, 6, 7, 8], sch.FixedColumnRepresentation())) expected_schema.column_schemas['fixed_int_with_default'] = ( sch.ColumnSchema(tf.int64, [1], sch.FixedColumnRepresentation(default_value=12))) expected_schema.column_schemas['fixed_categorical_int_with_range'] = ( sch.ColumnSchema( sch.IntDomain(tf.int64, 3, 4, is_categorical=True), [1], sch.FixedColumnRepresentation(default_value=0))) self.assertEqual(expected_schema, schema)
def _to_domain(domain): if domain.get('ints') is not None: return sch.IntDomain(tf.int64) if domain.get('floats') is not None: return sch.FloatDomain(tf.float32) if domain.get('strings') is not None: return sch.StringDomain(tf.string) if domain.get('bools') is not None: return sch.BoolDomain(tf.bool) raise ValueError('Unknown domain: {}'.format(domain))
def compute_deferred_metadata(metadata, column_schema_overrides, saved_model_dir, tensor_value_mapping): """Extracts constant values from graph.""" tensor_names = { tensor_name for override in six.itervalues(column_schema_overrides) for tensor_name in [override.min_value, override.max_value] } graph = tf.Graph() with graph.as_default(): tensor_replacement_map = {} for orig_tensor_name, ( value, is_asset) in six.iteritems(tensor_value_mapping): new_tensor = tf.constant(value) if is_asset: # Any newly frozen constant tensors containing filenames must be # added to the ASSET_FILENAMES collection. graph.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, new_tensor) tensor_replacement_map[orig_tensor_name] = new_tensor with tf.Session(graph=graph) as session: tensors_by_name = (saved_transform_io.fetch_tensor_values( saved_model_dir, tensor_replacement_map, tensor_names)) session.run(tf.global_variables_initializer()) session.run(tf.tables_initializer()) tensor_values_by_name = session.run(tensors_by_name) new_column_schemas = {} for key, column_schema in six.iteritems( metadata.schema.column_schemas): if key in column_schema_overrides: override = column_schema_overrides[key] min_value = tensor_values_by_name[override.min_value] max_value = tensor_values_by_name[override.max_value] assert column_schema.domain.dtype == tf.int64 assert isinstance(column_schema.domain, dataset_schema.IntDomain) # Create a new column schema. An override always results in a # categorical column. new_column_schemas[key] = dataset_schema.ColumnSchema( dataset_schema.IntDomain(tf.int64, min_value, max_value, is_categorical=True), column_schema.axes, column_schema.representation) else: new_column_schemas[key] = column_schema return dataset_metadata.DatasetMetadata( dataset_schema.Schema(new_column_schemas))
def test_int_domain_defaults(self): self.assertFalse(sch.IntDomain(tf.int64).is_categorical) self.assertTrue( sch.IntDomain(tf.int64, is_categorical=True).is_categorical) self.assertEqual(tf.int64.min, sch.IntDomain(tf.int64).min_value) self.assertEqual(-3, sch.IntDomain(tf.int64, min_value=-3).min_value) self.assertEqual(tf.int64.max, sch.IntDomain(tf.int64).max_value) self.assertEqual(3, sch.IntDomain(tf.int64, max_value=3).max_value)
def infer_feature_schema(features, graph, session=None): """Given a dict of tensors, creates a `Schema`. Infers a schema, in the format of a tf.Transform `Schema`, for the given dictionary of tensors. If there is an override specified, we override the inferred schema for the given feature's tensor. An override has the meaning that we should set is_categorical=True. If session is not provided then we just set is_categorical=True, and if the session is provided then was also compute values of the tensors representing the min and max values and set them in the schema. Args: features: A dict mapping column names to `Tensor` or `SparseTensor`s. The `Tensor` or `SparseTensor`s should have a 0'th dimension which is interpreted as the batch dimension. graph: A tf.Graph, used to look up schema overrides even they are not computed. session: (optional) A `tf.Session` used to compute schema overrides. If None, schema overrides will not be computed. Returns: A `Schema` object. """ tensor_overrides = _get_tensor_schema_overrides(graph) column_schemas = {} for name, tensor in six.iteritems(features): column_schema = dataset_schema.infer_column_schema_from_tensor(tensor) override_min_and_max = tensor_overrides.get( tensor.values if isinstance(tensor, tf.SparseTensor) else tensor) if override_min_and_max is not None: assert column_schema.domain.dtype == tf.int64 assert isinstance(column_schema.domain, dataset_schema.IntDomain) if session is not None: min_value, max_value = session.run(override_min_and_max) else: min_value, max_value = None, None column_schemas[name] = dataset_schema.ColumnSchema( dataset_schema.IntDomain(tf.int64, min_value, max_value, is_categorical=True), column_schema.axes, column_schema.representation) else: column_schemas[name] = column_schema return dataset_schema.Schema(column_schemas)
def _from_domain_dict(domain): """Translate a JSON domain dict into a Domain.""" if domain.get('ints') is not None: def maybe_to_int(s): return int(s) if s is not None else None return sch.IntDomain( tf.int64, maybe_to_int(domain['ints'].get('min')), maybe_to_int(domain['ints'].get('max')), domain['ints'].get('is_categorical')) if domain.get('floats') is not None: return sch.FloatDomain(tf.float32) if domain.get('strings') is not None: return sch.StringDomain(tf.string) if domain.get('bools') is not None: return sch.BoolDomain(tf.bool) raise ValueError('Unknown domain: {}'.format(domain))
def get_manually_created_schema(): """Provide a test schema built from scratch using the Schema classes.""" return sch.Schema({ # FixedLenFeatures 'fixed_categorical_int_with_range': sch.ColumnSchema(sch.IntDomain(tf.int64, -5, 10, True), [], sch.FixedColumnRepresentation()), 'fixed_int': sch.ColumnSchema(tf.int64, [5], sch.FixedColumnRepresentation()), 'fixed_float': sch.ColumnSchema(tf.float32, [5], sch.FixedColumnRepresentation()), 'fixed_string': sch.ColumnSchema(tf.string, [5], sch.FixedColumnRepresentation()), # VarLenFeatures 'var_int': sch.ColumnSchema(tf.int64, None, sch.ListColumnRepresentation()), 'var_float': sch.ColumnSchema(tf.float32, None, sch.ListColumnRepresentation()), 'var_string': sch.ColumnSchema(tf.string, None, sch.ListColumnRepresentation()) })
def _create_output_metadata(features_config, min_value, max_value): """Constructs a custom DatasetMetadata. Args: features_config: Features configuration mock. min_value: Minimum value for IntDomain. max_value: Maximum value for IntDomain. Returns: A `tft.tf_metadata.dataset_metadata.DatasetMetadata` object. """ schema = { features_config.TARGET_FEATURE: dataset_schema.ColumnSchema( tf.float32, [], dataset_schema.FixedColumnRepresentation()), features_config.ID_FEATURE: dataset_schema.ColumnSchema(tf.int64, [None], dataset_schema.ListColumnRepresentation()) } schema.update({ utils.make_transformed_key(feature): dataset_schema.ColumnSchema(tf.float32, [], dataset_schema.FixedColumnRepresentation()) for feature in features_config.NUMERIC_FEATURES }) categorical_col_schema = dataset_schema.ColumnSchema( dataset_schema.IntDomain(tf.int64, min_value, max_value, is_categorical=True), [], dataset_schema.FixedColumnRepresentation()) schema.update({ utils.make_transformed_key(feature): categorical_col_schema for feature in features_config.CATEGORICAL_FEATURES }) return dataset_metadata.DatasetMetadata(schema)
def string_to_int(x, default_value=-1, top_k=None, frequency_threshold=None, num_oov_buckets=0, vocab_filename=None): """Generates a vocabulary for `x` and maps it to an integer with this vocab. Args: x: A `Tensor` or `SparseTensor` of type tf.string. default_value: The value to use for out-of-vocabulary values, unless 'num_oov_buckets' is greater than zero. top_k: Limit the generated vocabulary to the first `top_k` elements. If set to None, the full vocabulary is generated. frequency_threshold: Limit the generated vocabulary only to elements whose frequency is >= to the supplied threshold. If set to None, the full vocabulary is generated. num_oov_buckets: Any lookup of an out-of-vocabulary token will return a bucket ID based on its hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the `default_value`. vocab_filename: The file name for the vocabulary file. If none, the "uniques" scope name in the context of this graph will be used as the file name. If not None, should be unique within a given preprocessing function. Returns: A `Tensor` or `SparseTensor` where each string value is mapped to an integer where each unique string value is mapped to a different integer and integers are consecutive and starting from 0. Raises: ValueError: If `top_k` or `frequency_threshold` is negative. """ if top_k is not None: top_k = int(top_k) if top_k < 0: raise ValueError('top_k must be non-negative, but got: %r' % top_k) if frequency_threshold is not None: frequency_threshold = int(frequency_threshold) if frequency_threshold < 0: raise ValueError( 'frequency_threshold must be non-negative, but got: %r' % frequency_threshold) def _apply_vocab(x, vocabulary_file): table = lookup.string_to_index_table_from_file( vocabulary_file, num_oov_buckets=num_oov_buckets, default_value=default_value) table_size = table.size() return table.lookup(x), table_size with tf.name_scope('string_to_int'): prefix = None if vocab_filename is None: prefix = analyzers.VOCAB_FILENAME_PREFIX vocab_filename = analyzers.sanitized_vocab_filename( vocab_filename, prefix) vocabulary_file = analyzers.uniques( x, top_k=top_k, frequency_threshold=frequency_threshold, vocab_filename=vocab_filename) result, table_size = api.apply_function(_apply_vocab, x, vocabulary_file) # Set the min and max values of the domain, where the max value is a `Future` # wrapping the max_value tensor. Note that min_value is a regular Python # value while max_value is a tensor. This tensor's value cannot be known # until the vocab has been computed. # # `table_size` includes the num oov buckets. The default value is only used # if num_oov_buckets > 0. min_value = 0 max_value = table_size - 1 if num_oov_buckets <= 0: min_value = min(min_value, default_value) max_value = tf.maximum(max_value, default_value) column_schema = dataset_schema.infer_column_schema_from_tensor(result) column_schema.domain = dataset_schema.IntDomain( result.dtype, min_value=min_value, max_value=futures.Future(max_value.name), vocabulary_file=vocab_filename) api.set_column_schema(result, column_schema) return result
from tensorflow_transform.tf_metadata import metadata_io import unittest from tensorflow.python.framework import test_util _TEST_METADATA_COMPLETE = dataset_metadata.DatasetMetadata({ 'fixed_column': dataset_schema.ColumnSchema(tf.string, (1, 3, 2), dataset_schema.FixedColumnRepresentation()), 'fixed_column_with_default': dataset_schema.ColumnSchema( tf.float32, (1, 3, 2), dataset_schema.FixedColumnRepresentation(123.4)), 'list_columm': dataset_schema.ColumnSchema( dataset_schema.IntDomain(tf.int64, min_value=-1, max_value=5), (None, ), dataset_schema.ListColumnRepresentation()) }) _TEST_METADATA = dataset_metadata.DatasetMetadata({ 'fixed_column': dataset_schema.ColumnSchema(tf.string, (1, 3, 2), dataset_schema.FixedColumnRepresentation()), 'fixed_column_with_default': dataset_schema.ColumnSchema( tf.float32, (1, 3, 2), dataset_schema.FixedColumnRepresentation(123.4)), # zeros will be overriddden 'list_columm': dataset_schema.ColumnSchema( dataset_schema.IntDomain(tf.int64, min_value=0, max_value=0), (None, ),
tf.io.VarLenFeature(dtype=tf.int64), 'var_float': tf.io.VarLenFeature(dtype=tf.float32), 'var_string': tf.io.VarLenFeature(dtype=tf.string), } def get_test_schema(): return sch.from_feature_spec(test_feature_spec) _COLUMN_SCHEMAS = { # FixedLenFeatures 'fixed_categorical_int_with_range': sch.ColumnSchema(sch.IntDomain(tf.int64, -5, 10, True), [], sch.FixedColumnRepresentation()), 'fixed_int': sch.ColumnSchema(tf.int64, [5], sch.FixedColumnRepresentation()), 'fixed_float': sch.ColumnSchema(tf.float32, [5], sch.FixedColumnRepresentation()), 'fixed_string': sch.ColumnSchema(tf.string, [5], sch.FixedColumnRepresentation()), # VarLenFeatures 'var_int': sch.ColumnSchema(tf.int64, None, sch.ListColumnRepresentation()), 'var_float': sch.ColumnSchema(tf.float32, None, sch.ListColumnRepresentation()), 'var_string': sch.ColumnSchema(tf.string, None, sch.ListColumnRepresentation()) }
def apply_vocab(x, deferred_vocab_filename_tensor, default_value=-1, num_oov_buckets=0, lookup_fn=None, name=None): r"""Maps `x` to a vocabulary specified by the deferred tensor. This function also writes domain statistics about the vocabulary min and max values. Note that the min and max are inclusive, and depend on the vocab size, num_oov_buckets and default_value. In case one of the tokens contains the '\n' or '\r' characters or is empty it will be discarded since we are currently writing the vocabularies as text files. This behavior will likely be fixed/improved in the future. Args: x: A `Tensor` or `SparseTensor` of type tf.string to which the vocabulary transformation should be applied. The colum names are those intended for the transformed tensors. deferred_vocab_filename_tensor: The deferred vocab filename tensor as returned by `tft.uniques`. default_value: The value to use for out-of-vocabulary values, unless 'num_oov_buckets' is greater than zero. num_oov_buckets: Any lookup of an out-of-vocabulary token will return a bucket ID based on its hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the `default_value`. lookup_fn: Optional lookup function, if specified it should take a tensor and a deferred vocab filename as an input and return a lookup `op` along with the table size, by default `apply_vocab` performs a lookup.string_to_index_table_from_file for the table lookup. name: (Optional) A name for this operation. Returns: A `Tensor` or `SparseTensor` where each string value is mapped to an integer; each unique string value is mapped to a different integer and integers are consecutive and start from default_value. """ def _apply_vocab(y, deferred_vocab_filename_tensor): table = lookup.index_table_from_file(deferred_vocab_filename_tensor, num_oov_buckets=num_oov_buckets, default_value=default_value) table_size = table.size() return table.lookup(y), table_size with tf.name_scope(name, 'apply_vocab'): lookup_fn = lookup_fn or _apply_vocab result, table_size = api.apply_function( lookup_fn, x, deferred_vocab_filename_tensor) # Set the min and max values of the domain, where the max value is a # `Future` wrapping the max_value tensor. Note that min_value is a regular # Python value while max_value is a tensor. This tensor's value cannot be # known until the vocab has been computed. # # `table_size` includes the num oov buckets. The default value is only used # if num_oov_buckets > 0. min_value = 0 max_value = table_size - 1 if num_oov_buckets <= 0: min_value = min(min_value, default_value) max_value = tf.maximum(max_value, default_value) column_schema = dataset_schema.infer_column_schema_from_tensor(result) # Extract the relative vocab filename from the absolute pathname. file_name_tensor = tf.string_split([deferred_vocab_filename_tensor], '/').values[-1] column_schema.domain = dataset_schema.IntDomain( result.dtype, min_value=min_value, max_value=futures.Future(max_value.name), is_categorical=True, vocabulary_file=futures.Future(file_name_tensor.name)) api.set_column_schema(result, column_schema) return result
# VarLenFeatures 'var_int': tf.VarLenFeature(dtype=tf.int64), 'var_float': tf.VarLenFeature(dtype=tf.float32), 'var_string': tf.VarLenFeature(dtype=tf.string), } def get_test_schema(): return sch.from_feature_spec(test_feature_spec) _COLUMN_SCHEMAS = { # FixedLenFeatures 'fixed_categorical_int_with_range': sch.ColumnSchema( sch.IntDomain(tf.int64, -5, 10, True), [], sch.FixedColumnRepresentation()), 'fixed_int': sch.ColumnSchema( tf.int64, [5], sch.FixedColumnRepresentation()), 'fixed_float': sch.ColumnSchema( tf.float32, [5], sch.FixedColumnRepresentation()), 'fixed_string': sch.ColumnSchema( tf.string, [5], sch.FixedColumnRepresentation()), # VarLenFeatures 'var_int': sch.ColumnSchema( tf.int64, None, sch.ListColumnRepresentation()), 'var_float': sch.ColumnSchema( tf.float32, None, sch.ListColumnRepresentation()), 'var_string': sch.ColumnSchema( tf.string, None, sch.ListColumnRepresentation()) }
from tensorflow_transform.tf_metadata import metadata_io import unittest from tensorflow.python.framework import test_util _TEST_METADATA = dataset_metadata.DatasetMetadata({ 'fixed_column': dataset_schema.ColumnSchema(tf.string, (1, 3, 2), dataset_schema.FixedColumnRepresentation()), 'fixed_column_with_default': dataset_schema.ColumnSchema( tf.float32, (1, 3, 2), dataset_schema.FixedColumnRepresentation(123.4)), 'list_columm': dataset_schema.ColumnSchema( dataset_schema.IntDomain(tf.int64, min_value=-1, max_value=5), (None, ), dataset_schema.ListColumnRepresentation()) }) _TEST_METADATA_WITH_FUTURES = dataset_metadata.DatasetMetadata({ 'fixed_column': dataset_schema.ColumnSchema(tf.string, (1, 3, 2), dataset_schema.FixedColumnRepresentation()), 'fixed_column_with_default': dataset_schema.ColumnSchema( tf.float32, (1, futures.Future('a'), 2), dataset_schema.FixedColumnRepresentation(123.4)), 'list_columm': dataset_schema.ColumnSchema( dataset_schema.IntDomain(tf.int64, min_value=-1,
def get_manually_created_schema(): """Provide a test schema built from scratch using the Schema classes.""" schema = sch.Schema() # FixedLenFeatures schema.column_schemas['fixed_bool_with_default'] = (sch.ColumnSchema( tf.bool, [1], sch.FixedColumnRepresentation(default_value=False))) schema.column_schemas['fixed_bool_without_default'] = (sch.ColumnSchema( tf.bool, [5], sch.FixedColumnRepresentation())) schema.column_schemas['fixed_int_with_default'] = (sch.ColumnSchema( tf.int64, [1], sch.FixedColumnRepresentation(default_value=0))) schema.column_schemas['fixed_categorical_int_with_range'] = ( sch.ColumnSchema(sch.IntDomain(tf.int64, -5, 10, True), [1], sch.FixedColumnRepresentation(0))) schema.column_schemas['fixed_categorical_int_with_vocab'] = ( sch.ColumnSchema( sch.IntDomain(tf.int64, vocabulary_file='test_filename'), [1], sch.FixedColumnRepresentation(0))) schema.column_schemas['fixed_int_without_default'] = (sch.ColumnSchema( tf.int64, [5], sch.FixedColumnRepresentation())) schema.column_schemas['fixed_float_with_default'] = (sch.ColumnSchema( tf.float32, [1], sch.FixedColumnRepresentation(default_value=0.0))) schema.column_schemas['fixed_float_without_default'] = (sch.ColumnSchema( tf.float32, [5], sch.FixedColumnRepresentation())) schema.column_schemas['fixed_string_with_default'] = (sch.ColumnSchema( tf.string, [1], sch.FixedColumnRepresentation(default_value='default'))) schema.column_schemas['fixed_string_without_default'] = (sch.ColumnSchema( tf.string, [5], sch.FixedColumnRepresentation())) schema.column_schemas['3d_fixed_int_without_default'] = (sch.ColumnSchema( tf.int64, [5, 6, 7], sch.FixedColumnRepresentation())) # VarLenFeatures schema.column_schemas['var_bool'] = (sch.ColumnSchema( tf.bool, None, sch.ListColumnRepresentation())) schema.column_schemas['var_int'] = (sch.ColumnSchema( tf.int64, None, sch.ListColumnRepresentation())) schema.column_schemas['var_float'] = (sch.ColumnSchema( tf.float32, None, sch.ListColumnRepresentation())) schema.column_schemas['var_string'] = (sch.ColumnSchema( tf.string, None, sch.ListColumnRepresentation())) # SparseFeatures schema.column_schemas['sparse_bool'] = (sch.ColumnSchema( tf.bool, [15], sch.SparseColumnRepresentation( 'sparse_bool_value', [sch.SparseIndexField('sparse_bool_index', True)]))) schema.column_schemas['sparse_int'] = (sch.ColumnSchema( tf.int64, [150], sch.SparseColumnRepresentation( 'sparse_int_value', [sch.SparseIndexField('sparse_int_index', False)]))) schema.column_schemas['sparse_float'] = (sch.ColumnSchema( tf.float32, [1500], sch.SparseColumnRepresentation( 'sparse_float_value', [sch.SparseIndexField('sparse_float_index', False)]))) schema.column_schemas['sparse_string'] = (sch.ColumnSchema( tf.string, [15000], sch.SparseColumnRepresentation( 'sparse_string_value', [sch.SparseIndexField('sparse_string_index', True)]))) return schema
from tensorflow.python.framework import test_util from tensorflow.python.lib.io import file_io _TEST_METADATA_COMPLETE = dataset_metadata.DatasetMetadata({ 'fixed_column': dataset_schema.ColumnSchema( tf.string, (3,), dataset_schema.FixedColumnRepresentation()), 'list_columm': dataset_schema.ColumnSchema( tf.float32, (None,), dataset_schema.ListColumnRepresentation()) }) _TEST_METADATA = dataset_metadata.DatasetMetadata({ 'fixed_column': dataset_schema.ColumnSchema( tf.string, (3,), dataset_schema.FixedColumnRepresentation()), # zeros will be overriddden 'list_columm': dataset_schema.ColumnSchema( dataset_schema.IntDomain(tf.int64, min_value=0, max_value=0), (None,), dataset_schema.ListColumnRepresentation()) }) class BeamMetadataIoTest(test_util.TensorFlowTestCase): def testReadTransformFn(self): path = self.get_temp_dir() # NOTE: we don't need to create or write to the transform_fn directory since # ReadTransformFn never inspects this directory. transform_fn_dir = os.path.join( path, tft.TFTransformOutput.TRANSFORM_FN_DIR) transformed_metadata_dir = os.path.join( path, tft.TFTransformOutput.TRANSFORMED_METADATA_DIR) metadata_io.write_metadata(_TEST_METADATA_COMPLETE,