def test_batch_and_split_fn_returns_dataset_with_correct_type_spec(self):
   token = tf.constant([[0, 1, 2, 3, 4]], dtype=tf.int64)
   ds = tf.data.Dataset.from_tensor_slices(token)
   padded_and_batched = stackoverflow_word_prediction.batch_and_split(
       ds, max_sequence_length=6, batch_size=1)
   self.assertIsInstance(padded_and_batched, tf.data.Dataset)
   self.assertEqual(padded_and_batched.element_spec, (tf.TensorSpec(
       [None, 6], dtype=tf.int64), tf.TensorSpec([None, 6], dtype=tf.int64)))
 def test_batch_and_split_fn_returns_dataset_yielding_expected_elements(self):
   token = tf.constant([[0, 1, 2, 3, 4]], dtype=tf.int64)
   ds = tf.data.Dataset.from_tensor_slices(token)
   padded_and_batched = stackoverflow_word_prediction.batch_and_split(
       ds, max_sequence_length=6, batch_size=1)
   num_elems = 0
   for elem in padded_and_batched:
     self.assertAllEqual(
         self.evaluate(elem[0]), np.array([[0, 1, 2, 3, 4, 0]], np.int64))
     self.assertAllEqual(
         self.evaluate(elem[1]), np.array([[1, 2, 3, 4, 0, 0]], np.int64))
     num_elems += 1
   self.assertEqual(num_elems, 1)
 def preprocess_fn(dataset):
     to_ids = stackoverflow_word_prediction.build_to_ids_fn(
         vocab=vocab,
         max_sequence_length=max_sequence_length,
         num_oov_buckets=num_oov_buckets)
     dataset = dataset.take(max_elements_per_client)
     if sort_by_date:
         dataset = dataset.batch(max_elements_per_client).map(
             _sort_examples_by_date).unbatch()
     else:
         dataset = dataset.shuffle(max_elements_per_client)
     dataset = dataset.map(to_ids,
                           num_parallel_calls=tf.data.experimental.AUTOTUNE)
     return stackoverflow_word_prediction.batch_and_split(
         dataset, max_sequence_length, client_batch_size)