def extract_work(before_aggregate, after_aggregate, canonical_form_types): """Converts `before_aggregate` and `after_aggregate` to `work`. Args: before_aggregate: The first result of splitting `after_broadcast` on `tff_framework.FEDERATED_AGGREGATE`. after_aggregate: The second result of splitting `after_broadcast` on `tff_framework.FEDERATED_AGGREGATE`. canonical_form_types: `dict` holding the `canonical_form.CanonicalForm` type signatures specified by the `tff.utils.IterativeProcess` we are compiling. Returns: `work` as specified by `canonical_form.CanonicalForm`, an instance of `tff_framework.CompiledComputation`. Raises: transformations.CanonicalFormCompilationError: If we fail to extract a `tff_framework.CompiledComputation`, or we extract one of the wrong type. """ # See `get_iterative_process_for_canonical_form()` above for the meaning of # variable names used in the code below. c3_elements_in_before_aggregate_parameter = [[0, 1], [1]] c3_to_before_aggregate_computation = ( transformations.zip_selection_as_argument_to_lower_level_lambda( before_aggregate, c3_elements_in_before_aggregate_parameter).result.function) c5_index_in_before_aggregate_result = 0 c3_to_c5_computation = transformations.select_output_from_lambda( c3_to_before_aggregate_computation, c5_index_in_before_aggregate_result) c6_index_in_after_aggregate_result = 2 after_aggregate_to_c6_computation = transformations.select_output_from_lambda( after_aggregate, c6_index_in_after_aggregate_result) c3_elements_in_after_aggregate_parameter = [[0, 0, 1], [0, 1]] c3_to_c6_computation = ( transformations.zip_selection_as_argument_to_lower_level_lambda( after_aggregate_to_c6_computation, c3_elements_in_after_aggregate_parameter).result.function) c3_to_unzipped_c4_computation = transformations.concatenate_function_outputs( c3_to_c5_computation, c3_to_c6_computation) c3_to_c4_computation = tff_framework.Lambda( c3_to_unzipped_c4_computation.parameter_name, c3_to_unzipped_c4_computation.parameter_type, tff_framework.create_federated_zip( c3_to_unzipped_c4_computation.result)) work = transformations.consolidate_and_extract_local_processing( c3_to_c4_computation) if not isinstance(work, tff_framework.CompiledComputation): raise transformations.CanonicalFormCompilationError( 'Failed to extract a `tff_framework.CompiledComputation` from ' 'work, instead received a {} (of type {}).'.format( type(work), work.type_signature)) if work.type_signature != canonical_form_types['work_type']: raise transformations.CanonicalFormCompilationError( 'Extracted a TF block of the wrong type. Expected a function with type ' '{}, but the type signature of the TF block was {}'.format( canonical_form_types['work_type'], work.type_signature)) return work
def extract_aggregate_functions(before_aggregate, canonical_form_types): """Converts `before_aggregate` to aggregation functions. Args: before_aggregate: The first result of splitting `after_broadcast` on `tff_framework.FEDERATED_AGGREGATE`. canonical_form_types: `dict` holding the `canonical_form.CanonicalForm` type signatures specified by the `tff.utils.IterativeProcess` we are compiling. Returns: `zero`, `accumulate`, `merge` and `report` as specified by `canonical_form.CanonicalForm`. All are instances of `tff_framework.CompiledComputation`. Raises: transformations.CanonicalFormCompilationError: if we fail to extract `tff_framework.CompiledComputation`s, or we extract one of the wrong type. """ # See `get_iterative_process_for_canonical_form()` above for the meaning of # variable names used in the code below. zero_index_in_before_aggregate_result = 1 zero_tff = transformations.select_output_from_lambda( before_aggregate, zero_index_in_before_aggregate_result).result accumulate_index_in_before_aggregate_result = 2 accumulate_tff = transformations.select_output_from_lambda( before_aggregate, accumulate_index_in_before_aggregate_result).result merge_index_in_before_aggregate_result = 3 merge_tff = transformations.select_output_from_lambda( before_aggregate, merge_index_in_before_aggregate_result).result report_index_in_before_aggregate_result = 4 report_tff = transformations.select_output_from_lambda( before_aggregate, report_index_in_before_aggregate_result).result zero = transformations.consolidate_and_extract_local_processing(zero_tff) accumulate = transformations.consolidate_and_extract_local_processing( accumulate_tff) merge = transformations.consolidate_and_extract_local_processing(merge_tff) report = transformations.consolidate_and_extract_local_processing( report_tff) for name, tf_block in (('zero', zero), ('accumulate', accumulate), ('merge', merge), ('report', report)): if not isinstance(tf_block, tff_framework.CompiledComputation): raise transformations.CanonicalFormCompilationError( 'Failed to extract a `tff_framework.CompiledComputation` from ' '{}, instead received a {} (of type {}).'.format( name, type(tf_block), tf_block.type_signature)) if tf_block.type_signature != canonical_form_types['{}_type'.format( name)]: raise transformations.CanonicalFormCompilationError( 'Extracted a TF block of the wrong type. Expected a function with type ' '{}, but the type signature of the TF block was {}'.format( canonical_form_types['{}_type'.format(name)], tf_block.type_signature)) return zero, accumulate, merge, report
def extract_prepare(before_broadcast, canonical_form_types): """Converts `before_broadcast` into `prepare`. Args: before_broadcast: The first result of splitting `next_comp` on `intrinsic_defs.FEDERATED_BROADCAST`. canonical_form_types: `dict` holding the `canonical_form.CanonicalForm` type signatures specified by the `tff.utils.IterativeProcess` we are compiling. Returns: `prepare` as specified by `canonical_form.CanonicalForm`, an instance of `building_blocks.CompiledComputation`. Raises: transformations.CanonicalFormCompilationError: If we extract an AST of the wrong type. """ # See `get_iterative_process_for_canonical_form()` above for the meaning of # variable names used in the code below. s1_index_in_before_broadcast = 0 s1_to_s2_computation = ( transformations. bind_single_selection_as_argument_to_lower_level_lambda( before_broadcast, s1_index_in_before_broadcast)).result.function prepare = transformations.consolidate_and_extract_local_processing( s1_to_s2_computation) if prepare.type_signature != canonical_form_types['prepare_type']: raise transformations.CanonicalFormCompilationError( 'Extracted a TF block of the wrong type. Expected a function with type ' '{}, but the type signature of the TF block was {}'.format( canonical_form_types['prepare_type'], prepare.type_signature)) return prepare
def extract_update(after_aggregate, canonical_form_types): """Converts `after_aggregate` to `update`. Args: after_aggregate: The second result of splitting `after_broadcast` on `intrinsic_defs.FEDERATED_AGGREGATE`. canonical_form_types: `dict` holding the `canonical_form.CanonicalForm` type signatures specified by the `tff.utils.IterativeProcess` we are compiling. Returns: `update` as specified by `canonical_form.CanonicalForm`, an instance of `building_blocks.CompiledComputation`. Raises: transformations.CanonicalFormCompilationError: If we fail to extract a `building_blocks.CompiledComputation`, or we extract one of the wrong type. """ # See `get_iterative_process_for_canonical_form()` above for the meaning of # variable names used in the code below. s5_elements_in_after_aggregate_result = [0, 1] s5_output_extracted = transformations.select_output_from_lambda( after_aggregate, s5_elements_in_after_aggregate_result) s5_output_zipped = building_blocks.Lambda( s5_output_extracted.parameter_name, s5_output_extracted.parameter_type, building_block_factory.create_federated_zip( s5_output_extracted.result)) s4_elements_in_after_aggregate_parameter = [[0, 0, 0], [1]] s4_to_s5_computation = ( transformations.zip_selection_as_argument_to_lower_level_lambda( s5_output_zipped, s4_elements_in_after_aggregate_parameter).result.function) update = transformations.consolidate_and_extract_local_processing( s4_to_s5_computation) if not isinstance(update, building_blocks.CompiledComputation): raise transformations.CanonicalFormCompilationError( 'Failed to extract a `building_blocks.CompiledComputation` from ' 'update, instead received a {} (of type {}).'.format( type(update), update.type_signature)) if update.type_signature != canonical_form_types['update_type']: raise transformations.CanonicalFormCompilationError( 'Extracted a TF block of the wrong type. Expected a function with type ' '{}, but the type signature of the TF block was {}'.format( canonical_form_types['update_type'], update.type_signature)) return update
def get_canonical_form_for_iterative_process(iterative_process): """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process. This function transforms computations from the input `iterative_process` into an instance of `tff.backends.mapreduce.CanonicalForm`. Args: iterative_process: An instance of `tff.utils.IterativeProcess`. Returns: An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this process. Raises: TypeError: If the arguments are of the wrong types. transformations.CanonicalFormCompilationError: If the compilation process fails. """ py_typecheck.check_type(iterative_process, computation_utils.IterativeProcess) initialize_comp = tff_framework.ComputationBuildingBlock.from_proto( iterative_process.initialize._computation_proto) # pylint: disable=protected-access next_comp = tff_framework.ComputationBuildingBlock.from_proto( iterative_process.next._computation_proto) # pylint: disable=protected-access if not (isinstance(next_comp.type_signature.parameter, tff.NamedTupleType) and isinstance(next_comp.type_signature.result, tff.NamedTupleType)): raise TypeError( 'Any IterativeProcess compatible with CanonicalForm must ' 'have a `next` function which takes and returns instances ' 'of `tff.NamedTupleType`; your next function takes ' 'parameters of type {} and returns results of type {}'.format( next_comp.type_signature.parameter, next_comp.type_signature.result)) if len(next_comp.type_signature.result) == 2: next_result = next_comp.result dummy_clients_metrics_appended = tff_framework.Tuple([ next_result[0], next_result[1], tff.federated_value([], tff.CLIENTS)._comp # pylint: disable=protected-access ]) next_comp = tff_framework.Lambda(next_comp.parameter_name, next_comp.parameter_type, dummy_clients_metrics_appended) initialize_comp = tff_framework.replace_intrinsics_with_bodies( initialize_comp) next_comp = tff_framework.replace_intrinsics_with_bodies(next_comp) tff_framework.check_intrinsics_whitelisted_for_reduction(initialize_comp) tff_framework.check_intrinsics_whitelisted_for_reduction(next_comp) tff_framework.check_broadcast_not_dependent_on_aggregate(next_comp) before_broadcast, after_broadcast = ( transformations.force_align_and_split_by_intrinsic( next_comp, tff_framework.FEDERATED_BROADCAST.uri)) before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsic( after_broadcast, tff_framework.FEDERATED_AGGREGATE.uri)) init_info_packed = pack_initialize_comp_type_signature( initialize_comp.type_signature) next_info_packed = pack_next_comp_type_signature(next_comp.type_signature, init_info_packed) before_broadcast_info_packed = ( check_and_pack_before_broadcast_type_signature( before_broadcast.type_signature, next_info_packed)) before_aggregate_info_packed = ( check_and_pack_before_aggregate_type_signature( before_aggregate.type_signature, before_broadcast_info_packed)) canonical_form_types = check_and_pack_after_aggregate_type_signature( after_aggregate.type_signature, before_aggregate_info_packed) initialize = transformations.consolidate_and_extract_local_processing( initialize_comp) if not (isinstance(initialize, tff_framework.CompiledComputation) and initialize.type_signature.result == canonical_form_types['initialize_type'].member): raise transformations.CanonicalFormCompilationError( 'Compilation of initialize has failed. Expected to extract a ' '`tff_framework.CompiledComputation` of type {}, instead we extracted ' 'a {} of type {}.'.format(next_comp.type_signature.parameter[0], type(initialize), initialize.type_signature.result)) prepare = extract_prepare(before_broadcast, canonical_form_types) work = extract_work(before_aggregate, after_aggregate, canonical_form_types) zero_noarg_function, accumulate, merge, report = extract_aggregate_functions( before_aggregate, canonical_form_types) update = extract_update(after_aggregate, canonical_form_types) cf = canonical_form.CanonicalForm( tff_framework.building_block_to_computation(initialize), tff_framework.building_block_to_computation(prepare), tff_framework.building_block_to_computation(work), tff_framework.building_block_to_computation(zero_noarg_function), tff_framework.building_block_to_computation(accumulate), tff_framework.building_block_to_computation(merge), tff_framework.building_block_to_computation(report), tff_framework.building_block_to_computation(update)) return cf
def _check_type_equal(actual, expected): py_typecheck.check_type(actual, computation_types.Type) py_typecheck.check_type(expected, computation_types.Type) if actual != expected: raise transformations.CanonicalFormCompilationError( 'Expected type of {}, found {}.'.format(expected, actual))
def _check_placement_equal(actual, expected): py_typecheck.check_type(actual, placement_literals.PlacementLiteral) py_typecheck.check_type(expected, placement_literals.PlacementLiteral) if actual != expected: raise transformations.CanonicalFormCompilationError( 'Expected placement of {}, found {}.'.format(expected, actual))
def _check_len_equal(target, length): py_typecheck.check_type(length, int) if len(target) != length: raise transformations.CanonicalFormCompilationError( 'Expected length of {}, found {}.'.format(length, len(target)))
def _check_type(target, type_spec): py_typecheck.check_type(type_spec, type) if not isinstance(target, type_spec): raise transformations.CanonicalFormCompilationError( 'Expected type of {}, found {}.'.format(type_spec, type(target)))
def _check_type_equal(actual, expected, label): if actual != expected: raise transformations.CanonicalFormCompilationError( 'Expected \'{}\' to have a type signature of {}, found {}.'.format( label, expected, actual))