Esempio n. 1
0
def replace_intrinsics_with_bodies(comp):
  """Reduces intrinsics to their bodies as defined in `intrinsic_bodies.py`.

  This function operates on the AST level; meaning, it takes in a
  `building_blocks.ComputationBuildingBlock` as an argument and
  returns one as well. `replace_intrinsics_with_bodies` is intended to be the
  standard reduction function, which will reduce all currently implemented
  intrinsics to their bodies.

  Notice that the success of this function depends on the contract of
  `intrinsic_bodies.get_intrinsic_bodies`, that the dict returned by that
  function is ordered from more complex intrinsic to less complex intrinsics.

  Args:
    comp: Instance of `building_blocks.ComputationBuildingBlock` in
      which we wish to replace all intrinsics with their bodies.

  Returns:
    An instance of `building_blocks.ComputationBuildingBlock` with
    all intrinsics defined in `intrinsic_bodies.py` replaced with their bodies.
  """
  py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
  context_stack = context_stack_impl.context_stack
  comp, _ = value_transformations.replace_all_intrinsics_with_bodies(
      comp, context_stack)
  return comp
Esempio n. 2
0
  def test_generic_divide_reduces(self):
    uri = intrinsic_defs.GENERIC_DIVIDE.uri
    context_stack = context_stack_impl.context_stack
    comp = building_blocks.Intrinsic(
        uri, computation_types.FunctionType([tf.float32, tf.float32],
                                            tf.float32))

    count_before_reduction = _count_intrinsics(comp, uri)
    reduced, modified = value_transformations.replace_all_intrinsics_with_bodies(
        comp, context_stack)
    count_after_reduction = _count_intrinsics(reduced, uri)

    self.assertGreater(count_before_reduction, 0)
    self.assertEqual(count_after_reduction, 0)
    tree_analysis.check_intrinsics_whitelisted_for_reduction(reduced)
    self.assertTrue(modified)
Esempio n. 3
0
    def compile(self, computation_to_compile):
        """Compiles `computation_to_compile`.

    Args:
      computation_to_compile: An instance of `computation_base.Computation` to
        compile.

    Returns:
      An instance of `computation_base.Computation` that repeesents the result.
    """
        py_typecheck.check_type(computation_to_compile,
                                computation_base.Computation)
        computation_proto = computation_impl.ComputationImpl.get_proto(
            computation_to_compile)
        py_typecheck.check_type(computation_proto, pb.Computation)
        comp = building_blocks.ComputationBuildingBlock.from_proto(
            computation_proto)

        # TODO(b/113123410): Add a compiler options argument that characterizes the
        # desired form of the output. To be driven by what the specific backend the
        # pipeline is targeting is able to understand. Pending a more fleshed out
        # design of the backend API.

        # Replace intrinsics with their bodies, for now manually in a fixed order.
        # TODO(b/113123410): Replace this with a more automated implementation that
        # does not rely on manual maintenance.
        comp, _ = value_transformations.replace_all_intrinsics_with_bodies(
            comp, self._context_stack)

        # Replaces called lambdas with LET constructs with a single local symbol.
        comp, _ = transformations.replace_called_lambda_with_block(comp)

        # Removes maped or applied identities.
        comp, _ = transformations.remove_mapped_or_applied_identity(comp)

        # Remove duplicate computations. This is important! otherwise the semantics
        # non-deterministic computations (e.g. a `tff.tf_computation` depending on
        # `tf.random`) will give unexpected behavior. Additionally, this may reduce
        # the amount of calls into TF for some ASTs.
        comp, _ = transformations.uniquify_reference_names(comp)
        comp, _ = transformations.extract_computations(comp)
        comp, _ = transformations.remove_duplicate_computations(comp)

        return computation_impl.ComputationImpl(comp.proto,
                                                self._context_stack)
    def test_federated_weighted_mean_reduces(self):
        uri = intrinsic_defs.FEDERATED_WEIGHTED_MEAN.uri
        context_stack = context_stack_impl.context_stack

        @computations.federated_computation(
            computation_types.FederatedType(tf.float32, placements.CLIENTS))
        def foo(x):
            return intrinsics.federated_mean(x, x)

        foo_building_block = building_blocks.ComputationBuildingBlock.from_proto(
            foo._computation_proto)
        count_before_reduction = _count_intrinsics(foo_building_block, uri)
        reduced, modified = value_transformations.replace_all_intrinsics_with_bodies(
            foo_building_block, context_stack)
        count_after_reduction = _count_intrinsics(reduced, uri)
        self.assertGreater(count_before_reduction, 0)
        self.assertEqual(count_after_reduction, 0)
        self.assertTrue(modified)
 def test_raises_on_none(self):
     context_stack = context_stack_impl.context_stack
     with self.assertRaises(TypeError):
         value_transformations.replace_all_intrinsics_with_bodies(
             None, context_stack)