Beispiel #1
0
 def test_reduces_lambda_returning_empty_tuple_to_tf(self):
     empty_tuple = building_blocks.Struct([])
     lam = building_blocks.Lambda('x', tf.int32, empty_tuple)
     extracted_tf = compiler.consolidate_and_extract_local_processing(
         lam, DEFAULT_GRAPPLER_CONFIG)
     self.assertIsInstance(extracted_tf,
                           building_blocks.CompiledComputation)
Beispiel #2
0
def _extract_prepare(before_broadcast, grappler_config):
    """extracts `prepare` from `before_broadcast`.

  This function is intended to be used by
  `get_map_reduce_form_for_iterative_process` only. As a result, this function
  does not assert that `before_broadcast` has the expected structure, the
  caller is expected to perform these checks before calling this function.

  Args:
    before_broadcast: The first result of splitting `next_bb` on
      `intrinsic_defs.FEDERATED_BROADCAST`.
    grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure
      Grappler graph optimization.

  Returns:
    `prepare` as specified by `forms.MapReduceForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    compiler.MapReduceFormCompilationError: If we extract an AST of the wrong
      type.
  """
    server_state_index_in_before_broadcast = 0
    prepare = _as_function_of_single_subparameter(
        before_broadcast, server_state_index_in_before_broadcast)
    return compiler.consolidate_and_extract_local_processing(
        prepare, grappler_config)
Beispiel #3
0
def _extract_compute_server_context(before_broadcast, grappler_config):
    """Extracts `compute_server_config` from `before_broadcast`."""
    server_data_index_in_before_broadcast = 0
    compute_server_context = _as_function_of_single_subparameter(
        before_broadcast, server_data_index_in_before_broadcast)
    return compiler.consolidate_and_extract_local_processing(
        compute_server_context, grappler_config)
Beispiel #4
0
 def test_reduces_unplaced_lambda_leaving_type_signature_alone(self):
     lam = building_blocks.Lambda('x', tf.int32,
                                  building_blocks.Reference('x', tf.int32))
     extracted_tf = compiler.consolidate_and_extract_local_processing(
         lam, DEFAULT_GRAPPLER_CONFIG)
     self.assertIsInstance(extracted_tf,
                           building_blocks.CompiledComputation)
     self.assertEqual(extracted_tf.type_signature, lam.type_signature)
Beispiel #5
0
def _compile_selected_output_as_tensorflow_function(
        comp: building_blocks.Lambda, path: building_block_factory.Path,
        grappler_config) -> building_blocks.CompiledComputation:
    """Compiles the functional result of `comp` at `path` to TensorFlow."""
    extracted = building_block_factory.select_output_from_lambda(comp,
                                                                 path).result
    return compiler.consolidate_and_extract_local_processing(
        extracted, grappler_config)
Beispiel #6
0
def _compile_selected_output_to_no_argument_tensorflow(
        comp: building_blocks.Lambda, path: building_block_factory.Path,
        grappler_config) -> building_blocks.CompiledComputation:
    """Compiles the independent value result of `comp` at `path` to TensorFlow."""
    extracted = building_block_factory.select_output_from_lambda(comp,
                                                                 path).result
    return compiler.consolidate_and_extract_local_processing(
        building_blocks.Lambda(None, None, extracted), grappler_config)
Beispiel #7
0
 def test_reduces_unplaced_lambda_to_equivalent_tf(self):
     lam = building_blocks.Lambda('x', tf.int32,
                                  building_blocks.Reference('x', tf.int32))
     extracted_tf = compiler.consolidate_and_extract_local_processing(
         lam, DEFAULT_GRAPPLER_CONFIG)
     executable_tf = computation_impl.ConcreteComputation.from_building_block(
         extracted_tf)
     executable_lam = computation_impl.ConcreteComputation.from_building_block(
         lam)
     for k in range(10):
         self.assertEqual(executable_tf(k), executable_lam(k))
Beispiel #8
0
 def test_reduces_federated_identity_to_member_identity(self):
     fed_int_type = computation_types.FederatedType(tf.int32,
                                                    placements.CLIENTS)
     lam = building_blocks.Lambda(
         'x', fed_int_type, building_blocks.Reference('x', fed_int_type))
     extracted_tf = compiler.consolidate_and_extract_local_processing(
         lam, DEFAULT_GRAPPLER_CONFIG)
     self.assertIsInstance(extracted_tf,
                           building_blocks.CompiledComputation)
     unplaced_function_type = computation_types.FunctionType(
         fed_int_type.member, fed_int_type.member)
     self.assertEqual(extracted_tf.type_signature, unplaced_function_type)
Beispiel #9
0
    def test_already_reduced_case(self):
        init = form_utils.get_iterative_process_for_map_reduce_form(
            mapreduce_test_utils.get_temperature_sensor_example()).initialize

        comp = init.to_building_block()

        result = compiler.consolidate_and_extract_local_processing(
            comp, DEFAULT_GRAPPLER_CONFIG)

        self.assertIsInstance(result, building_blocks.CompiledComputation)
        self.assertIsInstance(result.proto, computation_pb2.Computation)
        self.assertEqual(result.proto.WhichOneof('computation'), 'tensorflow')
Beispiel #10
0
 def test_reduces_federated_value_at_clients_to_equivalent_noarg_function(
         self):
     zero = building_block_factory.create_tensorflow_constant(
         computation_types.TensorType(tf.int32, shape=[]), 0)
     federated_value = building_block_factory.create_federated_value(
         zero, placements.CLIENTS)
     federated_value_func = building_blocks.Lambda(None, None,
                                                   federated_value)
     extracted_tf = compiler.consolidate_and_extract_local_processing(
         federated_value_func, DEFAULT_GRAPPLER_CONFIG)
     executable_tf = computation_impl.ConcreteComputation.from_building_block(
         extracted_tf)
     self.assertEqual(executable_tf(), 0)
Beispiel #11
0
def _extract_client_processing(after_broadcast, grappler_config):
    """Extracts `client_processing` from `after_broadcast`."""
    context_from_server_index_in_after_broadcast = (1, )
    client_data_index_in_after_broadcast = (0, 1)
    # NOTE: the order of parameters here is different from `work`.
    # `work` is odd in that it takes its parameters as `(data, params)` rather
    # than `(params, data)` (the order of the iterative process / computation).
    # Here, we use the same `(params, data)` ordering as in the input computation.
    client_processing = _as_function_of_some_federated_subparameters(
        after_broadcast, [
            context_from_server_index_in_after_broadcast,
            client_data_index_in_after_broadcast
        ])
    return compiler.consolidate_and_extract_local_processing(
        client_processing, grappler_config)
Beispiel #12
0
def _extract_work(before_aggregate, grappler_config):
    """Extracts `work` from `before_aggregate`.

  This function is intended to be used by
  `get_map_reduce_form_for_iterative_process` only. As a result, this function
  does not assert that `before_aggregate` has the expected structure, the caller
  is expected to perform these checks before calling this function.

  Args:
    before_aggregate: The first result of splitting `after_broadcast` on
      aggregate intrinsics.
    grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure
      Grappler graph optimization.

  Returns:
    `work` as specified by `forms.MapReduceForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    compiler.MapReduceFormCompilationError: If we extract an AST of the wrong
      type.
  """
    # Indices of `work` args in `before_aggregate` parameter
    client_data_index = ('original_arg', 1)
    broadcast_result_index = ('federated_broadcast_result', )
    work_to_before_aggregate = _as_function_of_some_federated_subparameters(
        before_aggregate, [client_data_index, broadcast_result_index])

    # Indices of `work` results in `before_aggregate` result
    aggregate_input_index = ('federated_aggregate_param', 0)
    secure_sum_bitwidth_input_index = ('federated_secure_sum_bitwidth_param',
                                       0)
    secure_sum_input_index = ('federated_secure_sum_param', 0)
    secure_modular_sum_input_index = ('federated_secure_modular_sum_param', 0)
    work_unzipped = building_block_factory.select_output_from_lambda(
        work_to_before_aggregate, [
            aggregate_input_index,
            secure_sum_bitwidth_input_index,
            secure_sum_input_index,
            secure_modular_sum_input_index,
        ])
    work = building_blocks.Lambda(
        work_unzipped.parameter_name, work_unzipped.parameter_type,
        building_block_factory.create_federated_zip(work_unzipped.result))
    return compiler.consolidate_and_extract_local_processing(
        work, grappler_config)
Beispiel #13
0
 def test_reduces_federated_apply_to_equivalent_function(self):
     lam = building_blocks.Lambda('x', tf.int32,
                                  building_blocks.Reference('x', tf.int32))
     arg_type = computation_types.FederatedType(tf.int32,
                                                placements.CLIENTS)
     arg = building_blocks.Reference('arg', arg_type)
     map_block = building_block_factory.create_federated_map_or_apply(
         lam, arg)
     mapping_fn = building_blocks.Lambda('arg', arg_type, map_block)
     extracted_tf = compiler.consolidate_and_extract_local_processing(
         mapping_fn, DEFAULT_GRAPPLER_CONFIG)
     self.assertIsInstance(extracted_tf,
                           building_blocks.CompiledComputation)
     executable_tf = computation_impl.ConcreteComputation.from_building_block(
         extracted_tf)
     executable_lam = computation_impl.ConcreteComputation.from_building_block(
         lam)
     for k in range(10):
         self.assertEqual(executable_tf(k), executable_lam(k))
Beispiel #14
0
 def test_raises_on_none(self):
     with self.assertRaises(TypeError):
         compiler.consolidate_and_extract_local_processing(
             None, DEFAULT_GRAPPLER_CONFIG)
Beispiel #15
0
def get_map_reduce_form_for_iterative_process(
    ip: iterative_process.IterativeProcess,
    grappler_config: tf.compat.v1.ConfigProto = _GRAPPLER_DEFAULT_CONFIG,
    *,
    tff_internal_preprocessing: Optional[BuildingBlockFn] = None,
) -> forms.MapReduceForm:
    """Constructs `tff.backends.mapreduce.MapReduceForm` given iterative process.

  Args:
    ip: An instance of `tff.templates.IterativeProcess` that is compatible with
      MapReduce form. Iterative processes are only compatible if `initialize_fn`
      returns a single federated value placed at `SERVER` and `next` takes
      exactly two arguments. The first must be the state value placed at
      `SERVER`. - `next` returns exactly two values.
    grappler_config: An optional instance of `tf.compat.v1.ConfigProto` to
      configure Grappler graph optimization of the TensorFlow graphs backing the
      resulting `tff.backends.mapreduce.MapReduceForm`. These options are
      combined with a set of defaults that aggressively configure Grappler. If
      the input `grappler_config` has
      `graph_options.rewrite_options.disable_meta_optimizer=True`, Grappler is
      bypassed.
    tff_internal_preprocessing: An optional function to transform the AST of the
      iterative process.

  Returns:
    An instance of `tff.backends.mapreduce.MapReduceForm` equivalent to the
    provided `tff.templates.IterativeProcess`.

  Raises:
    TypeError: If the arguments are of the wrong types.
    compiler.MapReduceFormCompilationError: If the compilation process fails.
  """
    py_typecheck.check_type(ip, iterative_process.IterativeProcess)
    initialize_bb, next_bb = (
        check_iterative_process_compatible_with_map_reduce_form(
            ip, tff_internal_preprocessing=tff_internal_preprocessing))
    py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto)
    grappler_config = _merge_grappler_config_with_default(grappler_config)

    next_bb, _ = tree_transformations.uniquify_reference_names(next_bb)
    before_broadcast, after_broadcast = _split_ast_on_broadcast(next_bb)
    before_aggregate, after_aggregate = _split_ast_on_aggregate(
        after_broadcast)

    initialize = compiler.consolidate_and_extract_local_processing(
        initialize_bb, grappler_config)
    prepare = _extract_prepare(before_broadcast, grappler_config)
    work = _extract_work(before_aggregate, grappler_config)
    zero, accumulate, merge, report = _extract_federated_aggregate_functions(
        before_aggregate, grappler_config)
    secure_sum_bitwidth = _compile_selected_output_to_no_argument_tensorflow(
        before_aggregate, ('federated_secure_sum_bitwidth_param', 1),
        grappler_config)
    secure_sum_max_input = _compile_selected_output_to_no_argument_tensorflow(
        before_aggregate, ('federated_secure_sum_param', 1), grappler_config)
    secure_sum_modulus = _compile_selected_output_to_no_argument_tensorflow(
        before_aggregate, ('federated_secure_modular_sum_param', 1),
        grappler_config)
    update = _extract_update(after_aggregate, grappler_config)

    next_parameter_names = structure.name_list_with_nones(
        ip.next.type_signature.parameter)
    server_state_label, client_data_label = next_parameter_names
    blocks = (initialize, prepare, work, zero, accumulate, merge, report,
              secure_sum_bitwidth, secure_sum_max_input, secure_sum_modulus,
              update)
    comps = (computation_impl.ConcreteComputation.from_building_block(bb)
             for bb in blocks)
    return forms.MapReduceForm(*comps,
                               server_state_label=server_state_label,
                               client_data_label=client_data_label)
Beispiel #16
0
def _extract_update(after_aggregate, grappler_config):
    """Extracts `update` from `after_aggregate`.

  This function is intended to be used by
  `get_map_reduce_form_for_iterative_process` only. As a result, this function
  does not assert that `after_aggregate` has the expected structure, the
  caller is expected to perform these checks before calling this function.

  Args:
    after_aggregate: The second result of splitting `after_broadcast` on
      aggregate intrinsics.
    grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure
      Grappler graph optimization.

  Returns:
    `update` as specified by `forms.MapReduceForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    compiler.MapReduceFormCompilationError: If we extract an AST of the wrong
      type.
  """
    after_aggregate_zipped = building_blocks.Lambda(
        after_aggregate.parameter_name, after_aggregate.parameter_type,
        building_block_factory.create_federated_zip(after_aggregate.result))
    # `create_federated_zip` doesn't have unique reference names, but we need
    # them for `as_function_of_some_federated_subparameters`.
    after_aggregate_zipped, _ = tree_transformations.uniquify_reference_names(
        after_aggregate_zipped)
    server_state_index = ('original_arg', 'original_arg', 0)
    aggregate_result_index = ('intrinsic_results',
                              'federated_aggregate_result')
    secure_sum_bitwidth_result_index = ('intrinsic_results',
                                        'federated_secure_sum_bitwidth_result')
    secure_sum_result_index = ('intrinsic_results',
                               'federated_secure_sum_result')
    secure_modular_sum_result_index = ('intrinsic_results',
                                       'federated_secure_modular_sum_result')
    update_with_flat_inputs = _as_function_of_some_federated_subparameters(
        after_aggregate_zipped, (
            server_state_index,
            aggregate_result_index,
            secure_sum_bitwidth_result_index,
            secure_sum_result_index,
            secure_modular_sum_result_index,
        ))

    # TODO(b/148942011): The transformation
    # `zip_selection_as_argument_to_lower_level_lambda` does not support selecting
    # from nested structures, therefore we need to transform the input from
    # <server_state, <aggregation_results...>> into
    # <server_state, aggregation_results...>
    # unpack = <v, <...>> -> <v, ...>
    name_generator = building_block_factory.unique_name_generator(
        update_with_flat_inputs)
    unpack_param_name = next(name_generator)
    original_param_type = update_with_flat_inputs.parameter_type.member
    unpack_param_type = computation_types.StructType([
        original_param_type[0],
        computation_types.StructType(original_param_type[1:]),
    ])
    unpack_param_ref = building_blocks.Reference(unpack_param_name,
                                                 unpack_param_type)
    select = lambda bb, i: building_blocks.Selection(bb, index=i)
    unpack = building_blocks.Lambda(
        unpack_param_name, unpack_param_type,
        building_blocks.Struct([select(unpack_param_ref, 0)] + [
            select(select(unpack_param_ref, 1), i)
            for i in range(len(original_param_type) - 1)
        ]))

    # update = v -> update_with_flat_inputs(federated_map(unpack, v))
    param_name = next(name_generator)
    param_type = computation_types.at_server(unpack_param_type)
    param_ref = building_blocks.Reference(param_name, param_type)
    update = building_blocks.Lambda(
        param_name, param_type,
        building_blocks.Call(
            update_with_flat_inputs,
            building_block_factory.create_federated_map_or_apply(
                unpack, param_ref)))
    return compiler.consolidate_and_extract_local_processing(
        update, grappler_config)