preprocess_fn = word_prediction_preprocessing.create_preprocess_fn( preprocess_spec, vocab=['A'], sequence_length=10) preprocessed_ds = preprocess_fn(ds) self.assertEqual( _compute_length_of_dataset(preprocessed_ds), tf.cast(tf.math.ceil(num_epochs / batch_size), tf.int32)) @parameterized.named_parameters( ('max_elements1', 1), ('max_elements3', 3), ('max_elements7', 7), ('max_elements11', 11), ('max_elements18', 18), ) def test_ds_length_with_max_elements(self, max_elements): repeat_size = 10 ds = tf.data.Dataset.from_tensor_slices(TEST_DATA) preprocess_spec = client_spec.ClientSpec( num_epochs=repeat_size, batch_size=1, max_elements=max_elements) preprocess_fn = word_prediction_preprocessing.create_preprocess_fn( preprocess_spec, vocab=['A']) preprocessed_ds = preprocess_fn(ds) self.assertEqual( _compute_length_of_dataset(preprocessed_ds), min(repeat_size, max_elements)) if __name__ == '__main__': execution_contexts.set_local_execution_context() tf.test.main()
def test_inner_federated_type_raises(self): with self.assertRaisesRegex(TypeError, 'FederatedType'): distributors.build_broadcast_process( computation_types.to_type([SERVER_FLOAT, SERVER_FLOAT])) class BroadcastProcessExecutionTest(test_case.TestCase): def test_broadcast_scalar(self): broadcast_process = distributors.build_broadcast_process( SERVER_FLOAT.member) output = broadcast_process.next(broadcast_process.initialize(), 2.5) self.assertEqual((), output.state) self.assertAllClose(2.5, output.result) self.assertEqual((), output.measurements) def test_broadcast_struct(self): struct_type = computation_types.to_type([(tf.float32, (2, )), tf.int32]) broadcast_process = distributors.build_broadcast_process(struct_type) output = broadcast_process.next(broadcast_process.initialize(), ((1.0, 2.5), 3)) self.assertEqual((), output.state) self.assertAllClose(((1.0, 2.5), 3), output.result) self.assertEqual((), output.measurements) if __name__ == '__main__': execution_contexts.set_local_execution_context(default_num_clients=1) test_case.main()