def testCorrectCardinality(self): dataset = dataset_ops.Dataset.range(10).filter(lambda x: True) self.assertEqual( self.evaluate(cardinality.cardinality(dataset)), cardinality.UNKNOWN) self.assertDatasetProduces(dataset, expected_output=range(10)) dataset = dataset.apply(cardinality.assert_cardinality(10)) self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10) self.assertDatasetProduces(dataset, expected_output=range(10))
def testIncorrectCardinality(self, num_elements, asserted_cardinality, expected_error): dataset = dataset_ops.Dataset.range(num_elements) dataset = dataset.apply( cardinality.assert_cardinality(asserted_cardinality)) get_next = self.getNext(dataset) with self.assertRaisesRegex(errors.FailedPreconditionError, expected_error): while True: self.evaluate(get_next())
def testAssertCardinality(self): dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) dataset = dataset.flat_map(core_readers.TFRecordDataset) dataset = dataset.batch(5) dataset = dataset.apply(cardinality.assert_cardinality(42)) dataset = distribute._AutoShardDataset(dataset, 5, 0) expected = [ b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension for f in (0, 5) for r in range(0, 10) ] self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
def build_dataset(num_elements): return dataset_ops.Dataset.range(num_elements).apply( cardinality.assert_cardinality(num_elements))
def table_from_dataset(dataset=None, num_oov_buckets=0, vocab_size=None, default_value=None, hasher_spec=lookup_ops.FastHashSpec, key_dtype=dtypes.string, name=None): """Returns a lookup table based on the given dataset. This operation constructs a lookup table based on the given dataset of pairs of (key, value). Any lookup of an out-of-vocabulary token will return a bucket ID based on its hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the `default_value`. The bucket ID range is `[vocabulary size, vocabulary size + num_oov_buckets - 1]`. Sample Usages: >>> keys = tf.data.Dataset.range(100) >>> values = tf.data.Dataset.range(100).map( ... lambda x: tf.strings.as_string(x * 2)) >>> ds = tf.data.Dataset.zip((keys, values)) >>> table = tf.data.experimental.table_from_dataset( ... ds, default_value='n/a', key_dtype=tf.int64) >>> table.lookup(tf.constant([0, 1, 2], dtype=tf.int64)).numpy() array([b'0', b'2', b'4'], dtype=object) Args: dataset: A dataset containing (key, value) pairs. num_oov_buckets: The number of out-of-vocabulary buckets. vocab_size: Number of the elements in the vocabulary, if known. default_value: The value to use for out-of-vocabulary feature values. Defaults to -1. hasher_spec: A `HasherSpec` to specify the hash function to use for assignation of out-of-vocabulary buckets. key_dtype: The `key` data type. name: A name for this op (optional). Returns: The lookup table based on the given dataset. Raises: ValueError: If * `dataset` does not contain pairs * The 2nd item in the `dataset` pairs has a dtype which is incompatible with `default_value` * `num_oov_buckets` is negative * `vocab_size` is not greater than zero * The `key_dtype` is not integer or string """ elem_spec = dataset.element_spec if len(elem_spec) != 2: raise ValueError("The given dataset must contain pairs.") if default_value is None: default_value = -1 if not (elem_spec[1].dtype.is_integer or elem_spec[1].dtype.is_floating): raise ValueError("The dtype of the values requires manually setting a " "compatible default_value.") if num_oov_buckets < 0: raise ValueError( "num_oov_buckets must be greater or equal than 0, got %d." % num_oov_buckets) if (not isinstance(vocab_size, ops.Tensor) and vocab_size is not None and vocab_size < 1): raise ValueError("vocab_size must be greater than 0, got %d." % vocab_size) if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype): raise TypeError("Only integer and string keys are supported.") if vocab_size is not None: if isinstance(vocab_size, ops.Tensor): vocab_size = math_ops.cast(vocab_size, dtypes.int64) dataset = dataset.take(vocab_size) dataset = dataset.apply(assert_cardinality(vocab_size)) with ops.name_scope(name, "string_to_index"): initializer = DatasetInitializer(dataset) with ops.name_scope(None, "hash_table"): table = lookup_ops.StaticHashTableV1(initializer, default_value) if num_oov_buckets: table = lookup_ops.IdTableWithHashBuckets( table, num_oov_buckets=num_oov_buckets, hasher_spec=hasher_spec, key_dtype=key_dtype) return table