def test_stale_asset_collections_are_cleaned(self):
        vocabulary_file = os.path.join(tf.compat.as_bytes(self.get_temp_dir()),
                                       tf.compat.as_bytes('asset'))
        file_io.write_string_to_file(vocabulary_file, 'foo bar baz')

        export_path = os.path.join(tempfile.mkdtemp(), 'export')

        # create a SavedModel including assets
        with tf.compat.v1.Graph().as_default():
            with tf.compat.v1.Session().as_default() as session:
                input_string = tf.compat.v1.placeholder(tf.string)
                # Map string through a table loaded from an asset file
                initializer = tf.lookup.TextFileInitializer(
                    vocabulary_file,
                    key_dtype=tf.string,
                    key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
                    value_dtype=tf.int64,
                    value_index=tf.lookup.TextFileIndex.LINE_NUMBER)
                table = tf.lookup.StaticHashTable(initializer,
                                                  default_value=12)
                table = lookup_ops.IdTableWithHashBuckets(table,
                                                          num_oov_buckets=12,
                                                          key_dtype=tf.string)
                output = table.lookup(input_string)
                inputs = {'input': input_string}
                outputs = {'output': output}
                saved_transform_io.write_saved_transform_from_session(
                    session, inputs, outputs, export_path)

        # Load it and save it again repeatedly, verifying that the asset collections
        # remain valid.
        for _ in [1, 2, 3]:
            with tf.compat.v1.Graph().as_default() as g:
                with tf.compat.v1.Session().as_default() as session:
                    input_string = tf.constant('dog')
                    inputs = {'input': input_string}
                    _, outputs = (saved_transform_io.
                                  partially_apply_saved_transform_internal(
                                      export_path, inputs))

                    self.assertEqual(
                        1,
                        len(
                            g.get_collection(
                                tf.compat.v1.GraphKeys.ASSET_FILEPATHS)))
                    self.assertEqual(
                        0, len(g.get_collection(tf.saved_model.ASSETS_KEY)))

                    # Check that every ASSET_FILEPATHS refers to a Tensor in the graph.
                    # If not, get_tensor_by_name() raises KeyError.
                    for asset_path in g.get_collection(
                            tf.compat.v1.GraphKeys.ASSET_FILEPATHS):
                        tensor_name = asset_path.name
                        g.get_tensor_by_name(tensor_name)

                    export_path = os.path.join(tempfile.mkdtemp(), 'export')
                    saved_transform_io.write_saved_transform_from_session(
                        session, inputs, outputs, export_path)
예제 #2
0
def table_from_dataset(dataset=None,
                       num_oov_buckets=0,
                       vocab_size=None,
                       default_value=None,
                       hasher_spec=lookup_ops.FastHashSpec,
                       key_dtype=dtypes.string,
                       name=None):
  """Returns a lookup table based on the given dataset.

  This operation constructs a lookup table based on the given dataset of pairs
  of (key, value).

  Any lookup of an out-of-vocabulary token will return a bucket ID based on its
  hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
  `default_value`.
  The bucket ID range is
  `[vocabulary size, vocabulary size + num_oov_buckets - 1]`.

  Sample Usages:

  >>> keys = tf.data.Dataset.range(100)
  >>> values = tf.data.Dataset.range(100).map(
  ...     lambda x: tf.strings.as_string(x * 2))
  >>> ds = tf.data.Dataset.zip((keys, values))
  >>> table = tf.data.experimental.table_from_dataset(
  ...                               ds, default_value='n/a', key_dtype=tf.int64)
  >>> table.lookup(tf.constant([0, 1, 2], dtype=tf.int64)).numpy()
  array([b'0', b'2', b'4'], dtype=object)

  Args:
    dataset: A dataset containing (key, value) pairs.
    num_oov_buckets: The number of out-of-vocabulary buckets.
    vocab_size: Number of the elements in the vocabulary, if known.
    default_value: The value to use for out-of-vocabulary feature values.
      Defaults to -1.
    hasher_spec: A `HasherSpec` to specify the hash function to use for
      assignation of out-of-vocabulary buckets.
    key_dtype: The `key` data type.
    name: A name for this op (optional).

  Returns:
    The lookup table based on the given dataset.

  Raises:
    ValueError: If
      * `dataset` does not contain pairs
      * The 2nd item in the `dataset` pairs has a dtype which is incompatible
        with `default_value`
      * `num_oov_buckets` is negative
      * `vocab_size` is not greater than zero
      * The `key_dtype` is not integer or string
  """
  elem_spec = dataset.element_spec
  if len(elem_spec) != 2:
    raise ValueError("The given dataset must contain pairs.")
  if default_value is None:
    default_value = -1
    if not (elem_spec[1].dtype.is_integer or elem_spec[1].dtype.is_floating):
      raise ValueError("The dtype of the values requires manually setting a "
                       "compatible default_value.")
  if num_oov_buckets < 0:
    raise ValueError(
        "num_oov_buckets must be greater or equal than 0, got %d." %
        num_oov_buckets)
  if (not isinstance(vocab_size, ops.Tensor) and vocab_size is not None and
      vocab_size < 1):
    raise ValueError("vocab_size must be greater than 0, got %d." % vocab_size)
  if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype):
    raise TypeError("Only integer and string keys are supported.")
  if vocab_size is not None:
    if isinstance(vocab_size, ops.Tensor):
      vocab_size = math_ops.cast(vocab_size, dtypes.int64)
    dataset = dataset.take(vocab_size)
    dataset = dataset.apply(assert_cardinality(vocab_size))
  with ops.name_scope(name, "string_to_index"):
    initializer = DatasetInitializer(dataset)
    with ops.name_scope(None, "hash_table"):
      table = lookup_ops.StaticHashTableV1(initializer, default_value)
      if num_oov_buckets:
        table = lookup_ops.IdTableWithHashBuckets(
            table,
            num_oov_buckets=num_oov_buckets,
            hasher_spec=hasher_spec,
            key_dtype=key_dtype)
      return table