def test_replace_called_lambda_does_not_replace_uncalled_lambda(self): comp = _create_lambda_to_add_one(tf.int32) self.assertEqual( _get_number_of_computations(comp, computation_building_blocks.Block), 0) comp_impl = _to_comp(comp) self.assertEqual(comp_impl(1), 2) transformed_comp = transformations.replace_called_lambda_with_block( comp) self.assertEqual( _get_number_of_computations(transformed_comp, computation_building_blocks.Call), _get_number_of_computations(comp, computation_building_blocks.Call)) self.assertEqual( _get_number_of_computations(transformed_comp, computation_building_blocks.Lambda), _get_number_of_computations(comp, computation_building_blocks.Lambda)) self.assertEqual( _get_number_of_computations(transformed_comp, computation_building_blocks.Block), 0) transformed_comp_impl = _to_comp(transformed_comp) self.assertEqual(transformed_comp_impl(1), 2)
def test_replace_called_lambda_replaces_multiple_called_lambdas(self): arg = computation_building_blocks.Reference('arg', tf.int32) lam = _create_lambda_to_add_one(arg.type_signature) calling_lambda = _create_lambda_to_chained_call(lam, arg, 10) comp = calling_lambda self.assertEqual( _get_number_of_computations(comp, computation_building_blocks.Block), 0) comp_impl = _to_comp(comp) self.assertEqual(comp_impl(1), 11) transformed_comp = transformations.replace_called_lambda_with_block( comp) self.assertEqual( _get_number_of_computations(transformed_comp, computation_building_blocks.Call), _get_number_of_computations(comp, computation_building_blocks.Call) - 10) self.assertEqual( _get_number_of_computations(transformed_comp, computation_building_blocks.Lambda), _get_number_of_computations( comp, computation_building_blocks.Lambda) - 10) self.assertEqual( _get_number_of_computations(transformed_comp, computation_building_blocks.Block), 10) transformed_comp_impl = _to_comp(transformed_comp) self.assertEqual(transformed_comp_impl(1), 11)
def test_replace_called_lambda_does_not_replace_uncalled_lambda(self): fn = _create_lambda_to_identity(tf.int32) comp = fn transformed_comp = transformations.replace_called_lambda_with_block( comp) self.assertEqual(transformed_comp.tff_repr, comp.tff_repr) self.assertEqual(transformed_comp.tff_repr, '(arg -> arg)')
def test_replace_called_lambda_replaces_called_lambda(self): fn = _create_lambda_to_identity(tf.int32) arg = computation_building_blocks.Data('x', tf.int32) call = computation_building_blocks.Call(fn, arg) comp = call transformed_comp = transformations.replace_called_lambda_with_block( comp) self.assertEqual(comp.tff_repr, '(arg -> arg)(x)') self.assertEqual(transformed_comp.tff_repr, '(let arg=x in arg)')
def test_replace_called_lambda_does_not_replace_separated_called_lambda( self): arg = computation_building_blocks.Reference('arg', tf.int32) lam = _create_lambda_to_identity(arg.type_signature) block = computation_building_blocks.Block([], lam) call = computation_building_blocks.Call(block, arg) comp = call transformed_comp = transformations.replace_called_lambda_with_block( comp) self.assertEqual(str(transformed_comp), str(comp)) self.assertEqual(str(transformed_comp), '(let in (arg -> arg))(arg)')
def test_replace_called_lambda_does_not_replace_separated_called_lambda( self): fn = _create_lambda_to_identity(tf.int32) block = _create_dummy_block(fn) arg = computation_building_blocks.Data('x', tf.int32) call = computation_building_blocks.Call(block, arg) comp = call transformed_comp = transformations.replace_called_lambda_with_block( comp) self.assertEqual(transformed_comp.tff_repr, comp.tff_repr) self.assertEqual(transformed_comp.tff_repr, '(let local=data in (arg -> arg))(x)')
def prepare_for_rebinding(comp): """Prepares `comp` for extracting rebound variables. Currently, this means replacing all called lambdas and inlining all blocks. This does not necessarly guarantee that the resulting computation has no called lambdas, it merely reduces a level of indirection here. This reduction has proved sufficient for identifying variables which are about to be rebound in the top-level lambda, necessarily when compiler components factor work out from a single function into multiple functions. Since this function makes no guarantees about sufficiency, it is the responsibility of the caller to ensure that no unbound variables are introduced during the rebinding. Args: comp: Instance of `building_blocks.ComputationBuildingBlock` from which all occurrences of a given variable need to be extracted and rebound. Returns: Another instance of `building_blocks.ComputationBuildingBlock` which has had all called lambdas replaced by blocks, all blocks inlined and all selections from tuples collapsed. """ # TODO(b/146430051): Follow up here and consider removing or enforcing more # strict output invariants when `remove_lambdas_and_blocks` is moved in here. py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) comp, _ = transformations.uniquify_reference_names(comp) comp, _ = transformations.replace_called_lambda_with_block(comp) block_inliner = transformations.InlineBlock(comp) selection_replacer = transformations.ReplaceSelectionFromTuple() transforms = [block_inliner, selection_replacer] symbol_tree = transformation_utils.SymbolTree( transformation_utils.ReferenceCounter) def _transform_fn(comp, symbol_tree): """Transform function chaining inlining and collapsing selections.""" modified = False for transform in transforms: if transform.global_transform: comp, transform_modified = transform.transform( comp, symbol_tree) else: comp, transform_modified = transform.transform(comp) modified = modified or transform_modified return comp, modified return transformation_utils.transform_postorder_with_symbol_bindings( comp, _transform_fn, symbol_tree)
def compile(self, computation_to_compile): """Compiles `computation_to_compile`. Args: computation_to_compile: An instance of `computation_base.Computation` to compile. Returns: An instance of `computation_base.Computation` that repeesents the result. """ py_typecheck.check_type(computation_to_compile, computation_base.Computation) computation_proto = computation_impl.ComputationImpl.get_proto( computation_to_compile) py_typecheck.check_type(computation_proto, pb.Computation) comp = building_blocks.ComputationBuildingBlock.from_proto( computation_proto) # TODO(b/113123410): Add a compiler options argument that characterizes the # desired form of the output. To be driven by what the specific backend the # pipeline is targeting is able to understand. Pending a more fleshed out # design of the backend API. # Replace intrinsics with their bodies, for now manually in a fixed order. # TODO(b/113123410): Replace this with a more automated implementation that # does not rely on manual maintenance. comp, _ = value_transformations.replace_all_intrinsics_with_bodies( comp, self._context_stack) # Replaces called lambdas with LET constructs with a single local symbol. comp, _ = transformations.replace_called_lambda_with_block(comp) # Removes maped or applied identities. comp, _ = transformations.remove_mapped_or_applied_identity(comp) # Remove duplicate computations. This is important! otherwise the semantics # non-deterministic computations (e.g. a `tff.tf_computation` depending on # `tf.random`) will give unexpected behavior. Additionally, this may reduce # the amount of calls into TF for some ASTs. comp, _ = transformations.uniquify_reference_names(comp) comp, _ = transformations.extract_computations(comp) comp, _ = transformations.remove_duplicate_computations(comp) return computation_impl.ComputationImpl(comp.proto, self._context_stack)
def compile(self, computation_to_compile): """Compiles `computation_to_compile`. Args: computation_to_compile: An instance of `computation_base.Computation` to compile. Returns: An instance of `computation_base.Computation` that repeesents the result. """ py_typecheck.check_type(computation_to_compile, computation_base.Computation) computation_proto = computation_impl.ComputationImpl.get_proto( computation_to_compile) # TODO(b/113123410): Add a compiler options argument that characterizes the # desired form of the output. To be driven by what the specific backend the # pipeline is targeting is able to understand. Pending a more fleshed out # design of the backend API. py_typecheck.check_type(computation_proto, pb.Computation) comp = computation_building_blocks.ComputationBuildingBlock.from_proto( computation_proto) # Replace intrinsics with their bodies, for now manually in a fixed order. # TODO(b/113123410): Replace this with a more automated implementation that # does not rely on manual maintenance. for uri, body in six.iteritems(self._intrinsic_bodies): comp, _ = transformations.replace_intrinsic_with_callable( comp, uri, body, self._context_stack) # Replaces called lambdas with LET constructs with a single local symbol. comp, _ = transformations.replace_called_lambda_with_block(comp) # TODO(b/113123410): Add more transformations to simplify and optimize the # structure, e.g., such as: # * removing unnecessary lambdas, # * flatteting the structure, # * merging TensorFlow blocks where appropriate, # * ...and so on. return computation_impl.ComputationImpl(comp.proto, self._context_stack)
def test_inline_conflicting_lambdas(self): comp = computation_building_blocks.Tuple( [computation_building_blocks.Data('test', tf.int32)]) input1 = computation_building_blocks.Reference('input2', comp.type_signature) first_level_call = computation_building_blocks.Call( computation_building_blocks.Lambda('input2', input1.type_signature, input1), comp) input2 = computation_building_blocks.Reference( 'input2', first_level_call.type_signature) second_level_call = computation_building_blocks.Call( computation_building_blocks.Lambda('input2', input2.type_signature, input2), first_level_call) self.assertEqual(str(second_level_call), '(input2 -> input2)((input2 -> input2)(<test>))') lambda_reduced_comp = transformations.replace_called_lambda_with_block( second_level_call) self.assertEqual( str(lambda_reduced_comp), '(let input2=(let input2=<test> in input2) in input2)') inlined = transformations.inline_blocks_with_n_referenced_locals( lambda_reduced_comp) self.assertEqual(str(inlined), '(let in (let in <test>))')
def test_replace_called_lambda_raises_type_error(self): with self.assertRaises(TypeError): transformations.replace_called_lambda_with_block(None)