@parameterized.named_parameters(
        ('vocab_size0', 0),
        ('vocab_size_minus1', -1),
        ('vocab_size_minus5', -5),
    )
    def test_raises_on_bad_vocab_size(self, vocab_size):
        train_client_spec = client_spec.ClientSpec(num_epochs=2,
                                                   batch_size=10,
                                                   max_elements=3,
                                                   shuffle_buffer_size=5)
        with self.assertRaises(ValueError):
            word_prediction_tasks.create_word_prediction_task(
                train_client_spec,
                vocab_size=vocab_size,
                use_synthetic_data=True)

    def test_model_is_compatible_with_preprocessed_data(self):
        train_client_spec = client_spec.ClientSpec(num_epochs=1, batch_size=10)
        baseline_task_spec = word_prediction_tasks.create_word_prediction_task(
            train_client_spec, use_synthetic_data=True)
        centralized_dataset = baseline_task_spec.datasets.get_centralized_test_data(
        )
        sample_batch = next(iter(centralized_dataset))
        model = baseline_task_spec.model_fn()
        model.forward_pass(sample_batch)


if __name__ == '__main__':
    execution_contexts.set_local_python_execution_context()
    tf.test.main()
Beispiel #2
0
 def setUp(self):
     super().setUp()
     execution_contexts.set_local_python_execution_context()