Example #1
0
    def test_replace_called_lambda_does_not_replace_uncalled_lambda(self):
        comp = _create_lambda_to_add_one(tf.int32)

        self.assertEqual(
            _get_number_of_computations(comp,
                                        computation_building_blocks.Block), 0)
        comp_impl = _to_comp(comp)
        self.assertEqual(comp_impl(1), 2)

        transformed_comp = transformations.replace_called_lambda_with_block(
            comp)

        self.assertEqual(
            _get_number_of_computations(transformed_comp,
                                        computation_building_blocks.Call),
            _get_number_of_computations(comp,
                                        computation_building_blocks.Call))
        self.assertEqual(
            _get_number_of_computations(transformed_comp,
                                        computation_building_blocks.Lambda),
            _get_number_of_computations(comp,
                                        computation_building_blocks.Lambda))
        self.assertEqual(
            _get_number_of_computations(transformed_comp,
                                        computation_building_blocks.Block), 0)
        transformed_comp_impl = _to_comp(transformed_comp)
        self.assertEqual(transformed_comp_impl(1), 2)
Example #2
0
    def test_replace_called_lambda_replaces_multiple_called_lambdas(self):
        arg = computation_building_blocks.Reference('arg', tf.int32)
        lam = _create_lambda_to_add_one(arg.type_signature)
        calling_lambda = _create_lambda_to_chained_call(lam, arg, 10)
        comp = calling_lambda

        self.assertEqual(
            _get_number_of_computations(comp,
                                        computation_building_blocks.Block), 0)
        comp_impl = _to_comp(comp)
        self.assertEqual(comp_impl(1), 11)

        transformed_comp = transformations.replace_called_lambda_with_block(
            comp)

        self.assertEqual(
            _get_number_of_computations(transformed_comp,
                                        computation_building_blocks.Call),
            _get_number_of_computations(comp, computation_building_blocks.Call)
            - 10)
        self.assertEqual(
            _get_number_of_computations(transformed_comp,
                                        computation_building_blocks.Lambda),
            _get_number_of_computations(
                comp, computation_building_blocks.Lambda) - 10)
        self.assertEqual(
            _get_number_of_computations(transformed_comp,
                                        computation_building_blocks.Block), 10)
        transformed_comp_impl = _to_comp(transformed_comp)
        self.assertEqual(transformed_comp_impl(1), 11)
    def test_replace_called_lambda_does_not_replace_uncalled_lambda(self):
        fn = _create_lambda_to_identity(tf.int32)
        comp = fn

        transformed_comp = transformations.replace_called_lambda_with_block(
            comp)

        self.assertEqual(transformed_comp.tff_repr, comp.tff_repr)
        self.assertEqual(transformed_comp.tff_repr, '(arg -> arg)')
    def test_replace_called_lambda_replaces_called_lambda(self):
        fn = _create_lambda_to_identity(tf.int32)
        arg = computation_building_blocks.Data('x', tf.int32)
        call = computation_building_blocks.Call(fn, arg)
        comp = call

        transformed_comp = transformations.replace_called_lambda_with_block(
            comp)

        self.assertEqual(comp.tff_repr, '(arg -> arg)(x)')
        self.assertEqual(transformed_comp.tff_repr, '(let arg=x in arg)')
Example #5
0
    def test_replace_called_lambda_does_not_replace_separated_called_lambda(
            self):
        arg = computation_building_blocks.Reference('arg', tf.int32)
        lam = _create_lambda_to_identity(arg.type_signature)
        block = computation_building_blocks.Block([], lam)
        call = computation_building_blocks.Call(block, arg)
        comp = call

        transformed_comp = transformations.replace_called_lambda_with_block(
            comp)

        self.assertEqual(str(transformed_comp), str(comp))
        self.assertEqual(str(transformed_comp), '(let  in (arg -> arg))(arg)')
    def test_replace_called_lambda_does_not_replace_separated_called_lambda(
            self):
        fn = _create_lambda_to_identity(tf.int32)
        block = _create_dummy_block(fn)
        arg = computation_building_blocks.Data('x', tf.int32)
        call = computation_building_blocks.Call(block, arg)
        comp = call

        transformed_comp = transformations.replace_called_lambda_with_block(
            comp)

        self.assertEqual(transformed_comp.tff_repr, comp.tff_repr)
        self.assertEqual(transformed_comp.tff_repr,
                         '(let local=data in (arg -> arg))(x)')
def prepare_for_rebinding(comp):
    """Prepares `comp` for extracting rebound variables.

  Currently, this means replacing all called lambdas and inlining all blocks.
  This does not necessarly guarantee that the resulting computation has no
  called lambdas, it merely reduces a level of indirection here. This reduction
  has proved sufficient for identifying variables which are about to be rebound
  in the top-level lambda, necessarily when compiler components factor work out
  from a single function into multiple functions. Since this function makes no
  guarantees about sufficiency, it is the responsibility of the caller to
  ensure that no unbound variables are introduced during the rebinding.

  Args:
    comp: Instance of `building_blocks.ComputationBuildingBlock` from which all
      occurrences of a given variable need to be extracted and rebound.

  Returns:
    Another instance of `building_blocks.ComputationBuildingBlock` which has
    had all called lambdas replaced by blocks, all blocks inlined and all
    selections from tuples collapsed.
  """
    # TODO(b/146430051): Follow up here and consider removing or enforcing more
    # strict output invariants when `remove_lambdas_and_blocks` is moved in here.
    py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
    comp, _ = transformations.uniquify_reference_names(comp)
    comp, _ = transformations.replace_called_lambda_with_block(comp)
    block_inliner = transformations.InlineBlock(comp)
    selection_replacer = transformations.ReplaceSelectionFromTuple()
    transforms = [block_inliner, selection_replacer]
    symbol_tree = transformation_utils.SymbolTree(
        transformation_utils.ReferenceCounter)

    def _transform_fn(comp, symbol_tree):
        """Transform function chaining inlining and collapsing selections."""
        modified = False
        for transform in transforms:
            if transform.global_transform:
                comp, transform_modified = transform.transform(
                    comp, symbol_tree)
            else:
                comp, transform_modified = transform.transform(comp)
            modified = modified or transform_modified
        return comp, modified

    return transformation_utils.transform_postorder_with_symbol_bindings(
        comp, _transform_fn, symbol_tree)
Example #8
0
    def compile(self, computation_to_compile):
        """Compiles `computation_to_compile`.

    Args:
      computation_to_compile: An instance of `computation_base.Computation` to
        compile.

    Returns:
      An instance of `computation_base.Computation` that repeesents the result.
    """
        py_typecheck.check_type(computation_to_compile,
                                computation_base.Computation)
        computation_proto = computation_impl.ComputationImpl.get_proto(
            computation_to_compile)
        py_typecheck.check_type(computation_proto, pb.Computation)
        comp = building_blocks.ComputationBuildingBlock.from_proto(
            computation_proto)

        # TODO(b/113123410): Add a compiler options argument that characterizes the
        # desired form of the output. To be driven by what the specific backend the
        # pipeline is targeting is able to understand. Pending a more fleshed out
        # design of the backend API.

        # Replace intrinsics with their bodies, for now manually in a fixed order.
        # TODO(b/113123410): Replace this with a more automated implementation that
        # does not rely on manual maintenance.
        comp, _ = value_transformations.replace_all_intrinsics_with_bodies(
            comp, self._context_stack)

        # Replaces called lambdas with LET constructs with a single local symbol.
        comp, _ = transformations.replace_called_lambda_with_block(comp)

        # Removes maped or applied identities.
        comp, _ = transformations.remove_mapped_or_applied_identity(comp)

        # Remove duplicate computations. This is important! otherwise the semantics
        # non-deterministic computations (e.g. a `tff.tf_computation` depending on
        # `tf.random`) will give unexpected behavior. Additionally, this may reduce
        # the amount of calls into TF for some ASTs.
        comp, _ = transformations.uniquify_reference_names(comp)
        comp, _ = transformations.extract_computations(comp)
        comp, _ = transformations.remove_duplicate_computations(comp)

        return computation_impl.ComputationImpl(comp.proto,
                                                self._context_stack)
    def compile(self, computation_to_compile):
        """Compiles `computation_to_compile`.

    Args:
      computation_to_compile: An instance of `computation_base.Computation` to
        compile.

    Returns:
      An instance of `computation_base.Computation` that repeesents the result.
    """
        py_typecheck.check_type(computation_to_compile,
                                computation_base.Computation)
        computation_proto = computation_impl.ComputationImpl.get_proto(
            computation_to_compile)

        # TODO(b/113123410): Add a compiler options argument that characterizes the
        # desired form of the output. To be driven by what the specific backend the
        # pipeline is targeting is able to understand. Pending a more fleshed out
        # design of the backend API.

        py_typecheck.check_type(computation_proto, pb.Computation)
        comp = computation_building_blocks.ComputationBuildingBlock.from_proto(
            computation_proto)

        # Replace intrinsics with their bodies, for now manually in a fixed order.
        # TODO(b/113123410): Replace this with a more automated implementation that
        # does not rely on manual maintenance.
        for uri, body in six.iteritems(self._intrinsic_bodies):
            comp, _ = transformations.replace_intrinsic_with_callable(
                comp, uri, body, self._context_stack)

        # Replaces called lambdas with LET constructs with a single local symbol.
        comp, _ = transformations.replace_called_lambda_with_block(comp)
        # TODO(b/113123410): Add more transformations to simplify and optimize the
        # structure, e.g., such as:
        # * removing unnecessary lambdas,
        # * flatteting the structure,
        # * merging TensorFlow blocks where appropriate,
        # * ...and so on.

        return computation_impl.ComputationImpl(comp.proto,
                                                self._context_stack)
Example #10
0
 def test_inline_conflicting_lambdas(self):
     comp = computation_building_blocks.Tuple(
         [computation_building_blocks.Data('test', tf.int32)])
     input1 = computation_building_blocks.Reference('input2',
                                                    comp.type_signature)
     first_level_call = computation_building_blocks.Call(
         computation_building_blocks.Lambda('input2', input1.type_signature,
                                            input1), comp)
     input2 = computation_building_blocks.Reference(
         'input2', first_level_call.type_signature)
     second_level_call = computation_building_blocks.Call(
         computation_building_blocks.Lambda('input2', input2.type_signature,
                                            input2), first_level_call)
     self.assertEqual(str(second_level_call),
                      '(input2 -> input2)((input2 -> input2)(<test>))')
     lambda_reduced_comp = transformations.replace_called_lambda_with_block(
         second_level_call)
     self.assertEqual(
         str(lambda_reduced_comp),
         '(let input2=(let input2=<test> in input2) in input2)')
     inlined = transformations.inline_blocks_with_n_referenced_locals(
         lambda_reduced_comp)
     self.assertEqual(str(inlined), '(let  in (let  in <test>))')
Example #11
0
 def test_replace_called_lambda_raises_type_error(self):
     with self.assertRaises(TypeError):
         transformations.replace_called_lambda_with_block(None)