def test_batch_and_split_fn_returns_dataset_with_correct_type_spec(self): token = tf.constant([[0, 1, 2, 3, 4]], dtype=tf.int64) ds = tf.data.Dataset.from_tensor_slices(token) padded_and_batched = stackoverflow_word_prediction.batch_and_split( ds, max_sequence_length=6, batch_size=1) self.assertIsInstance(padded_and_batched, tf.data.Dataset) self.assertEqual(padded_and_batched.element_spec, (tf.TensorSpec( [None, 6], dtype=tf.int64), tf.TensorSpec([None, 6], dtype=tf.int64)))
def test_batch_and_split_fn_returns_dataset_yielding_expected_elements(self): token = tf.constant([[0, 1, 2, 3, 4]], dtype=tf.int64) ds = tf.data.Dataset.from_tensor_slices(token) padded_and_batched = stackoverflow_word_prediction.batch_and_split( ds, max_sequence_length=6, batch_size=1) num_elems = 0 for elem in padded_and_batched: self.assertAllEqual( self.evaluate(elem[0]), np.array([[0, 1, 2, 3, 4, 0]], np.int64)) self.assertAllEqual( self.evaluate(elem[1]), np.array([[1, 2, 3, 4, 0, 0]], np.int64)) num_elems += 1 self.assertEqual(num_elems, 1)
def preprocess_fn(dataset): to_ids = stackoverflow_word_prediction.build_to_ids_fn( vocab=vocab, max_sequence_length=max_sequence_length, num_oov_buckets=num_oov_buckets) dataset = dataset.take(max_elements_per_client) if sort_by_date: dataset = dataset.batch(max_elements_per_client).map( _sort_examples_by_date).unbatch() else: dataset = dataset.shuffle(max_elements_per_client) dataset = dataset.map(to_ids, num_parallel_calls=tf.data.experimental.AUTOTUNE) return stackoverflow_word_prediction.batch_and_split( dataset, max_sequence_length, client_batch_size)