def test_returns_comps_with_federated_aggregate_no_unbound_references( self): federated_aggregate = test_utils.create_dummy_called_federated_aggregate( accumulate_parameter_name='a', merge_parameter_name='b', report_parameter_name='c') tup = building_blocks.Tuple([ federated_aggregate, federated_aggregate, ]) comp = building_blocks.Lambda('d', tf.int32, tup) uri = intrinsic_defs.FEDERATED_AGGREGATE.uri before, after = mapreduce_transformations.force_align_and_split_by_intrinsic( comp, uri) def _predicate(comp): return building_block_analysis.is_called_intrinsic(comp, uri) self.assertIsInstance(comp, building_blocks.Lambda) self.assertGreater(tree_analysis.count(comp, _predicate), 0) self.assertIsInstance(before, building_blocks.Lambda) self.assertEqual(tree_analysis.count(before, _predicate), 0) self.assertEqual(before.parameter_type, comp.parameter_type) self.assertIsInstance(after, building_blocks.Lambda) self.assertEqual(tree_analysis.count(after, _predicate), 0) self.assertEqual(after.result.type_signature, comp.result.type_signature)
def test_returns_comps_with_federated_aggregate(self): iterative_process = test_utils.construct_example_training_comp() comp = test_utils.computation_to_building_block(iterative_process.next) uri = intrinsic_defs.FEDERATED_AGGREGATE.uri before, after = mapreduce_transformations.force_align_and_split_by_intrinsic( comp, uri) def _predicate(comp): return building_block_analysis.is_called_intrinsic(comp, uri) self.assertIsInstance(comp, building_blocks.Lambda) self.assertGreater(tree_analysis.count(comp, _predicate), 0) self.assertIsInstance(before, building_blocks.Lambda) self.assertEqual(tree_analysis.count(before, _predicate), 0) self.assertEqual(before.parameter_type, comp.parameter_type) self.assertIsInstance(after, building_blocks.Lambda) self.assertEqual(tree_analysis.count(after, _predicate), 0) self.assertEqual(after.result.type_signature, comp.result.type_signature)
def get_canonical_form_for_iterative_process(iterative_process): """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process. This function transforms computations from the input `iterative_process` into an instance of `tff.backends.mapreduce.CanonicalForm`. Args: iterative_process: An instance of `tff.utils.IterativeProcess`. Returns: An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this process. Raises: TypeError: If the arguments are of the wrong types. transformations.CanonicalFormCompilationError: If the compilation process fails. """ py_typecheck.check_type(iterative_process, computation_utils.IterativeProcess) initialize_comp = tff_framework.ComputationBuildingBlock.from_proto( iterative_process.initialize._computation_proto) # pylint: disable=protected-access next_comp = tff_framework.ComputationBuildingBlock.from_proto( iterative_process.next._computation_proto) # pylint: disable=protected-access if not (isinstance(next_comp.type_signature.parameter, tff.NamedTupleType) and isinstance(next_comp.type_signature.result, tff.NamedTupleType)): raise TypeError( 'Any IterativeProcess compatible with CanonicalForm must ' 'have a `next` function which takes and returns instances ' 'of `tff.NamedTupleType`; your next function takes ' 'parameters of type {} and returns results of type {}'.format( next_comp.type_signature.parameter, next_comp.type_signature.result)) if len(next_comp.type_signature.result) == 2: next_result = next_comp.result dummy_clients_metrics_appended = tff_framework.Tuple([ next_result[0], next_result[1], tff.federated_value([], tff.CLIENTS)._comp # pylint: disable=protected-access ]) next_comp = tff_framework.Lambda(next_comp.parameter_name, next_comp.parameter_type, dummy_clients_metrics_appended) initialize_comp = tff_framework.replace_intrinsics_with_bodies( initialize_comp) next_comp = tff_framework.replace_intrinsics_with_bodies(next_comp) tff_framework.check_intrinsics_whitelisted_for_reduction(initialize_comp) tff_framework.check_intrinsics_whitelisted_for_reduction(next_comp) tff_framework.check_broadcast_not_dependent_on_aggregate(next_comp) before_broadcast, after_broadcast = ( transformations.force_align_and_split_by_intrinsic( next_comp, tff_framework.FEDERATED_BROADCAST.uri)) before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsic( after_broadcast, tff_framework.FEDERATED_AGGREGATE.uri)) init_info_packed = pack_initialize_comp_type_signature( initialize_comp.type_signature) next_info_packed = pack_next_comp_type_signature(next_comp.type_signature, init_info_packed) before_broadcast_info_packed = ( check_and_pack_before_broadcast_type_signature( before_broadcast.type_signature, next_info_packed)) before_aggregate_info_packed = ( check_and_pack_before_aggregate_type_signature( before_aggregate.type_signature, before_broadcast_info_packed)) canonical_form_types = check_and_pack_after_aggregate_type_signature( after_aggregate.type_signature, before_aggregate_info_packed) initialize = transformations.consolidate_and_extract_local_processing( initialize_comp) if not (isinstance(initialize, tff_framework.CompiledComputation) and initialize.type_signature.result == canonical_form_types['initialize_type'].member): raise transformations.CanonicalFormCompilationError( 'Compilation of initialize has failed. Expected to extract a ' '`tff_framework.CompiledComputation` of type {}, instead we extracted ' 'a {} of type {}.'.format(next_comp.type_signature.parameter[0], type(initialize), initialize.type_signature.result)) prepare = extract_prepare(before_broadcast, canonical_form_types) work = extract_work(before_aggregate, after_aggregate, canonical_form_types) zero_noarg_function, accumulate, merge, report = extract_aggregate_functions( before_aggregate, canonical_form_types) update = extract_update(after_aggregate, canonical_form_types) cf = canonical_form.CanonicalForm( tff_framework.building_block_to_computation(initialize), tff_framework.building_block_to_computation(prepare), tff_framework.building_block_to_computation(work), tff_framework.building_block_to_computation(zero_noarg_function), tff_framework.building_block_to_computation(accumulate), tff_framework.building_block_to_computation(merge), tff_framework.building_block_to_computation(report), tff_framework.building_block_to_computation(update)) return cf