def test_test_preprocess_fn_return_dataset_element_spec_oov_buckets(self):
   ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
   test_preprocess_fn = stackoverflow_dataset.create_test_dataset_preprocess_fn(
       max_seq_len=10, vocab=['one', 'must'], num_oov_buckets=10)
   test_preprocessed_ds = test_preprocess_fn(ds)
   self.assertEqual(test_preprocessed_ds.element_spec,
                    (tf.TensorSpec(shape=[None, 10], dtype=tf.int64),
                     tf.TensorSpec(shape=[None, 10], dtype=tf.int64)))
 def test_test_preprocess_fn_returns_correct_sequence(self):
   ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
   test_preprocess_fn = stackoverflow_dataset.create_test_dataset_preprocess_fn(
       max_seq_len=6, vocab=['one', 'must'], num_oov_buckets=1)
   test_preprocessed_ds = test_preprocess_fn(ds)
   element = next(iter(test_preprocessed_ds))
   # BOS is len(vocab)+2, EOS is len(vocab)+3, pad is 0, OOV is len(vocab)+1
   self.assertAllEqual(
       self.evaluate(element[0]), np.array([[4, 1, 2, 3, 5, 0]]))
 def test_test_preprocess_fn_returns_correct_sequence_oov_buckets(self):
   ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
   test_preprocess_fn = stackoverflow_dataset.create_test_dataset_preprocess_fn(
       max_seq_len=6, vocab=['one', 'must'], num_oov_buckets=3)
   test_preprocessed_ds = test_preprocess_fn(ds)
   element = next(iter(test_preprocessed_ds))
   # BOS is len(vocab)+3+1
   self.assertEqual(self.evaluate(element[0])[0][0], 6)
   self.assertEqual(self.evaluate(element[0])[0][1], 1)
   self.assertEqual(self.evaluate(element[0])[0][2], 2)
   # OOV is [len(vocab)+1, len(vocab)+2, len(vocab)+3]
   self.assertIn(self.evaluate(element[0])[0][3], [3, 4, 5])
   # EOS is len(vocab)+3+2
   self.assertEqual(self.evaluate(element[0])[0][4], 7)
   # pad is 0
   self.assertEqual(self.evaluate(element[0])[0][5], 0)