def test_get_size_info(self, num_clients): @computations.federated_computation( type_factory.at_clients(computation_types.SequenceType(tf.float32)), type_factory.at_server(tf.float32)) def comp(temperatures, threshold): client_data = [temperatures, intrinsics.federated_broadcast(threshold)] result_map = intrinsics.federated_map( count_over, intrinsics.federated_zip(client_data)) count_map = intrinsics.federated_map(count_total, temperatures) return intrinsics.federated_mean(result_map, count_map) factory = executor_stacks.sizing_executor_factory(num_clients=num_clients) default_executor.set_default_executor(factory) to_float = lambda x: tf.cast(x, tf.float32) temperatures = [tf.data.Dataset.range(10).map(to_float)] * num_clients threshold = 15.0 comp(temperatures, threshold) # Each client receives a tf.float32 and uploads two tf.float32 values. expected_broadcast_bits = num_clients * 32 expected_aggregate_bits = expected_broadcast_bits * 2 expected = ({ (('CLIENTS', num_clients),): [[1, tf.float32]] * num_clients }, { (('CLIENTS', num_clients),): [[1, tf.float32]] * num_clients * 2 }, [expected_broadcast_bits], [expected_aggregate_bits]) self.assertEqual(expected, factory.get_size_info())
def test_with_reference_executor(self): context_stack = context_stack_impl.context_stack executor = reference_executor.ReferenceExecutor() self.assertIsNot(context_stack.current, executor) default_executor.set_default_executor(executor) self.assertIs(context_stack.current, executor)
def test_as_default_executor(self): ex = executor.IreeExecutor(backend_info.VULKAN_SPIRV) default_executor.set_default_executor( executor_factory.create_executor_factory(lambda _: ex)) @computations.tf_computation(tf.float32) def comp(x): return x + 1.0 self.assertEqual(comp(10.0), 11.0)
def test_with_executor_factory(self): context_stack = context_stack_impl.context_stack executor_factory_impl = executor_factory.ExecutorFactoryImpl(lambda _: None) self.assertNotIsInstance(context_stack.current, execution_context.ExecutionContext) default_executor.set_default_executor(executor_factory_impl) self.assertIsInstance(context_stack.current, execution_context.ExecutionContext) self.assertIs(context_stack.current._executor_factory, executor_factory_impl)
('sum_example_with_no_federated_secure_sum', get_iterative_process_for_sum_example_with_no_federated_secure_sum()), ('sum_example_with_no_update', get_iterative_process_for_sum_example_with_no_update()), ('sum_example_with_no_server_state', get_iterative_process_for_sum_example_with_no_server_state()), ('minimal_sum_example', get_iterative_process_for_minimal_sum_example()), ('example_with_unused_lambda_arg', test_utils.get_iterative_process_for_example_with_unused_lambda_arg()), ('example_with_unused_tf_computation_arg', test_utils.get_iterative_process_for_example_with_unused_tf_computation_arg()), ) # pyformat: enable def test_returns_canonical_form(self, ip): cf = canonical_form_utils.get_canonical_form_for_iterative_process(ip) self.assertIsInstance(cf, canonical_form.CanonicalForm) def test_raises_value_error_for_sum_example_with_no_aggregation(self): ip = get_iterative_process_for_sum_example_with_no_aggregation() with self.assertRaises(ValueError): canonical_form_utils.get_canonical_form_for_iterative_process(ip) if __name__ == '__main__': reference_executor = reference_executor.ReferenceExecutor() default_executor.set_default_executor(reference_executor) test.main()
'The return type of next_fn must be assignable to the first parameter' ): @computations.federated_computation(tf.int32) def add_bad_result(_): return 0.0 iterative_process.IterativeProcess(initialize_fn=initialize, next_fn=add_bad_result) with self.assertRaisesRegex( TypeError, 'The return type of next_fn must be assignable to the first parameter' ): @computations.federated_computation(tf.int32) def add_bad_multi_result(_): return 0.0, 0 iterative_process.IterativeProcess(initialize_fn=initialize, next_fn=add_bad_multi_result) if __name__ == '__main__': # Note: num_clients must be explicit here to correctly test the broadcast # behavior. Otherwise TFF will infer there are zero clients, which is an # error. executor = executor_stacks.local_executor_factory(num_clients=3) default_executor.set_default_executor(executor) test.main()
def test_raises_type_error_with_none(self): with self.assertRaises(TypeError): default_executor.set_default_executor(None)