def test_flatten_fn_with_names(self, n): input_reference = computation_building_blocks.Reference( 'test', [(str(k), tf.int32) for k in range(n)]) input_fn = computation_building_blocks.Lambda( 'test', input_reference.type_signature, input_reference) unnamed_type_to_add = (None, computation_types.to_type(tf.int32)) unnamed_input_type = computation_types.NamedTupleType( [input_reference.type_signature, unnamed_type_to_add]) unnamed_desired_output_type = computation_types.to_type( [(str(k), tf.int32) for k in range(n)] + [tf.int32]) unnamed_desired_fn_type = computation_types.FunctionType( unnamed_input_type, unnamed_desired_output_type) unnamed_new_fn = value_utils.flatten_first_index( value_impl.to_value(input_fn, None, _context_stack), unnamed_type_to_add, _context_stack) self.assertEqual( str(unnamed_new_fn.type_signature), str(unnamed_desired_fn_type)) named_type_to_add = ('new', tf.int32) named_input_type = computation_types.NamedTupleType( [input_reference.type_signature, named_type_to_add]) named_types = [(str(k), tf.int32) for k in range(n)] + [('new', tf.int32)] named_desired_output_type = computation_types.to_type(named_types) named_desired_fn_type = computation_types.FunctionType( named_input_type, named_desired_output_type) new_named_fn = value_utils.flatten_first_index( value_impl.to_value(input_fn, None, _context_stack), named_type_to_add, _context_stack) self.assertEqual( str(new_named_fn.type_signature), str(named_desired_fn_type))
def test_flatten_fn_comp_raises_typeerror(self): input_reference = computation_building_blocks.Reference( 'test', [tf.int32] * 5) input_fn = computation_building_blocks.Lambda( 'test', input_reference.type_signature, input_reference) type_to_add = computation_types.NamedTupleType([tf.int32]) with self.assertRaisesRegexp(TypeError, '(Expected).*(Value)'): _ = value_utils.flatten_first_index(input_fn, type_to_add, _context_stack)
def test_flatten_fn(self, n): input_reference = computation_building_blocks.Reference( 'test', [tf.int32] * n) input_fn = computation_building_blocks.Lambda( 'test', input_reference.type_signature, input_reference) type_to_add = (None, computation_types.to_type(tf.int32)) input_type = computation_types.NamedTupleType( [input_reference.type_signature, type_to_add]) desired_output_type = computation_types.to_type([tf.int32] * (n + 1)) desired_fn_type = computation_types.FunctionType(input_type, desired_output_type) new_fn = value_utils.flatten_first_index( value_impl.to_value(input_fn, None, _context_stack), type_to_add, _context_stack) self.assertEqual(str(new_fn.type_signature), str(desired_fn_type))
def federated_zip(self, value): """Implements `federated_zip` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ # TODO(b/113112108): Extend this to accept *args. # TODO(b/113112108): We use the iterate/unwrap approach below because # our type system is not powerful enough to express the concept of # "an operation that takes tuples of T of arbitrary length", and therefore # the intrinsic federated_zip must only take a fixed number of arguments, # here fixed at 2. There are other potential approaches to getting around # this problem (e.g. having the operator act on sequences and thereby # sidestepping the issue) which we may want to explore. value = value_impl.to_value(value, None, self._context_stack) py_typecheck.check_type(value, value_base.Value) py_typecheck.check_type(value.type_signature, computation_types.NamedTupleType) elements_to_zip = anonymous_tuple.to_elements(value.type_signature) num_elements = len(elements_to_zip) py_typecheck.check_type(elements_to_zip[0][1], computation_types.FederatedType) output_placement = elements_to_zip[0][1].placement zip_apply_fn = { placements.CLIENTS: self.federated_map, placements.SERVER: self.federated_apply } if output_placement not in zip_apply_fn: raise TypeError( 'federated_zip only supports components with CLIENTS or ' 'SERVER placement, [{}] is unsupported'.format(output_placement)) if num_elements == 0: raise ValueError('federated_zip is only supported on nonempty tuples.') if num_elements == 1: input_ref = computation_building_blocks.Reference( 'value_in', elements_to_zip[0][1].member) output_tuple = computation_building_blocks.Tuple([(elements_to_zip[0][0], input_ref)]) lam = computation_building_blocks.Lambda( 'value_in', input_ref.type_signature, output_tuple) return zip_apply_fn[output_placement](lam, value[0]) for _, elem in elements_to_zip: py_typecheck.check_type(elem, computation_types.FederatedType) if elem.placement is not output_placement: raise TypeError( 'The elements of the named tuple to zip must be placed at {}.' .format(output_placement)) named_comps = [(elements_to_zip[k][0], value_impl.ValueImpl.get_comp(value[k])) for k in range(len(value))] tuple_to_zip = anonymous_tuple.AnonymousTuple( [named_comps[0], named_comps[1]]) zipped = value_utils.zip_two_tuple( value_impl.to_value(tuple_to_zip, None, self._context_stack), self._context_stack) inputs = value_impl.to_value( computation_building_blocks.Reference( 'inputs', zipped.type_signature.member), None, self._context_stack) flatten_func = value_impl.to_value( computation_building_blocks.Lambda( 'inputs', zipped.type_signature.member, value_impl.ValueImpl.get_comp(inputs)), None, self._context_stack) for k in range(2, num_elements): zipped = value_utils.zip_two_tuple( value_impl.to_value( computation_building_blocks.Tuple( [value_impl.ValueImpl.get_comp(zipped), named_comps[k]]), None, self._context_stack), self._context_stack) last_zipped = (named_comps[k][0], named_comps[k][1].type_signature.member) flatten_func = value_utils.flatten_first_index(flatten_func, last_zipped, self._context_stack) return zip_apply_fn[output_placement](flatten_func, zipped)