def test_stale_asset_collections_are_cleaned(self): vocabulary_file = os.path.join(tf.compat.as_bytes(self.get_temp_dir()), tf.compat.as_bytes('asset')) file_io.write_string_to_file(vocabulary_file, 'foo bar baz') export_path = os.path.join(tempfile.mkdtemp(), 'export') # create a SavedModel including assets with tf.compat.v1.Graph().as_default(): with tf.compat.v1.Session().as_default() as session: input_string = tf.compat.v1.placeholder(tf.string) # Map string through a table loaded from an asset file initializer = tf.lookup.TextFileInitializer( vocabulary_file, key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE, value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER) table = tf.lookup.StaticHashTable(initializer, default_value=12) table = lookup_ops.IdTableWithHashBuckets(table, num_oov_buckets=12, key_dtype=tf.string) output = table.lookup(input_string) inputs = {'input': input_string} outputs = {'output': output} saved_transform_io.write_saved_transform_from_session( session, inputs, outputs, export_path) # Load it and save it again repeatedly, verifying that the asset collections # remain valid. for _ in [1, 2, 3]: with tf.compat.v1.Graph().as_default() as g: with tf.compat.v1.Session().as_default() as session: input_string = tf.constant('dog') inputs = {'input': input_string} _, outputs = (saved_transform_io. partially_apply_saved_transform_internal( export_path, inputs)) self.assertEqual( 1, len( g.get_collection( tf.compat.v1.GraphKeys.ASSET_FILEPATHS))) self.assertEqual( 0, len(g.get_collection(tf.saved_model.ASSETS_KEY))) # Check that every ASSET_FILEPATHS refers to a Tensor in the graph. # If not, get_tensor_by_name() raises KeyError. for asset_path in g.get_collection( tf.compat.v1.GraphKeys.ASSET_FILEPATHS): tensor_name = asset_path.name g.get_tensor_by_name(tensor_name) export_path = os.path.join(tempfile.mkdtemp(), 'export') saved_transform_io.write_saved_transform_from_session( session, inputs, outputs, export_path)
def table_from_dataset(dataset=None, num_oov_buckets=0, vocab_size=None, default_value=None, hasher_spec=lookup_ops.FastHashSpec, key_dtype=dtypes.string, name=None): """Returns a lookup table based on the given dataset. This operation constructs a lookup table based on the given dataset of pairs of (key, value). Any lookup of an out-of-vocabulary token will return a bucket ID based on its hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the `default_value`. The bucket ID range is `[vocabulary size, vocabulary size + num_oov_buckets - 1]`. Sample Usages: >>> keys = tf.data.Dataset.range(100) >>> values = tf.data.Dataset.range(100).map( ... lambda x: tf.strings.as_string(x * 2)) >>> ds = tf.data.Dataset.zip((keys, values)) >>> table = tf.data.experimental.table_from_dataset( ... ds, default_value='n/a', key_dtype=tf.int64) >>> table.lookup(tf.constant([0, 1, 2], dtype=tf.int64)).numpy() array([b'0', b'2', b'4'], dtype=object) Args: dataset: A dataset containing (key, value) pairs. num_oov_buckets: The number of out-of-vocabulary buckets. vocab_size: Number of the elements in the vocabulary, if known. default_value: The value to use for out-of-vocabulary feature values. Defaults to -1. hasher_spec: A `HasherSpec` to specify the hash function to use for assignation of out-of-vocabulary buckets. key_dtype: The `key` data type. name: A name for this op (optional). Returns: The lookup table based on the given dataset. Raises: ValueError: If * `dataset` does not contain pairs * The 2nd item in the `dataset` pairs has a dtype which is incompatible with `default_value` * `num_oov_buckets` is negative * `vocab_size` is not greater than zero * The `key_dtype` is not integer or string """ elem_spec = dataset.element_spec if len(elem_spec) != 2: raise ValueError("The given dataset must contain pairs.") if default_value is None: default_value = -1 if not (elem_spec[1].dtype.is_integer or elem_spec[1].dtype.is_floating): raise ValueError("The dtype of the values requires manually setting a " "compatible default_value.") if num_oov_buckets < 0: raise ValueError( "num_oov_buckets must be greater or equal than 0, got %d." % num_oov_buckets) if (not isinstance(vocab_size, ops.Tensor) and vocab_size is not None and vocab_size < 1): raise ValueError("vocab_size must be greater than 0, got %d." % vocab_size) if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype): raise TypeError("Only integer and string keys are supported.") if vocab_size is not None: if isinstance(vocab_size, ops.Tensor): vocab_size = math_ops.cast(vocab_size, dtypes.int64) dataset = dataset.take(vocab_size) dataset = dataset.apply(assert_cardinality(vocab_size)) with ops.name_scope(name, "string_to_index"): initializer = DatasetInitializer(dataset) with ops.name_scope(None, "hash_table"): table = lookup_ops.StaticHashTableV1(initializer, default_value) if num_oov_buckets: table = lookup_ops.IdTableWithHashBuckets( table, num_oov_buckets=num_oov_buckets, hasher_spec=hasher_spec, key_dtype=key_dtype) return table