def test_raises_on_bad_vocab_size(self, vocab_size): train_client_spec = client_spec.ClientSpec(num_epochs=2, batch_size=10, max_elements=3, shuffle_buffer_size=5) with self.assertRaises(ValueError): word_prediction_tasks.create_word_prediction_task( train_client_spec, vocab_size=vocab_size, use_synthetic_data=True)
def test_constructs_with_different_vocab_sizes(self, vocab_size): train_client_spec = client_spec.ClientSpec(num_epochs=2, batch_size=10, max_elements=3, shuffle_buffer_size=5) baseline_task_spec = word_prediction_tasks.create_word_prediction_task( train_client_spec, vocab_size=vocab_size, use_synthetic_data=True) self.assertIsInstance(baseline_task_spec, baseline_task.BaselineTask)
def test_model_is_compatible_with_preprocessed_data(self): train_client_spec = client_spec.ClientSpec(num_epochs=1, batch_size=10) baseline_task_spec = word_prediction_tasks.create_word_prediction_task( train_client_spec, use_synthetic_data=True) centralized_dataset = baseline_task_spec.datasets.get_centralized_test_data( ) sample_batch = next(iter(centralized_dataset)) model = baseline_task_spec.model_fn() model.forward_pass(sample_batch)