Exemplo n.º 1
0
def check_iterative_process_compatible_with_map_reduce_form(
        ip: iterative_process.IterativeProcess):
    """Tests compatibility with `tff.backends.mapreduce.MapReduceForm`.

  Note: the conditions here are specified in the documentation for
    `get_map_reduce_form_for_iterative_process`. Changes to this function should
    be propagated to that documentation.

  Args:
    ip: An instance of `tff.templates.IterativeProcess` to check for
      compatibility with `tff.backends.mapreduce.MapReduceForm`.

  Returns:
    TFF-internal building-blocks representing the validated and simplified
    `initialize` and `next` computations.

  Raises:
    TypeError: If the arguments are of the wrong types.
  """
    py_typecheck.check_type(ip, iterative_process.IterativeProcess)
    initialize_tree = ip.initialize.to_building_block()
    next_tree = ip.next.to_building_block()

    init_type = initialize_tree.type_signature
    _check_type_is_no_arg_fn(init_type, '`initialize`', TypeError)
    if (not init_type.result.is_federated()
            or init_type.result.placement != placements.SERVER):
        raise TypeError(
            'Expected `initialize` to return a single federated value '
            'placed at server (type `T@SERVER`), found return type:\n'
            f'{init_type.result}')

    next_type = next_tree.type_signature
    _check_type_is_fn(next_type, '`next`', TypeError)
    if not next_type.parameter.is_struct() or len(next_type.parameter) != 2:
        raise TypeError(
            'Expected `next` to take two arguments, found parameter '
            f' type:\n{next_type.parameter}')
    if not next_type.result.is_struct() or len(next_type.result) != 2:
        raise TypeError('Expected `next` to return two values, found result '
                        f'type:\n{next_type.result}')

    initialize_tree, _ = intrinsic_reductions.replace_intrinsics_with_bodies(
        initialize_tree)
    next_tree, _ = intrinsic_reductions.replace_intrinsics_with_bodies(
        next_tree)
    next_tree = _replace_lambda_body_with_call_dominant_form(next_tree)

    tree_analysis.check_contains_only_reducible_intrinsics(initialize_tree)
    tree_analysis.check_contains_only_reducible_intrinsics(next_tree)
    tree_analysis.check_broadcast_not_dependent_on_aggregate(next_tree)

    return initialize_tree, next_tree
Exemplo n.º 2
0
 def test_federated_secure_select(self):
     uri = intrinsic_defs.FEDERATED_SECURE_SELECT.uri
     comp = building_blocks.Intrinsic(
         uri,
         computation_types.FunctionType(
             [
                 computation_types.at_clients(tf.int32),  # client_keys
                 computation_types.at_server(tf.int32),  # max_key
                 computation_types.at_server(tf.float32),  # server_state
                 computation_types.FunctionType([tf.float32, tf.int32],
                                                tf.float32)  # select_fn
             ],
             computation_types.at_clients(
                 computation_types.SequenceType(tf.float32))))
     self.assertGreater(_count_intrinsics(comp, uri), 0)
     # First without secure intrinsics shouldn't modify anything.
     reduced, modified = intrinsic_reductions.replace_intrinsics_with_bodies(
         comp)
     self.assertFalse(modified)
     self.assertGreater(_count_intrinsics(comp, uri), 0)
     self.assert_types_identical(comp.type_signature,
                                 reduced.type_signature)
     # Now replace bodies including secure intrinsics.
     reduced, modified = intrinsic_reductions.replace_secure_intrinsics_with_insecure_bodies(
         comp)
     self.assertTrue(modified)
     self.assert_types_identical(comp.type_signature,
                                 reduced.type_signature)
     self.assertGreater(
         _count_intrinsics(reduced, intrinsic_defs.FEDERATED_SELECT.uri), 0)
Exemplo n.º 3
0
 def test_federated_secure_sum_bitwidth(self, value_dtype, bitwidth_type):
     uri = intrinsic_defs.FEDERATED_SECURE_SUM_BITWIDTH.uri
     comp = building_blocks.Intrinsic(
         uri,
         computation_types.FunctionType(
             parameter=[
                 computation_types.at_clients(value_dtype),
                 computation_types.to_type(bitwidth_type)
             ],
             result=computation_types.at_server(value_dtype)))
     # First without secure intrinsics shouldn't modify anything.
     reduced, modified = intrinsic_reductions.replace_intrinsics_with_bodies(
         comp)
     self.assertFalse(modified)
     self.assertGreater(_count_intrinsics(comp, uri), 0)
     self.assert_types_identical(comp.type_signature,
                                 reduced.type_signature)
     # Now replace bodies including secure intrinsics.
     reduced, modified = intrinsic_reductions.replace_secure_intrinsics_with_insecure_bodies(
         comp)
     self.assertTrue(modified)
     self.assert_types_identical(comp.type_signature,
                                 reduced.type_signature)
     self.assertGreater(
         _count_intrinsics(reduced, intrinsic_defs.FEDERATED_AGGREGATE.uri),
         0)
Exemplo n.º 4
0
    def test_federated_mean_reduces_to_aggregate(self):
        uri = intrinsic_defs.FEDERATED_MEAN.uri

        @computations.federated_computation(
            computation_types.FederatedType(tf.float32,
                                            placement_literals.CLIENTS))
        def foo(x):
            return intrinsics.federated_mean(x)

        foo_building_block = building_blocks.ComputationBuildingBlock.from_proto(
            foo._computation_proto)
        count_means_before_reduction = _count_intrinsics(
            foo_building_block, uri)
        reduced, modified = intrinsic_reductions.replace_intrinsics_with_bodies(
            foo_building_block)
        count_means_after_reduction = _count_intrinsics(reduced, uri)
        count_aggregations = _count_intrinsics(
            reduced, intrinsic_defs.FEDERATED_AGGREGATE.uri)
        count_means_after_reduction = _count_intrinsics(reduced, uri)
        self.assertTrue(modified)
        self.assert_types_identical(foo_building_block.type_signature,
                                    reduced.type_signature)
        self.assertGreater(count_means_before_reduction, 0)
        self.assertEqual(count_means_after_reduction, 0)
        self.assertGreater(count_aggregations, 0)
Exemplo n.º 5
0
    def test_federated_reduce_not_reduced(self):
        # It is in general unsafe to attempt to decompose a federated_reduce into a
        # federated_aggregate, as not every reduction function can be decomposed
        # into accumulate and merge calls without writing fundamentally new logic,
        # e.g. 'return 1 if both arguments are 0'.
        uri = intrinsic_defs.FEDERATED_REDUCE.uri

        @computations.tf_computation(tf.float32, tf.float32)
        def add(x, y):
            return x + y

        @computations.federated_computation(
            computation_types.FederatedType(tf.float32,
                                            placement_literals.CLIENTS))
        def foo(x):
            return intrinsics.federated_reduce(x, 0., add)

        foo_building_block = building_blocks.ComputationBuildingBlock.from_proto(
            foo._computation_proto)

        count_reduce_before_reduction = _count_intrinsics(
            foo_building_block, uri)
        reduced, _ = intrinsic_reductions.replace_intrinsics_with_bodies(
            foo_building_block)
        self.assert_types_identical(foo_building_block.type_signature,
                                    reduced.type_signature)
        count_reduce_after_reduction = _count_intrinsics(reduced, uri)
        self.assertGreater(count_reduce_before_reduction, 0)
        self.assertEqual(count_reduce_after_reduction,
                         count_reduce_before_reduction)
Exemplo n.º 6
0
def get_broadcast_form_for_computation(
    comp: computation_base.Computation,
    grappler_config: tf.compat.v1.ConfigProto = _GRAPPLER_DEFAULT_CONFIG
) -> forms.BroadcastForm:
    """Constructs `tff.backends.mapreduce.BroadcastForm` given a computation.

  Args:
    comp: An instance of `tff.Computation` that is compatible with broadcast
      form. Computations are only compatible if they take in a single value
      placed at server, return a single value placed at clients, and do not
      contain any aggregations.
    grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure
      Grappler graph optimization of the Tensorflow graphs backing the resulting
      `tff.backends.mapreduce.BroadcastForm`. These options are combined with a
      set of defaults that aggressively configure Grappler. If
      `grappler_config_proto` has
      `graph_options.rewrite_options.disable_meta_optimizer=True`, Grappler is
      bypassed.

  Returns:
    An instance of `tff.backends.mapreduce.BroadcastForm` equivalent to the
    provided `tff.Computation`.
  """
    py_typecheck.check_type(comp, computation_base.Computation)
    _check_function_signature_compatible_with_broadcast_form(
        comp.type_signature)
    py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto)
    grappler_config = _merge_grappler_config_with_default(grappler_config)

    bb = comp.to_building_block()
    bb, _ = intrinsic_reductions.replace_intrinsics_with_bodies(bb)
    bb = _replace_lambda_body_with_call_dominant_form(bb)

    tree_analysis.check_contains_only_reducible_intrinsics(bb)
    aggregations = tree_analysis.find_aggregations_in_tree(bb)
    if aggregations:
        raise ValueError(
            f'`get_broadcast_form_for_computation` called with computation '
            f'containing {len(aggregations)} aggregations, but broadcast form '
            'does not allow aggregation. Full list of aggregations:\n{aggregations}'
        )

    before_broadcast, after_broadcast = _split_ast_on_broadcast(bb)
    compute_server_context = _extract_compute_server_context(
        before_broadcast, grappler_config)
    client_processing = _extract_client_processing(after_broadcast,
                                                   grappler_config)

    compute_server_context, client_processing = (
        computation_wrapper_instances.building_block_to_computation(bb)
        for bb in (compute_server_context, client_processing))

    comp_param_names = structure.name_list_with_nones(
        comp.type_signature.parameter)
    server_data_label, client_data_label = comp_param_names
    return forms.BroadcastForm(compute_server_context,
                               client_processing,
                               server_data_label=server_data_label,
                               client_data_label=client_data_label)
Exemplo n.º 7
0
  def test_generic_plus_reduces(self):
    uri = intrinsic_defs.GENERIC_PLUS.uri
    comp = building_blocks.Intrinsic(
        uri, computation_types.FunctionType([tf.float32, tf.float32],
                                            tf.float32))

    count_before_reduction = _count_intrinsics(comp, uri)
    reduced, modified = intrinsic_reductions.replace_intrinsics_with_bodies(
        comp)
    count_after_reduction = _count_intrinsics(reduced, uri)

    self.assertTrue(modified)
    self.assert_types_identical(comp.type_signature, reduced.type_signature)
    self.assertGreater(count_before_reduction, 0)
    self.assertEqual(count_after_reduction, 0)
    tree_analysis.check_contains_only_reducible_intrinsics(reduced)
Exemplo n.º 8
0
  def test_federated_sum_reduces_to_aggregate(self):
    uri = intrinsic_defs.FEDERATED_SUM.uri

    comp = building_blocks.Intrinsic(
        uri,
        computation_types.FunctionType(
            computation_types.at_clients(tf.float32),
            computation_types.at_server(tf.float32)))

    count_sum_before_reduction = _count_intrinsics(comp, uri)
    reduced, modified = intrinsic_reductions.replace_intrinsics_with_bodies(
        comp)
    count_sum_after_reduction = _count_intrinsics(reduced, uri)
    count_aggregations = _count_intrinsics(
        reduced, intrinsic_defs.FEDERATED_AGGREGATE.uri)
    self.assertTrue(modified)
    self.assert_types_identical(comp.type_signature, reduced.type_signature)
    self.assertGreater(count_sum_before_reduction, 0)
    self.assertEqual(count_sum_after_reduction, 0)
    self.assertGreater(count_aggregations, 0)
Exemplo n.º 9
0
 def test_raises_on_none(self):
     with self.assertRaises(TypeError):
         intrinsic_reductions.replace_intrinsics_with_bodies(None)