Esempio n. 1
0
def remove_called_lambdas_and_blocks(comp):
  """Removes any called lambdas and blocks from `comp`.

  This function first resolves any higher-order functions, so that replacing
  called lambdas with blocks and then inlining the block locals cannot result
  in more called lambdas. It then performs this sequence of transformations,
  taking care to inline selections from tuples before inlining the rest of
  the block locals to prevent possible combinatorial growth of the generated
  AST.

  Args:
    comp: Instance of `building_blocks.ComputationBuildingBlock` from which we
      want to remove called lambdas and blocks.

  Returns:
    A transformed version of `comp` which has no called lambdas or blocks, and
    no extraneous selections from tuples.
  """
  py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
  comp, names_uniquified = tree_transformations.uniquify_reference_names(comp)
  comp, fns_resolved = tree_transformations.resolve_higher_order_functions(comp)
  comp, lambdas_replaced = tree_transformations.replace_called_lambda_with_block(
      comp)
  if fns_resolved or lambdas_replaced:
    comp, _ = tree_transformations.uniquify_reference_names(comp)
  comp, sels_removed = tree_transformations.inline_selections_from_tuple(comp)
  if sels_removed:
    comp, _ = tree_transformations.uniquify_reference_names(comp)
  comp, locals_inlined = tree_transformations.inline_block_locals(comp)
  modified = names_uniquified or fns_resolved or lambdas_replaced or sels_removed or locals_inlined
  return comp, modified
Esempio n. 2
0
def remove_called_lambdas_and_blocks(comp):
    """Removes any called lambdas and blocks from `comp`.

  This function first resolves any higher-order functions, so that replacing
  called lambdas with blocks and then inlining the block locals cannot result
  in more called lambdas. It then performs this sequence of transformations,
  taking care to inline selections from tuples at appropriate stages to prevent
  possible combinatorial growth of the generated AST.

  Args:
    comp: Instance of `building_blocks.ComputationBuildingBlock` from which we
      want to remove called lambdas and blocks.

  Returns:
    A transformed version of `comp` which has no called lambdas or blocks, and
    no extraneous selections from tuples.
  """
    py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
    comp, names_uniquified = tree_transformations.uniquify_reference_names(
        comp)
    # TODO(b/162888191): Remove this gating when `resolve_higher_order_functions`
    # is more efficient, or avoided.
    if _contains_higher_order_fns(comp):
        # `resolve_higher_order_functions` can be expensive, so we only call into it
        # when necessary.
        comp, fns_resolved = tree_transformations.resolve_higher_order_functions(
            comp)
    else:
        # We must still inline any functional references. We first inline selections
        # from tuples to prevent the AST from becoming unnecessarly large.
        comp, sels_removed = tree_transformations.inline_selections_from_tuple(
            comp)
        if sels_removed:
            comp, _ = tree_transformations.uniquify_reference_names(comp)
        comp, locals_inlined = tree_transformations.inline_block_locals(comp)
        fns_resolved = sels_removed or locals_inlined
    comp, lambdas_replaced = tree_transformations.replace_called_lambda_with_block(
        comp)
    if fns_resolved or lambdas_replaced:
        comp, _ = tree_transformations.uniquify_reference_names(comp)
    comp, sels_removed = tree_transformations.inline_selections_from_tuple(
        comp)
    if sels_removed:
        comp, _ = tree_transformations.uniquify_reference_names(comp)
    comp, locals_inlined = tree_transformations.inline_block_locals(comp)
    if locals_inlined:
        # Inlining local symbols may reintroduce selection-from-tuple pattern,
        # combinatorially increasing build times in the worst case. We ensure
        # here that remove_called_lambdas_and_blocks respects the postcondition that
        # selections from tuples are always collapsed.
        comp, _ = tree_transformations.inline_selections_from_tuple(comp)
        comp, _ = tree_transformations.uniquify_reference_names(comp)
    modified = names_uniquified or fns_resolved or lambdas_replaced or sels_removed or locals_inlined
    return comp, modified
  def test_returns_tree(self):
    ip = get_iterative_process_for_sum_example_with_no_federated_aggregate()
    next_tree = building_blocks.ComputationBuildingBlock.from_proto(
        ip.next._computation_proto)

    before_aggregate, after_aggregate = canonical_form_utils._create_before_and_after_aggregate_for_no_federated_aggregate(
        next_tree)

    before_federated_secure_sum, after_federated_secure_sum = (
        transformations.force_align_and_split_by_intrinsics(
            next_tree, [intrinsic_defs.FEDERATED_SECURE_SUM.uri]))
    self.assertIsInstance(before_aggregate, building_blocks.Lambda)
    self.assertIsInstance(before_aggregate.result, building_blocks.Tuple)
    self.assertLen(before_aggregate.result, 2)

    # pyformat: disable
    self.assertEqual(
        before_aggregate.result[0].formatted_representation(),
        '<\n'
        '  federated_value_at_clients(<>),\n'
        '  <>,\n'
        '  (_var1 -> <>),\n'
        '  (_var2 -> <>),\n'
        '  (_var3 -> <>)\n'
        '>'
    )
    # pyformat: enable

    self.assertEqual(
        before_aggregate.result[1].formatted_representation(),
        before_federated_secure_sum.result.formatted_representation())

    self.assertIsInstance(after_aggregate, building_blocks.Lambda)
    self.assertIsInstance(after_aggregate.result, building_blocks.Call)
    actual_tree, _ = tree_transformations.uniquify_reference_names(
        after_aggregate.result.function)
    expected_tree, _ = tree_transformations.uniquify_reference_names(
        after_federated_secure_sum)
    self.assertEqual(actual_tree.formatted_representation(),
                     expected_tree.formatted_representation())

    # pyformat: disable
    self.assertEqual(
        after_aggregate.result.argument.formatted_representation(),
        '<\n'
        '  _var4[0],\n'
        '  _var4[1][1]\n'
        '>'
    )
Esempio n. 4
0
def get_map_reduce_form_for_iterative_process(
    ip: iterative_process.IterativeProcess,
    grappler_config: tf.compat.v1.ConfigProto = _GRAPPLER_DEFAULT_CONFIG
) -> forms.MapReduceForm:
    """Constructs `tff.backends.mapreduce.MapReduceForm` given iterative process.

  Args:
    ip: An instance of `tff.templates.IterativeProcess` that is compatible with
      MapReduce form. Iterative processes are only compatible if `initialize_fn`
      returns a single federated value placed at `SERVER` and `next` takes
      exactly two arguments. The first must be the state value placed at
      `SERVER`. - `next` returns exactly two values.
    grappler_config: An optional instance of `tf.compat.v1.ConfigProto` to
      configure Grappler graph optimization of the TensorFlow graphs backing the
      resulting `tff.backends.mapreduce.MapReduceForm`. These options are
      combined with a set of defaults that aggressively configure Grappler. If
      the input `grappler_config` has
      `graph_options.rewrite_options.disable_meta_optimizer=True`, Grappler is
      bypassed.

  Returns:
    An instance of `tff.backends.mapreduce.MapReduceForm` equivalent to the
    provided `tff.templates.IterativeProcess`.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.MapReduceFormCompilationError: If the compilation
      process fails.
  """
    py_typecheck.check_type(ip, iterative_process.IterativeProcess)
    initialize_bb, next_bb = (
        check_iterative_process_compatible_with_map_reduce_form(ip))
    py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto)
    grappler_config = _merge_grappler_config_with_default(grappler_config)

    next_bb, _ = tree_transformations.uniquify_reference_names(next_bb)
    before_broadcast, after_broadcast = _split_ast_on_broadcast(next_bb)
    before_aggregate, after_aggregate = _split_ast_on_aggregate(
        after_broadcast)

    initialize = transformations.consolidate_and_extract_local_processing(
        initialize_bb, grappler_config)
    prepare = _extract_prepare(before_broadcast, grappler_config)
    work = _extract_work(before_aggregate, grappler_config)
    zero, accumulate, merge, report = _extract_federated_aggregate_functions(
        before_aggregate, grappler_config)
    bitwidth = _extract_federated_secure_sum_bitwidth_functions(
        before_aggregate, grappler_config)
    update = _extract_update(after_aggregate, grappler_config)

    next_parameter_names = structure.name_list_with_nones(
        ip.next.type_signature.parameter)
    server_state_label, client_data_label = next_parameter_names
    comps = (computation_wrapper_instances.building_block_to_computation(bb)
             for bb in (initialize, prepare, work, zero, accumulate, merge,
                        report, bitwidth, update))
    return forms.MapReduceForm(*comps,
                               server_state_label=server_state_label,
                               client_data_label=client_data_label)
Esempio n. 5
0
def concatenate_function_outputs(first_function, second_function):
    """Constructs a new function concatenating the outputs of its arguments.

  Assumes that `first_function` and `second_function` already have unique
  names, and have declared parameters of the same type. The constructed
  function will bind its parameter to each of the parameters of
  `first_function` and `second_function`, and return the result of executing
  these functions in parallel and concatenating the outputs in a tuple.

  Args:
    first_function: Instance of `building_blocks.Lambda` whose result we wish to
      concatenate with the result of `second_function`.
    second_function: Instance of `building_blocks.Lambda` whose result we wish
      to concatenate with the result of `first_function`.

  Returns:
    A new instance of `building_blocks.Lambda` with unique names representing
    the computation described above.

  Raises:
    TypeError: If the arguments are not instances of `building_blocks.Lambda`,
    or declare parameters of different types.
  """

    py_typecheck.check_type(first_function, building_blocks.Lambda)
    py_typecheck.check_type(second_function, building_blocks.Lambda)
    tree_analysis.check_has_unique_names(first_function)
    tree_analysis.check_has_unique_names(second_function)

    if first_function.parameter_type != second_function.parameter_type:
        raise TypeError(
            'Must pass two functions which declare the same parameter '
            'type to `concatenate_function_outputs`; you have passed '
            'one function which declared a parameter of type {}, and '
            'another which declares a parameter of type {}'.format(
                first_function.type_signature, second_function.type_signature))

    def _rename_first_function_arg(comp):
        if comp.is_reference() and comp.name == first_function.parameter_name:
            if comp.type_signature != second_function.parameter_type:
                raise AssertionError('{}, {}'.format(
                    comp.type_signature, second_function.parameter_type))
            return building_blocks.Reference(second_function.parameter_name,
                                             comp.type_signature), True
        return comp, False

    first_function, _ = transformation_utils.transform_postorder(
        first_function, _rename_first_function_arg)

    concatenated_function = building_blocks.Lambda(
        second_function.parameter_name, second_function.parameter_type,
        building_blocks.Struct([first_function.result,
                                second_function.result]))

    renamed, _ = tree_transformations.uniquify_reference_names(
        concatenated_function)

    return renamed
Esempio n. 6
0
    def _remove_functional_symbol_bindings(comp):
        """Removes symbol bindings which contain functional types."""

        comp, refs_renamed = tree_transformations.uniquify_reference_names(
            comp)
        comp, lambdas_replaced = tree_transformations.replace_called_lambda_with_block(
            comp)
        comp, selections_inlined = tree_transformations.inline_selections_from_tuple(
            comp)
        if selections_inlined:
            comp, _ = tree_transformations.uniquify_reference_names(comp)
        comp, functions_inlined = _inline_functions(comp)
        comp, locals_removed = tree_transformations.remove_unused_block_locals(
            comp)

        modified = (refs_renamed or lambdas_replaced or selections_inlined
                    or functions_inlined or locals_removed)
        return comp, modified
Esempio n. 7
0
def _force_align_intrinsics_to_top_level_lambda(comp, uri):
    """Forcefully aligns `comp` by the intrinsics for the given `uri`.

  This function transforms `comp` by extracting, grouping, and potentially
  merging all the intrinsics for the given `uri`. The result of this
  transformation should contain exactly one instance of the intrinsic for the
  given `uri` that is bound only by the `parameter_name` of `comp`.

  Args:
    comp: The `building_blocks.Lambda` to align.
    uri: A Python `list` of URI of intrinsics.

  Returns:
    A new computation with the transformation applied or the original `comp`.
  """
    py_typecheck.check_type(comp, building_blocks.Lambda)
    py_typecheck.check_type(uri, list)
    for x in uri:
        py_typecheck.check_type(x, str)

    comp, _ = tree_transformations.uniquify_reference_names(comp)
    if not _can_extract_intrinsics_to_top_level_lambda(comp, uri):
        comp, _ = tree_transformations.replace_called_lambda_with_block(comp)
    comp = _inline_block_variables_required_to_align_intrinsics(comp, uri)
    comp, modified = _extract_intrinsics_to_top_level_lambda(comp, uri)
    if modified:
        if len(uri) > 1:
            comp, _ = _group_by_intrinsics_in_top_level_lambda(comp)
        modified = False
        for intrinsic_uri in uri:
            comp, transform_modified = transformations.dedupe_and_merge_tuple_intrinsics(
                comp, intrinsic_uri)
            if transform_modified:
                # Required because merging called intrinsics invokes building block
                # factories that do not name references uniquely.
                comp, _ = tree_transformations.uniquify_reference_names(comp)
            modified = modified or transform_modified
        if modified:
            # Required because merging called intrinsics will nest the called
            # intrinsics such that they can no longer be split.
            comp, _ = _extract_intrinsics_to_top_level_lambda(comp, uri)
    return comp
Esempio n. 8
0
def _inline_block_variables_required_to_align_intrinsics(comp, uri):
    """Inlines the variables required to align the intrinsic for the given `uri`.

  This function inlines only the block variables required to align an intrinsic,
  which is necessary because many transformations insert block variables that do
  not impact alignment and should not be inlined.

  Additionally, this function iteratively attempts to inline block variables a
  long as the intrinsic can not be extracted to the top level lambda. Meaning,
  that unbound references in variables that are inlined, will also be inlined.

  Args:
    comp: The `building_blocks.Lambda` to transform.
    uri: A Python `list` of URI of intrinsics.

  Returns:
    A new computation with the transformation applied or the original `comp`.

  Raises:
    ValueError: If an there are unbound references, other than block variables,
      preventing an intrinsic with the given `uri` from being aligned.
  """
    py_typecheck.check_type(comp, building_blocks.Lambda)
    py_typecheck.check_type(uri, list)
    for x in uri:
        py_typecheck.check_type(x, str)

    while not _can_extract_intrinsics_to_top_level_lambda(comp, uri):
        unbound_references = transformation_utils.get_map_of_unbound_references(
            comp)
        variable_names = set()
        intrinsics = _get_called_intrinsics(comp, uri)
        for intrinsic in intrinsics:
            names = unbound_references[intrinsic]
            names.discard(comp.parameter_name)
            variable_names.update(names)
        if not variable_names:
            raise tree_transformations.TransformationError(
                'Inlining `Block` variables has failed. Expected to find unbound '
                'references for called `Intrisic`s matching the URI: \'{}\', but '
                'none were found in the AST: \n{}'.format(
                    uri, comp.formatted_representation()))
        comp, modified = tree_transformations.inline_block_locals(
            comp, variable_names=variable_names)
        if modified:
            comp, _ = tree_transformations.uniquify_reference_names(comp)
        else:
            raise tree_transformations.TransformationError(
                'Inlining `Block` variables has failed, this will result in an '
                'infinite loop. Expected to modify the AST by inlining the variable '
                'names: \'{}\', but no transformations to the AST: \n{}'.
                format(variable_names, comp.formatted_representation()))
    return comp
Esempio n. 9
0
    def test_keeps_existing_nonoverlapping_names(self):
        data = building_blocks.Data('data', tf.int32)
        block = building_blocks.Block([('a', data), ('b', data)], data)
        comp = block

        transformed_comp, modified = tree_transformations.uniquify_reference_names(
            comp)

        self.assertEqual(block.compact_representation(),
                         '(let a=data,b=data in data)')
        self.assertEqual(transformed_comp.compact_representation(),
                         '(let a=data,b=data in data)')
        self.assertFalse(modified)
Esempio n. 10
0
    def test_single_level_block(self):
        ref = building_blocks.Reference('a', tf.int32)
        data = building_blocks.Data('data', tf.int32)
        block = building_blocks.Block((('a', data), ('a', ref), ('a', ref)),
                                      ref)

        transformed_comp, modified = tree_transformations.uniquify_reference_names(
            block)

        self.assertEqual(block.compact_representation(),
                         '(let a=data,a=a,a=a in a)')
        self.assertEqual(transformed_comp.compact_representation(),
                         '(let a=data,_var1=a,_var2=_var1 in _var2)')
        tree_analysis.check_has_unique_names(transformed_comp)
        self.assertTrue(modified)
Esempio n. 11
0
    def test_nested_blocks(self):
        x_ref = building_blocks.Reference('a', tf.int32)
        data = building_blocks.Data('data', tf.int32)
        block1 = building_blocks.Block([('a', data), ('a', x_ref)], x_ref)
        block2 = building_blocks.Block([('a', data), ('a', x_ref)], block1)

        transformed_comp, modified = tree_transformations.uniquify_reference_names(
            block2)

        self.assertEqual(block2.compact_representation(),
                         '(let a=data,a=a in (let a=data,a=a in a))')
        self.assertEqual(
            transformed_comp.compact_representation(),
            '(let a=data,_var1=a in (let _var2=data,_var3=_var2 in _var3))')
        tree_analysis.check_has_unique_names(transformed_comp)
        self.assertTrue(modified)
Esempio n. 12
0
    def test_renames_lambda_but_not_unbound_reference_when_given_name_generator(
            self):
        ref = building_blocks.Reference('x', tf.int32)
        lambda_binding_y = building_blocks.Lambda('y', tf.float32, ref)

        name_generator = building_block_factory.unique_name_generator(
            lambda_binding_y)
        transformed_comp, modified = tree_transformations.uniquify_reference_names(
            lambda_binding_y, name_generator)

        self.assertEqual(lambda_binding_y.compact_representation(), '(y -> x)')
        self.assertEqual(transformed_comp.compact_representation(),
                         '(_var1 -> x)')
        self.assertEqual(transformed_comp.type_signature,
                         lambda_binding_y.type_signature)
        self.assertTrue(modified)
Esempio n. 13
0
def prepare_for_rebinding(comp):
    """Prepares `comp` for extracting rebound variables.

  Currently, this means replacing all called lambdas and inlining all blocks.
  This does not necessarly guarantee that the resulting computation has no
  called lambdas, it merely reduces a level of indirection here. This reduction
  has proved sufficient for identifying variables which are about to be rebound
  in the top-level lambda, necessarily when compiler components factor work out
  from a single function into multiple functions. Since this function makes no
  guarantees about sufficiency, it is the responsibility of the caller to
  ensure that no unbound variables are introduced during the rebinding.

  Args:
    comp: Instance of `building_blocks.ComputationBuildingBlock` from which all
      occurrences of a given variable need to be extracted and rebound.

  Returns:
    Another instance of `building_blocks.ComputationBuildingBlock` which has
    had all called lambdas replaced by blocks, all blocks inlined and all
    selections from tuples collapsed.
  """
    # TODO(b/146430051): Follow up here and consider removing or enforcing more
    # strict output invariants when `remove_called_lambdas_and_blocks` is moved
    # in here.
    py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
    comp, _ = tree_transformations.uniquify_reference_names(comp)
    comp, _ = tree_transformations.replace_called_lambda_with_block(comp)
    block_inliner = tree_transformations.InlineBlock(comp)
    selection_replacer = tree_transformations.ReplaceSelectionFromTuple()
    transforms = [block_inliner, selection_replacer]
    symbol_tree = transformation_utils.SymbolTree(
        transformation_utils.ReferenceCounter)

    def _transform_fn(comp, symbol_tree):
        """Transform function chaining inlining and collapsing selections."""
        modified = False
        for transform in transforms:
            if transform.global_transform:
                comp, transform_modified = transform.transform(
                    comp, symbol_tree)
            else:
                comp, transform_modified = transform.transform(comp)
            modified = modified or transform_modified
        return comp, modified

    return transformation_utils.transform_postorder_with_symbol_bindings(
        comp, _transform_fn, symbol_tree)
Esempio n. 14
0
    def test_nested_lambdas(self):
        data = building_blocks.Data('data', tf.int32)
        input1 = building_blocks.Reference('a', data.type_signature)
        first_level_call = building_blocks.Call(
            building_blocks.Lambda('a', input1.type_signature, input1), data)
        input2 = building_blocks.Reference('b',
                                           first_level_call.type_signature)
        second_level_call = building_blocks.Call(
            building_blocks.Lambda('b', input2.type_signature, input2),
            first_level_call)

        transformed_comp, modified = tree_transformations.uniquify_reference_names(
            second_level_call)

        self.assertEqual(transformed_comp.compact_representation(),
                         '(b -> b)((a -> a)(data))')
        tree_analysis.check_has_unique_names(transformed_comp)
        self.assertFalse(modified)
Esempio n. 15
0
def generate_tensorflow_for_local_function(comp):
    """Generates TensorFlow for a local TFF computation.

  This function performs a deduplication of function invocations
  according to `tree_analysis.trees_equal`, and hence may reduce the number
  of calls under `comp`.

  Args:
    comp: Instance of `building_blocks.ComputationBuildingBlock` for which we
      wish to generate TensorFlow.

  Returns:
    Either a called instance of `building_blocks.CompiledComputation` or a
    `building_blocks.CompiledComputation` itself, depending on whether `comp`
    is of non-functional or functional type respectively. Additionally, returns
    a boolean to match the `transformation_utils.TransformSpec` pattern.
  """
    py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
    names_uniquified, _ = tree_transformations.uniquify_reference_names(comp)
    comp = transform_to_local_call_dominant(names_uniquified)

    def _package_as_deduplicated_block(inner_comp):
        repacked_block, _ = tree_transformations.remove_duplicate_block_locals(
            inner_comp)
        if not repacked_block.is_block():
            repacked_block = building_blocks.Block([], repacked_block)
        return repacked_block

    if comp.is_lambda():
        repacked_block = _package_as_deduplicated_block(comp.result)
        tf_generated, _ = create_tensorflow_representing_block(repacked_block)
        tff_func = building_blocks.Lambda(comp.parameter_name,
                                          comp.parameter_type, tf_generated)
        tf_parser_callable = tree_to_cc_transformations.TFParser()
        tf_generated, _ = transformation_utils.transform_postorder(
            tff_func, tf_parser_callable)
    else:
        repacked_block = _package_as_deduplicated_block(comp)
        tf_generated, _ = create_tensorflow_representing_block(repacked_block)
    return tf_generated, True
Esempio n. 16
0
    def test_block_lambda_block_lambda(self):
        x_ref = building_blocks.Reference('a', tf.int32)
        inner_lambda = building_blocks.Lambda('a', tf.int32, x_ref)
        called_lambda = building_blocks.Call(inner_lambda, x_ref)
        lower_block = building_blocks.Block([('a', x_ref), ('a', x_ref)],
                                            called_lambda)
        second_lambda = building_blocks.Lambda('a', tf.int32, lower_block)
        second_call = building_blocks.Call(second_lambda, x_ref)
        data = building_blocks.Data('data', tf.int32)
        last_block = building_blocks.Block([('a', data), ('a', x_ref)],
                                           second_call)

        transformed_comp, modified = tree_transformations.uniquify_reference_names(
            last_block)

        self.assertEqual(
            last_block.compact_representation(),
            '(let a=data,a=a in (a -> (let a=a,a=a in (a -> a)(a)))(a))')
        self.assertEqual(
            transformed_comp.compact_representation(),
            '(let a=data,_var1=a in (_var2 -> (let _var3=_var2,_var4=_var3 in (_var5 -> _var5)(_var4)))(_var1))'
        )
        tree_analysis.check_has_unique_names(transformed_comp)
        self.assertTrue(modified)
Esempio n. 17
0
def generate_tensorflow_for_local_computation(comp):
    """Generates TensorFlow for a local TFF computation.

  This function performs a deduplication of function invocations
  according to `tree_analysis.trees_equal`, and hence may reduce the number
  of calls under `comp`.

  We assume `comp` has type which can be represented by either a call to a
  no-arg `building_blocks.CompiledComputation` of type `tensorflow`, or such a
  `building_blocks.CompiledComputation` itself. That is, the type signature of
  `comp` must be either a potentially nested structure of
  `computation_types.TensorType`s and `computation_types.SequenceType`s, or a
  function whose parameter and return types are such potentially nested
  structures.

  Further, we assume that there are no intrinsic or data building blocks inside
  `comp`.

  Args:
    comp: Instance of `building_blocks.ComputationBuildingBlock` for which we
      wish to generate TensorFlow.

  Returns:
    Either a called instance of `building_blocks.CompiledComputation` or a
    `building_blocks.CompiledComputation` itself, depending on whether `comp`
    is of non-functional or functional type respectively. Additionally, returns
    a boolean to match the `transformation_utils.TransformSpec` pattern.
  """
    py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
    names_uniquified, _ = tree_transformations.uniquify_reference_names(comp)
    # We ensure the argument to `transform_to_local_call_dominant` is a Lambda, as
    # required.
    lambda_wrapping_comp = building_blocks.Lambda(None, None, names_uniquified)
    # CFG for local CDF plus the type of `lambda_wrapping_comp` imply result must
    # be another no-arg lambda.
    local_cdf_comp = transform_to_local_call_dominant(
        lambda_wrapping_comp).result

    def _package_as_deduplicated_block(inner_comp):
        repacked_block, _ = tree_transformations.remove_duplicate_block_locals(
            inner_comp)
        if not repacked_block.is_block():
            repacked_block = building_blocks.Block([], repacked_block)
        return repacked_block

    if local_cdf_comp.type_signature.is_function():
        # The CFG for local call dominant tells us that the following patterns are
        # possible for a functional computation respecting the structural
        # restrictions we require for `comp`:
        #   1. CompiledComputation
        #   2. Block(bindings, CompiledComp)
        #   3. Block(bindings, Lambda(non-functional result with at most one Block))
        #   4. Lambda(non-functional result with at most one Block)
        if local_cdf_comp.is_compiled_computation():
            # Case 1.
            return local_cdf_comp, not comp.is_compiled_computation()
        elif local_cdf_comp.is_block():
            if local_cdf_comp.result.is_compiled_computation():
                # Case 2. The bindings in `comp` cannot be referenced in `comp.result`;
                # we may return it directly.
                return local_cdf_comp.result, True
            elif local_cdf_comp.result.is_lambda():
                # Case 3. We reduce to case 4 and pass through.
                local_cdf_comp = building_blocks.Lambda(
                    local_cdf_comp.result.parameter_name,
                    local_cdf_comp.result.parameter_type,
                    building_blocks.Block(local_cdf_comp.locals,
                                          local_cdf_comp.result.result))
                # Reduce potential chain of blocks.
                local_cdf_comp, _ = tree_transformations.merge_chained_blocks(
                    local_cdf_comp)
                # This fall-through is intended, since we have merged with case 4.
        if local_cdf_comp.is_lambda():
            # Case 4.
            repacked_block = _package_as_deduplicated_block(
                local_cdf_comp.result)
            tf_generated, _ = create_tensorflow_representing_block(
                repacked_block)
            tff_func = building_blocks.Lambda(local_cdf_comp.parameter_name,
                                              local_cdf_comp.parameter_type,
                                              tf_generated)
            tf_parser_callable = tree_to_cc_transformations.TFParser()
            tf_generated, _ = transformation_utils.transform_postorder(
                tff_func, tf_parser_callable)
        else:
            raise tree_transformations.TransformationError(
                'Unexpected structure encountered for functional computation in '
                'local call-dominant form: \n'
                f'{local_cdf_comp.formatted_representation()}')
    else:
        # The CFG for local call dominant tells us no lambdas or blocks may be
        # present under `comp` for non-functional types which can be represented in
        # TensorFlow (in particular, structures of functions are disallowed by this
        # restriction). So we may package as a block directly.
        repacked_block = _package_as_deduplicated_block(local_cdf_comp)
        tf_generated, _ = create_tensorflow_representing_block(repacked_block)
    return tf_generated, True
 def transformation_fn(x):
   x, _ = tree_transformations.uniquify_reference_names(x)
   x, _ = tree_transformations.remove_mapped_or_applied_identity(x)
   x = transformations.to_call_dominant(x)
   return x
Esempio n. 19
0
def _extract_update(after_aggregate, grappler_config):
    """Extracts `update` from `after_aggregate`.

  This function is intended to be used by
  `get_map_reduce_form_for_iterative_process` only. As a result, this function
  does not assert that `after_aggregate` has the expected structure, the
  caller is expected to perform these checks before calling this function.

  Args:
    after_aggregate: The second result of splitting `after_broadcast` on
      aggregate intrinsics.
    grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure
      Grappler graph optimization.

  Returns:
    `update` as specified by `forms.MapReduceForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    compiler.MapReduceFormCompilationError: If we extract an AST of the wrong
      type.
  """
    after_aggregate_zipped = building_blocks.Lambda(
        after_aggregate.parameter_name, after_aggregate.parameter_type,
        building_block_factory.create_federated_zip(after_aggregate.result))
    # `create_federated_zip` doesn't have unique reference names, but we need
    # them for `as_function_of_some_federated_subparameters`.
    after_aggregate_zipped, _ = tree_transformations.uniquify_reference_names(
        after_aggregate_zipped)
    server_state_index = ('original_arg', 'original_arg', 0)
    aggregate_result_index = ('intrinsic_results',
                              'federated_aggregate_result')
    secure_sum_bitwidth_result_index = ('intrinsic_results',
                                        'federated_secure_sum_bitwidth_result')
    secure_sum_result_index = ('intrinsic_results',
                               'federated_secure_sum_result')
    secure_modular_sum_result_index = ('intrinsic_results',
                                       'federated_secure_modular_sum_result')
    update_with_flat_inputs = _as_function_of_some_federated_subparameters(
        after_aggregate_zipped, (
            server_state_index,
            aggregate_result_index,
            secure_sum_bitwidth_result_index,
            secure_sum_result_index,
            secure_modular_sum_result_index,
        ))

    # TODO(b/148942011): The transformation
    # `zip_selection_as_argument_to_lower_level_lambda` does not support selecting
    # from nested structures, therefore we need to transform the input from
    # <server_state, <aggregation_results...>> into
    # <server_state, aggregation_results...>
    # unpack = <v, <...>> -> <v, ...>
    name_generator = building_block_factory.unique_name_generator(
        update_with_flat_inputs)
    unpack_param_name = next(name_generator)
    original_param_type = update_with_flat_inputs.parameter_type.member
    unpack_param_type = computation_types.StructType([
        original_param_type[0],
        computation_types.StructType(original_param_type[1:]),
    ])
    unpack_param_ref = building_blocks.Reference(unpack_param_name,
                                                 unpack_param_type)
    select = lambda bb, i: building_blocks.Selection(bb, index=i)
    unpack = building_blocks.Lambda(
        unpack_param_name, unpack_param_type,
        building_blocks.Struct([select(unpack_param_ref, 0)] + [
            select(select(unpack_param_ref, 1), i)
            for i in range(len(original_param_type) - 1)
        ]))

    # update = v -> update_with_flat_inputs(federated_map(unpack, v))
    param_name = next(name_generator)
    param_type = computation_types.at_server(unpack_param_type)
    param_ref = building_blocks.Reference(param_name, param_type)
    update = building_blocks.Lambda(
        param_name, param_type,
        building_blocks.Call(
            update_with_flat_inputs,
            building_block_factory.create_federated_map_or_apply(
                unpack, param_ref)))
    return compiler.consolidate_and_extract_local_processing(
        update, grappler_config)
Esempio n. 20
0
def _merge_args(
    abstract_parameter_type,
    args: List[building_blocks.ComputationBuildingBlock],
    name_generator,
) -> building_blocks.ComputationBuildingBlock:
    """Merges the arguments of multiple function invocations into one.

  Args:
    abstract_parameter_type: The abstract parameter type specification for the
      function being invoked. This is used to determine whether any functional
      parameters accept multiple arguments.
    args: A list where each element contains the arguments to a single call.
    name_generator: A generator used to create unique names.

  Returns:
    A building block to use as the new (merged) argument.
  """
    if abstract_parameter_type.is_federated():
        zip_args = building_block_factory.create_federated_zip(
            building_blocks.Struct(args))
        # `create_federated_zip` introduces repeated names.
        zip_args, _ = tree_transformations.uniquify_reference_names(
            zip_args, name_generator)
        return zip_args
    if (abstract_parameter_type.is_tensor()
            or abstract_parameter_type.is_abstract()):
        return building_blocks.Struct([(None, arg) for arg in args])
    if abstract_parameter_type.is_function():
        # For functions, we must compose them differently depending on whether the
        # abstract function (from the intrinsic definition) takes more than one
        # parameter.
        #
        # If it does not, such as in the `fn` argument to `federated_map`, we can
        # simply select out the argument and call the result:
        # `(fn0(arg[0]), fn1(arg[1]), ..., fnN(arg[n]))`
        #
        # If it takes multiple arguments such as the `accumulate` argument to
        # `federated_aggregate`, we have to select out the individual arguments to
        # pass to each function:
        #
        # `(
        #   fn0(arg[0][0], arg[1][0]),
        #   fn1(arg[0][1], arg[1][1]),
        #   ...
        #   fnN(arg[0][n], arg[1][n]),
        # )`
        param_name = next(name_generator)
        if abstract_parameter_type.parameter.is_struct():
            num_args = len(abstract_parameter_type.parameter)
            parameter_types = [[] for i in range(num_args)]
            for arg in args:
                for i in range(num_args):
                    parameter_types[i].append(arg.type_signature.parameter[i])
            param_type = computation_types.StructType(parameter_types)
            param_ref = building_blocks.Reference(param_name, param_type)
            calls = []
            for (n, fn) in enumerate(args):
                args_to_fn = []
                for i in range(num_args):
                    args_to_fn.append(
                        building_blocks.Selection(building_blocks.Selection(
                            param_ref, index=i),
                                                  index=n))
                calls.append(
                    building_blocks.Call(
                        fn,
                        building_blocks.Struct([(None, arg)
                                                for arg in args_to_fn])))
        else:
            param_type = computation_types.StructType(
                [arg.type_signature.parameter for arg in args])
            param_ref = building_blocks.Reference(param_name, param_type)
            calls = [
                building_blocks.Call(
                    fn, building_blocks.Selection(param_ref, index=n))
                for (n, fn) in enumerate(args)
            ]
        return building_blocks.Lambda(parameter_name=param_name,
                                      parameter_type=param_type,
                                      result=building_blocks.Struct([
                                          (None, call) for call in calls
                                      ]))
    if abstract_parameter_type.is_struct():
        # Bind each argument to a name so that we can reference them multiple times.
        arg_locals = []
        arg_refs = []
        for arg in args:
            arg_name = next(name_generator)
            arg_locals.append((arg_name, arg))
            arg_refs.append(
                building_blocks.Reference(arg_name, arg.type_signature))
        merged_args = []
        for i in range(len(abstract_parameter_type)):
            ith_args = [
                building_blocks.Selection(ref, index=i) for ref in arg_refs
            ]
            merged_args.append(
                _merge_args(abstract_parameter_type[i], ith_args,
                            name_generator))
        return building_blocks.Block(
            arg_locals,
            building_blocks.Struct([(None, arg) for arg in merged_args]))
    raise TypeError(f'Cannot merge args of type: {abstract_parameter_type}')
Esempio n. 21
0
    def test_returns_tree(self):
        ip = get_iterative_process_for_sum_example_with_no_federated_aggregate(
        )
        next_tree = building_blocks.ComputationBuildingBlock.from_proto(
            ip.next._computation_proto)

        before_aggregate, after_aggregate = canonical_form_utils._create_before_and_after_aggregate_for_no_federated_aggregate(
            next_tree)

        before_federated_secure_sum, after_federated_secure_sum = (
            transformations.force_align_and_split_by_intrinsics(
                next_tree, [intrinsic_defs.FEDERATED_SECURE_SUM.uri]))
        self.assertIsInstance(before_aggregate, building_blocks.Lambda)
        self.assertIsInstance(before_aggregate.result, building_blocks.Struct)
        self.assertLen(before_aggregate.result, 2)

        # pyformat: disable
        self.assertEqual(
            before_aggregate.result[0].formatted_representation(), '<\n'
            '  federated_value_at_clients(<>),\n'
            '  <>,\n'
            '  (_var1 -> <>),\n'
            '  (_var2 -> <>),\n'
            '  (_var3 -> <>)\n'
            '>')
        # pyformat: enable

        # trees_equal will fail if computations refer to unbound references, so we
        # create a new dummy computation to bind them.
        unbound_refs_in_before_agg_result = transformation_utils.get_map_of_unbound_references(
            before_aggregate.result[1])[before_aggregate.result[1]]
        unbound_refs_in_before_secure_sum_result = transformation_utils.get_map_of_unbound_references(
            before_federated_secure_sum.result)[
                before_federated_secure_sum.result]

        dummy_data = building_blocks.Data('data',
                                          computation_types.AbstractType('T'))

        blk_binding_refs_in_before_agg = building_blocks.Block(
            [(name, dummy_data) for name in unbound_refs_in_before_agg_result],
            before_aggregate.result[1])
        blk_binding_refs_in_before_secure_sum = building_blocks.Block(
            [(name, dummy_data)
             for name in unbound_refs_in_before_secure_sum_result],
            before_federated_secure_sum.result)

        self.assertTrue(
            tree_analysis.trees_equal(blk_binding_refs_in_before_agg,
                                      blk_binding_refs_in_before_secure_sum))

        self.assertIsInstance(after_aggregate, building_blocks.Lambda)
        self.assertIsInstance(after_aggregate.result, building_blocks.Call)
        actual_after_aggregate_tree, _ = tree_transformations.uniquify_reference_names(
            after_aggregate.result.function)
        expected_after_aggregate_tree, _ = tree_transformations.uniquify_reference_names(
            after_federated_secure_sum)
        self.assertTrue(
            tree_analysis.trees_equal(actual_after_aggregate_tree,
                                      expected_after_aggregate_tree))

        # pyformat: disable
        self.assertEqual(
            after_aggregate.result.argument.formatted_representation(), '<\n'
            '  _var4[0],\n'
            '  _var4[1][1]\n'
            '>')
 def transformation_fn(x):
     x, _ = tree_transformations.uniquify_reference_names(x)
     x, _ = tree_transformations.inline_block_locals(x)
     x, _ = tree_transformations.remove_mapped_or_applied_identity(x)
     return x
Esempio n. 23
0
 def test_raises_type_error(self):
     with self.assertRaises(TypeError):
         tree_transformations.uniquify_reference_names(None)
Esempio n. 24
0
def _extract_update(after_aggregate, grappler_config):
    """Extracts `update` from `after_aggregate`.

  This function is intended to be used by
  `get_map_reduce_form_for_iterative_process` only. As a result, this function
  does not assert that `after_aggregate` has the expected structure, the
  caller is expected to perform these checks before calling this function.

  Args:
    after_aggregate: The second result of splitting `after_broadcast` on
      aggregate intrinsics.
    grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure
      Grappler graph optimization.

  Returns:
    `update` as specified by `forms.MapReduceForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    transformations.MapReduceFormCompilationError: If we extract an AST of the
      wrong type.
  """
    s7_elements_in_after_aggregate_result = [0, 1]
    s7_output_extracted = transformations.select_output_from_lambda(
        after_aggregate, s7_elements_in_after_aggregate_result)
    s7_output_zipped = building_blocks.Lambda(
        s7_output_extracted.parameter_name, s7_output_extracted.parameter_type,
        building_block_factory.create_federated_zip(
            s7_output_extracted.result))
    # `create_federated_zip` doesn't have unique reference names, but we need
    # them for `as_function_of_some_federated_subparameters`.
    s7_output_zipped, _ = tree_transformations.uniquify_reference_names(
        s7_output_zipped)
    s6_elements_in_after_aggregate_parameter = [(0, 0, 0), (1, 0), (1, 1)]
    s6_to_s7_computation = _as_function_of_some_federated_subparameters(
        s7_output_zipped, s6_elements_in_after_aggregate_parameter)

    # TODO(b/148942011): The transformation
    # `zip_selection_as_argument_to_lower_level_lambda` does not support selecting
    # from nested structures, therefore we need to pack the type signature
    # `<s1, s3, s4>` as `<s1, <s3, s4>>`.
    name_generator = building_block_factory.unique_name_generator(
        s6_to_s7_computation)

    pack_ref_name = next(name_generator)
    pack_ref_type = computation_types.StructType([
        s6_to_s7_computation.parameter_type.member[0],
        computation_types.StructType([
            s6_to_s7_computation.parameter_type.member[1],
            s6_to_s7_computation.parameter_type.member[2],
        ]),
    ])
    pack_ref = building_blocks.Reference(pack_ref_name, pack_ref_type)
    sel_s1 = building_blocks.Selection(pack_ref, index=0)
    sel = building_blocks.Selection(pack_ref, index=1)
    sel_s3 = building_blocks.Selection(sel, index=0)
    sel_s4 = building_blocks.Selection(sel, index=1)
    result = building_blocks.Struct([sel_s1, sel_s3, sel_s4])
    pack_fn = building_blocks.Lambda(pack_ref.name, pack_ref.type_signature,
                                     result)
    ref_name = next(name_generator)
    ref_type = computation_types.FederatedType(pack_ref_type,
                                               placements.SERVER)
    ref = building_blocks.Reference(ref_name, ref_type)
    unpacked_args = building_block_factory.create_federated_map_or_apply(
        pack_fn, ref)
    call = building_blocks.Call(s6_to_s7_computation, unpacked_args)
    fn = building_blocks.Lambda(ref.name, ref.type_signature, call)
    return transformations.consolidate_and_extract_local_processing(
        fn, grappler_config)
Esempio n. 25
0
 def transform(self, comp):
     return tree_transformations.uniquify_reference_names(comp)