def test_raises_on_bad_vocab_size(self, vocab_size):
     train_client_spec = client_spec.ClientSpec(num_epochs=2,
                                                batch_size=10,
                                                max_elements=3,
                                                shuffle_buffer_size=5)
     with self.assertRaises(ValueError):
         word_prediction_tasks.create_word_prediction_task(
             train_client_spec,
             vocab_size=vocab_size,
             use_synthetic_data=True)
 def test_constructs_with_different_vocab_sizes(self, vocab_size):
     train_client_spec = client_spec.ClientSpec(num_epochs=2,
                                                batch_size=10,
                                                max_elements=3,
                                                shuffle_buffer_size=5)
     baseline_task_spec = word_prediction_tasks.create_word_prediction_task(
         train_client_spec, vocab_size=vocab_size, use_synthetic_data=True)
     self.assertIsInstance(baseline_task_spec, baseline_task.BaselineTask)
 def test_model_is_compatible_with_preprocessed_data(self):
     train_client_spec = client_spec.ClientSpec(num_epochs=1, batch_size=10)
     baseline_task_spec = word_prediction_tasks.create_word_prediction_task(
         train_client_spec, use_synthetic_data=True)
     centralized_dataset = baseline_task_spec.datasets.get_centralized_test_data(
     )
     sample_batch = next(iter(centralized_dataset))
     model = baseline_task_spec.model_fn()
     model.forward_pass(sample_batch)