def test_raises_on_non_federated_selection(self): lam = building_blocks.Lambda( 'x', [tf.int32], building_blocks.Reference('x', [tf.int32])) with self.assertRaises(TypeError): mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda( lam, [[0]])
def test_raises_on_non_tuple_parameter(self): lam = building_blocks.Lambda('x', tf.int32, building_blocks.Reference('x', tf.int32)) with self.assertRaises(TypeError): mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda( lam, [[0]])
def test_raises_on_selection_from_non_tuple(self): lam = building_blocks.Lambda( 'x', [tf.int32], building_blocks.Reference('x', [tf.int32])) with self.assertRaisesRegex(TypeError, 'nonexistent index'): mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda( lam, [[0, 0]])
def test_raises_on_non_lambda_comp(self): ref = building_blocks.Reference('x', [tf.int32]) with self.assertRaises(TypeError): mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda( ref, [0])
def test_raises_on_none(self): with self.assertRaises(TypeError): mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda( None, [0])
def _extract_update(after_aggregate): """Extracts `update` from `after_aggregate`. This function is intended to be used by `get_canonical_form_for_iterative_process` only. As a result, this function does not assert that `after_aggregate` has the expected structure, the caller is expected to perform these checks before calling this function. Args: after_aggregate: The second result of splitting `after_broadcast` on aggregate intrinsics. Returns: `update` as specified by `canonical_form.CanonicalForm`, an instance of `building_blocks.CompiledComputation`. Raises: transformations.CanonicalFormCompilationError: If we extract an AST of the wrong type. """ s7_elements_in_after_aggregate_result = [0, 1] s7_output_extracted = transformations.select_output_from_lambda( after_aggregate, s7_elements_in_after_aggregate_result) s7_output_zipped = building_blocks.Lambda( s7_output_extracted.parameter_name, s7_output_extracted.parameter_type, building_block_factory.create_federated_zip( s7_output_extracted.result)) s6_elements_in_after_aggregate_parameter = [[0, 0, 0], [1, 0], [1, 1]] s6_to_s7_computation = ( transformations.zip_selection_as_argument_to_lower_level_lambda( s7_output_zipped, s6_elements_in_after_aggregate_parameter).result.function) # TODO(b/148942011): The transformation # `zip_selection_as_argument_to_lower_level_lambda` does not support selecting # from nested structures, therefore we need to pack the type signature # `<s1, s3, s4>` as `<s1, <s3, s4>>`. name_generator = building_block_factory.unique_name_generator( s6_to_s7_computation) pack_ref_name = next(name_generator) pack_ref_type = computation_types.NamedTupleType([ s6_to_s7_computation.parameter_type.member[0], computation_types.NamedTupleType([ s6_to_s7_computation.parameter_type.member[1], s6_to_s7_computation.parameter_type.member[2], ]), ]) pack_ref = building_blocks.Reference(pack_ref_name, pack_ref_type) sel_s1 = building_blocks.Selection(pack_ref, index=0) sel = building_blocks.Selection(pack_ref, index=1) sel_s3 = building_blocks.Selection(sel, index=0) sel_s4 = building_blocks.Selection(sel, index=1) result = building_blocks.Tuple([sel_s1, sel_s3, sel_s4]) pack_fn = building_blocks.Lambda(pack_ref.name, pack_ref.type_signature, result) ref_name = next(name_generator) ref_type = computation_types.FederatedType(pack_ref_type, placements.SERVER) ref = building_blocks.Reference(ref_name, ref_type) unpacked_args = building_block_factory.create_federated_map_or_apply( pack_fn, ref) call = building_blocks.Call(s6_to_s7_computation, unpacked_args) fn = building_blocks.Lambda(ref.name, ref.type_signature, call) return transformations.consolidate_and_extract_local_processing(fn)
def test_raises_on_non_int_index(self): lam = building_blocks.Lambda( 'x', [tf.int32], building_blocks.Reference('x', [('a', tf.int32)])) with self.assertRaises(TypeError): transformations.zip_selection_as_argument_to_lower_level_lambda( lam, [['a']])
def test_raises_on_selection_tuple(self): lam = building_blocks.Lambda('x', tf.int32, building_blocks.Reference('x', [tf.int32])) with self.assertRaises(TypeError): transformations.zip_selection_as_argument_to_lower_level_lambda(lam, (0))