Пример #1
0
 def test_returns_correct_example_of_broadcast_dependent_on_aggregate(self):
   aggregate = test_utils.create_dummy_called_federated_aggregate()
   broadcasted_aggregate = building_block_factory.create_federated_broadcast(
       aggregate)
   with self.assertRaisesRegex(ValueError, 'acc_param'):
     tree_analysis.check_broadcast_not_dependent_on_aggregate(
         broadcasted_aggregate)
Пример #2
0
 def test_finds_broadcast_dependent_on_aggregate(self):
   aggregate = test_utils.create_dummy_called_federated_aggregate()
   broadcasted_aggregate = building_block_factory.create_federated_broadcast(
       aggregate)
   with self.assertRaises(ValueError):
     tree_analysis.check_broadcast_not_dependent_on_aggregate(
         broadcasted_aggregate)
Пример #3
0
 def test_returns_correct_example_of_broadcast_dependent_on_aggregate(self):
     aggregate = computation_test_utils.create_dummy_called_federated_aggregate(
         'accumulate_parameter', 'merge_parameter', 'report_parameter')
     broadcasted_aggregate = building_block_factory.create_federated_broadcast(
         aggregate)
     with self.assertRaisesRegex(ValueError, 'accumulate_parameter'):
         tree_analysis.check_broadcast_not_dependent_on_aggregate(
             broadcasted_aggregate)
Пример #4
0
def check_iterative_process_compatible_with_map_reduce_form(
        ip: iterative_process.IterativeProcess):
    """Tests compatibility with `tff.backends.mapreduce.MapReduceForm`.

  Note: the conditions here are specified in the documentation for
    `get_map_reduce_form_for_iterative_process`. Changes to this function should
    be propagated to that documentation.

  Args:
    ip: An instance of `tff.templates.IterativeProcess` to check for
      compatibility with `tff.backends.mapreduce.MapReduceForm`.

  Returns:
    TFF-internal building-blocks representing the validated and simplified
    `initialize` and `next` computations.

  Raises:
    TypeError: If the arguments are of the wrong types.
  """
    py_typecheck.check_type(ip, iterative_process.IterativeProcess)
    initialize_tree = ip.initialize.to_building_block()
    next_tree = ip.next.to_building_block()

    init_type = initialize_tree.type_signature
    _check_type_is_no_arg_fn(init_type, '`initialize`', TypeError)
    if (not init_type.result.is_federated()
            or init_type.result.placement != placements.SERVER):
        raise TypeError(
            'Expected `initialize` to return a single federated value '
            'placed at server (type `T@SERVER`), found return type:\n'
            f'{init_type.result}')

    next_type = next_tree.type_signature
    _check_type_is_fn(next_type, '`next`', TypeError)
    if not next_type.parameter.is_struct() or len(next_type.parameter) != 2:
        raise TypeError(
            'Expected `next` to take two arguments, found parameter '
            f' type:\n{next_type.parameter}')
    if not next_type.result.is_struct() or len(next_type.result) != 2:
        raise TypeError('Expected `next` to return two values, found result '
                        f'type:\n{next_type.result}')

    initialize_tree, _ = intrinsic_reductions.replace_intrinsics_with_bodies(
        initialize_tree)
    next_tree, _ = intrinsic_reductions.replace_intrinsics_with_bodies(
        next_tree)
    next_tree = _replace_lambda_body_with_call_dominant_form(next_tree)

    tree_analysis.check_contains_only_reducible_intrinsics(initialize_tree)
    tree_analysis.check_contains_only_reducible_intrinsics(next_tree)
    tree_analysis.check_broadcast_not_dependent_on_aggregate(next_tree)

    return initialize_tree, next_tree
Пример #5
0
 def test_does_not_find_aggregate_dependent_on_broadcast(self):
   broadcast = test_utils.create_dummy_called_federated_broadcast()
   value_type = broadcast.type_signature
   zero = building_blocks.Data('zero', value_type.member)
   accumulate_result = building_blocks.Data('accumulate_result',
                                            value_type.member)
   accumulate = building_blocks.Lambda('accumulate_parameter',
                                       [value_type.member, value_type.member],
                                       accumulate_result)
   merge_result = building_blocks.Data('merge_result', value_type.member)
   merge = building_blocks.Lambda('merge_parameter',
                                  [value_type.member, value_type.member],
                                  merge_result)
   report_result = building_blocks.Data('report_result', value_type.member)
   report = building_blocks.Lambda('report_parameter', value_type.member,
                                   report_result)
   aggregate_dependent_on_broadcast = building_block_factory.create_federated_aggregate(
       broadcast, zero, accumulate, merge, report)
   tree_analysis.check_broadcast_not_dependent_on_aggregate(
       aggregate_dependent_on_broadcast)
Пример #6
0
 def test_raises_on_none_comp(self):
     with self.assertRaises(TypeError):
         tree_analysis.check_broadcast_not_dependent_on_aggregate(None)
def get_canonical_form_for_iterative_process(ip):
  """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process.

  This function transforms computations from the input `ip` into
  an instance of `tff.backends.mapreduce.CanonicalForm`.

  Args:
    ip: An instance of `tff.templates.IterativeProcess`.

  Returns:
    An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this
    process.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.CanonicalFormCompilationError: If the compilation
      process fails.
  """
  py_typecheck.check_type(ip, iterative_process.IterativeProcess)

  initialize_comp = building_blocks.ComputationBuildingBlock.from_proto(
      ip.initialize._computation_proto)  # pylint: disable=protected-access
  next_comp = building_blocks.ComputationBuildingBlock.from_proto(
      ip.next._computation_proto)  # pylint: disable=protected-access
  _check_iterative_process_compatible_with_canonical_form(
      initialize_comp, next_comp)

  initialize_comp = _replace_intrinsics_with_bodies(initialize_comp)
  next_comp = _replace_intrinsics_with_bodies(next_comp)
  tree_analysis.check_contains_only_reducible_intrinsics(initialize_comp)
  tree_analysis.check_contains_only_reducible_intrinsics(next_comp)
  tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp)

  if tree_analysis.contains_called_intrinsic(
      next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri):
    before_broadcast, after_broadcast = (
        transformations.force_align_and_split_by_intrinsics(
            next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri]))
  else:
    before_broadcast, after_broadcast = (
        _create_before_and_after_broadcast_for_no_broadcast(next_comp))

  contains_federated_aggregate = tree_analysis.contains_called_intrinsic(
      next_comp, intrinsic_defs.FEDERATED_AGGREGATE.uri)
  contains_federated_secure_sum = tree_analysis.contains_called_intrinsic(
      next_comp, intrinsic_defs.FEDERATED_SECURE_SUM.uri)
  if not (contains_federated_aggregate or contains_federated_secure_sum):
    raise ValueError(
        'Expected an `tff.templates.IterativeProcess` containing at least one '
        '`federated_aggregate` or `federated_secure_sum`, found none.')

  if contains_federated_aggregate and contains_federated_secure_sum:
    before_aggregate, after_aggregate = (
        transformations.force_align_and_split_by_intrinsics(
            after_broadcast, [
                intrinsic_defs.FEDERATED_AGGREGATE.uri,
                intrinsic_defs.FEDERATED_SECURE_SUM.uri,
            ]))
  elif contains_federated_secure_sum:
    assert not contains_federated_aggregate
    before_aggregate, after_aggregate = (
        _create_before_and_after_aggregate_for_no_federated_aggregate(
            after_broadcast))
  else:
    assert contains_federated_aggregate and not contains_federated_secure_sum
    before_aggregate, after_aggregate = (
        _create_before_and_after_aggregate_for_no_federated_secure_sum(
            after_broadcast))

  type_info = _get_type_info(initialize_comp, before_broadcast, after_broadcast,
                             before_aggregate, after_aggregate)

  initialize = transformations.consolidate_and_extract_local_processing(
      initialize_comp)
  _check_type_equal(initialize.type_signature, type_info['initialize_type'])

  prepare = _extract_prepare(before_broadcast)
  _check_type_equal(prepare.type_signature, type_info['prepare_type'])

  work = _extract_work(before_aggregate)
  _check_type_equal(work.type_signature, type_info['work_type'])

  zero, accumulate, merge, report = _extract_federated_aggregate_functions(
      before_aggregate)
  _check_type_equal(zero.type_signature, type_info['zero_type'])
  _check_type_equal(accumulate.type_signature, type_info['accumulate_type'])
  _check_type_equal(merge.type_signature, type_info['merge_type'])
  _check_type_equal(report.type_signature, type_info['report_type'])

  bitwidth = _extract_federated_secure_sum_functions(before_aggregate)
  _check_type_equal(bitwidth.type_signature, type_info['bitwidth_type'])

  update = _extract_update(after_aggregate)
  _check_type_equal(update.type_signature, type_info['update_type'])

  return canonical_form.CanonicalForm(
      computation_wrapper_instances.building_block_to_computation(initialize),
      computation_wrapper_instances.building_block_to_computation(prepare),
      computation_wrapper_instances.building_block_to_computation(work),
      computation_wrapper_instances.building_block_to_computation(zero),
      computation_wrapper_instances.building_block_to_computation(accumulate),
      computation_wrapper_instances.building_block_to_computation(merge),
      computation_wrapper_instances.building_block_to_computation(report),
      computation_wrapper_instances.building_block_to_computation(bitwidth),
      computation_wrapper_instances.building_block_to_computation(update))
Пример #8
0
def get_canonical_form_for_iterative_process(iterative_process):
    """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process.

  This function transforms computations from the input `iterative_process` into
  an instance of `tff.backends.mapreduce.CanonicalForm`.

  Args:
    iterative_process: An instance of `tff.utils.IterativeProcess`.

  Returns:
    An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this
    process.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.CanonicalFormCompilationError: If the compilation
      process fails.
  """
    py_typecheck.check_type(iterative_process,
                            computation_utils.IterativeProcess)

    initialize_comp = building_blocks.ComputationBuildingBlock.from_proto(
        iterative_process.initialize._computation_proto)  # pylint: disable=protected-access
    next_comp = building_blocks.ComputationBuildingBlock.from_proto(
        iterative_process.next._computation_proto)  # pylint: disable=protected-access
    _check_iterative_process_compatible_with_canonical_form(
        initialize_comp, next_comp)

    if len(next_comp.type_signature.result) == 2:
        next_comp = _create_next_with_fake_client_output(next_comp)

    initialize_comp = _replace_intrinsics_with_bodies(initialize_comp)
    next_comp = _replace_intrinsics_with_bodies(next_comp)
    tree_analysis.check_intrinsics_whitelisted_for_reduction(initialize_comp)
    tree_analysis.check_intrinsics_whitelisted_for_reduction(next_comp)
    tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp)

    if tree_analysis.contains_called_intrinsic(
            next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri):
        before_broadcast, after_broadcast = (
            transformations.force_align_and_split_by_intrinsics(
                next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri]))
    else:
        before_broadcast, after_broadcast = (
            _create_before_and_after_broadcast_for_no_broadcast(next_comp))

    before_aggregate, after_aggregate = (
        transformations.force_align_and_split_by_intrinsics(
            after_broadcast, [intrinsic_defs.FEDERATED_AGGREGATE.uri]))

    type_info = _get_type_info(initialize_comp, before_broadcast,
                               after_broadcast, before_aggregate,
                               after_aggregate)

    initialize = transformations.consolidate_and_extract_local_processing(
        initialize_comp)
    _check_type_equal(initialize.type_signature, type_info['initialize_type'])

    prepare = _extract_prepare(before_broadcast)
    _check_type_equal(prepare.type_signature, type_info['prepare_type'])

    work = _extract_work(before_aggregate, after_aggregate)
    _check_type_equal(work.type_signature, type_info['work_type'])

    zero, accumulate, merge, report = _extract_aggregate_functions(
        before_aggregate)
    _check_type_equal(zero.type_signature, type_info['zero_type'])
    _check_type_equal(accumulate.type_signature, type_info['accumulate_type'])
    _check_type_equal(merge.type_signature, type_info['merge_type'])
    _check_type_equal(report.type_signature, type_info['report_type'])

    update = _extract_update(after_aggregate)
    _check_type_equal(update.type_signature, type_info['update_type'])

    return canonical_form.CanonicalForm(
        computation_wrapper_instances.building_block_to_computation(
            initialize),
        computation_wrapper_instances.building_block_to_computation(prepare),
        computation_wrapper_instances.building_block_to_computation(work),
        computation_wrapper_instances.building_block_to_computation(zero),
        computation_wrapper_instances.building_block_to_computation(
            accumulate),
        computation_wrapper_instances.building_block_to_computation(merge),
        computation_wrapper_instances.building_block_to_computation(report),
        computation_wrapper_instances.building_block_to_computation(update))
Пример #9
0
def get_canonical_form_for_iterative_process(iterative_process):
    """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process.

  This function transforms computations from the input `iterative_process` into
  an instance of `tff.backends.mapreduce.CanonicalForm`.

  Args:
    iterative_process: An instance of `tff.utils.IterativeProcess`.

  Returns:
    An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this
    process.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.CanonicalFormCompilationError: If the compilation
      process fails.
  """
    py_typecheck.check_type(iterative_process,
                            computation_utils.IterativeProcess)

    initialize_comp = building_blocks.ComputationBuildingBlock.from_proto(
        iterative_process.initialize._computation_proto)  # pylint: disable=protected-access

    next_comp = building_blocks.ComputationBuildingBlock.from_proto(
        iterative_process.next._computation_proto)  # pylint: disable=protected-access

    if not (isinstance(next_comp.type_signature.parameter,
                       computation_types.NamedTupleType)
            and isinstance(next_comp.type_signature.result,
                           computation_types.NamedTupleType)):
        raise TypeError(
            'Any IterativeProcess compatible with CanonicalForm must '
            'have a `next` function which takes and returns instances '
            'of `tff.NamedTupleType`; your next function takes '
            'parameters of type {} and returns results of type {}'.format(
                next_comp.type_signature.parameter,
                next_comp.type_signature.result))

    if len(next_comp.type_signature.result) == 2:
        next_comp = _create_next_with_fake_client_output(next_comp)

    initialize_comp = replace_intrinsics_with_bodies(initialize_comp)
    next_comp = replace_intrinsics_with_bodies(next_comp)

    tree_analysis.check_intrinsics_whitelisted_for_reduction(initialize_comp)
    tree_analysis.check_intrinsics_whitelisted_for_reduction(next_comp)
    tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp)

    if tree_analysis.contains_called_intrinsic(
            next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri):
        before_broadcast, after_broadcast = (
            transformations.force_align_and_split_by_intrinsics(
                next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri]))
    else:
        before_broadcast, after_broadcast = (
            _create_before_and_after_broadcast_for_no_broadcast(next_comp))

    before_aggregate, after_aggregate = (
        transformations.force_align_and_split_by_intrinsics(
            after_broadcast, [intrinsic_defs.FEDERATED_AGGREGATE.uri]))

    init_info_packed = pack_initialize_comp_type_signature(
        initialize_comp.type_signature)

    next_info_packed = pack_next_comp_type_signature(next_comp.type_signature,
                                                     init_info_packed)

    before_broadcast_info_packed = (
        check_and_pack_before_broadcast_type_signature(
            before_broadcast.type_signature, next_info_packed))

    before_aggregate_info_packed = (
        check_and_pack_before_aggregate_type_signature(
            before_aggregate.type_signature, before_broadcast_info_packed))

    canonical_form_types = check_and_pack_after_aggregate_type_signature(
        after_aggregate.type_signature, before_aggregate_info_packed)

    initialize = transformations.consolidate_and_extract_local_processing(
        initialize_comp)

    if initialize.type_signature.result != canonical_form_types[
            'initialize_type'].member:
        raise transformations.CanonicalFormCompilationError(
            'Compilation of initialize has failed. Expected to extract a '
            '`building_blocks.CompiledComputation` of type {}, instead we extracted '
            'a {} of type {}.'.format(next_comp.type_signature.parameter[0],
                                      type(initialize),
                                      initialize.type_signature.result))

    prepare = extract_prepare(before_broadcast, canonical_form_types)

    work = extract_work(before_aggregate, after_aggregate,
                        canonical_form_types)

    zero_noarg_function, accumulate, merge, report = extract_aggregate_functions(
        before_aggregate, canonical_form_types)

    update = extract_update(after_aggregate, canonical_form_types)

    cf = canonical_form.CanonicalForm(
        computation_wrapper_instances.building_block_to_computation(
            initialize),
        computation_wrapper_instances.building_block_to_computation(prepare),
        computation_wrapper_instances.building_block_to_computation(work),
        computation_wrapper_instances.building_block_to_computation(
            zero_noarg_function),
        computation_wrapper_instances.building_block_to_computation(
            accumulate),
        computation_wrapper_instances.building_block_to_computation(merge),
        computation_wrapper_instances.building_block_to_computation(report),
        computation_wrapper_instances.building_block_to_computation(update))
    return cf
Пример #10
0
def get_canonical_form_for_iterative_process(
    ip: iterative_process.IterativeProcess,
    grappler_config: Optional[
        tf.compat.v1.ConfigProto] = _GRAPPLER_DEFAULT_CONFIG):
    """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process.

  This function transforms computations from the input `ip` into
  an instance of `tff.backends.mapreduce.CanonicalForm`.

  Args:
    ip: An instance of `tff.templates.IterativeProcess` that is compatible
      with canonical form. Iterative processes are only compatible if: -
        `initialize_fn` returns a single federated value placed at `SERVER`. -
        `next` takes exactly two arguments. The first must be the state value
        placed at `SERVER`. - `next` returns exactly two values.
    grappler_config: An optional instance of `tf.compat.v1.ConfigProto` to
      configure Grappler graph optimization of the TensorFlow graphs backing the
      resulting `tff.backends.mapreduce.CanonicalForm`. These options are
      combined with a set of defaults that aggressively configure Grappler. If
      `None`, Grappler is bypassed.

  Returns:
    An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this
    process.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.CanonicalFormCompilationError: If the compilation
      process fails.
  """
    py_typecheck.check_type(ip, iterative_process.IterativeProcess)
    if grappler_config is not None:
        py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto)
        overridden_grappler_config = tf.compat.v1.ConfigProto()
        overridden_grappler_config.CopyFrom(_GRAPPLER_DEFAULT_CONFIG)
        overridden_grappler_config.MergeFrom(grappler_config)
        grappler_config = overridden_grappler_config

    initialize_comp = building_blocks.ComputationBuildingBlock.from_proto(
        ip.initialize._computation_proto)  # pylint: disable=protected-access
    next_comp = building_blocks.ComputationBuildingBlock.from_proto(
        ip.next._computation_proto)  # pylint: disable=protected-access
    _check_iterative_process_compatible_with_canonical_form(
        initialize_comp, next_comp)

    initialize_comp = _replace_intrinsics_with_bodies(initialize_comp)
    next_comp = _replace_intrinsics_with_bodies(next_comp)
    next_comp = _replace_lambda_body_with_call_dominant_form(next_comp)

    tree_analysis.check_contains_only_reducible_intrinsics(initialize_comp)
    tree_analysis.check_contains_only_reducible_intrinsics(next_comp)
    tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp)

    if tree_analysis.contains_called_intrinsic(
            next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri):

        before_broadcast, after_broadcast = (
            transformations.force_align_and_split_by_intrinsics(
                next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri]))
    else:
        before_broadcast, after_broadcast = (
            _create_before_and_after_broadcast_for_no_broadcast(next_comp))

    contains_federated_aggregate = tree_analysis.contains_called_intrinsic(
        next_comp, intrinsic_defs.FEDERATED_AGGREGATE.uri)
    contains_federated_secure_sum = tree_analysis.contains_called_intrinsic(
        next_comp, intrinsic_defs.FEDERATED_SECURE_SUM.uri)
    if not (contains_federated_aggregate or contains_federated_secure_sum):
        raise ValueError(
            'Expected an `tff.templates.IterativeProcess` containing at least one '
            '`federated_aggregate` or `federated_secure_sum`, found none.')

    if contains_federated_aggregate and contains_federated_secure_sum:
        before_aggregate, after_aggregate = (
            transformations.force_align_and_split_by_intrinsics(
                after_broadcast, [
                    intrinsic_defs.FEDERATED_AGGREGATE.uri,
                    intrinsic_defs.FEDERATED_SECURE_SUM.uri,
                ]))
    elif contains_federated_secure_sum:
        assert not contains_federated_aggregate
        before_aggregate, after_aggregate = (
            _create_before_and_after_aggregate_for_no_federated_aggregate(
                after_broadcast))
    else:
        assert contains_federated_aggregate and not contains_federated_secure_sum
        before_aggregate, after_aggregate = (
            _create_before_and_after_aggregate_for_no_federated_secure_sum(
                after_broadcast))

    type_info = _get_type_info(initialize_comp, before_broadcast,
                               after_broadcast, before_aggregate,
                               after_aggregate)

    initialize = transformations.consolidate_and_extract_local_processing(
        initialize_comp, grappler_config)
    _check_type_equal(initialize.type_signature, type_info['initialize_type'])

    prepare = _extract_prepare(before_broadcast, grappler_config)
    _check_type_equal(prepare.type_signature, type_info['prepare_type'])

    work = _extract_work(before_aggregate, grappler_config)
    _check_type_equal(work.type_signature, type_info['work_type'])

    zero, accumulate, merge, report = _extract_federated_aggregate_functions(
        before_aggregate, grappler_config)
    _check_type_equal(zero.type_signature, type_info['zero_type'])
    _check_type_equal(accumulate.type_signature, type_info['accumulate_type'])
    _check_type_equal(merge.type_signature, type_info['merge_type'])
    _check_type_equal(report.type_signature, type_info['report_type'])

    bitwidth = _extract_federated_secure_sum_functions(before_aggregate,
                                                       grappler_config)
    _check_type_equal(bitwidth.type_signature, type_info['bitwidth_type'])

    update = _extract_update(after_aggregate, grappler_config)
    _check_type_equal(update.type_signature, type_info['update_type'])

    next_parameter_names = (name for (
        name, _) in structure.iter_elements(ip.next.type_signature.parameter))
    server_state_label, client_data_label = next_parameter_names
    return canonical_form.CanonicalForm(
        computation_wrapper_instances.building_block_to_computation(
            initialize),
        computation_wrapper_instances.building_block_to_computation(prepare),
        computation_wrapper_instances.building_block_to_computation(work),
        computation_wrapper_instances.building_block_to_computation(zero),
        computation_wrapper_instances.building_block_to_computation(
            accumulate),
        computation_wrapper_instances.building_block_to_computation(merge),
        computation_wrapper_instances.building_block_to_computation(report),
        computation_wrapper_instances.building_block_to_computation(bitwidth),
        computation_wrapper_instances.building_block_to_computation(update),
        server_state_label=server_state_label,
        client_data_label=client_data_label)