예제 #1
0
    def test_compile_computation(self):
        @computations.federated_computation([
            computation_types.FederatedType(tf.float32, placements.CLIENTS),
            computation_types.FederatedType(tf.float32, placements.SERVER,
                                            True)
        ])
        def foo(temperatures, threshold):
            return intrinsics.federated_sum(
                intrinsics.federated_map(
                    computations.tf_computation(
                        lambda x, y: tf.to_int32(tf.greater(x, y)),
                        [tf.float32, tf.float32]),
                    [temperatures,
                     intrinsics.federated_broadcast(threshold)]))

        pipeline = compiler_pipeline.CompilerPipeline(
            context_stack_impl.context_stack)

        compiled_foo = pipeline.compile(foo)

        def _not_federated_sum(x):
            if isinstance(x, computation_building_blocks.Intrinsic):
                self.assertNotEqual(x.uri, intrinsic_defs.FEDERATED_SUM.uri)
            return x, False

        transformation_utils.transform_postorder(
            computation_building_blocks.ComputationBuildingBlock.from_proto(
                computation_impl.ComputationImpl.get_proto(compiled_foo)),
            _not_federated_sum)
예제 #2
0
def count_tensorflow_variables_under(comp):
  """Counts total TF variables in any TensorFlow computations under `comp`.

  Notice that this function is designed for the purpose of instrumentation,
  in particular to check the size and constituents of the TensorFlow
  artifacts generated.

  Args:
    comp: Instance of `building_blocks.ComputationBuildingBlock` whose TF
      variables we wish to count.

  Returns:
    `integer` count of number of TF variables present in any
    `building_blocks.CompiledComputation` of the TensorFlow
    variety under `comp`.
  """
  py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
  # TODO(b/129791812): Cleanup Python 2 and 3 compatibility
  total_tf_vars = [0]

  def _count_tf_vars(inner_comp):
    if isinstance(
        inner_comp, building_blocks.CompiledComputation
    ) and inner_comp.proto.WhichOneof('computation') == 'tensorflow':
      total_tf_vars[0] += building_block_analysis.count_tensorflow_variables_in(
          inner_comp)
    return inner_comp, False

  transformation_utils.transform_postorder(comp, _count_tf_vars)
  return total_tf_vars[0]
예제 #3
0
def check_has_single_placement(comp, single_placement):
    """Checks that the AST of `comp` contains only `single_placement`.

  Args:
    comp: Instance of `computation_building_blocks.ComputationBuildingBlock`.
    single_placement: Instance of `placement_literals.PlacementLiteral` which
      should be the only placement present under `comp`.

  Raises:
    ValueError: If the AST under `comp` contains any
    `computation_types.FederatedType` other than `single_placement`.
  """
    py_typecheck.check_type(
        comp, computation_building_blocks.ComputationBuildingBlock)
    py_typecheck.check_type(single_placement,
                            placement_literals.PlacementLiteral)

    def _check_single_placement(comp):
        """Checks that the placement in `type_spec` matches `single_placement`."""
        if (isinstance(comp.type_signature, computation_types.FederatedType)
                and comp.type_signature.placement != single_placement):
            raise ValueError(
                'Comp contains a placement other than {}; '
                'placement {} on comp {} inside the structure. '.format(
                    single_placement, comp.type_signature.placement,
                    computation_building_blocks.compact_representation(comp)))
        return comp, False

    transformation_utils.transform_postorder(comp, _check_single_placement)
예제 #4
0
def _get_unbound_references(comp):
    """Gets a Python `dict` of the unbound references in `comp`.

  Compuations that are equal will have the same collections of unbounded
  references, so it is safe to use `comp` as the key for this `dict` even though
  a given compuation may appear in many positions in the AST.

  Args:
    comp: The computation building block to parse.

  Returns:
    A Python `dict` of elements where keys are the compuations in `comp` and
    values are a Python `set` of the names of the unbound references in the
    subtree of that compuation.
  """
    py_typecheck.check_type(
        comp, computation_building_blocks.ComputationBuildingBlock)
    references = {}

    def _update(comp):
        """Updates the Python dict of references."""
        if isinstance(comp, computation_building_blocks.Reference):
            references[comp] = set((comp.name, ))
        elif isinstance(comp, computation_building_blocks.Block):
            references[comp] = set()
            names = []
            for name, variable in comp.locals:
                elements = references[variable]
                references[comp].update(
                    [e for e in elements if e not in names])
                names.append(name)
            elements = references[comp.result]
            references[comp].update([e for e in elements if e not in names])
        elif isinstance(comp, computation_building_blocks.Call):
            elements = references[comp.function]
            if comp.argument is not None:
                elements.update(references[comp.argument])
            references[comp] = elements
        elif isinstance(comp, computation_building_blocks.Lambda):
            elements = references[comp.result]
            references[comp] = set(
                [e for e in elements if e != comp.parameter_name])
        elif isinstance(comp, computation_building_blocks.Selection):
            references[comp] = references[comp.source]
        elif isinstance(comp, computation_building_blocks.Tuple):
            elements = [references[e] for e in comp]
            references[comp] = set(itertools.chain.from_iterable(elements))
        else:
            references[comp] = set()
        return comp, False

    transformation_utils.transform_postorder(comp, _update)
    return references
예제 #5
0
def _get_number_of_nodes(comp, predicate=None):
    """Returns the number of nodes in `comp` matching `predicate`."""
    py_typecheck.check_type(
        comp, computation_building_blocks.ComputationBuildingBlock)
    count = [0]  # TODO(b/129791812): Cleanup Python 2 and 3 compatibility.

    def fn(comp):
        if predicate is None or predicate(comp):
            count[0] += 1
        return comp

    transformation_utils.transform_postorder(comp, fn)
    return count[0]
예제 #6
0
def uniquify_compiled_computation_names(comp):
  """Replaces all the compiled computations names in `comp` with unique names.

  This transform traverses `comp` postorder and replaces the name of all the
  comiled computations found in `comp` with a unique name.

  Args:
    comp: The computation building block in which to perform the replacements.

  Returns:
    A new computation with the transformation applied or the original `comp`.

  Raises:
    TypeError: If types do not match.
  """
  py_typecheck.check_type(comp,
                          computation_building_blocks.ComputationBuildingBlock)
  name_generator = computation_constructing_utils.unique_name_generator(
      None, prefix='')

  def _should_transform(comp):
    return isinstance(comp, computation_building_blocks.CompiledComputation)

  def _transform(comp):
    if not _should_transform(comp):
      return comp, False
    transformed_comp = computation_building_blocks.CompiledComputation(
        comp.proto, six.next(name_generator))
    return transformed_comp, True

  return transformation_utils.transform_postorder(comp, _transform)
예제 #7
0
def remove_mapped_or_applied_identity(comp):
  r"""Removes all the mapped or applied identity functions in `comp`.

  This transform traverses `comp` postorder, matches the following pattern, and
  removes all the mapped or applied identity fucntions by replacing the
  following computation:

            Call
           /    \
  Intrinsic      Tuple
                 |
                 [Lambda(x), Comp(y)]
                           \
                            Ref(x)

  Intrinsic(<(x -> x), y>)

  with its argument:

  Comp(y)

  y

  Args:
    comp: The computation building block in which to perform the removals.

  Returns:
    A new computation with the transformation applied or the original `comp`.

  Raises:
    TypeError: If types do not match.
  """
  py_typecheck.check_type(comp,
                          computation_building_blocks.ComputationBuildingBlock)

  def _should_transform(comp):
    """Returns `True` if `comp` is a mapped or applied identity function."""
    if (isinstance(comp, computation_building_blocks.Call) and
        isinstance(comp.function, computation_building_blocks.Intrinsic) and
        comp.function.uri in (
            intrinsic_defs.FEDERATED_MAP.uri,
            intrinsic_defs.FEDERATED_APPLY.uri,
            intrinsic_defs.SEQUENCE_MAP.uri,
        )):
      called_function = comp.argument[0]
      if _is_identity_function(called_function):
        return True
    return False

  def _transform(comp):
    if not _should_transform(comp):
      return comp, False
    transformed_comp = comp.argument[1]
    return transformed_comp, True

  return transformation_utils.transform_postorder(comp, _transform)
예제 #8
0
def count(comp, predicate=None):
    """Returns the number of computations in `comp` matching `predicate`.

  Args:
    comp: The computation to test.
    predicate: A Python function that takes a computation as a parameter and
      returns a boolean value.
  """
    py_typecheck.check_type(
        comp, computation_building_blocks.ComputationBuildingBlock)
    counter = [0]

    def _function(comp):
        if predicate is None or predicate(comp):
            counter[0] += 1
        return comp, False

    transformation_utils.transform_postorder(comp, _function)
    return counter[0]
예제 #9
0
    def test_parameters_are_mapped_together(self):
        x_reference = building_blocks.Reference('x', tf.int32)
        x_lambda = building_blocks.Lambda('x', tf.int32, x_reference)
        y_reference = building_blocks.Reference('y', tf.int32)
        y_lambda = building_blocks.Lambda('y', tf.int32, y_reference)
        concatenated = mapreduce_transformations.concatenate_function_outputs(
            x_lambda, y_lambda)
        parameter_name = concatenated.parameter_name

        def _raise_on_other_name_reference(comp):
            if isinstance(
                    comp,
                    building_blocks.Reference) and comp.name != parameter_name:
                raise ValueError
            return comp, True

        tree_analysis.check_has_unique_names(concatenated)
        transformation_utils.transform_postorder(
            concatenated, _raise_on_other_name_reference)
예제 #10
0
def replace_called_lambda_with_block(comp):
    r"""Replaces all the called lambdas in `comp` with a block.

  This transform traverses `comp` postorder, matches the following pattern, and
  replaces the following computation containing a called lambda:

            Call
           /    \
  Lambda(x)      Comp(y)
           \
            Comp(z)

  (x -> z)(y)

  with the following computation containing a block:

             Block
            /     \
  [x=Comp(y)]       Comp(z)

  let x=y in z

  The functional computation `b` and the argument `c` are retained; the other
  computations are replaced. This transformation is used to facilitate the
  merging of TFF orchestration logic, in particular to remove unnecessary lambda
  expressions and as a stepping stone for merging Blocks together.

  Args:
    comp: The computation building block in which to perform the replacements.

  Returns:
    A new computation with the transformation applied or the original `comp`.

  Raises:
    TypeError: If types do not match.
  """
    py_typecheck.check_type(
        comp, computation_building_blocks.ComputationBuildingBlock)

    def _should_transform(comp):
        return (isinstance(comp, computation_building_blocks.Call) and
                isinstance(comp.function, computation_building_blocks.Lambda))

    def _transform(comp):
        if not _should_transform(comp):
            return comp, False
        transformed_comp = computation_building_blocks.Block(
            [(comp.function.parameter_name, comp.argument)],
            comp.function.result)
        return transformed_comp, True

    return transformation_utils.transform_postorder(comp, _transform)
예제 #11
0
def check_intrinsics_whitelisted_for_reduction(comp):
    """Checks whitelist of intrinsics reducible to aggregate or broadcast.

  Args:
    comp: Instance of `computation_building_blocks.ComputationBuildingBlock` to
      check for presence of intrinsics not currently immediately reducible to
      `FEDERATED_AGGREGATE` or `FEDERATED_BROADCAST`, or local processing.

  Raises:
    ValueError: If we encounter an intrinsic under `comp` that is not
    whitelisted as currently reducible.
  """
    # TODO(b/135930668): Factor this and other non-transforms (e.g.
    # `check_has_unique_names` out of this file into a structure specified for
    # static analysis of ASTs.
    py_typecheck.check_type(
        comp, computation_building_blocks.ComputationBuildingBlock)
    uri_whitelist = (
        intrinsic_defs.FEDERATED_AGGREGATE.uri,
        intrinsic_defs.FEDERATED_APPLY.uri,
        intrinsic_defs.FEDERATED_BROADCAST.uri,
        intrinsic_defs.FEDERATED_MAP.uri,
        intrinsic_defs.FEDERATED_MAP_ALL_EQUAL.uri,
        intrinsic_defs.FEDERATED_VALUE_AT_CLIENTS.uri,
        intrinsic_defs.FEDERATED_VALUE_AT_SERVER.uri,
        intrinsic_defs.FEDERATED_ZIP_AT_SERVER.uri,
        intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS.uri,
    )

    def _check_whitelisted(comp):
        if isinstance(comp, computation_building_blocks.Intrinsic
                      ) and comp.uri not in uri_whitelist:
            raise ValueError(
                'Encountered an Intrinsic not currently reducible to aggregate or '
                'broadcast, the intrinsic {}'.format(
                    computation_building_blocks.compact_representation(comp)))
        return comp, False

    transformation_utils.transform_postorder(comp, _check_whitelisted)
예제 #12
0
def replace_selection_from_tuple_with_tuple_element(comp):
    r"""Replaces any selection from a tuple with the underlying tuple element.

  Replaces any occurences of:

                              Selection
                                  |
                                Tuple
                               / ... \
                           Comp  ...  Comp

  with the appropriate Comp, as determined by the `index` or `name` of the
  `Selection`.

  Args:
    comp: Instance of `computation_building_blocks.ComputationBuildingBlock` to
      transform.

  Returns:
    A possibly modified version of comp, without any occurrences of selections
    from tuples.

  Raises:
    TypeError: If `comp` is not an instance of
      `computation_building_blocks.ComputationBuildingBlock`.
  """
    py_typecheck.check_type(
        comp, computation_building_blocks.ComputationBuildingBlock)

    def _should_transform(comp):
        if (isinstance(comp, computation_building_blocks.Selection) and
                isinstance(comp.source, computation_building_blocks.Tuple)):
            return True
        return False

    def _get_index_from_name(selection_name, tuple_type_signature):
        type_elements = anonymous_tuple.to_elements(tuple_type_signature)
        return [x[0] for x in type_elements].index(selection_name)

    def _transform(comp):
        if not _should_transform(comp):
            return comp, False
        if comp.name is not None:
            index = _get_index_from_name(comp.name, comp.source.type_signature)
        else:
            index = comp.index
        return comp.source[index], True

    return transformation_utils.transform_postorder(comp, _transform)
예제 #13
0
def replace_intrinsic_with_callable(comp, uri, body, context_stack):
    """Replaces all the intrinsics with the given `uri` with a callable.

  This transform traverses `comp` postorder and replaces all the intrinsics with
  the given `uri` with a polymorphic callable that represents the body of the
  implementation of the intrinsic; i.e., one that given the parameter of the
  intrinsic constructs the intended result. This will typically be a Python
  function decorated with `@federated_computation` to make it into a polymorphic
  callable.

  Args:
    comp: The computation building block in which to perform the replacements.
    uri: The URI of the intrinsic to replace.
    body: A polymorphic callable.
    context_stack: The context stack to use.

  Returns:
    A new computation with the transformation applied or the original `comp`.

  Raises:
    TypeError: If types do not match.
  """
    py_typecheck.check_type(
        comp, computation_building_blocks.ComputationBuildingBlock)
    py_typecheck.check_type(uri, six.string_types)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    if not callable(body):
        raise TypeError('The body of the intrinsic must be a callable.')

    def _should_transform(comp):
        return (isinstance(comp, computation_building_blocks.Intrinsic)
                and comp.uri == uri and isinstance(
                    comp.type_signature, computation_types.FunctionType))

    def _transform(comp):
        """Internal transform function."""
        if not _should_transform(comp):
            return comp
        # We need 'wrapped_body' to accept exactly one argument.
        wrapped_body = lambda x: body(x)  # pylint: disable=unnecessary-lambda
        return federated_computation_utils.zero_or_one_arg_fn_to_building_block(
            wrapped_body,
            'arg',
            comp.type_signature.parameter,
            context_stack,
            suggested_name=uri)

    return transformation_utils.transform_postorder(comp, _transform)
예제 #14
0
def merge_chained_blocks(comp):
    r"""Merges Block constructs defined in sequence in the AST of `comp`.

  Looks for occurrences of the following pattern:

        Block
       /     \
  [...]       Block
             /     \
        [...]       Comp(x)

  And merges them to

        Block
       /     \
  [...]       Comp(x)

  Preserving the relative ordering of any locals declarations in a postorder
  walk, which therefore preserves scoping rules.

  Notice that because TFF Block constructs bind their variables in sequence, it
  is completely safe to add the locals lists together in this implementation,

  Args:
    comp: The `computation_building_blocks.ComputationBuildingBlock` whose
      blocks should be merged if possible.

  Returns:
    Transformed version of `comp` with its neighboring blocks merged.
  """
    py_typecheck.check_type(
        comp, computation_building_blocks.ComputationBuildingBlock)

    def _should_transform(comp):
        return (isinstance(comp, computation_building_blocks.Block)
                and isinstance(comp.result, computation_building_blocks.Block))

    def _transform(comp):
        if not _should_transform(comp):
            return comp, False
        transformed_comp = computation_building_blocks.Block(
            comp.locals + comp.result.locals, comp.result.result)
        return transformed_comp, True

    return transformation_utils.transform_postorder(comp, _transform)
예제 #15
0
def inline_blocks_with_n_referenced_locals(comp, inlining_threshold=1):
    """Replaces locals referenced few times in `comp` with bound values.

  Args:
    comp: The computation building block in which to inline the locals which
      occur only `inlining_threshold` times in the result computation.
    inlining_threshold: The threshhold below which to inline computations. E.g.
      if `inlining_threshold` is 1, locals which are referenced exactly once
      will be inlined, but locals which are referenced twice or more will not.

  Returns:
    A modified version of `comp` for which all occurrences of
    `computation_building_blocks.Block`s with locals which
      are referenced `inlining_threshold` or fewer times inlined with the value
      of the local.
  """

    py_typecheck.check_type(
        comp, computation_building_blocks.ComputationBuildingBlock)
    count_snapshot = transformation_utils.scope_count_snapshot(comp)
    op = transformation_utils.InlineReferences(inlining_threshold,
                                               count_snapshot, comp)
    return transformation_utils.transform_postorder(comp, op)
예제 #16
0
def replace_chained_federated_maps_with_federated_map(comp):
    r"""Replaces all the chained federated maps in `comp` with one federated map.

  This transform traverses `comp` postorder, matches the following pattern `*`,
  and replaces the following computation containing two federated map
  intrinsics:

            *Call
            /    \
  *Intrinsic     *Tuple
                 /     \
    x=Computation      *Call
                       /    \
             *Intrinsic     *Tuple
                            /     \
               y=Computation       z=Computation

  federated_map(<x, federated_map(<y, z>)>)

  with the following computation containing one federated map intrinsic:

            Call
           /    \
  Intrinsic      Tuple
                /     \
       Lambda(a)       z=Computation
                \
                 Call
                /    \
   x=Computation      Call
                     /    \
        y=Computation      Reference(a)

  federated_map(<(a -> x(y(a))), z>)

  The functional computations `x` and `y`, and the argument `z` are retained;
  the other computations are replaced.

  Args:
    comp: The computation building block in which to perform the replacements.

  Returns:
    A new computation with the transformation applied or the original `comp`.

  Raises:
    TypeError: If types do not match.
  """
    py_typecheck.check_type(
        comp, computation_building_blocks.ComputationBuildingBlock)

    def _is_federated_map(comp):
        """Returns `True` if `comp` is a federated map."""
        return (isinstance(comp, computation_building_blocks.Call)
                and isinstance(comp.function,
                               computation_building_blocks.Intrinsic)
                and comp.function.uri == intrinsic_defs.FEDERATED_MAP.uri)

    def _should_transform(comp):
        """Returns `True` if `comp` is a chained federated map."""
        if _is_federated_map(comp):
            outer_arg = comp.argument[1]
            if _is_federated_map(outer_arg):
                return True
        return False

    def _transform(comp):
        """Internal transform function."""
        if not _should_transform(comp):
            return comp
        map_arg = comp.argument[1].argument[1]
        inner_arg = computation_building_blocks.Reference(
            'inner_arg', map_arg.type_signature.member)
        inner_fn = comp.argument[1].argument[0]
        inner_call = computation_building_blocks.Call(inner_fn, inner_arg)
        outer_fn = comp.argument[0]
        outer_call = computation_building_blocks.Call(outer_fn, inner_call)
        map_lambda = computation_building_blocks.Lambda(
            inner_arg.name, inner_arg.type_signature, outer_call)
        map_tuple = computation_building_blocks.Tuple([map_lambda, map_arg])
        map_intrinsic_type = computation_types.FunctionType(
            map_tuple.type_signature, comp.function.type_signature.result)
        map_intrinsic = computation_building_blocks.Intrinsic(
            comp.function.uri, map_intrinsic_type)
        return computation_building_blocks.Call(map_intrinsic, map_tuple)

    return transformation_utils.transform_postorder(comp, _transform)
예제 #17
0
def replace_tuple_intrinsics_with_intrinsic(comp):
    r"""Replaces all the tuples of intrinsics in `comp` with one intrinsic.

  This transform traverses `comp` postorder, matches the following pattern, and
  replaces the following computation containing a tuple of called intrinsics all
  represeting the same operation:

           Tuple
           |
           [Call,                        Call, ...]
           /    \                       /    \
  Intrinsic      Tuple         Intrinsic      Tuple
                 |                            |
        [Comp(f1), Comp(v1), ...]    [Comp(f2), Comp(v2), ...]

  <Intrinsic(<f1, v1>), Intrinsic(<f2, v2>)>

  with the following computation containing one called intrinsic:

            Call
           /    \
  Intrinsic      Tuple
                 |
                 [Block,               Tuple, ...]
                 /     \               |
         fn=Tuple       Lambda(arg)    [Comp(v1), Comp(v2), ...]
            |                      \
   [Comp(f1), Comp(f2), ...]        Tuple
                                    |
                               [Call,                  Call, ...]
                               /    \                 /    \
                         Sel(0)      Sel(0)     Sel(1)      Sel(1)
                        /           /          /           /
                 Ref(fn)    Ref(arg)    Ref(fn)    Ref(arg)

  Intrinsic(<
    (let fn=<f1, f2> in (arg -> <fn[0](arg[0]), fn[1](arg[1])>)),
    <v1, v2>,
  >)

  The functional computations `f1`, `f2`, etc..., and the computations `v1`,
  `v2`, etc... are retained; the other computations are replaced.

  NOTE: This is just an example of what this transformation would look like when
  applied to a tuple of federated maps. The components `f1`, `f2`, `v1`, and
  `v2` and the number of those components are not important.

  NOTE: This transformation is implemented to match the following intrinsics:

  * intrinsic_defs.FEDERATED_MAP.uri
  * intrinsic_defs.FEDERATED_AGGREGATE.uri

  Args:
    comp: The computation building block in which to perform the replacements.

  Returns:
    A new computation with the transformation applied or the original `comp`.

  Raises:
    TypeError: If types do not match.
  """
    py_typecheck.check_type(
        comp, computation_building_blocks.ComputationBuildingBlock)

    def _should_transform(comp):
        uri = (
            intrinsic_defs.FEDERATED_MAP.uri,
            intrinsic_defs.FEDERATED_AGGREGATE.uri,
        )
        return (isinstance(comp, computation_building_blocks.Tuple)
                and _is_called_intrinsic(comp[0], uri) and all(
                    _is_called_intrinsic(element, comp[0].function.uri)
                    for element in comp))

    def _transform(comp):
        """Returns a new transformed computation or `comp`."""
        if not _should_transform(comp):
            return comp

        def _get_comps(comp):
            """Constructs a 2 dimentional Python list of computations.

      Args:
        comp: A `computation_building_blocks.Tuple` containing `n` called
          intrinsics with `m` arguments.

      Returns:
        A 2 dimentional Python list of computations.
      """
            first_call = comp[0]
            comps = [[] for _ in range(len(first_call.argument))]
            for _, call in anonymous_tuple.to_elements(comp):
                for index, arg in enumerate(call.argument):
                    comps[index].append(arg)
            return comps

        def _create_block_to_calls(call_names, comps):
            r"""Constructs a transformed block computation from `comps`.

      Given the "original" computation containing `n` called intrinsics
      with `m` arguments, this function constructs the following computation:

                     Block
                    /     \
            fn=Tuple       Lambda(arg)
               |                      \
      [Comp(f1), Comp(f2), ...]        Tuple
                                       |
                                  [Call,                  Call, ...]
                                  /    \                 /    \
                            Sel(0)      Sel(0)     Sel(1)      Sel(1)
                           /           /          /           /
                    Ref(fn)    Ref(arg)    Ref(fn)    Ref(arg)

      with one `computation_building_blocks.Call` for each `n`. This computation
      represents one of `m` arguments that should be passed to the call of the
      "transformed" computation.

      Args:
        call_names: a Python list of names.
        comps: a Python list of computations.

      Returns:
        A `computation_building_blocks.Block`.
      """
            functions = computation_building_blocks.Tuple(
                zip(call_names, comps))
            fn = computation_building_blocks.Reference(
                'fn', functions.type_signature)
            arg_type = [element.type_signature.parameter for element in comps]
            arg = computation_building_blocks.Reference('arg', arg_type)
            elements = []
            for index, name in enumerate(call_names):
                sel_fn = computation_building_blocks.Selection(fn, index=index)
                sel_arg = computation_building_blocks.Selection(arg,
                                                                index=index)
                call = computation_building_blocks.Call(sel_fn, sel_arg)
                elements.append((name, call))
            calls = computation_building_blocks.Tuple(elements)
            lam = computation_building_blocks.Lambda(arg.name,
                                                     arg.type_signature, calls)
            return computation_building_blocks.Block([('fn', functions)], lam)

        def _create_transformed_args_from_comps(call_names, elements):
            """Constructs a Python list of transformed computations.

      Given the "original" computation containing `n` called intrinsics
      with `m` arguments, this function constructs the following Python list
      of computations:

      [Block, Tuple, ...]

      with one `computation_building_blocks.Block` for each functional
      computation in `m` and one `computation_building_blocks.Tuple` for each
      non-functional computation in `m`. This list of computations represent the
      arguments that should be passed to the `computation_building_blocks.Call`
      of the "transformed" computation.

      Args:
        call_names: a Python list of names.
        elements: A 2 dimentional Python list of computations.

      Returns:
        A Python list of computations.
      """
            args = []
            for comps in elements:
                if isinstance(comps[0].type_signature,
                              computation_types.FunctionType):
                    arg = _create_block_to_calls(call_names, comps)
                else:
                    arg = computation_building_blocks.Tuple(
                        zip(call_names, comps))
                args.append(arg)
            return args

        elements = anonymous_tuple.to_elements(comp)
        call_names = [name for name, _ in elements]
        comps = _get_comps(comp)
        args = _create_transformed_args_from_comps(call_names, comps)
        arg = computation_building_blocks.Tuple(args)
        parameter_type = computation_types.to_type(arg.type_signature)
        result_type = [(name, call.type_signature) for name, call in elements]
        intrinsic_type = computation_types.FunctionType(
            parameter_type, result_type)
        intrinsic = computation_building_blocks.Intrinsic(
            comp[0].function.uri, intrinsic_type)
        return computation_building_blocks.Call(intrinsic, arg)

    return transformation_utils.transform_postorder(comp, _transform)
예제 #18
0
def remove_mapped_or_applied_identity(comp):
    r"""Removes all the mapped or applied identity functions in `comp`.

  This transform traverses `comp` postorder, matches the follwoing pattern `*`,
  and removes all the mapped or applied identity fucntions by replacing the
  following computation:

            *Call
            /    \
  *Intrinsic     *Tuple
                 /     \
       *Lambda(a)       x=Computation
                 \
                 *Reference(a)

  (<(a -> a), x>)

  with its argument:

  x=Computation

  x

  Args:
    comp: The computation building block in which to perform the replacements.

  Returns:
    A new computation with the transformation applied or the original `comp`.

  Raises:
    TypeError: If types do not match.
  """
    py_typecheck.check_type(
        comp, computation_building_blocks.ComputationBuildingBlock)

    def _is_identity_function(comp):
        """Returns `True` if `comp` is an identity function."""
        return (isinstance(comp, computation_building_blocks.Lambda) and
                isinstance(comp.result, computation_building_blocks.Reference)
                and comp.parameter_name == comp.result.name)

    def _should_transform(comp):
        """Returns `True` if `comp` is a mapped or applied identity function."""
        if (isinstance(comp, computation_building_blocks.Call) and isinstance(
                comp.function, computation_building_blocks.Intrinsic)
                and comp.function.uri in (
                    intrinsic_defs.FEDERATED_MAP.uri,
                    intrinsic_defs.FEDERATED_APPLY.uri,
                    intrinsic_defs.SEQUENCE_MAP.uri,
                )):
            called_function = comp.argument[0]
            if _is_identity_function(called_function):
                return True
        return False

    def _transform(comp):
        if not _should_transform(comp):
            return comp
        called_arg = comp.argument[1]
        return called_arg

    return transformation_utils.transform_postorder(comp, _transform)
예제 #19
0
def merge_tuple_intrinsics(comp, uri):
  r"""Merges all the tuples of intrinsics in `comp` into one intrinsic.

  This transform traverses `comp` postorder, matches the following pattern, and
  replaces the following computation containing a tuple of called intrinsics all
  represeting the same operation:

           Tuple
           |
           [Call,                        Call, ...]
           /    \                       /    \
  Intrinsic      Tuple         Intrinsic      Tuple
                 |                            |
        [Comp(f1), Comp(v1), ...]    [Comp(f2), Comp(v2), ...]

  <Intrinsic(<f1, v1>), Intrinsic(<f2, v2>)>

  with the following computation containing one called intrinsic:

  federated_unzip(Call)
                 /    \
        Intrinsic      Tuple
                       |
                       [Block,    federated_zip(Tuple), ...]
                       /     \                  |
               fn=Tuple       Lambda(arg)       [Comp(v1), Comp(v2), ...]
                  |                      \
         [Comp(f1), Comp(f2), ...]        Tuple
                                          |
                                     [Call,                  Call, ...]
                                     /    \                 /    \
                               Sel(0)      Sel(0)     Sel(1)      Sel(1)
                              /           /          /           /
                       Ref(fn)    Ref(arg)    Ref(fn)    Ref(arg)

  Intrinsic(<
    (let fn=<f1, f2> in (arg -> <fn[0](arg[0]), fn[1](arg[1])>)),
    <v1, v2>,
  >)

  The functional computations `f1`, `f2`, etc..., and the computations `v1`,
  `v2`, etc... are retained; the other computations are replaced.

  NOTE: This is just an example of what this transformation would look like when
  applied to a tuple of federated maps. The components `f1`, `f2`, `v1`, and
  `v2` and the number of those components are not important.

  This transformation is implemented to match the following intrinsics:

  * intrinsic_defs.FEDERATED_AGGREGATE.uri
  * intrinsic_defs.FEDERATED_BROADCAST.uri
  * intrinsic_defs.FEDERATED_MAP.uri

  Args:
    comp: The computation building block in which to perform the merges.
    uri: The URI of the intrinsic to merge.

  Returns:
    A new computation with the transformation applied or the original `comp`.

  Raises:
    TypeError: If types do not match.
  """
  py_typecheck.check_type(comp,
                          computation_building_blocks.ComputationBuildingBlock)
  py_typecheck.check_type(uri, six.string_types)
  expected_uri = (
      intrinsic_defs.FEDERATED_AGGREGATE.uri,
      intrinsic_defs.FEDERATED_BROADCAST.uri,
      intrinsic_defs.FEDERATED_MAP.uri,
  )
  if uri not in expected_uri:
    raise ValueError(
        'The value of `uri` is expected to be on of {}, found {}'.format(
            expected_uri, uri))

  def _should_transform(comp):
    return (isinstance(comp, computation_building_blocks.Tuple) and
            _is_called_intrinsic(comp[0], uri) and all(
                _is_called_intrinsic(element, comp[0].function.uri)
                for element in comp))

  def _transform_functional_args(comps):
    r"""Transforms the functional computations `comps`.

    Given a computation containing `n` called intrinsics with `m` arguments,
    this function constructs the following computation from the functional
    arguments of the called intrinsic:

                    Block
                   /     \
         [fn=Tuple]       Lambda(arg)
             |                       \
    [Comp(f1), Comp(f2), ...]         Tuple
                                      |
                                 [Call,                  Call, ...]
                                 /    \                 /    \
                           Sel(0)      Sel(0)     Sel(1)      Sel(1)
                          /           /          /           /
                   Ref(fn)    Ref(arg)    Ref(fn)    Ref(arg)

    with one `computation_building_blocks.Call` for each `n`. This computation
    represents one of `m` arguments that should be passed to the call of the
    transformed computation.

    Args:
      comps: a Python list of computations.

    Returns:
      A `computation_building_blocks.Block`.
    """
    functions = computation_building_blocks.Tuple(comps)
    fn = computation_building_blocks.Reference('fn', functions.type_signature)
    arg_type = [element.type_signature.parameter for element in comps]
    arg = computation_building_blocks.Reference('arg', arg_type)
    elements = []
    for index in range(len(comps)):
      sel_fn = computation_building_blocks.Selection(fn, index=index)
      sel_arg = computation_building_blocks.Selection(arg, index=index)
      call = computation_building_blocks.Call(sel_fn, sel_arg)
      elements.append(call)
    calls = computation_building_blocks.Tuple(elements)
    lam = computation_building_blocks.Lambda(arg.name, arg.type_signature,
                                             calls)
    return computation_building_blocks.Block((('fn', functions),), lam)

  def _transform_non_functional_args(comps):
    r"""Transforms the non-functional computations `comps`.

    Given a computation containing `n` called intrinsics with `m` arguments,
    this function constructs the following computation from the non-functional
    arguments of the called intrinsic:

    federated_zip(Tuple)
                  |
                  [Comp, Comp, ...]

    or

    Tuple
    |
    [Comp, Comp, ...]

    with one `computation_building_blocks.ComputationBuildignBlock` for each
    `n`. This computation represents one of `m` arguments that should be passed
    to the call of the transformed computation.

    Args:
      comps: A Python list of computations.

    Returns:
      A `computation_building_blocks.Block`.
    """
    values = computation_building_blocks.Tuple(comps)
    first_comp = comps[0]
    if isinstance(first_comp.type_signature, computation_types.FederatedType):
      return computation_constructing_utils.create_federated_zip(values)
    else:
      return values

  def _transform_args(comp):
    """Transforms the arguments from `comp`.

    Given a computation containing a tuple of intrinsics that can be merged,
    this function constructs the follwing computation from the arguments of the
    called intrinsic:

    Tuple
    |
    [Block, federated_zip(Tuple), ...]

    with one `computation_building_blocks.Block` for each functional computation
    in `m` and one called federated zip (or Tuple) for each non-functional
    computation in `m`. This list of computations represent the `m` arguments
    that should be passed to the call of the transformed computation.

    Args:
      comp: The computation building block in which to perform the merges.

    Returns:
      A `computation_building_blocks.ComputationBuildingBlock` representing the
      transformed arguments from `comp`.
    """
    first_comp = comp[0]
    if isinstance(first_comp.argument, computation_building_blocks.Tuple):
      comps = [[] for _ in range(len(first_comp.argument))]
      for _, call in anonymous_tuple.to_elements(comp):
        for index, arg in enumerate(call.argument):
          comps[index].append(arg)
    else:
      comps = [[]]
      for _, call in anonymous_tuple.to_elements(comp):
        comps[0].append(call.argument)
    elements = []
    for args in comps:
      first_args = args[0]
      if isinstance(first_args.type_signature, computation_types.FunctionType):
        transformed_args = _transform_functional_args(args)
      else:
        transformed_args = _transform_non_functional_args(args)
      elements.append(transformed_args)
    if isinstance(first_comp.argument, computation_building_blocks.Tuple):
      return computation_building_blocks.Tuple(elements)
    else:
      return elements[0]

  def _transform(comp):
    """Returns a new transformed computation or `comp`."""
    if not _should_transform(comp):
      return comp, False
    arg = _transform_args(comp)
    first_comp = comp[0]
    named_comps = anonymous_tuple.to_elements(comp)
    parameter_type = computation_types.to_type(arg.type_signature)
    type_signature = [call.type_signature.member for _, call in named_comps]
    result_type = computation_types.FederatedType(
        type_signature, first_comp.type_signature.placement,
        first_comp.type_signature.all_equal)
    intrinsic_type = computation_types.FunctionType(parameter_type, result_type)
    intrinsic = computation_building_blocks.Intrinsic(first_comp.function.uri,
                                                      intrinsic_type)
    call = computation_building_blocks.Call(intrinsic, arg)
    tup = computation_constructing_utils.create_federated_unzip(call)
    names = [name for name, _ in named_comps]
    transformed_comp = computation_constructing_utils.create_named_tuple(
        tup, names)
    return transformed_comp, True

  return transformation_utils.transform_postorder(comp, _transform)
예제 #20
0
def extract_intrinsics(comp):
  r"""Extracts intrinsics to the scope which binds any variable it depends on.

  This transform traverses `comp` postorder, matches the following pattern, and
  replaces the following computation containing a called intrinsic:

        ...
           \
            Call
           /    \
  Intrinsic      ...

  with the following computation containing a block with the extracted called
  intrinsic:

                  Block
                 /     \
         [x=Call]       ...
           /    \          \
  Intrinsic      ...        Ref(x)

  The called intrinsics are extracted to the scope which binds any variable the
  called intrinsic depends. If the called intrinsic is not bound by any
  computation in `comp` it will be extracted to the root. Both the
  `parameter_name` of a `computation_building_blocks.Lambda` and the name of any
  variable defined by a `computation_building_blocks.Block` can affect the scope
  in which a reference in called intrinsic is bound.

  NOTE: This function will also extract blocks to the scope in which they are
  bound because block variables can restrict the scope in which intrinsics are
  bound.

  Args:
    comp: The computation building block in which to perform the extractions.
      The names of lambda parameters and locals in blocks in `comp` must be
      unique.

  Returns:
    A new computation with the transformation applied or the original `comp`.

  Raises:
    TypeError: If types do not match.
    ValueError: If `comp` contains a reference named `name`.
  """
  py_typecheck.check_type(comp,
                          computation_building_blocks.ComputationBuildingBlock)
  _check_has_unique_names(comp)
  name_generator = computation_constructing_utils.unique_name_generator(comp)
  unbound_references = _get_unbound_references(comp)

  def _contains_unbound_reference(comp, names):
    """Returns `True` if `comp` contains unbound references to `names`.

    This function will update the non-local `unbound_references` captured from
    the parent context if `comp` is not contained in that collection. This can
    happen when new computations are created and added to the AST.

    Args:
      comp: The computation building block to test.
      names: A Python string or a list, tuple, or set of Python strings.
    """
    if isinstance(names, six.string_types):
      names = (names,)
    if comp not in unbound_references:
      references = _get_unbound_references(comp)
      unbound_references.update(references)
    return any(n in unbound_references[comp] for n in names)

  def _is_called_intrinsic_or_block(comp):
    """Returns `True` if `comp` is a called intrinsic or a block."""
    return (_is_called_intrinsic(comp) or
            isinstance(comp, computation_building_blocks.Block))

  def _should_transform(comp):
    """Returns `True` if `comp` should be transformed.

    The following `_extract_intrinsic_*` methods all depend on being invoked
    after `_should_transform` evaluates to `True` for a given `comp`. Because of
    this certain assumptions are made:

    * transformation functions will transform a given `comp`
    * block variables are guaranteed to not be empty

    Args:
      comp: The computation building block in which to test.
    """
    if isinstance(comp, computation_building_blocks.Block):
      return (_is_called_intrinsic_or_block(comp.result) or any(
          isinstance(e, computation_building_blocks.Block)
          for _, e in comp.locals))
    elif isinstance(comp, computation_building_blocks.Call):
      return _is_called_intrinsic_or_block(comp.argument)
    elif isinstance(comp, computation_building_blocks.Lambda):
      if _is_called_intrinsic(comp.result):
        return True
      if isinstance(comp.result, computation_building_blocks.Block):
        for index, (_, variable) in enumerate(comp.result.locals):
          names = [n for n, _ in comp.result.locals[:index]]
          if (not _contains_unbound_reference(variable, comp.parameter_name) and
              not _contains_unbound_reference(variable, names)):
            return True
    elif isinstance(comp, computation_building_blocks.Selection):
      return _is_called_intrinsic_or_block(comp.source)
    elif isinstance(comp, computation_building_blocks.Tuple):
      return any(_is_called_intrinsic_or_block(e) for e in comp)
    return False

  def _extract_from_block(comp):
    """Returns a new computation with all intrinsics extracted."""
    if _is_called_intrinsic(comp.result):
      called_intrinsic = comp.result
      name = six.next(name_generator)
      variables = comp.locals
      variables.append((name, called_intrinsic))
      result = computation_building_blocks.Reference(
          name, called_intrinsic.type_signature)
      return computation_building_blocks.Block(variables, result)
    elif isinstance(comp.result, computation_building_blocks.Block):
      return computation_building_blocks.Block(comp.locals + comp.result.locals,
                                               comp.result.result)
    else:
      variables = []
      for name, variable in comp.locals:
        if isinstance(variable, computation_building_blocks.Block):
          variables.extend(variable.locals)
          variables.append((name, variable.result))
        else:
          variables.append((name, variable))
      return computation_building_blocks.Block(variables, comp.result)

  def _extract_from_call(comp):
    """Returns a new computation with all intrinsics extracted."""
    if _is_called_intrinsic(comp.argument):
      called_intrinsic = comp.argument
      name = six.next(name_generator)
      variables = ((name, called_intrinsic),)
      result = computation_building_blocks.Reference(
          name, called_intrinsic.type_signature)
    else:
      block = comp.argument
      variables = block.locals
      result = block.result
    call = computation_building_blocks.Call(comp.function, result)
    block = computation_building_blocks.Block(variables, call)
    return _extract_from_block(block)

  def _extract_from_lambda(comp):
    """Returns a new computation with all intrinsics extracted."""
    if _is_called_intrinsic(comp.result):
      called_intrinsic = comp.result
      name = six.next(name_generator)
      variables = ((name, called_intrinsic),)
      ref = computation_building_blocks.Reference(
          name, called_intrinsic.type_signature)
      if not _contains_unbound_reference(comp.result, comp.parameter_name):
        fn = computation_building_blocks.Lambda(comp.parameter_name,
                                                comp.parameter_type, ref)
        return computation_building_blocks.Block(variables, fn)
      else:
        block = computation_building_blocks.Block(variables, ref)
        return computation_building_blocks.Lambda(comp.parameter_name,
                                                  comp.parameter_type, block)
    else:
      block = comp.result
      extracted_variables = []
      retained_variables = []
      for name, variable in block.locals:
        names = [n for n, _ in retained_variables]
        if (not _contains_unbound_reference(variable, comp.parameter_name) and
            not _contains_unbound_reference(variable, names)):
          extracted_variables.append((name, variable))
        else:
          retained_variables.append((name, variable))
      if retained_variables:
        result = computation_building_blocks.Block(retained_variables,
                                                   block.result)
      else:
        result = block.result
      fn = computation_building_blocks.Lambda(comp.parameter_name,
                                              comp.parameter_type, result)
      block = computation_building_blocks.Block(extracted_variables, fn)
      return _extract_from_block(block)

  def _extract_from_selection(comp):
    """Returns a new computation with all intrinsics extracted."""
    if _is_called_intrinsic(comp.source):
      called_intrinsic = comp.source
      name = six.next(name_generator)
      variables = ((name, called_intrinsic),)
      result = computation_building_blocks.Reference(
          name, called_intrinsic.type_signature)
    else:
      block = comp.source
      variables = block.locals
      result = block.result
    selection = computation_building_blocks.Selection(
        result, name=comp.name, index=comp.index)
    block = computation_building_blocks.Block(variables, selection)
    return _extract_from_block(block)

  def _extract_from_tuple(comp):
    """Returns a new computation with all intrinsics extracted."""
    variables = []
    elements = []
    for name, element in anonymous_tuple.to_elements(comp):
      if _is_called_intrinsic_or_block(element):
        variable_name = six.next(name_generator)
        variables.append((variable_name, element))
        ref = computation_building_blocks.Reference(variable_name,
                                                    element.type_signature)
        elements.append((name, ref))
      else:
        elements.append((name, element))
    tup = computation_building_blocks.Tuple(elements)
    block = computation_building_blocks.Block(variables, tup)
    return _extract_from_block(block)

  def _transform(comp):
    """Returns a new transformed computation or `comp`."""
    if not _should_transform(comp):
      return comp, False
    if isinstance(comp, computation_building_blocks.Block):
      comp = _extract_from_block(comp)
    elif isinstance(comp, computation_building_blocks.Call):
      comp = _extract_from_call(comp)
    elif isinstance(comp, computation_building_blocks.Lambda):
      comp = _extract_from_lambda(comp)
    elif isinstance(comp, computation_building_blocks.Selection):
      comp = _extract_from_selection(comp)
    elif isinstance(comp, computation_building_blocks.Tuple):
      comp = _extract_from_tuple(comp)
    return comp, True

  return transformation_utils.transform_postorder(comp, _transform)
예제 #21
0
def merge_chained_federated_maps_or_applys(comp):
  r"""Merges all the chained federated maps or federated apply in `comp`.

  This transform traverses `comp` postorder, matches the following pattern, and
  replaces the following computation containing two federated map intrinsics:

            Call
           /    \
  Intrinsic      Tuple
                 |
                 [Comp(x), Call]
                          /    \
                 Intrinsic      Tuple
                                |
                                [Comp(y), Comp(z)]

  intrinsic(<x, intrinsic(<y, z>)>)

  with the following computation containing one federated map or apply
  intrinsic:


            Call
           /    \
  Intrinsic      Tuple
                 |
                 [Block, Comp(z)]
                 /     \
       [fn=Tuple]       Lambda(arg)
           |                       \
   [Comp(y), Comp(x)]               Call
                                   /    \
                             Sel(1)      Call
                            /           /    \
                     Ref(fn)      Sel(0)      Ref(arg)
                                 /
                          Ref(fn)

  intrinsic(<(let fn=<y, x> in (arg -> fn[1](fn[0](arg)))), z>)

  The functional computations `x` and `y`, and the argument `z` are retained;
  the other computations are replaced.

  Args:
    comp: The computation building block in which to perform the merges.

  Returns:
    A new computation with the transformation applied or the original `comp`.

  Raises:
    TypeError: If types do not match.
  """
  py_typecheck.check_type(comp,
                          computation_building_blocks.ComputationBuildingBlock)

  def _should_transform(comp):
    """Returns `True` if `comp` is a chained federated map."""
    if _is_called_intrinsic(comp, (
        intrinsic_defs.FEDERATED_APPLY.uri,
        intrinsic_defs.FEDERATED_MAP.uri,
    )):
      outer_arg = comp.argument[1]
      if _is_called_intrinsic(outer_arg, comp.function.uri):
        return True
    return False

  def _transform(comp):
    """Returns a new transformed computation or `comp`."""
    if not _should_transform(comp):
      return comp, False

    def _create_block_to_chained_calls(comps):
      r"""Constructs a transformed block computation from `comps`.

                     Block
                    /     \
          [fn=Tuple]       Lambda(arg)
              |                       \
      [Comp(y), Comp(x)]               Call
                                      /    \
                                Sel(1)      Call
                               /           /    \
                        Ref(fn)      Sel(0)      Ref(arg)
                                    /
                             Ref(fn)

      (let fn=<y, x> in (arg -> fn[1](fn[0](arg)))

      Args:
        comps: A Python list of computations.

      Returns:
        A `computation_building_blocks.Block`.
      """
      functions = computation_building_blocks.Tuple(comps)
      fn_ref = computation_building_blocks.Reference('fn',
                                                     functions.type_signature)
      arg_type = comps[0].type_signature.parameter
      arg_ref = computation_building_blocks.Reference('arg', arg_type)
      arg = arg_ref
      for index, _ in enumerate(comps):
        fn_sel = computation_building_blocks.Selection(fn_ref, index=index)
        call = computation_building_blocks.Call(fn_sel, arg)
        arg = call
      lam = computation_building_blocks.Lambda(arg_ref.name,
                                               arg_ref.type_signature, call)
      return computation_building_blocks.Block([('fn', functions)], lam)

    block = _create_block_to_chained_calls((
        comp.argument[1].argument[0],
        comp.argument[0],
    ))
    arg = computation_building_blocks.Tuple([
        block,
        comp.argument[1].argument[1],
    ])
    intrinsic_type = computation_types.FunctionType(
        arg.type_signature, comp.function.type_signature.result)
    intrinsic = computation_building_blocks.Intrinsic(comp.function.uri,
                                                      intrinsic_type)
    transformed_comp = computation_building_blocks.Call(intrinsic, arg)
    return transformed_comp, True

  return transformation_utils.transform_postorder(comp, _transform)