Beispiel #1
0
 def testCorrectCardinality(self):
   dataset = dataset_ops.Dataset.range(10).filter(lambda x: True)
   self.assertEqual(
       self.evaluate(cardinality.cardinality(dataset)), cardinality.UNKNOWN)
   self.assertDatasetProduces(dataset, expected_output=range(10))
   dataset = dataset.apply(cardinality.assert_cardinality(10))
   self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10)
   self.assertDatasetProduces(dataset, expected_output=range(10))
Beispiel #2
0
 def testIncorrectCardinality(self, num_elements, asserted_cardinality,
                              expected_error):
   dataset = dataset_ops.Dataset.range(num_elements)
   dataset = dataset.apply(
       cardinality.assert_cardinality(asserted_cardinality))
   get_next = self.getNext(dataset)
   with self.assertRaisesRegex(errors.FailedPreconditionError, expected_error):
     while True:
       self.evaluate(get_next())
    def testAssertCardinality(self):
        dataset = dataset_ops.Dataset.list_files(self._filenames,
                                                 shuffle=False)
        dataset = dataset.flat_map(core_readers.TFRecordDataset)
        dataset = dataset.batch(5)
        dataset = dataset.apply(cardinality.assert_cardinality(42))
        dataset = distribute._AutoShardDataset(dataset, 5, 0)

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in (0, 5) for r in range(0, 10)
        ]
        self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
Beispiel #4
0
 def build_dataset(num_elements):
   return dataset_ops.Dataset.range(num_elements).apply(
       cardinality.assert_cardinality(num_elements))
def table_from_dataset(dataset=None,
                       num_oov_buckets=0,
                       vocab_size=None,
                       default_value=None,
                       hasher_spec=lookup_ops.FastHashSpec,
                       key_dtype=dtypes.string,
                       name=None):
  """Returns a lookup table based on the given dataset.

  This operation constructs a lookup table based on the given dataset of pairs
  of (key, value).

  Any lookup of an out-of-vocabulary token will return a bucket ID based on its
  hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
  `default_value`.
  The bucket ID range is
  `[vocabulary size, vocabulary size + num_oov_buckets - 1]`.

  Sample Usages:

  >>> keys = tf.data.Dataset.range(100)
  >>> values = tf.data.Dataset.range(100).map(
  ...     lambda x: tf.strings.as_string(x * 2))
  >>> ds = tf.data.Dataset.zip((keys, values))
  >>> table = tf.data.experimental.table_from_dataset(
  ...                               ds, default_value='n/a', key_dtype=tf.int64)
  >>> table.lookup(tf.constant([0, 1, 2], dtype=tf.int64)).numpy()
  array([b'0', b'2', b'4'], dtype=object)

  Args:
    dataset: A dataset containing (key, value) pairs.
    num_oov_buckets: The number of out-of-vocabulary buckets.
    vocab_size: Number of the elements in the vocabulary, if known.
    default_value: The value to use for out-of-vocabulary feature values.
      Defaults to -1.
    hasher_spec: A `HasherSpec` to specify the hash function to use for
      assignation of out-of-vocabulary buckets.
    key_dtype: The `key` data type.
    name: A name for this op (optional).

  Returns:
    The lookup table based on the given dataset.

  Raises:
    ValueError: If
      * `dataset` does not contain pairs
      * The 2nd item in the `dataset` pairs has a dtype which is incompatible
        with `default_value`
      * `num_oov_buckets` is negative
      * `vocab_size` is not greater than zero
      * The `key_dtype` is not integer or string
  """
  elem_spec = dataset.element_spec
  if len(elem_spec) != 2:
    raise ValueError("The given dataset must contain pairs.")
  if default_value is None:
    default_value = -1
    if not (elem_spec[1].dtype.is_integer or elem_spec[1].dtype.is_floating):
      raise ValueError("The dtype of the values requires manually setting a "
                       "compatible default_value.")
  if num_oov_buckets < 0:
    raise ValueError(
        "num_oov_buckets must be greater or equal than 0, got %d." %
        num_oov_buckets)
  if (not isinstance(vocab_size, ops.Tensor) and vocab_size is not None and
      vocab_size < 1):
    raise ValueError("vocab_size must be greater than 0, got %d." % vocab_size)
  if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype):
    raise TypeError("Only integer and string keys are supported.")
  if vocab_size is not None:
    if isinstance(vocab_size, ops.Tensor):
      vocab_size = math_ops.cast(vocab_size, dtypes.int64)
    dataset = dataset.take(vocab_size)
    dataset = dataset.apply(assert_cardinality(vocab_size))
  with ops.name_scope(name, "string_to_index"):
    initializer = DatasetInitializer(dataset)
    with ops.name_scope(None, "hash_table"):
      table = lookup_ops.StaticHashTableV1(initializer, default_value)
      if num_oov_buckets:
        table = lookup_ops.IdTableWithHashBuckets(
            table,
            num_oov_buckets=num_oov_buckets,
            hasher_spec=hasher_spec,
            key_dtype=key_dtype)
      return table