@parameterized.named_parameters( ('vocab_size0', 0), ('vocab_size_minus1', -1), ('vocab_size_minus5', -5), ) def test_raises_on_bad_vocab_size(self, vocab_size): train_client_spec = client_spec.ClientSpec(num_epochs=2, batch_size=10, max_elements=3, shuffle_buffer_size=5) with self.assertRaises(ValueError): word_prediction_tasks.create_word_prediction_task( train_client_spec, vocab_size=vocab_size, use_synthetic_data=True) def test_model_is_compatible_with_preprocessed_data(self): train_client_spec = client_spec.ClientSpec(num_epochs=1, batch_size=10) baseline_task_spec = word_prediction_tasks.create_word_prediction_task( train_client_spec, use_synthetic_data=True) centralized_dataset = baseline_task_spec.datasets.get_centralized_test_data( ) sample_batch = next(iter(centralized_dataset)) model = baseline_task_spec.model_fn() model.forward_pass(sample_batch) if __name__ == '__main__': execution_contexts.set_local_python_execution_context() tf.test.main()
def setUp(self): super().setUp() execution_contexts.set_local_python_execution_context()