Beispiel #1
0
 def test_batch_and_split_fn_returns_dataset_with_correct_type_spec(self):
   token = tf.constant([[0, 1, 2, 3, 4]], dtype=tf.int64)
   ds = tf.data.Dataset.from_tensor_slices(token)
   padded_and_batched = word_prediction_preprocessing.batch_and_split(
       ds, sequence_length=6, batch_size=1)
   self.assertIsInstance(padded_and_batched, tf.data.Dataset)
   self.assertEqual(padded_and_batched.element_spec, (tf.TensorSpec(
       [None, 6], dtype=tf.int64), tf.TensorSpec([None, 6], dtype=tf.int64)))
Beispiel #2
0
 def test_batch_and_split_fn_returns_dataset_yielding_expected_elements(self):
   token = tf.constant([[0, 1, 2, 3, 4]], dtype=tf.int64)
   ds = tf.data.Dataset.from_tensor_slices(token)
   padded_and_batched = word_prediction_preprocessing.batch_and_split(
       ds, sequence_length=6, batch_size=1)
   num_elems = 0
   for elem in padded_and_batched:
     self.assertAllEqual(
         self.evaluate(elem[0]),
         tf.constant([[0, 1, 2, 3, 4, 0]], dtype=tf.int64))
     self.assertAllEqual(
         self.evaluate(elem[1]),
         tf.constant([[1, 2, 3, 4, 0, 0]], dtype=tf.int64))
     num_elems += 1
   self.assertEqual(num_elems, 1)