def get_named_parameters_for_supported_intrinsics() -> List[Tuple[str, Any]]: # pyformat: disable return [ ('intrinsic_def_federated_aggregate', *executor_test_utils.create_whimsy_intrinsic_def_federated_aggregate()), ('intrinsic_def_federated_apply', *executor_test_utils.create_whimsy_intrinsic_def_federated_apply()), ('intrinsic_def_federated_broadcast', *executor_test_utils.create_whimsy_intrinsic_def_federated_broadcast()), ('intrinsic_def_federated_eval_at_clients', *executor_test_utils.create_whimsy_intrinsic_def_federated_eval_at_clients()), ('intrinsic_def_federated_eval_at_server', *executor_test_utils.create_whimsy_intrinsic_def_federated_eval_at_server()), ('intrinsic_def_federated_map', *executor_test_utils.create_whimsy_intrinsic_def_federated_map()), ('intrinsic_def_federated_map_all_equal', *executor_test_utils.create_whimsy_intrinsic_def_federated_map_all_equal()), ('intrinsic_def_federated_mean', *executor_test_utils.create_whimsy_intrinsic_def_federated_mean()), ('intrinsic_def_federated_select', *executor_test_utils.create_whimsy_intrinsic_def_federated_select()), ('intrinsic_def_federated_sum', *executor_test_utils.create_whimsy_intrinsic_def_federated_sum()), ('intrinsic_def_federated_value_at_clients', *executor_test_utils.create_whimsy_intrinsic_def_federated_value_at_clients()), ('intrinsic_def_federated_value_at_server', *executor_test_utils.create_whimsy_intrinsic_def_federated_value_at_server()), ('intrinsic_def_federated_weighted_mean', *executor_test_utils.create_whimsy_intrinsic_def_federated_weighted_mean()), ('intrinsic_def_federated_zip_at_clients', *executor_test_utils.create_whimsy_intrinsic_def_federated_zip_at_clients()), ('intrinsic_def_federated_zip_at_server', *executor_test_utils.create_whimsy_intrinsic_def_federated_zip_at_server()), ]
class FederatingExecutorCreateCallTest(executor_test_utils.AsyncTestCase, parameterized.TestCase): # pyformat: disable @parameterized.named_parameters([ ('intrinsic_def_federated_aggregate', *executor_test_utils.create_whimsy_intrinsic_def_federated_aggregate(), [executor_test_utils.create_whimsy_value_at_clients(), executor_test_utils.create_whimsy_value_unplaced(), executor_test_utils.create_whimsy_computation_tensorflow_add(), executor_test_utils.create_whimsy_computation_tensorflow_add(), executor_test_utils.create_whimsy_computation_tensorflow_identity()], 43.0), ('intrinsic_def_federated_apply', *executor_test_utils.create_whimsy_intrinsic_def_federated_apply(), [executor_test_utils.create_whimsy_computation_tensorflow_identity(), executor_test_utils.create_whimsy_value_at_server()], 10.0), ('intrinsic_def_federated_broadcast', *executor_test_utils.create_whimsy_intrinsic_def_federated_broadcast(), [executor_test_utils.create_whimsy_value_at_server()], 10.0), ('intrinsic_def_federated_eval_at_clients', *executor_test_utils.create_whimsy_intrinsic_def_federated_eval_at_clients(), [executor_test_utils.create_whimsy_computation_tensorflow_constant()], [10.0] * 3), ('intrinsic_def_federated_eval_at_server', *executor_test_utils.create_whimsy_intrinsic_def_federated_eval_at_server(), [executor_test_utils.create_whimsy_computation_tensorflow_constant()], 10.0), ('intrinsic_def_federated_map', *executor_test_utils.create_whimsy_intrinsic_def_federated_map(), [executor_test_utils.create_whimsy_computation_tensorflow_identity(), executor_test_utils.create_whimsy_value_at_clients()], [10.0, 11.0, 12.0]), ('intrinsic_def_federated_map_all_equal', *executor_test_utils.create_whimsy_intrinsic_def_federated_map_all_equal(), [executor_test_utils.create_whimsy_computation_tensorflow_identity(), executor_test_utils.create_whimsy_value_at_clients_all_equal()], 10.0), ('intrinsic_def_federated_mean', *executor_test_utils.create_whimsy_intrinsic_def_federated_mean(), [executor_test_utils.create_whimsy_value_at_clients()], 11.0), ('intrinsic_def_federated_select', *executor_test_utils.create_whimsy_intrinsic_def_federated_select(), executor_test_utils.create_whimsy_federated_select_args(), executor_test_utils.create_whimsy_federated_select_expected_result(), ), ('intrinsic_def_federated_sum', *executor_test_utils.create_whimsy_intrinsic_def_federated_sum(), [executor_test_utils.create_whimsy_value_at_clients()], 33.0), ('intrinsic_def_federated_value_at_clients', *executor_test_utils.create_whimsy_intrinsic_def_federated_value_at_clients(), [executor_test_utils.create_whimsy_value_unplaced()], 10.0), ('intrinsic_def_federated_value_at_server', *executor_test_utils.create_whimsy_intrinsic_def_federated_value_at_server(), [executor_test_utils.create_whimsy_value_unplaced()], 10.0), ('intrinsic_def_federated_weighted_mean', *executor_test_utils.create_whimsy_intrinsic_def_federated_weighted_mean(), [executor_test_utils.create_whimsy_value_at_clients(), executor_test_utils.create_whimsy_value_at_clients()], 11.060606), ('intrinsic_def_federated_zip_at_clients', *executor_test_utils.create_whimsy_intrinsic_def_federated_zip_at_clients(), [executor_test_utils.create_whimsy_value_at_clients(), executor_test_utils.create_whimsy_value_at_clients()], [structure.Struct([(None, 10.0), (None, 10.0)]), structure.Struct([(None, 11.0), (None, 11.0)]), structure.Struct([(None, 12.0), (None, 12.0)])]), ('intrinsic_def_federated_zip_at_server', *executor_test_utils.create_whimsy_intrinsic_def_federated_zip_at_server(), [executor_test_utils.create_whimsy_value_at_server(), executor_test_utils.create_whimsy_value_at_server()], structure.Struct([(None, 10.0), (None, 10.0)])), ('computation_intrinsic', *executor_test_utils.create_whimsy_computation_intrinsic(), [executor_test_utils.create_whimsy_computation_tensorflow_constant()], 10.0), ('computation_tensorflow', *executor_test_utils.create_whimsy_computation_tensorflow_identity(), [executor_test_utils.create_whimsy_value_unplaced()], 10.0), ]) # pyformat: enable def test_returns_value_with_comp_and_arg(self, comp, comp_type, args, expected_result): executor = create_test_executor() comp = self.run_sync(executor.create_value(comp, comp_type)) elements = [self.run_sync(executor.create_value(*x)) for x in args] if len(elements) > 1: arg = self.run_sync(executor.create_struct(elements)) else: arg = elements[0] result = self.run_sync(executor.create_call(comp, arg)) self.assertIsInstance(result, executor_value_base.ExecutorValue) self.assertEqual(result.type_signature.compact_representation(), comp_type.result.compact_representation()) actual_result = self.run_sync(result.compute()) self.assert_maybe_list_equal(actual_result, expected_result) def assert_maybe_list_equal(self, actual_result, expected_result): if (all_isinstance([actual_result, expected_result], list) or all_isinstance([actual_result, expected_result], tf.data.Dataset)): for actual_element, expected_element in zip(actual_result, expected_result): self.assert_maybe_list_equal(actual_element, expected_element) else: self.assertEqual(actual_result, expected_result) def test_returns_value_with_intrinsic_def_federated_eval_at_clients_and_random( self): executor = create_test_executor(number_of_clients=3) comp, comp_type = executor_test_utils.create_whimsy_intrinsic_def_federated_eval_at_clients( ) arg, arg_type = executor_test_utils.create_whimsy_computation_tensorflow_random( ) comp = self.run_sync(executor.create_value(comp, comp_type)) arg = self.run_sync(executor.create_value(arg, arg_type)) result = self.run_sync(executor.create_call(comp, arg)) self.assertIsInstance(result, executor_value_base.ExecutorValue) self.assertEqual(result.type_signature.compact_representation(), comp_type.result.compact_representation()) actual_result = self.run_sync(result.compute()) unique_results = set([x.numpy() for x in actual_result]) if len(actual_result) != len(unique_results): self.fail( 'Expected the result to contain different random numbers, found {}.' .format(actual_result)) # pyformat: disable @parameterized.named_parameters([ ('computation_tensorflow', *executor_test_utils.create_whimsy_computation_tensorflow_empty()), ]) # pyformat: enable def test_returns_value_with_comp_only(self, comp, comp_type): executor = create_test_executor() comp = self.run_sync(executor.create_value(comp, comp_type)) result = self.run_sync(executor.create_call(comp)) self.assertIsInstance(result, executor_value_base.ExecutorValue) self.assertEqual(result.type_signature.compact_representation(), comp_type.result.compact_representation()) actual_result = self.run_sync(result.compute()) expected_result = [] self.assertCountEqual(actual_result, expected_result) def test_raises_type_error_with_unembedded_comp(self): executor = create_test_executor() comp, _ = executor_test_utils.create_whimsy_computation_tensorflow_identity( ) arg, arg_type = executor_test_utils.create_whimsy_value_unplaced() arg = self.run_sync(executor.create_value(arg, arg_type)) with self.assertRaises(TypeError): self.run_sync(executor.create_call(comp, arg)) def test_raises_type_error_with_unembedded_arg(self): executor = create_test_executor() comp, comp_type = executor_test_utils.create_whimsy_computation_tensorflow_identity( ) arg, _ = executor_test_utils.create_whimsy_value_unplaced() comp = self.run_sync(executor.create_value(comp, comp_type)) with self.assertRaises(TypeError): self.run_sync(executor.create_call(comp, arg)) # pyformat: disable @parameterized.named_parameters([ ('computation_intrinsic', *executor_test_utils.create_whimsy_computation_intrinsic()), ('computation_lambda', *executor_test_utils.create_whimsy_computation_lambda_identity()), ('computation_tensorflow', *executor_test_utils.create_whimsy_computation_tensorflow_identity()), ] + get_named_parameters_for_supported_intrinsics()) # pyformat: enable def test_raises_type_error_with_comp_and_bad_arg(self, comp, comp_type): executor = create_test_executor() bad_arg = 'string' bad_arg_type = computation_types.TensorType(tf.string) comp = self.run_sync(executor.create_value(comp, comp_type)) arg = self.run_sync(executor.create_value(bad_arg, bad_arg_type)) with self.assertRaises(TypeError): self.run_sync(executor.create_call(comp, arg)) # pyformat: disable @parameterized.named_parameters([ ('computation_lambda', *executor_test_utils.create_whimsy_computation_lambda_empty()), ('federated_type_at_clients', *executor_test_utils.create_whimsy_value_at_clients()), ('federated_type_at_clients_all_equal', *executor_test_utils.create_whimsy_value_at_clients_all_equal()), ('federated_type_at_server', *executor_test_utils.create_whimsy_value_at_server()), ('unplaced_type', *executor_test_utils.create_whimsy_value_unplaced()), ]) # pyformat: enable def test_raises_value_error_with_comp(self, comp, comp_type): executor = create_test_executor() comp = self.run_sync(executor.create_value(comp, comp_type)) with self.assertRaises(ValueError): self.run_sync(executor.create_call(comp)) def test_raises_not_implemented_error_with_intrinsic_def_federated_secure_sum_bitwidth( self): executor = create_test_executor() comp, comp_type = executor_test_utils.create_whimsy_intrinsic_def_federated_secure_sum_bitwidth( ) arg_1 = [10, 11, 12] arg_1_type = computation_types.at_clients(tf.int32, all_equal=False) arg_2 = 10 arg_2_type = computation_types.TensorType(tf.int32) comp = self.run_sync(executor.create_value(comp, comp_type)) arg_1 = self.run_sync(executor.create_value(arg_1, arg_1_type)) arg_2 = self.run_sync(executor.create_value(arg_2, arg_2_type)) args = self.run_sync(executor.create_struct([arg_1, arg_2])) with self.assertRaises(NotImplementedError): self.run_sync(executor.create_call(comp, args)) def test_raises_not_implemented_error_with_unimplemented_intrinsic(self): executor = create_test_executor() # `whimsy_intrinsic` definition is needed to allow lookup. whimsy_intrinsic = intrinsic_defs.IntrinsicDef( 'WHIMSY_INTRINSIC', 'whimsy_intrinsic', computation_types.AbstractType('T')) type_signature = computation_types.TensorType(tf.int32) comp = pb.Computation( intrinsic=pb.Intrinsic(uri='whimsy_intrinsic'), type=type_serialization.serialize_type(type_signature)) del whimsy_intrinsic comp = self.run_sync(executor.create_value(comp)) with self.assertRaises(NotImplementedError): self.run_sync(executor.create_call(comp))