def _prepare_for_rebinding(bb): """Replaces `bb` with semantically equivalent version for rebinding.""" bb = compiler.normalize_all_equal_bit(bb) bb, _ = tree_transformations.remove_mapped_or_applied_identity(bb) bb = transformations.to_call_dominant(bb) bb, _ = tree_transformations.remove_unused_block_locals(bb) return bb
def test_inlines_selection_from_struct(self): int_type = computation_types.to_type(tf.int32) bb = building_blocks before = bb.Lambda( 'x', int_type, bb.Selection(bb.Struct([bb.Reference('x', int_type)]), index=0)) after = transformations.to_call_dominant(before) expected = bb.Lambda('x', int_type, bb.Reference('x', int_type)) self.assert_compact_representations_equal(after, expected)
def test_evaluates_called_lambdas(self): int_type = computation_types.to_type(tf.int32) int_to_int_type = computation_types.FunctionType(int_type, int_type) int_thunk_type = computation_types.FunctionType(None, int_type) bb = building_blocks int_to_int_fn = bb.Data('ext', int_to_int_type) # -> (let result = ext(x) in (-> result)) # Each call of the outer lambda should create a single binding, with # calls to the inner lambda repeatedly returning references to the binding. higher_fn = bb.Lambda( None, None, bb.Block([ ('result', bb.Call(int_to_int_fn, bb.Reference('x', int_type))) ], bb.Lambda(None, None, bb.Reference('result', int_type)))) block_locals = [ ('fn', higher_fn), # fn = -> (let result = ext(x) in (-> result)) ('get_val1', bb.Call(bb.Reference('fn', higher_fn.type_signature))), # _var2 = ext(x) # get_val1 = -> _var2 ('get_val2', bb.Call(bb.Reference('fn', higher_fn.type_signature))), # _var3 = ext(x) # get_val2 = -> _var3 ('val11', bb.Call(bb.Reference('get_val1', int_thunk_type))), # val11 = _var2 ('val12', bb.Call(bb.Reference('get_val1', int_thunk_type))), # val12 = _var2 ('val2', bb.Call(bb.Reference('get_val2', int_thunk_type))), # val2 = _var3 ] before = bb.Lambda( 'x', int_type, bb.Block( block_locals, # <_var2, _var2, _var3> bb.Struct([ bb.Reference('val11', int_type), bb.Reference('val12', int_type), bb.Reference('val2', int_type) ]))) after = transformations.to_call_dominant(before) expected = bb.Lambda( 'x', int_type, bb.Block([ ('_var2', bb.Call(int_to_int_fn, bb.Reference('x', int_type))), ('_var3', bb.Call(int_to_int_fn, bb.Reference('x', int_type))), ], bb.Struct([ bb.Reference('_var2', int_type), bb.Reference('_var2', int_type), bb.Reference('_var3', int_type) ]))) self.assert_compact_representations_equal(after, expected)
def compile_local_computation_to_tensorflow( comp: building_blocks.ComputationBuildingBlock, ) -> building_blocks.ComputationBuildingBlock: """Compiles a fully specified local computation to TensorFlow. Args: comp: A `building_blocks.ComputationBuildingBlock` which can be compiled to TensorFlow. In order to compile a computation to TensorFlow, it must not contain 1. References to values defined outside of comp, 2. `Data`, `Intrinsic`, or `Placement` blocks, or 3. Calls to intrinsics or non-TensorFlow computations. Returns: A `building_blocks.ComputationBuildingBlock` containing a TensorFlow-only representation of `comp`. If `comp` is of functional type, this will be a `building_blocks.CompiledComputation`. Otherwise, it will be a `building_blocks.Call` which wraps a `building_blocks.CompiledComputation`. """ if not comp.type_signature.is_function(): lambda_wrapped = building_blocks.Lambda(None, None, comp) return building_blocks.Call( compile_local_computation_to_tensorflow(lambda_wrapped), None) parameter_type = comp.type_signature.parameter type_analysis.check_tensorflow_compatible_type(parameter_type) type_analysis.check_tensorflow_compatible_type(comp.type_signature.result) if (comp.is_compiled_computation() and comp.proto.WhichOneof('computation') == 'tensorflow'): return comp # Ensure that unused values are removed and that reference bindings have # unique names. comp = unpack_compiled_computations(comp) comp = transformations.to_call_dominant(comp) if parameter_type is None: to_evaluate = building_blocks.Call(comp) @tensorflow_computation.tf_computation def result_computation(): return _evaluate_to_tensorflow(to_evaluate, {}) else: name_generator = building_block_factory.unique_name_generator(comp) parameter_name = next(name_generator) to_evaluate = building_blocks.Call( comp, building_blocks.Reference(parameter_name, parameter_type)) @tensorflow_computation.tf_computation(parameter_type) def result_computation(arg): if parameter_type.is_struct(): arg = structure.from_container(arg, recursive=True) return _evaluate_to_tensorflow(to_evaluate, {parameter_name: arg}) return result_computation.to_compiled_building_block()
def test_creates_block_for_non_lambda(self): bb = building_blocks int_type = computation_types.TensorType(tf.int32) two_int_type = computation_types.StructType([(None, int_type), (None, int_type)]) get_two_int_type = computation_types.FunctionType(None, two_int_type) call_ext = bb.Call(bb.Data('ext', get_two_int_type)) before = bb.Selection(call_ext, index=0) after = transformations.to_call_dominant(before) expected = bb.Block([ ('_var1', call_ext), ], bb.Selection(bb.Reference('_var1', two_int_type), index=0)) self.assert_compact_representations_equal(after, expected)
def test_splits_on_selected_intrinsic_nested_in_tuple_broadcast(self): first_broadcast = building_block_test_utils.create_whimsy_called_federated_broadcast( ) packed_broadcast = building_blocks.Struct([ building_blocks.Data('a', computation_types.at_server(tf.int32)), first_broadcast ]) sel = building_blocks.Selection(packed_broadcast, index=0) second_broadcast = building_block_factory.create_federated_broadcast( sel) result = transformations.to_call_dominant(second_broadcast) comp = building_blocks.Lambda('a', tf.int32, result) call = building_block_factory.create_null_federated_broadcast() self.assert_splits_on(comp, call)
def test_inlines_references(self): int_type = computation_types.to_type(tf.int32) int_ref = lambda name: building_blocks.Reference(name, int_type) int_fn = lambda name, result: building_blocks.Lambda( name, int_type, result) before = int_fn( 'x', building_blocks.Block([ ('y', int_ref('x')), ('z', int_ref('y')), ], int_ref('z'))) after = transformations.to_call_dominant(before) expected = int_fn('x', int_ref('x')) self.assert_compact_representations_equal(after, expected)
def test_inlines_structs(self): int_type = computation_types.to_type(tf.int32) structed = computation_types.StructType([int_type]) double = computation_types.StructType([structed]) bb = building_blocks before = bb.Lambda( 'x', int_type, bb.Block([ ('y', bb.Struct([building_blocks.Reference('x', int_type)])), ('z', bb.Struct([building_blocks.Reference('y', structed)])), ], bb.Reference('z', double))) after = transformations.to_call_dominant(before) expected = bb.Lambda( 'x', int_type, bb.Struct([bb.Struct([bb.Reference('x', int_type)])])) self.assert_compact_representations_equal(after, expected)
def test_call_to_higher_order_external_allowed(self): bb = building_blocks types = computation_types int_type = types.TensorType(tf.int32) int_to_int_type = types.FunctionType(int_type, int_type) int_to_int_to_int_type = types.FunctionType(int_to_int_type, int_type) call_ext = bb.Call(bb.Data('call_with_one', int_to_int_to_int_type), bb.Lambda('x', int_type, bb.Data('num', int_type))) after = transformations.to_call_dominant(call_ext) after.check_block() self.assertLen(after.locals, 1) (ref_name, bound_call) = after.locals[0] self.assertEqual(bound_call.compact_representation(), call_ext.compact_representation()) expected_result = bb.Reference(ref_name, call_ext.type_signature) self.assert_compact_representations_equal(after.result, expected_result)
def test_creates_binding_for_each_call(self): int_type = computation_types.to_type(tf.int32) int_to_int_type = computation_types.FunctionType(int_type, int_type) bb = building_blocks int_to_int_fn = bb.Data('ext', int_to_int_type) before = bb.Lambda( 'x', int_type, bb.Call(int_to_int_fn, bb.Call(int_to_int_fn, bb.Reference('x', int_type)))) after = transformations.to_call_dominant(before) expected = bb.Lambda( 'x', int_type, bb.Block([ ('_var1', bb.Call(int_to_int_fn, bb.Reference('x', int_type))), ('_var2', bb.Call(int_to_int_fn, bb.Reference('_var1', int_type))) ], bb.Reference('_var2', int_type))) self.assert_compact_representations_equal(after, expected)
def _replace_lambda_body_with_call_dominant_form( comp: building_blocks.Lambda) -> building_blocks.Lambda: """Transforms the body of `comp` to call-dominant form. Call-dominant form ensures that all higher-order functions are fully resolved, as well that called intrinsics are pulled out into a top-level let-binding. This combination of condition ensures first that pattern-matching on calls to intrinsics is sufficient to identify communication operators in `force_align_and_split_by_intrinsics`, and second that there are no nested intrinsics which will cause that function to fail. Args: comp: `building_blocks.Lambda` the body of which to convert to call-dominant form. Returns: A transformed version of `comp`, whose body is call-dominant. """ comp.check_lambda() transformed = transformations.to_call_dominant(comp) transformed.check_lambda() return transformed
def transform_to_native_form( comp: computation_impl.ConcreteComputation, transform_math_to_tf: bool = False, grappler_config: Optional[tf.compat.v1.ConfigProto] = None ) -> computation_impl.ConcreteComputation: """Compiles a computation for execution in the TFF native runtime. This function transforms the proto underlying `comp` by transforming it to call-dominant form (see `tff.framework.to_call_dominant` for definition). Args: comp: Instance of `computation_impl.ConcreteComputation` to compile. transform_math_to_tf: Whether to additional transform math to TensorFlow graphs. Necessary if running on a execution state without ReferenceResolvingExecutors underneath FederatingExecutors. grappler_config: Configuration for Grappler optimizations to perform on the TensorFlow computations. If `None`, Grappler will not be run and no optimizations wil be applied. Returns: A new `computation_impl.ConcreteComputation` representing the compiled version of `comp`. """ proto = computation_impl.ConcreteComputation.get_proto(comp) computation_building_block = building_blocks.ComputationBuildingBlock.from_proto( proto) try: logging.debug('Compiling TFF computation to CDF.') with tracing.span('transform_to_native_form', 'to_call_dominant', span=True): call_dominant_form = transformations.to_call_dominant( computation_building_block) logging.debug('Computation compiled to:') logging.debug(call_dominant_form.formatted_representation()) if transform_math_to_tf: logging.debug('Compiling local computations to TensorFlow.') with tracing.span('transform_to_native_form', 'compile_local_subcomputations_to_tensorflow', span=True): call_dominant_form = compiler.compile_local_subcomputations_to_tensorflow( call_dominant_form) logging.debug('Computation compiled to:') logging.debug(call_dominant_form.formatted_representation()) if grappler_config is not None: with tracing.span('transform_to_native_form', 'optimize_tf_graphs', span=True): call_dominant_form, _ = compiled_computation_transformations.optimize_tensorflow_graphs( call_dominant_form, grappler_config) with tracing.span('transform_to_native_form', 'transform_tf_call_ops_disable_grappler', span=True): disabled_grapler_form, _ = compiled_computation_transformations.transform_tf_call_ops_to_disable_grappler( call_dominant_form) with tracing.span('transform_to_native_form', 'transform_tf_add_ids', span=True): form_with_ids, _ = compiled_computation_transformations.transform_tf_add_ids( disabled_grapler_form) return computation_impl.ConcreteComputation.from_building_block( form_with_ids) except ValueError as e: logging.debug('Compilation for native runtime failed with error %s', e) logging.debug('computation: %s', computation_building_block.compact_representation()) return comp
def transformation_fn(x): x, _ = tree_transformations.remove_mapped_or_applied_identity(x) return transformations.to_call_dominant(x)
def compile_to_mergeable_comp_form( comp: computation_impl.ConcreteComputation ) -> mergeable_comp_execution_context.MergeableCompForm: """Compiles a computation with a single aggregation to `MergeableCompForm`. Compilation proceeds by splitting on the lone aggregation, and using the aggregation's internal functions to generate a semantically equivalent instance of `mergeable_comp_execution_context.MergeableCompForm`. Args: comp: Instance of `computation_impl.ConcreteComputation` to compile. Assumed to be representable as a computation with a single aggregation in its body, so that for example two parallel aggregations are allowed, but multiple dependent aggregations are disallowed. Additionally assumed to be of functional type. Returns: A semantically equivalent instance of `mergeable_comp_execution_context.MergeableCompForm`. Raises: TypeError: If `comp` is not a building block, or is not of functional TFF type. ValueError: If `comp` cannot be represented as a computation with at most one aggregation in its body. """ original_return_type = comp.type_signature.result building_block = comp.to_building_block() lam = _ensure_lambda(building_block) lowered_bb, _ = tree_transformations.replace_intrinsics_with_bodies(lam) # We transform the body of this computation to easily preserve the top-level # lambda required by force-aligning. call_dominant_body_bb = transformations.to_call_dominant(lowered_bb.result) call_dominant_bb = building_blocks.Lambda(lowered_bb.parameter_name, lowered_bb.parameter_type, call_dominant_body_bb) # This check should not throw false positives because we just ensured we are # in call-dominant form. tree_analysis.check_aggregate_not_dependent_on_aggregate(call_dominant_bb) before_agg, after_agg = transformations.force_align_and_split_by_intrinsics( call_dominant_bb, [building_block_factory.create_null_federated_aggregate()]) # Construct a report function which accepts the result of merge. merge_fn_type = before_agg.type_signature.result[ 'federated_aggregate_param'][3] identity_report = computation_impl.ConcreteComputation.from_building_block( building_block_factory.create_compiled_identity(merge_fn_type.result)) zero_comp, accumulate_comp, merge_comp, report_comp = _extract_federated_aggregate_computations( before_agg) before_agg_callable = computation_impl.ConcreteComputation.from_building_block( before_agg) after_agg_callable = computation_impl.ConcreteComputation.from_building_block( after_agg) if before_agg.type_signature.parameter is not None: # TODO(b/147499373): If None-arguments were uniformly represented as empty # tuples, we would be able to avoid this (and related) ugly casing. @federated_computation.federated_computation( before_agg.type_signature.parameter) def up_to_merge_computation(arg): federated_aggregate_args = before_agg_callable( arg)['federated_aggregate_param'] value_to_aggregate = federated_aggregate_args[0] zero = zero_comp() return intrinsics.federated_aggregate(value_to_aggregate, zero, accumulate_comp, merge_comp, identity_report) @federated_computation.federated_computation( before_agg.type_signature.parameter, computation_types.at_server(identity_report.type_signature.result)) def after_merge_computation(top_level_arg, merge_result): reported_result = intrinsics.federated_map(report_comp, merge_result) return after_agg_callable(top_level_arg, [reported_result]) else: @federated_computation.federated_computation() def up_to_merge_computation(): federated_aggregate_args = before_agg_callable( )['federated_aggregate_param'] value_to_aggregate = federated_aggregate_args[0] zero = zero_comp() return intrinsics.federated_aggregate(value_to_aggregate, zero, accumulate_comp, merge_comp, identity_report) @federated_computation.federated_computation( computation_types.at_server(identity_report.type_signature.result)) def after_merge_computation(merge_result): reported_result = intrinsics.federated_map(report_comp, merge_result) return after_agg_callable([[reported_result]]) annotated_type_signature = computation_types.FunctionType( after_merge_computation.type_signature.parameter, original_return_type) after_merge_computation = computation_impl.ConcreteComputation.with_type( after_merge_computation, annotated_type_signature) return mergeable_comp_execution_context.MergeableCompForm( up_to_merge=up_to_merge_computation, merge=merge_comp, after_merge=after_merge_computation)
def _compile_to_tf(fn): simplified = transformations.to_call_dominant(fn) unplaced, _ = tree_transformations.strip_placement(simplified) return compiler.compile_local_subcomputations_to_tensorflow(unplaced)
def consolidate_and_extract_local_processing(comp, grappler_config_proto): """Consolidates all the local processing in `comp`. The input computation `comp` must have the following properties: 1. The output of `comp` may be of a federated type or unplaced. We refer to the placement `p` of that type as the placement of `comp`. There is no placement anywhere in the body of `comp` different than `p`. If `comp` is of a functional type, and has a parameter, the type of that parameter is a federated type placed at `p` as well, or unplaced if the result of the function is unplaced. 2. The only intrinsics that may appear in the body of `comp` are those that manipulate data locally within the same placement. The exact set of these intrinsics will be gradually updated. At the moment, we support only the following: * Either `federated_apply` or `federated_map`, depending on whether `comp` is `SERVER`- or `CLIENTS`-placed. `federated_map_all_equal` is also allowed in the `CLIENTS`-placed case. * Either `federated_value_at_server` or `federated_value_at_clients`, likewise placement-dependent. * Either `federated_zip_at_server` or `federated_zip_at_clients`, again placement-dependent. Anything else, including `sequence_*` operators, should have been reduced already prior to calling this function. 3. There are no lambdas in the body of `comp` except for `comp` itself being possibly a (top-level) lambda. All other lambdas must have been reduced. This requirement may eventually be relaxed by embedding lambda reducer into this helper method. 4. If `comp` is of a functional type, it is either an instance of `building_blocks.CompiledComputation`, in which case there is nothing for us to do here, or a `building_blocks.Lambda`. 5. There is at most one unbound reference under `comp`, and this is only allowed in the case that `comp` is not of a functional type. Aside from the intrinsics specified above, and the possibility of allowing lambdas, blocks, and references given the constraints above, the remaining constructs in `comp` include a combination of tuples, selections, calls, and sections of TensorFlow (as `CompiledComputation`s). This helper function does contain the logic to consolidate these constructs. The output of this transformation is always a single section of TensorFlow, which we henceforth refer to as `result`, the exact form of which depends on the placement of `comp` and the presence or absence of an argument. a. If there is no argument in `comp`, and `comp` is `SERVER`-placed, then the `result` is such that `comp` can be equivalently represented as: ``` federated_value_at_server(result()) ``` b. If there is no argument in `comp`, and `comp` is `CLIENTS`-placed, then the `result` is such that `comp` can be equivalently represented as: ``` federated_value_at_clients(result()) ``` c. If there is an argument in `comp`, and `comp` is `SERVER`-placed, then the `result` is such that `comp` can be equivalently represented as: ``` (arg -> federated_apply(<result, arg>)) ``` d. If there is an argument in `comp`, and `comp` is `CLIENTS`-placed, then the `result` is such that `comp` can be equivalently represented as: ``` (arg -> federated_map(<result, arg>)) ``` If the type of `comp` is `T@p` (thus `comp` is non-functional), the type of `result` is `T`, where `p` is the specific (concrete) placement of `comp`. If the type of `comp` is `(T@p -> U@p)`, then the type of `result` must be `(T -> U)`, where `p` is again a specific placement. Args: comp: An instance of `building_blocks.ComputationBuildingBlock` that serves as the input to this transformation, as described above. grappler_config_proto: An instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization of the generated TensorFlow graph. If `grappler_config_proto` has `graph_options.rewrite_options.disable_meta_optimizer=True`, Grappler is bypassed. Returns: An instance of `building_blocks.CompiledComputation` that holds the TensorFlow section produced by this extraction step, as described above. """ py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) comp.type_signature.check_function() # Drop any unused subcomputations which may reference placements different # from the result. simplified = transformations.to_call_dominant(comp) unplaced, _ = tree_transformations.strip_placement(simplified) extracted = parse_tff_to_tf(unplaced, grappler_config_proto) check_extraction_result(unplaced, extracted) return extracted
def transformation_fn(x): x, _ = tree_transformations.uniquify_reference_names(x) x, _ = tree_transformations.remove_mapped_or_applied_identity(x) x = transformations.to_call_dominant(x) return x