def test_federated_sum(self): bodies = intrinsic_bodies.get_intrinsic_bodies( context_stack_impl.context_stack) @computations.federated_computation( computation_types.FederatedType(tf.int32, placements.CLIENTS)) def foo(x): return bodies[intrinsic_defs.FEDERATED_SUM.uri](x) self.assertEqual(str(foo.type_signature), '({int32}@CLIENTS -> int32@SERVER)') self.assertEqual(foo([1]), 1) self.assertEqual(foo([1, 2, 3]), 6)
def test_federated_generic_add_with_ints(self): bodies = intrinsic_bodies.get_intrinsic_bodies( context_stack_impl.context_stack) @computations.federated_computation( computation_types.FederatedType(tf.int32, placements.CLIENTS)) def foo(x): return bodies[intrinsic_defs.GENERIC_PLUS.uri]([x, x]) self.assertEqual( str(foo.type_signature), '({int32}@CLIENTS -> {int32}@CLIENTS)') self.assertEqual(foo([1]), [2]) self.assertEqual(foo([1, 2, 3]), [2, 4, 6])
def test_generic_add_with_unplaced_named_tuples(self): bodies = intrinsic_bodies.get_intrinsic_bodies( context_stack_impl.context_stack) @computations.federated_computation( computation_types.NamedTupleType([('a', tf.int32), ('b', tf.float32)])) def foo(x): return bodies[intrinsic_defs.GENERIC_PLUS.uri]([x, x]) self.assertEqual( str(foo.type_signature), '(<a=int32,b=float32> -> <a=int32,b=float32>)') self.assertEqual( foo([1, 1.]), anonymous_tuple.AnonymousTuple([('a', 2), ('b', 2.)]))
def test_federated_weighted_mean_with_ints(self): bodies = intrinsic_bodies.get_intrinsic_bodies( context_stack_impl.context_stack) @computations.federated_computation( computation_types.FederatedType(tf.int32, placements.CLIENTS)) def foo(x): return bodies[intrinsic_defs.FEDERATED_WEIGHTED_MEAN.uri]([x, x]) self.assertEqual(str(foo.type_signature), '({int32}@CLIENTS -> float64@SERVER)') self.assertEqual(foo([1]), 1.) self.assertEqual(foo([1, 2, 3]), 14. / 6)
def test_generic_divide_with_unplaced_named_tuples(self): bodies = intrinsic_bodies.get_intrinsic_bodies( context_stack_impl.context_stack) @computations.federated_computation( computation_types.StructType([('a', tf.int32), ('b', tf.float32)])) def foo(x): return bodies[intrinsic_defs.GENERIC_DIVIDE.uri]([x, x]) self.assertEqual(str(foo.type_signature), '(<a=int32,b=float32> -> <a=float64,b=float32>)') self.assertEqual(foo([1, 1.]), structure.Struct([('a', 1.), ('b', 1.)]))
def test_federated_sum(self): bodies = intrinsic_bodies.get_intrinsic_bodies( context_stack_impl.context_stack) @computations.federated_computation( computation_types.FederatedType(tf.int32, placements.CLIENTS)) def foo(x): return bodies['federated_sum'](x) self.assertEqual(str(foo.type_signature), '({int32}@CLIENTS -> int32@SERVER)') self.assertEqual( _body_str(foo), '(FEDERATED_arg -> federated_reduce(<FEDERATED_arg,generic_zero,generic_plus>))' )
def test_generic_add_with_unplaced_named_tuple_and_tensor(self): bodies = intrinsic_bodies.get_intrinsic_bodies( context_stack_impl.context_stack) @computations.federated_computation( computation_types.StructType([[('a', tf.float32), ('b', tf.float32)], tf.float32])) def foo(x): return bodies[intrinsic_defs.GENERIC_PLUS.uri](x) self.assertEqual( str(foo.type_signature), '(<<a=float32,b=float32>,float32> -> <a=float32,b=float32>)') self.assertEqual( foo([[1., 1.], 1.]), structure.Struct([('a', 2.), ('b', 2.)]))
def test_generic_multiply_with_unplaced_named_tuple_and_tensor(self): bodies = intrinsic_bodies.get_intrinsic_bodies( context_stack_impl.context_stack) @computations.federated_computation( computation_types.NamedTupleType([[('a', tf.float32), ('b', tf.float32)], tf.float32])) def foo(x): return bodies[intrinsic_defs.GENERIC_MULTIPLY.uri](x) self.assertEqual( str(foo.type_signature), '(<<a=float32,b=float32>,float32> -> <a=float32,b=float32>)') self.assertEqual( foo([[1., 1.], 2.]), anonymous_tuple.AnonymousTuple([('a', 2.), ('b', 2.)]))
def test_generic_multiply_with_named_tuple_of_federated_types(self): bodies = intrinsic_bodies.get_intrinsic_bodies( context_stack_impl.context_stack) fed_int = computation_types.FederatedType(tf.int32, placements.CLIENTS) @computations.federated_computation([('a', fed_int), ('b', fed_int)]) def foo(x): return bodies[intrinsic_defs.GENERIC_MULTIPLY.uri]([x, x]) self.assertEqual( str(foo.type_signature), '(<a={int32}@CLIENTS,b={int32}@CLIENTS> -> <a={int32}@CLIENTS,b={int32}@CLIENTS>)' ) self.assertEqual( foo([[1], [1]]), anonymous_tuple.AnonymousTuple([('a', [1]), ('b', [1])]))
def test_federated_generic_divide_with_federated_named_tuples(self): bodies = intrinsic_bodies.get_intrinsic_bodies( context_stack_impl.context_stack) @computations.federated_computation( computation_types.FederatedType([('a', tf.int32), ('b', tf.float32)], placement_literals.CLIENTS)) def foo(x): return bodies[intrinsic_defs.GENERIC_DIVIDE.uri]([x, x]) self.assertEqual( str(foo.type_signature), '({<a=int32,b=float32>}@CLIENTS -> {<a=float64,b=float32>}@CLIENTS)') self.assertEqual(foo([[1, 1.]]), [structure.Struct([('a', 1.), ('b', 1.)])]) self.assertEqual( foo([[1, 1.], [1, 2.], [3, 3.]]), [structure.Struct([('a', 1.), ('b', 1.)])] * 3)
def test_generic_divide_with_named_tuple_of_federated_types(self): bodies = intrinsic_bodies.get_intrinsic_bodies( context_stack_impl.context_stack) fed_int = computation_types.FederatedType(tf.int32, placement_literals.CLIENTS) @computations.federated_computation([('a', fed_int), ('b', fed_int)]) def foo(x): return bodies[intrinsic_defs.GENERIC_DIVIDE.uri]([x, x]) self.assertEqual( str(foo.type_signature), '(<a={int32}@CLIENTS,b={int32}@CLIENTS> -> <a={float64}@CLIENTS,b={float64}@CLIENTS>)' ) self.assertEqual( foo([[1], [1]]), structure.Struct([('a', [1.]), ('b', [1.])]))
def test_federated_mean_named_tuple_with_tensor(self): bodies = intrinsic_bodies.get_intrinsic_bodies( context_stack_impl.context_stack) @computations.federated_computation( computation_types.FederatedType([('a', tf.float32), ('b', tf.float32)], placement_literals.CLIENTS)) def foo(x): return bodies[intrinsic_defs.FEDERATED_MEAN.uri](x) self.assertEqual( str(foo.type_signature), '({<a=float32,b=float32>}@CLIENTS -> <a=float32,b=float32>@SERVER)') self.assertEqual(foo([[1., 1.]]), structure.Struct([('a', 1.), ('b', 1.)])) self.assertEqual( foo([[1., 1.], [1., 2.], [1., 3.]]), structure.Struct([('a', 1.), ('b', 2.)]))
def test_federated_sum_named_tuples(self): bodies = intrinsic_bodies.get_intrinsic_bodies( context_stack_impl.context_stack) @computations.federated_computation( computation_types.FederatedType([('a', tf.int32), ('b', tf.float32)], placement_literals.CLIENTS)) def foo(x): return bodies[intrinsic_defs.FEDERATED_SUM.uri](x) self.assertEqual( str(foo.type_signature), '({<a=int32,b=float32>}@CLIENTS -> <a=int32,b=float32>@SERVER)') self.assertDictEqual(structure.to_odict(foo([[1, 2.]])), {'a': 1, 'b': 2.}) self.assertDictEqual( structure.to_odict(foo([[1, 2.], [3, 4.]])), { 'a': 4, 'b': 6. })
def test_federated_generic_divide_with_unnamed_tuples(self): bodies = intrinsic_bodies.get_intrinsic_bodies( context_stack_impl.context_stack) @computations.federated_computation( computation_types.FederatedType([tf.int32, tf.float32], placements.CLIENTS)) def foo(x): return bodies[intrinsic_defs.GENERIC_DIVIDE.uri]([x, x]) self.assertEqual( str(foo.type_signature), '({<int32,float32>}@CLIENTS -> {<float64,float32>}@CLIENTS)') self.assertEqual( foo([[1, 1.]]), [anonymous_tuple.AnonymousTuple([(None, 1.), (None, 1.)])]) self.assertEqual( foo([[1, 1.], [1, 2.], [3, 3.]]), [anonymous_tuple.AnonymousTuple([(None, 1.), (None, 1.)])] * 3)
def test_federated_generic_add_with_unnamed_tuples(self): bodies = intrinsic_bodies.get_intrinsic_bodies( context_stack_impl.context_stack) @computations.federated_computation( computation_types.FederatedType([tf.int32, tf.float32], placement_literals.CLIENTS)) def foo(x): return bodies[intrinsic_defs.GENERIC_PLUS.uri]([x, x]) self.assertEqual( str(foo.type_signature), '({<int32,float32>}@CLIENTS -> {<int32,float32>}@CLIENTS)') self.assertEqual(foo([[1, 1.]]), [structure.Struct([(None, 2), (None, 2.)])]) self.assertEqual(foo([[1, 1.], [1, 2.], [1, 3.]]), [ structure.Struct([(None, 2), (None, 2.)]), structure.Struct([(None, 2), (None, 4.)]), structure.Struct([(None, 2), (None, 6.)]) ])
def test_federated_weighted_mean_named_tuple_with_tensor(self): bodies = intrinsic_bodies.get_intrinsic_bodies( context_stack_impl.context_stack) @computations.federated_computation( computation_types.FederatedType([[('a', tf.float32), ('b', tf.float32)], tf.float32], placements.CLIENTS)) def foo(x): return bodies[intrinsic_defs.FEDERATED_WEIGHTED_MEAN.uri](x) self.assertEqual( str(foo.type_signature), '({<<a=float32,b=float32>,float32>}@CLIENTS -> <a=float32,b=float32>@SERVER)' ) self.assertEqual( foo([[[1., 1.], 1.]]), anonymous_tuple.AnonymousTuple([('a', 1.), ('b', 1.)])) self.assertEqual( foo([[[1., 1.], 1.], [[1., 2.], 2.], [[1., 4.], 4.]]), anonymous_tuple.AnonymousTuple([('a', 1.), ('b', 3.)]))
def test_federated_generic_add_with_named_tuples(self): bodies = intrinsic_bodies.get_intrinsic_bodies( context_stack_impl.context_stack) @computations.federated_computation( computation_types.FederatedType([('a', tf.int32), ('b', tf.float32)], placements.CLIENTS)) def foo(x): return bodies[intrinsic_defs.GENERIC_PLUS.uri]([x, x]) self.assertEqual( str(foo.type_signature), '({<a=int32,b=float32>}@CLIENTS -> {<a=int32,b=float32>}@CLIENTS)') self.assertEqual( foo([[1, 1.]]), [anonymous_tuple.AnonymousTuple([('a', 2), ('b', 2.)])]) self.assertEqual( foo([[1, 1.], [1, 2.], [1, 3.]]), [ anonymous_tuple.AnonymousTuple([('a', 2), ('b', 2.)]), anonymous_tuple.AnonymousTuple([('a', 2), ('b', 4.)]), anonymous_tuple.AnonymousTuple([('a', 2), ('b', 6.)]) ])
def test_federated_sum(self): bodies = intrinsic_bodies.get_intrinsic_bodies( context_stack_impl.context_stack) @computations.federated_computation( computation_types.FederatedType(tf.int32, placements.CLIENTS)) def foo(x): return bodies[intrinsic_defs.FEDERATED_SUM.uri](x) self.assertEqual( str(foo.type_signature), '({int32}@CLIENTS -> int32@SERVER)') body_string = (r'\(FEDERATED_arg -> ' r'federated_aggregate\(<FEDERATED_arg,generic_zero,' r'\(binary_operator_arg' r' -> comp#[a-z0-9]+\(<binary_operator_arg\[0\],' r'binary_operator_arg\[1\]>\)\),\(binary_operator_arg' r' -> comp#[a-z0-9]+\(<binary_operator_arg\[0\],' r'binary_operator_arg\[1\]>\)\),comp#[a-z0-9]+>\)\)') self.assertRegexMatch(_body_str(foo), [body_string]) self.assertEqual(foo([1]), 1) self.assertEqual(foo([1, 2, 3]), 6)
def replace_intrinsics_with_bodies(comp, context_stack): """Iterates over all intrinsic bodies, inlining the intrinsics in `comp`. This function operates on the AST level; meaning, it takes in a `building_blocks.ComputationBuildingBlock` as an argument and returns one as well. `replace_intrinsics_with_bodies` is intended to be the standard reduction function, which will reduce all currently implemented intrinsics to their bodies. Notice that the success of this function depends on the contract of `intrinsic_bodies.get_intrinsic_bodies`, that the dict returned by that function is ordered from more complex intrinsic to less complex intrinsics. Args: comp: Instance of `building_blocks.ComputationBuildingBlock` in which we wish to replace all intrinsics with their bodies. context_stack: Instance of `context_stack_base.ContextStack`, the context stack to use for the bodies of the intrinsics. Returns: Instance of `building_blocks.ComputationBuildingBlock` with all the intrinsics from `intrinsic_bodies.py` inlined with their bodies, along with a Boolean indicating whether there was any inlining in fact done. Raises: TypeError: If the types don't match. """ py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) bodies = intrinsic_bodies.get_intrinsic_bodies(context_stack) transformed = False for uri, body in bodies.items(): comp, uri_found = replace_intrinsics_with_callable( comp, uri, body, context_stack) transformed = transformed or uri_found return comp, transformed
def test_generic_add_federated_named_tuple_by_tensor(self): bodies = intrinsic_bodies.get_intrinsic_bodies( context_stack_impl.context_stack) @computations.federated_computation( computation_types.FederatedType([[('a', tf.float32), ('b', tf.float32)], tf.float32], placement_literals.CLIENTS)) def foo(x): return bodies[intrinsic_defs.GENERIC_PLUS.uri]([x[0], x[1]]) self.assertEqual( str(foo.type_signature), '({<<a=float32,b=float32>,float32>}@CLIENTS -> {<a=float32,b=float32>}@CLIENTS)' ) self.assertEqual( foo([[[1., 1.], 1.]]), [structure.Struct([('a', 2.), ('b', 2.)])]) self.assertEqual( foo([[[1., 1.], 1.], [[1., 2.], 2.], [[1., 4.], 4.]]), [ structure.Struct([('a', 2.), ('b', 2.)]), structure.Struct([('a', 3.), ('b', 4.)]), structure.Struct([('a', 5.), ('b', 8.)]) ])