def test_two_tuple_zip_fails_bad_args(self): server_test_ref = computation_building_blocks.Reference( 'test', computation_types.NamedTupleType([ computation_types.FederatedType(tf.int32, placements.CLIENTS, True), computation_types.FederatedType(tf.bool, placements.SERVER, True) ])) with self.assertRaisesRegexp(TypeError, 'should be placed at CLIENTS'): _ = value_utils.zip_two_tuple( value_impl.to_value(server_test_ref, None, _context_stack), _context_stack) client_test_ref = computation_building_blocks.Reference( 'test', computation_types.NamedTupleType([ computation_types.FederatedType(tf.int32, placements.CLIENTS, True), computation_types.FederatedType(tf.bool, placements.CLIENTS, True) ])) with self.assertRaisesRegexp(TypeError, '(Expected).*(Value)'): _ = value_utils.zip_two_tuple(client_test_ref, _context_stack) three_tuple_test_ref = computation_building_blocks.Reference( 'three_tuple_test', computation_types.NamedTupleType([ computation_types.FederatedType(tf.int32, placements.CLIENTS, True), computation_types.FederatedType(tf.int32, placements.CLIENTS, True), computation_types.FederatedType(tf.int32, placements.CLIENTS, True) ])) with self.assertRaisesRegexp(ValueError, 'must be a 2-tuple'): _ = value_utils.zip_two_tuple( value_impl.to_value(three_tuple_test_ref, None, _context_stack), _context_stack)
def test_two_tuple_zip_with_client_all_equal_int_and_bool(self): test_ref = computation_building_blocks.Reference( 'test', computation_types.NamedTupleType([ computation_types.FederatedType(tf.int32, placements.CLIENTS, True), computation_types.FederatedType(tf.bool, placements.CLIENTS, True) ])) zipped = value_utils.zip_two_tuple( value_impl.to_value(test_ref, None, _context_stack), _context_stack) self.assertEqual(str(zipped.type_signature), '{<int32,bool>}@CLIENTS')
def federated_zip(self, value): """Implements `federated_zip` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ # TODO(b/113112108): Extend this to accept *args. # TODO(b/113112108): We use the iterate/unwrap approach below because # our type system is not powerful enough to express the concept of # "an operation that takes tuples of T of arbitrary length", and therefore # the intrinsic federated_zip must only take a fixed number of arguments, # here fixed at 2. There are other potential approaches to getting around # this problem (e.g. having the operator act on sequences and thereby # sidestepping the issue) which we may want to explore. value = value_impl.to_value(value, None, self._context_stack) py_typecheck.check_type(value, value_base.Value) py_typecheck.check_type(value.type_signature, computation_types.NamedTupleType) elements_to_zip = anonymous_tuple.to_elements(value.type_signature) num_elements = len(elements_to_zip) py_typecheck.check_type(elements_to_zip[0][1], computation_types.FederatedType) output_placement = elements_to_zip[0][1].placement zip_apply_fn = { placements.CLIENTS: self.federated_map, placements.SERVER: self.federated_apply } if output_placement not in zip_apply_fn: raise TypeError( 'federated_zip only supports components with CLIENTS or ' 'SERVER placement, [{}] is unsupported'.format(output_placement)) if num_elements == 0: raise ValueError('federated_zip is only supported on nonempty tuples.') if num_elements == 1: input_ref = computation_building_blocks.Reference( 'value_in', elements_to_zip[0][1].member) output_tuple = computation_building_blocks.Tuple([(elements_to_zip[0][0], input_ref)]) lam = computation_building_blocks.Lambda( 'value_in', input_ref.type_signature, output_tuple) return zip_apply_fn[output_placement](lam, value[0]) for _, elem in elements_to_zip: py_typecheck.check_type(elem, computation_types.FederatedType) if elem.placement is not output_placement: raise TypeError( 'The elements of the named tuple to zip must be placed at {}.' .format(output_placement)) named_comps = [(elements_to_zip[k][0], value_impl.ValueImpl.get_comp(value[k])) for k in range(len(value))] tuple_to_zip = anonymous_tuple.AnonymousTuple( [named_comps[0], named_comps[1]]) zipped = value_utils.zip_two_tuple( value_impl.to_value(tuple_to_zip, None, self._context_stack), self._context_stack) inputs = value_impl.to_value( computation_building_blocks.Reference( 'inputs', zipped.type_signature.member), None, self._context_stack) flatten_func = value_impl.to_value( computation_building_blocks.Lambda( 'inputs', zipped.type_signature.member, value_impl.ValueImpl.get_comp(inputs)), None, self._context_stack) for k in range(2, num_elements): zipped = value_utils.zip_two_tuple( value_impl.to_value( computation_building_blocks.Tuple( [value_impl.ValueImpl.get_comp(zipped), named_comps[k]]), None, self._context_stack), self._context_stack) last_zipped = (named_comps[k][0], named_comps[k][1].type_signature.member) flatten_func = value_utils.flatten_first_index(flatten_func, last_zipped, self._context_stack) return zip_apply_fn[output_placement](flatten_func, zipped)