Example #1
0
 def test_reduces_federated_identity_to_member_identity(self):
     fed_int_type = computation_types.FederatedType(tf.int32,
                                                    placements.CLIENTS)
     lam = building_blocks.Lambda(
         'x', fed_int_type, building_blocks.Reference('x', fed_int_type))
     extracted_tf = mapreduce_transformations.consolidate_and_extract_local_processing(
         lam)
     self.assertIsInstance(extracted_tf,
                           building_blocks.CompiledComputation)
     unplaced_function_type = computation_types.FunctionType(
         fed_int_type.member, fed_int_type.member)
     self.assertEqual(extracted_tf.type_signature, unplaced_function_type)
Example #2
0
    def test_already_reduced_case(self):
        init = canonical_form_utils.get_iterative_process_for_canonical_form(
            test_utils.get_temperature_sensor_example()).initialize

        comp = test_utils.computation_to_building_block(init)

        result = mapreduce_transformations.consolidate_and_extract_local_processing(
            comp)

        self.assertIsInstance(result, building_blocks.CompiledComputation)
        self.assertIsInstance(result.proto, computation_pb2.Computation)
        self.assertEqual(result.proto.WhichOneof('computation'), 'tensorflow')
    def test_already_reduced_case(self):
        init = form_utils.get_iterative_process_for_map_reduce_form(
            mapreduce_test_utils.get_temperature_sensor_example()).initialize

        comp = init.to_building_block()

        result = transformations.consolidate_and_extract_local_processing(
            comp, DEFAULT_GRAPPLER_CONFIG)

        self.assertIsInstance(result, building_blocks.CompiledComputation)
        self.assertIsInstance(result.proto, computation_pb2.Computation)
        self.assertEqual(result.proto.WhichOneof('computation'), 'tensorflow')
Example #4
0
def _extract_work(before_aggregate, after_aggregate):
    """Extracts `work` from `before_aggregate` and `after_aggregate`.

  This function is intended to be used by
  `get_canonical_form_for_iterative_process` only. As a result, this function
  does not assert that `before_aggregate` or `after_aggregate` has the expected
  structure, the caller is expected to perform these checks before calling this
  function.

  Args:
    before_aggregate: The first result of splitting `after_broadcast` on
      aggregate intrinsics.
    after_aggregate: The second result of splitting `after_broadcast` on
      aggregate intrinsics.

  Returns:
    `work` as specified by `canonical_form.CanonicalForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we extract an AST of the
      wrong type.
  """
    c3_elements_in_before_aggregate_parameter = [[0, 1], [1]]
    c3_to_before_aggregate_computation = (
        transformations.zip_selection_as_argument_to_lower_level_lambda(
            before_aggregate,
            c3_elements_in_before_aggregate_parameter).result.function)
    c6_index_in_before_aggregate_result = [[0, 0], [1, 0]]
    c3_to_c6_computation = transformations.select_output_from_lambda(
        c3_to_before_aggregate_computation,
        c6_index_in_before_aggregate_result)
    c8_index_in_after_aggregate_result = 2
    after_aggregate_to_c8_computation = transformations.select_output_from_lambda(
        after_aggregate, c8_index_in_after_aggregate_result)
    c3_elements_in_after_aggregate_parameter = [[0, 0, 1], [0, 1]]
    c3_to_c8_computation = (
        transformations.zip_selection_as_argument_to_lower_level_lambda(
            after_aggregate_to_c8_computation,
            c3_elements_in_after_aggregate_parameter).result.function)
    c3_to_unzipped_c4_computation = transformations.concatenate_function_outputs(
        c3_to_c6_computation, c3_to_c8_computation)
    c3_to_c4_computation = building_blocks.Lambda(
        c3_to_unzipped_c4_computation.parameter_name,
        c3_to_unzipped_c4_computation.parameter_type,
        building_block_factory.create_federated_zip(
            c3_to_unzipped_c4_computation.result))

    return transformations.consolidate_and_extract_local_processing(
        c3_to_c4_computation)
Example #5
0
def _extract_aggregate_functions(before_aggregate):
  """Converts `before_aggregate` to aggregation functions.

  Args:
    before_aggregate: The first result of splitting `after_broadcast` on
      `intrinsic_defs.FEDERATED_AGGREGATE`.

  Returns:
    `zero`, `accumulate`, `merge` and `report` as specified by
    `canonical_form.CanonicalForm`. All are instances of
    `building_blocks.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we extract an ASTs of the
      wrong type.
  """
  # See `get_iterative_process_for_canonical_form()` above for the meaning of
  # variable names used in the code below.
  zero_index_in_before_aggregate_result = 1
  zero_tff = transformations.select_output_from_lambda(
      before_aggregate, zero_index_in_before_aggregate_result).result
  accumulate_index_in_before_aggregate_result = 2
  accumulate_tff = transformations.select_output_from_lambda(
      before_aggregate, accumulate_index_in_before_aggregate_result).result
  merge_index_in_before_aggregate_result = 3
  merge_tff = transformations.select_output_from_lambda(
      before_aggregate, merge_index_in_before_aggregate_result).result
  report_index_in_before_aggregate_result = 4
  report_tff = transformations.select_output_from_lambda(
      before_aggregate, report_index_in_before_aggregate_result).result

  zero = transformations.consolidate_and_extract_local_processing(zero_tff)
  accumulate = transformations.consolidate_and_extract_local_processing(
      accumulate_tff)
  merge = transformations.consolidate_and_extract_local_processing(merge_tff)
  report = transformations.consolidate_and_extract_local_processing(report_tff)
  return zero, accumulate, merge, report
Example #6
0
 def test_reduces_federated_apply_to_equivalent_function(self):
   lam = building_blocks.Lambda('x', tf.int32,
                                building_blocks.Reference('x', tf.int32))
   arg = building_blocks.Reference(
       'arg', computation_types.FederatedType(tf.int32, placements.CLIENTS))
   mapped_fn = building_block_factory.create_federated_map_or_apply(lam, arg)
   extracted_tf = mapreduce_transformations.consolidate_and_extract_local_processing(
       mapped_fn)
   self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation)
   executable_tf = computation_wrapper_instances.building_block_to_computation(
       extracted_tf)
   executable_lam = computation_wrapper_instances.building_block_to_computation(
       lam)
   for k in range(10):
     self.assertEqual(executable_tf(k), executable_lam(k))
Example #7
0
def _extract_client_processing(after_broadcast, grappler_config):
    """Extracts `client_processing` from `after_broadcast`."""
    context_from_server_index_in_after_broadcast = (1, )
    client_data_index_in_after_broadcast = (0, 1)
    # NOTE: the order of parameters here is different from `work`.
    # `work` is odd in that it takes its parameters as `(data, params)` rather
    # than `(params, data)` (the order of the iterative process / computation).
    # Here, we use the same `(params, data)` ordering as in the input computation.
    client_processing = _as_function_of_some_federated_subparameters(
        after_broadcast, [
            context_from_server_index_in_after_broadcast,
            client_data_index_in_after_broadcast
        ])
    return transformations.consolidate_and_extract_local_processing(
        client_processing, grappler_config)
Example #8
0
def _extract_work(before_aggregate, after_aggregate):
  """Converts `before_aggregate` and `after_aggregate` to `work`.

  Args:
    before_aggregate: The first result of splitting `after_broadcast` on
      `intrinsic_defs.FEDERATED_AGGREGATE`.
    after_aggregate: The second result of splitting `after_broadcast` on
      `intrinsic_defs.FEDERATED_AGGREGATE`.

  Returns:
    `work` as specified by `canonical_form.CanonicalForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we extract an AST of the
      wrong type.
  """
  # See `get_iterative_process_for_canonical_form()` above for the meaning of
  # variable names used in the code below.
  c3_elements_in_before_aggregate_parameter = [[0, 1], [1]]
  c3_to_before_aggregate_computation = (
      transformations.zip_selection_as_argument_to_lower_level_lambda(
          before_aggregate,
          c3_elements_in_before_aggregate_parameter).result.function)
  c5_index_in_before_aggregate_result = 0
  c3_to_c5_computation = transformations.select_output_from_lambda(
      c3_to_before_aggregate_computation, c5_index_in_before_aggregate_result)
  c6_index_in_after_aggregate_result = 2
  after_aggregate_to_c6_computation = transformations.select_output_from_lambda(
      after_aggregate, c6_index_in_after_aggregate_result)
  c3_elements_in_after_aggregate_parameter = [[0, 0, 1], [0, 1]]
  c3_to_c6_computation = (
      transformations.zip_selection_as_argument_to_lower_level_lambda(
          after_aggregate_to_c6_computation,
          c3_elements_in_after_aggregate_parameter).result.function)
  c3_to_unzipped_c4_computation = transformations.concatenate_function_outputs(
      c3_to_c5_computation, c3_to_c6_computation)
  c3_to_c4_computation = building_blocks.Lambda(
      c3_to_unzipped_c4_computation.parameter_name,
      c3_to_unzipped_c4_computation.parameter_type,
      building_block_factory.create_federated_zip(
          c3_to_unzipped_c4_computation.result))

  work = transformations.consolidate_and_extract_local_processing(
      c3_to_c4_computation)
  return work
def extract_update(after_aggregate, canonical_form_types):
    """Converts `after_aggregate` to `update`.

  Args:
    after_aggregate: The second result of splitting `after_broadcast` on
      `intrinsic_defs.FEDERATED_AGGREGATE`.
    canonical_form_types: `dict` holding the `canonical_form.CanonicalForm` type
      signatures specified by the `tff.utils.IterativeProcess` we are compiling.

  Returns:
    `update` as specified by `canonical_form.CanonicalForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we fail to extract a
    `building_blocks.CompiledComputation`, or we extract one of the wrong type.
  """
    # See `get_iterative_process_for_canonical_form()` above for the meaning of
    # variable names used in the code below.
    s5_elements_in_after_aggregate_result = [0, 1]
    s5_output_extracted = transformations.select_output_from_lambda(
        after_aggregate, s5_elements_in_after_aggregate_result)
    s5_output_zipped = building_blocks.Lambda(
        s5_output_extracted.parameter_name, s5_output_extracted.parameter_type,
        building_block_factory.create_federated_zip(
            s5_output_extracted.result))
    s4_elements_in_after_aggregate_parameter = [[0, 0, 0], [1]]
    s4_to_s5_computation = (
        transformations.zip_selection_as_argument_to_lower_level_lambda(
            s5_output_zipped,
            s4_elements_in_after_aggregate_parameter).result.function)

    update = transformations.consolidate_and_extract_local_processing(
        s4_to_s5_computation)
    if not isinstance(update, building_blocks.CompiledComputation):
        raise transformations.CanonicalFormCompilationError(
            'Failed to extract a `building_blocks.CompiledComputation` from '
            'update, instead received a {} (of type {}).'.format(
                type(update), update.type_signature))
    if update.type_signature != canonical_form_types['update_type']:
        raise transformations.CanonicalFormCompilationError(
            'Extracted a TF block of the wrong type. Expected a function with type '
            '{}, but the type signature of the TF block was {}'.format(
                canonical_form_types['update_type'], update.type_signature))
    return update
Example #10
0
def _extract_prepare(before_broadcast):
    """Converts `before_broadcast` into `prepare`.

  Args:
    before_broadcast: The first result of splitting `next_comp` on
      `intrinsic_defs.FEDERATED_BROADCAST`.

  Returns:
    `prepare` as specified by `canonical_form.CanonicalForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we extract an AST of the
      wrong type.
  """
    s1_index_in_before_broadcast = 0
    s1_to_s2_computation = (
        transformations.
        bind_single_selection_as_argument_to_lower_level_lambda(
            before_broadcast, s1_index_in_before_broadcast)).result.function
    prepare = transformations.consolidate_and_extract_local_processing(
        s1_to_s2_computation)
    return prepare
Example #11
0
def extract_prepare(before_broadcast, canonical_form_types):
    """Converts `before_broadcast` into `prepare`.

  Args:
    before_broadcast: The first result of splitting `next_comp` on
      `tff_framework.FEDERATED_BROADCAST`.
    canonical_form_types: `dict` holding the `canonical_form.CanonicalForm` type
      signatures specified by the `tff.utils.IterativeProcess` we are compiling.

  Returns:
    `prepare` as specified by `canonical_form.CanonicalForm`, an instance of
    `tff_framework.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we fail to extract a
    `tff_framework.CompiledComputation`, or we extract one of the wrong type.
  """
    # See `get_iterative_process_for_canonical_form()` above for the meaning of
    # variable names used in the code below.
    s1_index_in_before_broadcast = 0
    s1_to_s2_computation = (
        transformations.
        bind_single_selection_as_argument_to_lower_level_lambda(
            before_broadcast, s1_index_in_before_broadcast)).result.function
    prepare = transformations.consolidate_and_extract_local_processing(
        s1_to_s2_computation)
    if not isinstance(prepare, tff_framework.CompiledComputation):
        raise transformations.CanonicalFormCompilationError(
            'Failed to extract a `tff_framework.CompiledComputation` from '
            'prepare, instead received a {} (of type {}).'.format(
                type(prepare), prepare.type_signature))
    if prepare.type_signature != canonical_form_types['prepare_type']:
        raise transformations.CanonicalFormCompilationError(
            'Extracted a TF block of the wrong type. Expected a function with type '
            '{}, but the type signature of the TF block was {}'.format(
                canonical_form_types['prepare_type'], prepare.type_signature))
    return prepare
Example #12
0
def _extract_work(before_aggregate, grappler_config):
    """Extracts `work` from `before_aggregate`.

  This function is intended to be used by
  `get_map_reduce_form_for_iterative_process` only. As a result, this function
  does not assert that `before_aggregate` has the expected structure, the caller
  is expected to perform these checks before calling this function.

  Args:
    before_aggregate: The first result of splitting `after_broadcast` on
      aggregate intrinsics.
    grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure
      Grappler graph optimization.

  Returns:
    `work` as specified by `forms.MapReduceForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    transformations.MapReduceFormCompilationError: If we extract an AST of the
      wrong type.
  """
    c3_elements_in_before_aggregate_parameter = [(0, 1), (1, )]
    c3_to_before_aggregate_computation = _as_function_of_some_federated_subparameters(
        before_aggregate, c3_elements_in_before_aggregate_parameter)
    c4_index_in_before_aggregate_result = [[0, 0], [1, 0]]
    c3_to_unzipped_c4_computation = transformations.select_output_from_lambda(
        c3_to_before_aggregate_computation,
        c4_index_in_before_aggregate_result)
    c3_to_c4_computation = building_blocks.Lambda(
        c3_to_unzipped_c4_computation.parameter_name,
        c3_to_unzipped_c4_computation.parameter_type,
        building_block_factory.create_federated_zip(
            c3_to_unzipped_c4_computation.result))
    return transformations.consolidate_and_extract_local_processing(
        c3_to_c4_computation, grappler_config)
Example #13
0
def get_canonical_form_for_iterative_process(iterative_process):
    """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process.

  This function transforms computations from the input `iterative_process` into
  an instance of `tff.backends.mapreduce.CanonicalForm`.

  Args:
    iterative_process: An instance of `tff.utils.IterativeProcess`.

  Returns:
    An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this
    process.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.CanonicalFormCompilationError: If the compilation
      process fails.
  """
    py_typecheck.check_type(iterative_process,
                            computation_utils.IterativeProcess)

    initialize_comp = tff_framework.ComputationBuildingBlock.from_proto(
        iterative_process.initialize._computation_proto)  # pylint: disable=protected-access

    next_comp = tff_framework.ComputationBuildingBlock.from_proto(
        iterative_process.next._computation_proto)  # pylint: disable=protected-access

    if not (isinstance(next_comp.type_signature.parameter, tff.NamedTupleType)
            and isinstance(next_comp.type_signature.result,
                           tff.NamedTupleType)):
        raise TypeError(
            'Any IterativeProcess compatible with CanonicalForm must '
            'have a `next` function which takes and returns instances '
            'of `tff.NamedTupleType`; your next function takes '
            'parameters of type {} and returns results of type {}'.format(
                next_comp.type_signature.parameter,
                next_comp.type_signature.result))

    if len(next_comp.type_signature.result) == 2:
        next_result = next_comp.result
        dummy_clients_metrics_appended = tff_framework.Tuple([
            next_result[0],
            next_result[1],
            tff.federated_value([], tff.CLIENTS)._comp  # pylint: disable=protected-access
        ])
        next_comp = tff_framework.Lambda(next_comp.parameter_name,
                                         next_comp.parameter_type,
                                         dummy_clients_metrics_appended)

    initialize_comp = tff_framework.replace_intrinsics_with_bodies(
        initialize_comp)
    next_comp = tff_framework.replace_intrinsics_with_bodies(next_comp)

    tff_framework.check_intrinsics_whitelisted_for_reduction(initialize_comp)
    tff_framework.check_intrinsics_whitelisted_for_reduction(next_comp)
    tff_framework.check_broadcast_not_dependent_on_aggregate(next_comp)

    before_broadcast, after_broadcast = (
        transformations.force_align_and_split_by_intrinsic(
            next_comp, tff_framework.FEDERATED_BROADCAST.uri))

    before_aggregate, after_aggregate = (
        transformations.force_align_and_split_by_intrinsic(
            after_broadcast, tff_framework.FEDERATED_AGGREGATE.uri))

    init_info_packed = pack_initialize_comp_type_signature(
        initialize_comp.type_signature)

    next_info_packed = pack_next_comp_type_signature(next_comp.type_signature,
                                                     init_info_packed)

    before_broadcast_info_packed = (
        check_and_pack_before_broadcast_type_signature(
            before_broadcast.type_signature, next_info_packed))

    before_aggregate_info_packed = (
        check_and_pack_before_aggregate_type_signature(
            before_aggregate.type_signature, before_broadcast_info_packed))

    canonical_form_types = check_and_pack_after_aggregate_type_signature(
        after_aggregate.type_signature, before_aggregate_info_packed)

    initialize = transformations.consolidate_and_extract_local_processing(
        initialize_comp)

    if not (isinstance(initialize, tff_framework.CompiledComputation)
            and initialize.type_signature.result
            == canonical_form_types['initialize_type'].member):
        raise transformations.CanonicalFormCompilationError(
            'Compilation of initialize has failed. Expected to extract a '
            '`tff_framework.CompiledComputation` of type {}, instead we extracted '
            'a {} of type {}.'.format(next_comp.type_signature.parameter[0],
                                      type(initialize),
                                      initialize.type_signature.result))

    prepare = extract_prepare(before_broadcast, canonical_form_types)

    work = extract_work(before_aggregate, after_aggregate,
                        canonical_form_types)

    zero_noarg_function, accumulate, merge, report = extract_aggregate_functions(
        before_aggregate, canonical_form_types)

    update = extract_update(after_aggregate, canonical_form_types)

    cf = canonical_form.CanonicalForm(
        tff_framework.building_block_to_computation(initialize),
        tff_framework.building_block_to_computation(prepare),
        tff_framework.building_block_to_computation(work),
        tff_framework.building_block_to_computation(zero_noarg_function),
        tff_framework.building_block_to_computation(accumulate),
        tff_framework.building_block_to_computation(merge),
        tff_framework.building_block_to_computation(report),
        tff_framework.building_block_to_computation(update))
    return cf
Example #14
0
 def test_raises_reference_to_functional_type(self):
     function_type = computation_types.FunctionType(tf.int32, tf.int32)
     ref = building_blocks.Reference('x', function_type)
     with self.assertRaisesRegex(ValueError, 'of functional type passed'):
         mapreduce_transformations.consolidate_and_extract_local_processing(
             ref)
Example #15
0
 def test_raises_on_none(self):
     with self.assertRaises(TypeError):
         mapreduce_transformations.consolidate_and_extract_local_processing(
             None)
def get_canonical_form_for_iterative_process(ip):
  """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process.

  This function transforms computations from the input `ip` into
  an instance of `tff.backends.mapreduce.CanonicalForm`.

  Args:
    ip: An instance of `tff.templates.IterativeProcess`.

  Returns:
    An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this
    process.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.CanonicalFormCompilationError: If the compilation
      process fails.
  """
  py_typecheck.check_type(ip, iterative_process.IterativeProcess)

  initialize_comp = building_blocks.ComputationBuildingBlock.from_proto(
      ip.initialize._computation_proto)  # pylint: disable=protected-access
  next_comp = building_blocks.ComputationBuildingBlock.from_proto(
      ip.next._computation_proto)  # pylint: disable=protected-access
  _check_iterative_process_compatible_with_canonical_form(
      initialize_comp, next_comp)

  initialize_comp = _replace_intrinsics_with_bodies(initialize_comp)
  next_comp = _replace_intrinsics_with_bodies(next_comp)
  tree_analysis.check_contains_only_reducible_intrinsics(initialize_comp)
  tree_analysis.check_contains_only_reducible_intrinsics(next_comp)
  tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp)

  if tree_analysis.contains_called_intrinsic(
      next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri):
    before_broadcast, after_broadcast = (
        transformations.force_align_and_split_by_intrinsics(
            next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri]))
  else:
    before_broadcast, after_broadcast = (
        _create_before_and_after_broadcast_for_no_broadcast(next_comp))

  contains_federated_aggregate = tree_analysis.contains_called_intrinsic(
      next_comp, intrinsic_defs.FEDERATED_AGGREGATE.uri)
  contains_federated_secure_sum = tree_analysis.contains_called_intrinsic(
      next_comp, intrinsic_defs.FEDERATED_SECURE_SUM.uri)
  if not (contains_federated_aggregate or contains_federated_secure_sum):
    raise ValueError(
        'Expected an `tff.templates.IterativeProcess` containing at least one '
        '`federated_aggregate` or `federated_secure_sum`, found none.')

  if contains_federated_aggregate and contains_federated_secure_sum:
    before_aggregate, after_aggregate = (
        transformations.force_align_and_split_by_intrinsics(
            after_broadcast, [
                intrinsic_defs.FEDERATED_AGGREGATE.uri,
                intrinsic_defs.FEDERATED_SECURE_SUM.uri,
            ]))
  elif contains_federated_secure_sum:
    assert not contains_federated_aggregate
    before_aggregate, after_aggregate = (
        _create_before_and_after_aggregate_for_no_federated_aggregate(
            after_broadcast))
  else:
    assert contains_federated_aggregate and not contains_federated_secure_sum
    before_aggregate, after_aggregate = (
        _create_before_and_after_aggregate_for_no_federated_secure_sum(
            after_broadcast))

  type_info = _get_type_info(initialize_comp, before_broadcast, after_broadcast,
                             before_aggregate, after_aggregate)

  initialize = transformations.consolidate_and_extract_local_processing(
      initialize_comp)
  _check_type_equal(initialize.type_signature, type_info['initialize_type'])

  prepare = _extract_prepare(before_broadcast)
  _check_type_equal(prepare.type_signature, type_info['prepare_type'])

  work = _extract_work(before_aggregate)
  _check_type_equal(work.type_signature, type_info['work_type'])

  zero, accumulate, merge, report = _extract_federated_aggregate_functions(
      before_aggregate)
  _check_type_equal(zero.type_signature, type_info['zero_type'])
  _check_type_equal(accumulate.type_signature, type_info['accumulate_type'])
  _check_type_equal(merge.type_signature, type_info['merge_type'])
  _check_type_equal(report.type_signature, type_info['report_type'])

  bitwidth = _extract_federated_secure_sum_functions(before_aggregate)
  _check_type_equal(bitwidth.type_signature, type_info['bitwidth_type'])

  update = _extract_update(after_aggregate)
  _check_type_equal(update.type_signature, type_info['update_type'])

  return canonical_form.CanonicalForm(
      computation_wrapper_instances.building_block_to_computation(initialize),
      computation_wrapper_instances.building_block_to_computation(prepare),
      computation_wrapper_instances.building_block_to_computation(work),
      computation_wrapper_instances.building_block_to_computation(zero),
      computation_wrapper_instances.building_block_to_computation(accumulate),
      computation_wrapper_instances.building_block_to_computation(merge),
      computation_wrapper_instances.building_block_to_computation(report),
      computation_wrapper_instances.building_block_to_computation(bitwidth),
      computation_wrapper_instances.building_block_to_computation(update))
def _extract_update(after_aggregate):
  """Extracts `update` from `after_aggregate`.

  This function is intended to be used by
  `get_canonical_form_for_iterative_process` only. As a result, this function
  does not assert that `after_aggregate` has the expected structure, the
  caller is expected to perform these checks before calling this function.

  Args:
    after_aggregate: The second result of splitting `after_broadcast` on
      aggregate intrinsics.

  Returns:
    `update` as specified by `canonical_form.CanonicalForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we extract an AST of the
      wrong type.
  """
  s7_elements_in_after_aggregate_result = [0, 1]
  s7_output_extracted = transformations.select_output_from_lambda(
      after_aggregate, s7_elements_in_after_aggregate_result)
  s7_output_zipped = building_blocks.Lambda(
      s7_output_extracted.parameter_name, s7_output_extracted.parameter_type,
      building_block_factory.create_federated_zip(s7_output_extracted.result))
  s6_elements_in_after_aggregate_parameter = [[0, 0, 0], [1, 0], [1, 1]]
  s6_to_s7_computation = (
      transformations.zip_selection_as_argument_to_lower_level_lambda(
          s7_output_zipped,
          s6_elements_in_after_aggregate_parameter).result.function)

  # TODO(b/148942011): The transformation
  # `zip_selection_as_argument_to_lower_level_lambda` does not support selecting
  # from nested structures, therefore we need to pack the type signature
  # `<s1, s3, s4>` as `<s1, <s3, s4>>`.
  name_generator = building_block_factory.unique_name_generator(
      s6_to_s7_computation)

  pack_ref_name = next(name_generator)
  pack_ref_type = computation_types.StructType([
      s6_to_s7_computation.parameter_type.member[0],
      computation_types.StructType([
          s6_to_s7_computation.parameter_type.member[1],
          s6_to_s7_computation.parameter_type.member[2],
      ]),
  ])
  pack_ref = building_blocks.Reference(pack_ref_name, pack_ref_type)
  sel_s1 = building_blocks.Selection(pack_ref, index=0)
  sel = building_blocks.Selection(pack_ref, index=1)
  sel_s3 = building_blocks.Selection(sel, index=0)
  sel_s4 = building_blocks.Selection(sel, index=1)
  result = building_blocks.Struct([sel_s1, sel_s3, sel_s4])
  pack_fn = building_blocks.Lambda(pack_ref.name, pack_ref.type_signature,
                                   result)
  ref_name = next(name_generator)
  ref_type = computation_types.FederatedType(pack_ref_type, placements.SERVER)
  ref = building_blocks.Reference(ref_name, ref_type)
  unpacked_args = building_block_factory.create_federated_map_or_apply(
      pack_fn, ref)
  call = building_blocks.Call(s6_to_s7_computation, unpacked_args)
  fn = building_blocks.Lambda(ref.name, ref.type_signature, call)
  return transformations.consolidate_and_extract_local_processing(fn)
 def test_reduces_lambda_returning_empty_tuple_to_tf(self):
   empty_tuple = building_blocks.Struct([])
   lam = building_blocks.Lambda('x', tf.int32, empty_tuple)
   extracted_tf = transformations.consolidate_and_extract_local_processing(
       lam, DEFAULT_GRAPPLER_CONFIG)
   self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation)
 def test_raises_on_none(self):
   with self.assertRaises(TypeError):
     transformations.consolidate_and_extract_local_processing(
         None, DEFAULT_GRAPPLER_CONFIG)
Example #20
0
def get_canonical_form_for_iterative_process(
    ip: iterative_process.IterativeProcess,
    grappler_config: Optional[
        tf.compat.v1.ConfigProto] = _GRAPPLER_DEFAULT_CONFIG):
    """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process.

  This function transforms computations from the input `ip` into
  an instance of `tff.backends.mapreduce.CanonicalForm`.

  Args:
    ip: An instance of `tff.templates.IterativeProcess` that is compatible
      with canonical form. Iterative processes are only compatible if: -
        `initialize_fn` returns a single federated value placed at `SERVER`. -
        `next` takes exactly two arguments. The first must be the state value
        placed at `SERVER`. - `next` returns exactly two values.
    grappler_config: An optional instance of `tf.compat.v1.ConfigProto` to
      configure Grappler graph optimization of the TensorFlow graphs backing the
      resulting `tff.backends.mapreduce.CanonicalForm`. These options are
      combined with a set of defaults that aggressively configure Grappler. If
      `None`, Grappler is bypassed.

  Returns:
    An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this
    process.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.CanonicalFormCompilationError: If the compilation
      process fails.
  """
    py_typecheck.check_type(ip, iterative_process.IterativeProcess)
    if grappler_config is not None:
        py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto)
        overridden_grappler_config = tf.compat.v1.ConfigProto()
        overridden_grappler_config.CopyFrom(_GRAPPLER_DEFAULT_CONFIG)
        overridden_grappler_config.MergeFrom(grappler_config)
        grappler_config = overridden_grappler_config

    initialize_comp = building_blocks.ComputationBuildingBlock.from_proto(
        ip.initialize._computation_proto)  # pylint: disable=protected-access
    next_comp = building_blocks.ComputationBuildingBlock.from_proto(
        ip.next._computation_proto)  # pylint: disable=protected-access
    _check_iterative_process_compatible_with_canonical_form(
        initialize_comp, next_comp)

    initialize_comp = _replace_intrinsics_with_bodies(initialize_comp)
    next_comp = _replace_intrinsics_with_bodies(next_comp)
    next_comp = _replace_lambda_body_with_call_dominant_form(next_comp)

    tree_analysis.check_contains_only_reducible_intrinsics(initialize_comp)
    tree_analysis.check_contains_only_reducible_intrinsics(next_comp)
    tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp)

    if tree_analysis.contains_called_intrinsic(
            next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri):

        before_broadcast, after_broadcast = (
            transformations.force_align_and_split_by_intrinsics(
                next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri]))
    else:
        before_broadcast, after_broadcast = (
            _create_before_and_after_broadcast_for_no_broadcast(next_comp))

    contains_federated_aggregate = tree_analysis.contains_called_intrinsic(
        next_comp, intrinsic_defs.FEDERATED_AGGREGATE.uri)
    contains_federated_secure_sum = tree_analysis.contains_called_intrinsic(
        next_comp, intrinsic_defs.FEDERATED_SECURE_SUM.uri)
    if not (contains_federated_aggregate or contains_federated_secure_sum):
        raise ValueError(
            'Expected an `tff.templates.IterativeProcess` containing at least one '
            '`federated_aggregate` or `federated_secure_sum`, found none.')

    if contains_federated_aggregate and contains_federated_secure_sum:
        before_aggregate, after_aggregate = (
            transformations.force_align_and_split_by_intrinsics(
                after_broadcast, [
                    intrinsic_defs.FEDERATED_AGGREGATE.uri,
                    intrinsic_defs.FEDERATED_SECURE_SUM.uri,
                ]))
    elif contains_federated_secure_sum:
        assert not contains_federated_aggregate
        before_aggregate, after_aggregate = (
            _create_before_and_after_aggregate_for_no_federated_aggregate(
                after_broadcast))
    else:
        assert contains_federated_aggregate and not contains_federated_secure_sum
        before_aggregate, after_aggregate = (
            _create_before_and_after_aggregate_for_no_federated_secure_sum(
                after_broadcast))

    type_info = _get_type_info(initialize_comp, before_broadcast,
                               after_broadcast, before_aggregate,
                               after_aggregate)

    initialize = transformations.consolidate_and_extract_local_processing(
        initialize_comp, grappler_config)
    _check_type_equal(initialize.type_signature, type_info['initialize_type'])

    prepare = _extract_prepare(before_broadcast, grappler_config)
    _check_type_equal(prepare.type_signature, type_info['prepare_type'])

    work = _extract_work(before_aggregate, grappler_config)
    _check_type_equal(work.type_signature, type_info['work_type'])

    zero, accumulate, merge, report = _extract_federated_aggregate_functions(
        before_aggregate, grappler_config)
    _check_type_equal(zero.type_signature, type_info['zero_type'])
    _check_type_equal(accumulate.type_signature, type_info['accumulate_type'])
    _check_type_equal(merge.type_signature, type_info['merge_type'])
    _check_type_equal(report.type_signature, type_info['report_type'])

    bitwidth = _extract_federated_secure_sum_functions(before_aggregate,
                                                       grappler_config)
    _check_type_equal(bitwidth.type_signature, type_info['bitwidth_type'])

    update = _extract_update(after_aggregate, grappler_config)
    _check_type_equal(update.type_signature, type_info['update_type'])

    next_parameter_names = (name for (
        name, _) in structure.iter_elements(ip.next.type_signature.parameter))
    server_state_label, client_data_label = next_parameter_names
    return canonical_form.CanonicalForm(
        computation_wrapper_instances.building_block_to_computation(
            initialize),
        computation_wrapper_instances.building_block_to_computation(prepare),
        computation_wrapper_instances.building_block_to_computation(work),
        computation_wrapper_instances.building_block_to_computation(zero),
        computation_wrapper_instances.building_block_to_computation(
            accumulate),
        computation_wrapper_instances.building_block_to_computation(merge),
        computation_wrapper_instances.building_block_to_computation(report),
        computation_wrapper_instances.building_block_to_computation(bitwidth),
        computation_wrapper_instances.building_block_to_computation(update),
        server_state_label=server_state_label,
        client_data_label=client_data_label)
Example #21
0
def get_canonical_form_for_iterative_process(iterative_process):
    """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process.

  This function transforms computations from the input `iterative_process` into
  an instance of `tff.backends.mapreduce.CanonicalForm`.

  Args:
    iterative_process: An instance of `tff.utils.IterativeProcess`.

  Returns:
    An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this
    process.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.CanonicalFormCompilationError: If the compilation
      process fails.
  """
    py_typecheck.check_type(iterative_process,
                            computation_utils.IterativeProcess)

    initialize_comp = building_blocks.ComputationBuildingBlock.from_proto(
        iterative_process.initialize._computation_proto)  # pylint: disable=protected-access
    next_comp = building_blocks.ComputationBuildingBlock.from_proto(
        iterative_process.next._computation_proto)  # pylint: disable=protected-access
    _check_iterative_process_compatible_with_canonical_form(
        initialize_comp, next_comp)

    if len(next_comp.type_signature.result) == 2:
        next_comp = _create_next_with_fake_client_output(next_comp)

    initialize_comp = _replace_intrinsics_with_bodies(initialize_comp)
    next_comp = _replace_intrinsics_with_bodies(next_comp)
    tree_analysis.check_intrinsics_whitelisted_for_reduction(initialize_comp)
    tree_analysis.check_intrinsics_whitelisted_for_reduction(next_comp)
    tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp)

    if tree_analysis.contains_called_intrinsic(
            next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri):
        before_broadcast, after_broadcast = (
            transformations.force_align_and_split_by_intrinsics(
                next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri]))
    else:
        before_broadcast, after_broadcast = (
            _create_before_and_after_broadcast_for_no_broadcast(next_comp))

    before_aggregate, after_aggregate = (
        transformations.force_align_and_split_by_intrinsics(
            after_broadcast, [intrinsic_defs.FEDERATED_AGGREGATE.uri]))

    type_info = _get_type_info(initialize_comp, before_broadcast,
                               after_broadcast, before_aggregate,
                               after_aggregate)

    initialize = transformations.consolidate_and_extract_local_processing(
        initialize_comp)
    _check_type_equal(initialize.type_signature, type_info['initialize_type'])

    prepare = _extract_prepare(before_broadcast)
    _check_type_equal(prepare.type_signature, type_info['prepare_type'])

    work = _extract_work(before_aggregate, after_aggregate)
    _check_type_equal(work.type_signature, type_info['work_type'])

    zero, accumulate, merge, report = _extract_aggregate_functions(
        before_aggregate)
    _check_type_equal(zero.type_signature, type_info['zero_type'])
    _check_type_equal(accumulate.type_signature, type_info['accumulate_type'])
    _check_type_equal(merge.type_signature, type_info['merge_type'])
    _check_type_equal(report.type_signature, type_info['report_type'])

    update = _extract_update(after_aggregate)
    _check_type_equal(update.type_signature, type_info['update_type'])

    return canonical_form.CanonicalForm(
        computation_wrapper_instances.building_block_to_computation(
            initialize),
        computation_wrapper_instances.building_block_to_computation(prepare),
        computation_wrapper_instances.building_block_to_computation(work),
        computation_wrapper_instances.building_block_to_computation(zero),
        computation_wrapper_instances.building_block_to_computation(
            accumulate),
        computation_wrapper_instances.building_block_to_computation(merge),
        computation_wrapper_instances.building_block_to_computation(report),
        computation_wrapper_instances.building_block_to_computation(update))
 def test_reduces_unplaced_lambda_leaving_type_signature_alone(self):
   lam = building_blocks.Lambda('x', tf.int32,
                                building_blocks.Reference('x', tf.int32))
   extracted_tf = transformations.consolidate_and_extract_local_processing(lam)
   self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation)
   self.assertEqual(extracted_tf.type_signature, lam.type_signature)
 def test_reduces_lambda_returning_empty_tuple_to_tf(self):
   self.skipTest('Depends on a lower level fix, currently in review.')
   empty_tuple = building_blocks.Tuple([])
   lam = building_blocks.Lambda('x', tf.int32, empty_tuple)
   extracted_tf = transformations.consolidate_and_extract_local_processing(lam)
   self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation)