コード例 #1
0
 def test_federated_secure_modular_sum(self, value_dtype, modulus_type):
     uri = intrinsic_defs.FEDERATED_SECURE_MODULAR_SUM.uri
     comp = building_blocks.Intrinsic(
         uri,
         computation_types.FunctionType(
             parameter=[
                 computation_types.at_clients(value_dtype),
                 computation_types.to_type(modulus_type)
             ],
             result=computation_types.at_server(value_dtype)))
     # First without secure intrinsics shouldn't modify anything.
     reduced, modified = tree_transformations.replace_intrinsics_with_bodies(
         comp)
     self.assertFalse(modified)
     self.assertGreater(_count_intrinsics(comp, uri), 0)
     type_test_utils.assert_types_identical(comp.type_signature,
                                            reduced.type_signature)
     # Now replace bodies including secure intrinsics.
     reduced, modified = tree_transformations.replace_secure_intrinsics_with_insecure_bodies(
         comp)
     self.assertTrue(modified)
     # Inserting tensorflow, as we do here, does not preserve python containers
     # currently.
     type_test_utils.assert_types_equivalent(comp.type_signature,
                                             reduced.type_signature)
     self.assertGreater(
         _count_intrinsics(reduced, intrinsic_defs.FEDERATED_SUM.uri), 0)
コード例 #2
0
 def test_federated_secure_select(self):
     uri = intrinsic_defs.FEDERATED_SECURE_SELECT.uri
     comp = building_blocks.Intrinsic(
         uri,
         computation_types.FunctionType(
             [
                 computation_types.at_clients(tf.int32),  # client_keys
                 computation_types.at_server(tf.int32),  # max_key
                 computation_types.at_server(tf.float32),  # server_state
                 computation_types.FunctionType([tf.float32, tf.int32],
                                                tf.float32)  # select_fn
             ],
             computation_types.at_clients(
                 computation_types.SequenceType(tf.float32))))
     self.assertGreater(_count_intrinsics(comp, uri), 0)
     # First without secure intrinsics shouldn't modify anything.
     reduced, modified = tree_transformations.replace_intrinsics_with_bodies(
         comp)
     self.assertFalse(modified)
     self.assertGreater(_count_intrinsics(comp, uri), 0)
     type_test_utils.assert_types_identical(comp.type_signature,
                                            reduced.type_signature)
     # Now replace bodies including secure intrinsics.
     reduced, modified = tree_transformations.replace_secure_intrinsics_with_insecure_bodies(
         comp)
     self.assertTrue(modified)
     type_test_utils.assert_types_identical(comp.type_signature,
                                            reduced.type_signature)
     self.assertGreater(
         _count_intrinsics(reduced, intrinsic_defs.FEDERATED_SELECT.uri), 0)
コード例 #3
0
 def test_federated_secure_sum_bitwidth(self, value_dtype, bitwidth_type):
     uri = intrinsic_defs.FEDERATED_SECURE_SUM_BITWIDTH.uri
     comp = building_blocks.Intrinsic(
         uri,
         computation_types.FunctionType(
             parameter=[
                 computation_types.at_clients(value_dtype),
                 computation_types.to_type(bitwidth_type)
             ],
             result=computation_types.at_server(value_dtype)))
     # First without secure intrinsics shouldn't modify anything.
     reduced, modified = tree_transformations.replace_intrinsics_with_bodies(
         comp)
     self.assertFalse(modified)
     self.assertGreater(_count_intrinsics(comp, uri), 0)
     type_test_utils.assert_types_identical(comp.type_signature,
                                            reduced.type_signature)
     # Now replace bodies including secure intrinsics.
     reduced, modified = tree_transformations.replace_secure_intrinsics_with_insecure_bodies(
         comp)
     self.assertTrue(modified)
     type_test_utils.assert_types_identical(comp.type_signature,
                                            reduced.type_signature)
     self.assertGreater(
         _count_intrinsics(reduced, intrinsic_defs.FEDERATED_AGGREGATE.uri),
         0)
コード例 #4
0
    def test_generic_plus_reduces(self):
        uri = intrinsic_defs.GENERIC_PLUS.uri
        comp = building_blocks.Intrinsic(
            uri,
            computation_types.FunctionType([tf.float32, tf.float32],
                                           tf.float32))

        count_before_reduction = _count_intrinsics(comp, uri)
        reduced, modified = tree_transformations.replace_intrinsics_with_bodies(
            comp)
        count_after_reduction = _count_intrinsics(reduced, uri)

        self.assertTrue(modified)
        type_test_utils.assert_types_identical(comp.type_signature,
                                               reduced.type_signature)
        self.assertGreater(count_before_reduction, 0)
        self.assertEqual(count_after_reduction, 0)
        tree_analysis.check_contains_only_reducible_intrinsics(reduced)
コード例 #5
0
    def test_federated_sum_reduces_to_aggregate(self):
        uri = intrinsic_defs.FEDERATED_SUM.uri

        comp = building_blocks.Intrinsic(
            uri,
            computation_types.FunctionType(
                computation_types.at_clients(tf.float32),
                computation_types.at_server(tf.float32)))

        count_sum_before_reduction = _count_intrinsics(comp, uri)
        reduced, modified = tree_transformations.replace_intrinsics_with_bodies(
            comp)
        count_sum_after_reduction = _count_intrinsics(reduced, uri)
        count_aggregations = _count_intrinsics(
            reduced, intrinsic_defs.FEDERATED_AGGREGATE.uri)
        self.assertTrue(modified)
        type_test_utils.assert_types_identical(comp.type_signature,
                                               reduced.type_signature)
        self.assertGreater(count_sum_before_reduction, 0)
        self.assertEqual(count_sum_after_reduction, 0)
        self.assertGreater(count_aggregations, 0)
コード例 #6
0
def desugar_and_transform_to_native(comp):
    """Transform to native form and replace intrinsics with TensorFlow."""
    # Turn on static grappler. The function inlining is critical for GPU support,
    # otherwise variant placeholders that received datasets will be placed on GPUs
    # which don't have kernels for datastes, causing TF to error.
    grappler_config = tf.compat.v1.ConfigProto()
    aggressive = grappler_config.graph_options.rewrite_options.AGGRESSIVE
    rewrite_options = grappler_config.graph_options.rewrite_options
    rewrite_options.memory_optimization = aggressive
    rewrite_options.constant_folding = aggressive
    rewrite_options.arithmetic_optimization = aggressive
    rewrite_options.loop_optimization = aggressive
    rewrite_options.function_optimization = aggressive

    intrinsics_desugared_bb, _ = tree_transformations.replace_intrinsics_with_bodies(
        comp.to_building_block())
    # Desugaring intrinsics injects TF computations; transforming to native form
    # adds TF cache IDs to them. It is crucial that these transformations execute
    # in this order.
    native_form = transform_to_native_form(
        computation_impl.ConcreteComputation.from_building_block(
            intrinsics_desugared_bb),
        grappler_config=grappler_config)
    return native_form
コード例 #7
0
def get_broadcast_form_for_computation(
    comp: computation_base.Computation,
    grappler_config: tf.compat.v1.ConfigProto = _GRAPPLER_DEFAULT_CONFIG,
    *,
    tff_internal_preprocessing: Optional[BuildingBlockFn] = None,
) -> forms.BroadcastForm:
    """Constructs `tff.backends.mapreduce.BroadcastForm` given a computation.

  Args:
    comp: An instance of `tff.Computation` that is compatible with broadcast
      form. Computations are only compatible if they take in a single value
      placed at server, return a single value placed at clients, and do not
      contain any aggregations.
    grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure
      Grappler graph optimization of the Tensorflow graphs backing the resulting
      `tff.backends.mapreduce.BroadcastForm`. These options are combined with a
      set of defaults that aggressively configure Grappler. If
      `grappler_config_proto` has
      `graph_options.rewrite_options.disable_meta_optimizer=True`, Grappler is
      bypassed.
    tff_internal_preprocessing: An optional function to transform the AST of the
      computation.

  Returns:
    An instance of `tff.backends.mapreduce.BroadcastForm` equivalent to the
    provided `tff.Computation`.
  """
    py_typecheck.check_type(comp, computation_base.Computation)
    _check_function_signature_compatible_with_broadcast_form(
        comp.type_signature)
    py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto)
    grappler_config = _merge_grappler_config_with_default(grappler_config)

    bb = comp.to_building_block()
    if tff_internal_preprocessing:
        bb = tff_internal_preprocessing(bb)
    bb, _ = tree_transformations.replace_intrinsics_with_bodies(bb)
    bb = _replace_lambda_body_with_call_dominant_form(bb)

    tree_analysis.check_contains_only_reducible_intrinsics(bb)
    aggregations = tree_analysis.find_aggregations_in_tree(bb)
    if aggregations:
        raise ValueError(
            f'`get_broadcast_form_for_computation` called with computation '
            f'containing {len(aggregations)} aggregations, but broadcast form '
            'does not allow aggregation. Full list of aggregations:\n{aggregations}'
        )

    before_broadcast, after_broadcast = _split_ast_on_broadcast(bb)
    compute_server_context = _extract_compute_server_context(
        before_broadcast, grappler_config)
    client_processing = _extract_client_processing(after_broadcast,
                                                   grappler_config)

    compute_server_context, client_processing = (
        computation_impl.ConcreteComputation.from_building_block(bb)
        for bb in (compute_server_context, client_processing))

    comp_param_names = structure.name_list_with_nones(
        comp.type_signature.parameter)
    server_data_label, client_data_label = comp_param_names
    return forms.BroadcastForm(compute_server_context,
                               client_processing,
                               server_data_label=server_data_label,
                               client_data_label=client_data_label)
コード例 #8
0
def check_iterative_process_compatible_with_map_reduce_form(
    ip: iterative_process.IterativeProcess,
    *,
    tff_internal_preprocessing: Optional[BuildingBlockFn] = None,
) -> Tuple[building_blocks.ComputationBuildingBlock,
           building_blocks.ComputationBuildingBlock]:
    """Tests compatibility with `tff.backends.mapreduce.MapReduceForm`.

  Note: the conditions here are specified in the documentation for
    `get_map_reduce_form_for_iterative_process`. Changes to this function should
    be propagated to that documentation.

  Args:
    ip: An instance of `tff.templates.IterativeProcess` to check for
      compatibility with `tff.backends.mapreduce.MapReduceForm`.
    tff_internal_preprocessing: An optional function to transform the AST of the
      computation.

  Returns:
    TFF-internal building-blocks representing the validated and simplified
    `initialize` and `next` computations.

  Raises:
    TypeError: If the arguments are of the wrong types.
  """
    py_typecheck.check_type(ip, iterative_process.IterativeProcess)
    initialize_tree = ip.initialize.to_building_block()
    next_tree = ip.next.to_building_block()
    if tff_internal_preprocessing:
        initialize_tree = tff_internal_preprocessing(initialize_tree)
        next_tree = tff_internal_preprocessing(next_tree)

    init_type = initialize_tree.type_signature
    _check_type_is_no_arg_fn(init_type, '`initialize`', TypeError)
    if (not init_type.result.is_federated()
            or init_type.result.placement != placements.SERVER):
        raise TypeError(
            'Expected `initialize` to return a single federated value '
            'placed at server (type `T@SERVER`), found return type:\n'
            f'{init_type.result}')

    next_type = next_tree.type_signature
    _check_type_is_fn(next_type, '`next`', TypeError)
    if not next_type.parameter.is_struct() or len(next_type.parameter) != 2:
        raise TypeError(
            'Expected `next` to take two arguments, found parameter '
            f' type:\n{next_type.parameter}')
    if not next_type.result.is_struct() or len(next_type.result) != 2:
        raise TypeError('Expected `next` to return two values, found result '
                        f'type:\n{next_type.result}')

    initialize_tree, _ = tree_transformations.replace_intrinsics_with_bodies(
        initialize_tree)
    next_tree, _ = tree_transformations.replace_intrinsics_with_bodies(
        next_tree)
    next_tree = _replace_lambda_body_with_call_dominant_form(next_tree)

    tree_analysis.check_contains_only_reducible_intrinsics(initialize_tree)
    tree_analysis.check_contains_only_reducible_intrinsics(next_tree)
    tree_analysis.check_broadcast_not_dependent_on_aggregate(next_tree)

    return initialize_tree, next_tree
コード例 #9
0
def compile_to_mergeable_comp_form(
    comp: computation_impl.ConcreteComputation
) -> mergeable_comp_execution_context.MergeableCompForm:
    """Compiles a computation with a single aggregation to `MergeableCompForm`.

  Compilation proceeds by splitting on the lone aggregation, and using the
  aggregation's internal functions to generate a semantically equivalent
  instance of `mergeable_comp_execution_context.MergeableCompForm`.

  Args:
    comp: Instance of `computation_impl.ConcreteComputation` to compile. Assumed
      to be representable as a computation with a single aggregation in its
      body, so that for example two parallel aggregations are allowed, but
      multiple dependent aggregations are disallowed. Additionally assumed to be
      of functional type.

  Returns:
    A semantically equivalent instance of
    `mergeable_comp_execution_context.MergeableCompForm`.

  Raises:
    TypeError: If `comp` is not a building block, or is not of functional TFF
      type.
    ValueError: If `comp` cannot be represented as a computation with at most
    one aggregation in its body.
  """
    original_return_type = comp.type_signature.result
    building_block = comp.to_building_block()
    lam = _ensure_lambda(building_block)
    lowered_bb, _ = tree_transformations.replace_intrinsics_with_bodies(lam)

    # We transform the body of this computation to easily preserve the top-level
    # lambda required by force-aligning.
    call_dominant_body_bb = transformations.to_call_dominant(lowered_bb.result)
    call_dominant_bb = building_blocks.Lambda(lowered_bb.parameter_name,
                                              lowered_bb.parameter_type,
                                              call_dominant_body_bb)

    # This check should not throw false positives because we just ensured we are
    # in call-dominant form.
    tree_analysis.check_aggregate_not_dependent_on_aggregate(call_dominant_bb)

    before_agg, after_agg = transformations.force_align_and_split_by_intrinsics(
        call_dominant_bb,
        [building_block_factory.create_null_federated_aggregate()])

    # Construct a report function which accepts the result of merge.
    merge_fn_type = before_agg.type_signature.result[
        'federated_aggregate_param'][3]
    identity_report = computation_impl.ConcreteComputation.from_building_block(
        building_block_factory.create_compiled_identity(merge_fn_type.result))

    zero_comp, accumulate_comp, merge_comp, report_comp = _extract_federated_aggregate_computations(
        before_agg)

    before_agg_callable = computation_impl.ConcreteComputation.from_building_block(
        before_agg)
    after_agg_callable = computation_impl.ConcreteComputation.from_building_block(
        after_agg)

    if before_agg.type_signature.parameter is not None:
        # TODO(b/147499373): If None-arguments were uniformly represented as empty
        # tuples, we would be able to avoid this (and related) ugly casing.

        @federated_computation.federated_computation(
            before_agg.type_signature.parameter)
        def up_to_merge_computation(arg):
            federated_aggregate_args = before_agg_callable(
                arg)['federated_aggregate_param']
            value_to_aggregate = federated_aggregate_args[0]
            zero = zero_comp()
            return intrinsics.federated_aggregate(value_to_aggregate, zero,
                                                  accumulate_comp, merge_comp,
                                                  identity_report)

        @federated_computation.federated_computation(
            before_agg.type_signature.parameter,
            computation_types.at_server(identity_report.type_signature.result))
        def after_merge_computation(top_level_arg, merge_result):
            reported_result = intrinsics.federated_map(report_comp,
                                                       merge_result)
            return after_agg_callable(top_level_arg, [reported_result])

    else:

        @federated_computation.federated_computation()
        def up_to_merge_computation():
            federated_aggregate_args = before_agg_callable(
            )['federated_aggregate_param']
            value_to_aggregate = federated_aggregate_args[0]
            zero = zero_comp()
            return intrinsics.federated_aggregate(value_to_aggregate, zero,
                                                  accumulate_comp, merge_comp,
                                                  identity_report)

        @federated_computation.federated_computation(
            computation_types.at_server(identity_report.type_signature.result))
        def after_merge_computation(merge_result):
            reported_result = intrinsics.federated_map(report_comp,
                                                       merge_result)
            return after_agg_callable([[reported_result]])

    annotated_type_signature = computation_types.FunctionType(
        after_merge_computation.type_signature.parameter, original_return_type)
    after_merge_computation = computation_impl.ConcreteComputation.with_type(
        after_merge_computation, annotated_type_signature)

    return mergeable_comp_execution_context.MergeableCompForm(
        up_to_merge=up_to_merge_computation,
        merge=merge_comp,
        after_merge=after_merge_computation)
コード例 #10
0
 def test_raises_on_none(self):
     with self.assertRaises(TypeError):
         tree_transformations.replace_intrinsics_with_bodies(None)