コード例 #1
0
def get_broadcast_form_for_computation(
    comp: computation_base.Computation,
    grappler_config: tf.compat.v1.ConfigProto = _GRAPPLER_DEFAULT_CONFIG
) -> forms.BroadcastForm:
    """Constructs `tff.backends.mapreduce.BroadcastForm` given a computation.

  Args:
    comp: An instance of `tff.Computation` that is compatible with broadcast
      form. Computations are only compatible if they take in a single value
      placed at server, return a single value placed at clients, and do not
      contain any aggregations.
    grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure
      Grappler graph optimization of the Tensorflow graphs backing the resulting
      `tff.backends.mapreduce.BroadcastForm`. These options are combined with a
      set of defaults that aggressively configure Grappler. If
      `grappler_config_proto` has
      `graph_options.rewrite_options.disable_meta_optimizer=True`, Grappler is
      bypassed.

  Returns:
    An instance of `tff.backends.mapreduce.BroadcastForm` equivalent to the
    provided `tff.Computation`.
  """
    py_typecheck.check_type(comp, computation_base.Computation)
    _check_function_signature_compatible_with_broadcast_form(
        comp.type_signature)
    py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto)
    grappler_config = _merge_grappler_config_with_default(grappler_config)

    bb = comp.to_building_block()
    bb, _ = intrinsic_reductions.replace_intrinsics_with_bodies(bb)
    bb = _replace_lambda_body_with_call_dominant_form(bb)

    tree_analysis.check_contains_only_reducible_intrinsics(bb)
    aggregations = tree_analysis.find_aggregations_in_tree(bb)
    if aggregations:
        raise ValueError(
            f'`get_broadcast_form_for_computation` called with computation '
            f'containing {len(aggregations)} aggregations, but broadcast form '
            'does not allow aggregation. Full list of aggregations:\n{aggregations}'
        )

    before_broadcast, after_broadcast = _split_ast_on_broadcast(bb)
    compute_server_context = _extract_compute_server_context(
        before_broadcast, grappler_config)
    client_processing = _extract_client_processing(after_broadcast,
                                                   grappler_config)

    compute_server_context, client_processing = (
        computation_wrapper_instances.building_block_to_computation(bb)
        for bb in (compute_server_context, client_processing))

    comp_param_names = structure.name_list_with_nones(
        comp.type_signature.parameter)
    server_data_label, client_data_label = comp_param_names
    return forms.BroadcastForm(compute_server_context,
                               client_processing,
                               server_data_label=server_data_label,
                               client_data_label=client_data_label)
コード例 #2
0
 def test_passes_with_federated_map(self):
   intrinsic = building_blocks.Intrinsic(
       intrinsic_defs.FEDERATED_MAP.uri,
       computation_types.FunctionType([
           computation_types.FunctionType(tf.int32, tf.float32),
           computation_types.FederatedType(tf.int32, placements.CLIENTS)
       ], computation_types.FederatedType(tf.float32, placements.CLIENTS)))
   tree_analysis.check_contains_only_reducible_intrinsics(intrinsic)
コード例 #3
0
  def test_raises_with_federated_mean(self):
    intrinsic = building_blocks.Intrinsic(
        intrinsic_defs.FEDERATED_MEAN.uri,
        computation_types.FunctionType(
            computation_types.FederatedType(tf.int32, placements.CLIENTS),
            computation_types.FederatedType(tf.int32, placements.SERVER)))

    with self.assertRaisesRegex(ValueError, intrinsic.compact_representation()):
      tree_analysis.check_contains_only_reducible_intrinsics(intrinsic)
コード例 #4
0
def check_iterative_process_compatible_with_map_reduce_form(
        ip: iterative_process.IterativeProcess):
    """Tests compatibility with `tff.backends.mapreduce.MapReduceForm`.

  Note: the conditions here are specified in the documentation for
    `get_map_reduce_form_for_iterative_process`. Changes to this function should
    be propagated to that documentation.

  Args:
    ip: An instance of `tff.templates.IterativeProcess` to check for
      compatibility with `tff.backends.mapreduce.MapReduceForm`.

  Returns:
    TFF-internal building-blocks representing the validated and simplified
    `initialize` and `next` computations.

  Raises:
    TypeError: If the arguments are of the wrong types.
  """
    py_typecheck.check_type(ip, iterative_process.IterativeProcess)
    initialize_tree = ip.initialize.to_building_block()
    next_tree = ip.next.to_building_block()

    init_type = initialize_tree.type_signature
    _check_type_is_no_arg_fn(init_type, '`initialize`', TypeError)
    if (not init_type.result.is_federated()
            or init_type.result.placement != placements.SERVER):
        raise TypeError(
            'Expected `initialize` to return a single federated value '
            'placed at server (type `T@SERVER`), found return type:\n'
            f'{init_type.result}')

    next_type = next_tree.type_signature
    _check_type_is_fn(next_type, '`next`', TypeError)
    if not next_type.parameter.is_struct() or len(next_type.parameter) != 2:
        raise TypeError(
            'Expected `next` to take two arguments, found parameter '
            f' type:\n{next_type.parameter}')
    if not next_type.result.is_struct() or len(next_type.result) != 2:
        raise TypeError('Expected `next` to return two values, found result '
                        f'type:\n{next_type.result}')

    initialize_tree, _ = intrinsic_reductions.replace_intrinsics_with_bodies(
        initialize_tree)
    next_tree, _ = intrinsic_reductions.replace_intrinsics_with_bodies(
        next_tree)
    next_tree = _replace_lambda_body_with_call_dominant_form(next_tree)

    tree_analysis.check_contains_only_reducible_intrinsics(initialize_tree)
    tree_analysis.check_contains_only_reducible_intrinsics(next_tree)
    tree_analysis.check_broadcast_not_dependent_on_aggregate(next_tree)

    return initialize_tree, next_tree
コード例 #5
0
  def test_generic_plus_reduces(self):
    uri = intrinsic_defs.GENERIC_PLUS.uri
    comp = building_blocks.Intrinsic(
        uri, computation_types.FunctionType([tf.float32, tf.float32],
                                            tf.float32))

    count_before_reduction = _count_intrinsics(comp, uri)
    reduced, modified = intrinsic_reductions.replace_intrinsics_with_bodies(
        comp)
    count_after_reduction = _count_intrinsics(reduced, uri)

    self.assertTrue(modified)
    self.assert_types_identical(comp.type_signature, reduced.type_signature)
    self.assertGreater(count_before_reduction, 0)
    self.assertEqual(count_after_reduction, 0)
    tree_analysis.check_contains_only_reducible_intrinsics(reduced)
コード例 #6
0
    def test_generic_divide_reduces(self):
        uri = intrinsic_defs.GENERIC_DIVIDE.uri
        context_stack = context_stack_impl.context_stack
        comp = building_blocks.Intrinsic(
            uri,
            computation_types.FunctionType([tf.float32, tf.float32],
                                           tf.float32))

        count_before_reduction = _count_intrinsics(comp, uri)
        reduced, modified = value_transformations.replace_intrinsics_with_bodies(
            comp, context_stack)
        count_after_reduction = _count_intrinsics(reduced, uri)

        self.assertGreater(count_before_reduction, 0)
        self.assertEqual(count_after_reduction, 0)
        tree_analysis.check_contains_only_reducible_intrinsics(reduced)
        self.assertTrue(modified)
コード例 #7
0
 def test_raises_on_none(self):
     with self.assertRaises(TypeError):
         tree_analysis.check_contains_only_reducible_intrinsics(None)
コード例 #8
0
def get_canonical_form_for_iterative_process(ip):
  """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process.

  This function transforms computations from the input `ip` into
  an instance of `tff.backends.mapreduce.CanonicalForm`.

  Args:
    ip: An instance of `tff.templates.IterativeProcess`.

  Returns:
    An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this
    process.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.CanonicalFormCompilationError: If the compilation
      process fails.
  """
  py_typecheck.check_type(ip, iterative_process.IterativeProcess)

  initialize_comp = building_blocks.ComputationBuildingBlock.from_proto(
      ip.initialize._computation_proto)  # pylint: disable=protected-access
  next_comp = building_blocks.ComputationBuildingBlock.from_proto(
      ip.next._computation_proto)  # pylint: disable=protected-access
  _check_iterative_process_compatible_with_canonical_form(
      initialize_comp, next_comp)

  initialize_comp = _replace_intrinsics_with_bodies(initialize_comp)
  next_comp = _replace_intrinsics_with_bodies(next_comp)
  tree_analysis.check_contains_only_reducible_intrinsics(initialize_comp)
  tree_analysis.check_contains_only_reducible_intrinsics(next_comp)
  tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp)

  if tree_analysis.contains_called_intrinsic(
      next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri):
    before_broadcast, after_broadcast = (
        transformations.force_align_and_split_by_intrinsics(
            next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri]))
  else:
    before_broadcast, after_broadcast = (
        _create_before_and_after_broadcast_for_no_broadcast(next_comp))

  contains_federated_aggregate = tree_analysis.contains_called_intrinsic(
      next_comp, intrinsic_defs.FEDERATED_AGGREGATE.uri)
  contains_federated_secure_sum = tree_analysis.contains_called_intrinsic(
      next_comp, intrinsic_defs.FEDERATED_SECURE_SUM.uri)
  if not (contains_federated_aggregate or contains_federated_secure_sum):
    raise ValueError(
        'Expected an `tff.templates.IterativeProcess` containing at least one '
        '`federated_aggregate` or `federated_secure_sum`, found none.')

  if contains_federated_aggregate and contains_federated_secure_sum:
    before_aggregate, after_aggregate = (
        transformations.force_align_and_split_by_intrinsics(
            after_broadcast, [
                intrinsic_defs.FEDERATED_AGGREGATE.uri,
                intrinsic_defs.FEDERATED_SECURE_SUM.uri,
            ]))
  elif contains_federated_secure_sum:
    assert not contains_federated_aggregate
    before_aggregate, after_aggregate = (
        _create_before_and_after_aggregate_for_no_federated_aggregate(
            after_broadcast))
  else:
    assert contains_federated_aggregate and not contains_federated_secure_sum
    before_aggregate, after_aggregate = (
        _create_before_and_after_aggregate_for_no_federated_secure_sum(
            after_broadcast))

  type_info = _get_type_info(initialize_comp, before_broadcast, after_broadcast,
                             before_aggregate, after_aggregate)

  initialize = transformations.consolidate_and_extract_local_processing(
      initialize_comp)
  _check_type_equal(initialize.type_signature, type_info['initialize_type'])

  prepare = _extract_prepare(before_broadcast)
  _check_type_equal(prepare.type_signature, type_info['prepare_type'])

  work = _extract_work(before_aggregate)
  _check_type_equal(work.type_signature, type_info['work_type'])

  zero, accumulate, merge, report = _extract_federated_aggregate_functions(
      before_aggregate)
  _check_type_equal(zero.type_signature, type_info['zero_type'])
  _check_type_equal(accumulate.type_signature, type_info['accumulate_type'])
  _check_type_equal(merge.type_signature, type_info['merge_type'])
  _check_type_equal(report.type_signature, type_info['report_type'])

  bitwidth = _extract_federated_secure_sum_functions(before_aggregate)
  _check_type_equal(bitwidth.type_signature, type_info['bitwidth_type'])

  update = _extract_update(after_aggregate)
  _check_type_equal(update.type_signature, type_info['update_type'])

  return canonical_form.CanonicalForm(
      computation_wrapper_instances.building_block_to_computation(initialize),
      computation_wrapper_instances.building_block_to_computation(prepare),
      computation_wrapper_instances.building_block_to_computation(work),
      computation_wrapper_instances.building_block_to_computation(zero),
      computation_wrapper_instances.building_block_to_computation(accumulate),
      computation_wrapper_instances.building_block_to_computation(merge),
      computation_wrapper_instances.building_block_to_computation(report),
      computation_wrapper_instances.building_block_to_computation(bitwidth),
      computation_wrapper_instances.building_block_to_computation(update))
コード例 #9
0
def get_canonical_form_for_iterative_process(
    ip: iterative_process.IterativeProcess,
    grappler_config: Optional[
        tf.compat.v1.ConfigProto] = _GRAPPLER_DEFAULT_CONFIG):
    """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process.

  This function transforms computations from the input `ip` into
  an instance of `tff.backends.mapreduce.CanonicalForm`.

  Args:
    ip: An instance of `tff.templates.IterativeProcess` that is compatible
      with canonical form. Iterative processes are only compatible if: -
        `initialize_fn` returns a single federated value placed at `SERVER`. -
        `next` takes exactly two arguments. The first must be the state value
        placed at `SERVER`. - `next` returns exactly two values.
    grappler_config: An optional instance of `tf.compat.v1.ConfigProto` to
      configure Grappler graph optimization of the TensorFlow graphs backing the
      resulting `tff.backends.mapreduce.CanonicalForm`. These options are
      combined with a set of defaults that aggressively configure Grappler. If
      `None`, Grappler is bypassed.

  Returns:
    An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this
    process.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.CanonicalFormCompilationError: If the compilation
      process fails.
  """
    py_typecheck.check_type(ip, iterative_process.IterativeProcess)
    if grappler_config is not None:
        py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto)
        overridden_grappler_config = tf.compat.v1.ConfigProto()
        overridden_grappler_config.CopyFrom(_GRAPPLER_DEFAULT_CONFIG)
        overridden_grappler_config.MergeFrom(grappler_config)
        grappler_config = overridden_grappler_config

    initialize_comp = building_blocks.ComputationBuildingBlock.from_proto(
        ip.initialize._computation_proto)  # pylint: disable=protected-access
    next_comp = building_blocks.ComputationBuildingBlock.from_proto(
        ip.next._computation_proto)  # pylint: disable=protected-access
    _check_iterative_process_compatible_with_canonical_form(
        initialize_comp, next_comp)

    initialize_comp = _replace_intrinsics_with_bodies(initialize_comp)
    next_comp = _replace_intrinsics_with_bodies(next_comp)
    next_comp = _replace_lambda_body_with_call_dominant_form(next_comp)

    tree_analysis.check_contains_only_reducible_intrinsics(initialize_comp)
    tree_analysis.check_contains_only_reducible_intrinsics(next_comp)
    tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp)

    if tree_analysis.contains_called_intrinsic(
            next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri):

        before_broadcast, after_broadcast = (
            transformations.force_align_and_split_by_intrinsics(
                next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri]))
    else:
        before_broadcast, after_broadcast = (
            _create_before_and_after_broadcast_for_no_broadcast(next_comp))

    contains_federated_aggregate = tree_analysis.contains_called_intrinsic(
        next_comp, intrinsic_defs.FEDERATED_AGGREGATE.uri)
    contains_federated_secure_sum = tree_analysis.contains_called_intrinsic(
        next_comp, intrinsic_defs.FEDERATED_SECURE_SUM.uri)
    if not (contains_federated_aggregate or contains_federated_secure_sum):
        raise ValueError(
            'Expected an `tff.templates.IterativeProcess` containing at least one '
            '`federated_aggregate` or `federated_secure_sum`, found none.')

    if contains_federated_aggregate and contains_federated_secure_sum:
        before_aggregate, after_aggregate = (
            transformations.force_align_and_split_by_intrinsics(
                after_broadcast, [
                    intrinsic_defs.FEDERATED_AGGREGATE.uri,
                    intrinsic_defs.FEDERATED_SECURE_SUM.uri,
                ]))
    elif contains_federated_secure_sum:
        assert not contains_federated_aggregate
        before_aggregate, after_aggregate = (
            _create_before_and_after_aggregate_for_no_federated_aggregate(
                after_broadcast))
    else:
        assert contains_federated_aggregate and not contains_federated_secure_sum
        before_aggregate, after_aggregate = (
            _create_before_and_after_aggregate_for_no_federated_secure_sum(
                after_broadcast))

    type_info = _get_type_info(initialize_comp, before_broadcast,
                               after_broadcast, before_aggregate,
                               after_aggregate)

    initialize = transformations.consolidate_and_extract_local_processing(
        initialize_comp, grappler_config)
    _check_type_equal(initialize.type_signature, type_info['initialize_type'])

    prepare = _extract_prepare(before_broadcast, grappler_config)
    _check_type_equal(prepare.type_signature, type_info['prepare_type'])

    work = _extract_work(before_aggregate, grappler_config)
    _check_type_equal(work.type_signature, type_info['work_type'])

    zero, accumulate, merge, report = _extract_federated_aggregate_functions(
        before_aggregate, grappler_config)
    _check_type_equal(zero.type_signature, type_info['zero_type'])
    _check_type_equal(accumulate.type_signature, type_info['accumulate_type'])
    _check_type_equal(merge.type_signature, type_info['merge_type'])
    _check_type_equal(report.type_signature, type_info['report_type'])

    bitwidth = _extract_federated_secure_sum_functions(before_aggregate,
                                                       grappler_config)
    _check_type_equal(bitwidth.type_signature, type_info['bitwidth_type'])

    update = _extract_update(after_aggregate, grappler_config)
    _check_type_equal(update.type_signature, type_info['update_type'])

    next_parameter_names = (name for (
        name, _) in structure.iter_elements(ip.next.type_signature.parameter))
    server_state_label, client_data_label = next_parameter_names
    return canonical_form.CanonicalForm(
        computation_wrapper_instances.building_block_to_computation(
            initialize),
        computation_wrapper_instances.building_block_to_computation(prepare),
        computation_wrapper_instances.building_block_to_computation(work),
        computation_wrapper_instances.building_block_to_computation(zero),
        computation_wrapper_instances.building_block_to_computation(
            accumulate),
        computation_wrapper_instances.building_block_to_computation(merge),
        computation_wrapper_instances.building_block_to_computation(report),
        computation_wrapper_instances.building_block_to_computation(bitwidth),
        computation_wrapper_instances.building_block_to_computation(update),
        server_state_label=server_state_label,
        client_data_label=client_data_label)