Example #1
0
    def federated_zip(self, value):
        """Implements `federated_zip` as defined in `api/intrinsics.py`."""
        # TODO(b/113112108): Extend this to accept *args.

        # TODO(b/113112108): We use the iterate/unwrap approach below because
        # our type system is not powerful enough to express the concept of
        # "an operation that takes tuples of T of arbitrary length", and therefore
        # the intrinsic federated_zip must only take a fixed number of arguments,
        # here fixed at 2. There are other potential approaches to getting around
        # this problem (e.g. having the operator act on sequences and thereby
        # sidestepping the issue) which we may want to explore.
        value = value_impl.to_value(value, None, self._context_stack)
        py_typecheck.check_type(value, value_base.Value)
        py_typecheck.check_type(value.type_signature,
                                computation_types.StructType)

        value = value_impl.ValueImpl.get_comp(value)
        comp = building_block_factory.create_federated_zip(value)
        comp = self._bind_comp_as_reference(comp)
        return value_impl.ValueImpl(comp, self._context_stack)
Example #2
0
def extract_update(after_aggregate, canonical_form_types):
    """Converts `after_aggregate` to `update`.

  Args:
    after_aggregate: The second result of splitting `after_broadcast` on
      `intrinsic_defs.FEDERATED_AGGREGATE`.
    canonical_form_types: `dict` holding the `canonical_form.CanonicalForm` type
      signatures specified by the `tff.utils.IterativeProcess` we are compiling.

  Returns:
    `update` as specified by `canonical_form.CanonicalForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we extract an AST of the
      wrong type.
  """
    # See `get_iterative_process_for_canonical_form()` above for the meaning of
    # variable names used in the code below.
    s5_elements_in_after_aggregate_result = [0, 1]
    s5_output_extracted = transformations.select_output_from_lambda(
        after_aggregate, s5_elements_in_after_aggregate_result)
    s5_output_zipped = building_blocks.Lambda(
        s5_output_extracted.parameter_name, s5_output_extracted.parameter_type,
        building_block_factory.create_federated_zip(
            s5_output_extracted.result))
    s4_elements_in_after_aggregate_parameter = [[0, 0, 0], [1]]
    s4_to_s5_computation = (
        transformations.zip_selection_as_argument_to_lower_level_lambda(
            s5_output_zipped,
            s4_elements_in_after_aggregate_parameter).result.function)

    update = transformations.consolidate_and_extract_local_processing(
        s4_to_s5_computation)
    if update.type_signature != canonical_form_types['update_type']:
        raise transformations.CanonicalFormCompilationError(
            'Extracted a TF block of the wrong type. Expected a function with type '
            '{}, but the type signature of the TF block was {}'.format(
                canonical_form_types['update_type'], update.type_signature))
    return update
Example #3
0
def _extract_work(before_aggregate, grappler_config):
    """Extracts `work` from `before_aggregate`.

  This function is intended to be used by
  `get_canonical_form_for_iterative_process` only. As a result, this function
  does not assert that `before_aggregate` has the expected structure, the caller
  is expected to perform these checks before calling this function.

  Args:
    before_aggregate: The first result of splitting `after_broadcast` on
      aggregate intrinsics.
    grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure
      Grappler graph optimization.

  Returns:
    `work` as specified by `canonical_form.CanonicalForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we extract an AST of the
      wrong type.
  """
    c3_elements_in_before_aggregate_parameter = [[0, 1], [1]]
    c3_to_before_aggregate_computation = (
        transformations.zip_selection_as_argument_to_lower_level_lambda(
            before_aggregate,
            c3_elements_in_before_aggregate_parameter).result.function)
    c4_index_in_before_aggregate_result = [[0, 0], [1, 0]]
    c3_to_unzipped_c4_computation = transformations.select_output_from_lambda(
        c3_to_before_aggregate_computation,
        c4_index_in_before_aggregate_result)
    c3_to_c4_computation = building_blocks.Lambda(
        c3_to_unzipped_c4_computation.parameter_name,
        c3_to_unzipped_c4_computation.parameter_type,
        building_block_factory.create_federated_zip(
            c3_to_unzipped_c4_computation.result))

    return transformations.consolidate_and_extract_local_processing(
        c3_to_c4_computation, grappler_config)
def _extract_update(after_aggregate):
  """Extracts `update` from `after_aggregate`.

  This function is intended to be used by
  `get_canonical_form_for_iterative_process` only. As a result, this function
  does not assert that `after_aggregate` has the expected structure, the
  caller is expected to perform these checks before calling this function.

  Args:
    after_aggregate: The second result of splitting `after_broadcast` on
      aggregate intrinsics.

  Returns:
    `update` as specified by `canonical_form.CanonicalForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we extract an AST of the
      wrong type.
  """
  s7_elements_in_after_aggregate_result = [0, 1]
  s7_output_extracted = transformations.select_output_from_lambda(
      after_aggregate, s7_elements_in_after_aggregate_result)
  s7_output_zipped = building_blocks.Lambda(
      s7_output_extracted.parameter_name, s7_output_extracted.parameter_type,
      building_block_factory.create_federated_zip(s7_output_extracted.result))
  s6_elements_in_after_aggregate_parameter = [[0, 0, 0], [1, 0], [1, 1]]
  s6_to_s7_computation = (
      transformations.zip_selection_as_argument_to_lower_level_lambda(
          s7_output_zipped,
          s6_elements_in_after_aggregate_parameter).result.function)

  # TODO(b/148942011): The transformation
  # `zip_selection_as_argument_to_lower_level_lambda` does not support selecting
  # from nested structures, therefore we need to pack the type signature
  # `<s1, s3, s4>` as `<s1, <s3, s4>>`.
  name_generator = building_block_factory.unique_name_generator(
      s6_to_s7_computation)

  pack_ref_name = next(name_generator)
  pack_ref_type = computation_types.StructType([
      s6_to_s7_computation.parameter_type.member[0],
      computation_types.StructType([
          s6_to_s7_computation.parameter_type.member[1],
          s6_to_s7_computation.parameter_type.member[2],
      ]),
  ])
  pack_ref = building_blocks.Reference(pack_ref_name, pack_ref_type)
  sel_s1 = building_blocks.Selection(pack_ref, index=0)
  sel = building_blocks.Selection(pack_ref, index=1)
  sel_s3 = building_blocks.Selection(sel, index=0)
  sel_s4 = building_blocks.Selection(sel, index=1)
  result = building_blocks.Struct([sel_s1, sel_s3, sel_s4])
  pack_fn = building_blocks.Lambda(pack_ref.name, pack_ref.type_signature,
                                   result)
  ref_name = next(name_generator)
  ref_type = computation_types.FederatedType(pack_ref_type, placements.SERVER)
  ref = building_blocks.Reference(ref_name, ref_type)
  unpacked_args = building_block_factory.create_federated_map_or_apply(
      pack_fn, ref)
  call = building_blocks.Call(s6_to_s7_computation, unpacked_args)
  fn = building_blocks.Lambda(ref.name, ref.type_signature, call)
  return transformations.consolidate_and_extract_local_processing(fn)
Example #5
0
def _extract_update(after_aggregate, grappler_config):
    """Extracts `update` from `after_aggregate`.

  This function is intended to be used by
  `get_map_reduce_form_for_iterative_process` only. As a result, this function
  does not assert that `after_aggregate` has the expected structure, the
  caller is expected to perform these checks before calling this function.

  Args:
    after_aggregate: The second result of splitting `after_broadcast` on
      aggregate intrinsics.
    grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure
      Grappler graph optimization.

  Returns:
    `update` as specified by `forms.MapReduceForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    compiler.MapReduceFormCompilationError: If we extract an AST of the wrong
      type.
  """
    after_aggregate_zipped = building_blocks.Lambda(
        after_aggregate.parameter_name, after_aggregate.parameter_type,
        building_block_factory.create_federated_zip(after_aggregate.result))
    # `create_federated_zip` doesn't have unique reference names, but we need
    # them for `as_function_of_some_federated_subparameters`.
    after_aggregate_zipped, _ = tree_transformations.uniquify_reference_names(
        after_aggregate_zipped)
    server_state_index = ('original_arg', 'original_arg', 0)
    aggregate_result_index = ('intrinsic_results',
                              'federated_aggregate_result')
    secure_sum_bitwidth_result_index = ('intrinsic_results',
                                        'federated_secure_sum_bitwidth_result')
    secure_sum_result_index = ('intrinsic_results',
                               'federated_secure_sum_result')
    secure_modular_sum_result_index = ('intrinsic_results',
                                       'federated_secure_modular_sum_result')
    update_with_flat_inputs = _as_function_of_some_federated_subparameters(
        after_aggregate_zipped, (
            server_state_index,
            aggregate_result_index,
            secure_sum_bitwidth_result_index,
            secure_sum_result_index,
            secure_modular_sum_result_index,
        ))

    # TODO(b/148942011): The transformation
    # `zip_selection_as_argument_to_lower_level_lambda` does not support selecting
    # from nested structures, therefore we need to transform the input from
    # <server_state, <aggregation_results...>> into
    # <server_state, aggregation_results...>
    # unpack = <v, <...>> -> <v, ...>
    name_generator = building_block_factory.unique_name_generator(
        update_with_flat_inputs)
    unpack_param_name = next(name_generator)
    original_param_type = update_with_flat_inputs.parameter_type.member
    unpack_param_type = computation_types.StructType([
        original_param_type[0],
        computation_types.StructType(original_param_type[1:]),
    ])
    unpack_param_ref = building_blocks.Reference(unpack_param_name,
                                                 unpack_param_type)
    select = lambda bb, i: building_blocks.Selection(bb, index=i)
    unpack = building_blocks.Lambda(
        unpack_param_name, unpack_param_type,
        building_blocks.Struct([select(unpack_param_ref, 0)] + [
            select(select(unpack_param_ref, 1), i)
            for i in range(len(original_param_type) - 1)
        ]))

    # update = v -> update_with_flat_inputs(federated_map(unpack, v))
    param_name = next(name_generator)
    param_type = computation_types.at_server(unpack_param_type)
    param_ref = building_blocks.Reference(param_name, param_type)
    update = building_blocks.Lambda(
        param_name, param_type,
        building_blocks.Call(
            update_with_flat_inputs,
            building_block_factory.create_federated_map_or_apply(
                unpack, param_ref)))
    return compiler.consolidate_and_extract_local_processing(
        update, grappler_config)
Example #6
0
def _merge_args(
    abstract_parameter_type,
    args: List[building_blocks.ComputationBuildingBlock],
    name_generator,
) -> building_blocks.ComputationBuildingBlock:
    """Merges the arguments of multiple function invocations into one.

  Args:
    abstract_parameter_type: The abstract parameter type specification for the
      function being invoked. This is used to determine whether any functional
      parameters accept multiple arguments.
    args: A list where each element contains the arguments to a single call.
    name_generator: A generator used to create unique names.

  Returns:
    A building block to use as the new (merged) argument.
  """
    if abstract_parameter_type.is_federated():
        zip_args = building_block_factory.create_federated_zip(
            building_blocks.Struct(args))
        # `create_federated_zip` introduces repeated names.
        zip_args, _ = tree_transformations.uniquify_reference_names(
            zip_args, name_generator)
        return zip_args
    if (abstract_parameter_type.is_tensor()
            or abstract_parameter_type.is_abstract()):
        return building_blocks.Struct([(None, arg) for arg in args])
    if abstract_parameter_type.is_function():
        # For functions, we must compose them differently depending on whether the
        # abstract function (from the intrinsic definition) takes more than one
        # parameter.
        #
        # If it does not, such as in the `fn` argument to `federated_map`, we can
        # simply select out the argument and call the result:
        # `(fn0(arg[0]), fn1(arg[1]), ..., fnN(arg[n]))`
        #
        # If it takes multiple arguments such as the `accumulate` argument to
        # `federated_aggregate`, we have to select out the individual arguments to
        # pass to each function:
        #
        # `(
        #   fn0(arg[0][0], arg[1][0]),
        #   fn1(arg[0][1], arg[1][1]),
        #   ...
        #   fnN(arg[0][n], arg[1][n]),
        # )`
        param_name = next(name_generator)
        if abstract_parameter_type.parameter.is_struct():
            num_args = len(abstract_parameter_type.parameter)
            parameter_types = [[] for i in range(num_args)]
            for arg in args:
                for i in range(num_args):
                    parameter_types[i].append(arg.type_signature.parameter[i])
            param_type = computation_types.StructType(parameter_types)
            param_ref = building_blocks.Reference(param_name, param_type)
            calls = []
            for (n, fn) in enumerate(args):
                args_to_fn = []
                for i in range(num_args):
                    args_to_fn.append(
                        building_blocks.Selection(building_blocks.Selection(
                            param_ref, index=i),
                                                  index=n))
                calls.append(
                    building_blocks.Call(
                        fn,
                        building_blocks.Struct([(None, arg)
                                                for arg in args_to_fn])))
        else:
            param_type = computation_types.StructType(
                [arg.type_signature.parameter for arg in args])
            param_ref = building_blocks.Reference(param_name, param_type)
            calls = [
                building_blocks.Call(
                    fn, building_blocks.Selection(param_ref, index=n))
                for (n, fn) in enumerate(args)
            ]
        return building_blocks.Lambda(parameter_name=param_name,
                                      parameter_type=param_type,
                                      result=building_blocks.Struct([
                                          (None, call) for call in calls
                                      ]))
    if abstract_parameter_type.is_struct():
        # Bind each argument to a name so that we can reference them multiple times.
        arg_locals = []
        arg_refs = []
        for arg in args:
            arg_name = next(name_generator)
            arg_locals.append((arg_name, arg))
            arg_refs.append(
                building_blocks.Reference(arg_name, arg.type_signature))
        merged_args = []
        for i in range(len(abstract_parameter_type)):
            ith_args = [
                building_blocks.Selection(ref, index=i) for ref in arg_refs
            ]
            merged_args.append(
                _merge_args(abstract_parameter_type[i], ith_args,
                            name_generator))
        return building_blocks.Block(
            arg_locals,
            building_blocks.Struct([(None, arg) for arg in merged_args]))
    raise TypeError(f'Cannot merge args of type: {abstract_parameter_type}')