Ejemplo n.º 1
0
 def test_passes_with_federated_map(self):
   intrinsic = building_blocks.Intrinsic(
       intrinsic_defs.FEDERATED_MAP.uri,
       computation_types.FunctionType([
           computation_types.FunctionType(tf.int32, tf.float32),
           computation_types.FederatedType(tf.int32, placements.CLIENTS)
       ], computation_types.FederatedType(tf.float32, placements.CLIENTS)))
   tree_analysis.check_intrinsics_whitelisted_for_reduction(intrinsic)
Ejemplo n.º 2
0
  def test_raises_with_federated_mean(self):
    intrinsic = building_blocks.Intrinsic(
        intrinsic_defs.FEDERATED_MEAN.uri,
        computation_types.FunctionType(
            computation_types.FederatedType(tf.int32, placements.CLIENTS),
            computation_types.FederatedType(tf.int32, placements.SERVER)))

    with self.assertRaisesRegex(ValueError, intrinsic.compact_representation()):
      tree_analysis.check_intrinsics_whitelisted_for_reduction(intrinsic)
Ejemplo n.º 3
0
  def test_generic_divide_reduces(self):
    uri = intrinsic_defs.GENERIC_DIVIDE.uri
    context_stack = context_stack_impl.context_stack
    comp = building_blocks.Intrinsic(
        uri, computation_types.FunctionType([tf.float32, tf.float32],
                                            tf.float32))

    count_before_reduction = _count_intrinsics(comp, uri)
    reduced, modified = value_transformations.replace_all_intrinsics_with_bodies(
        comp, context_stack)
    count_after_reduction = _count_intrinsics(reduced, uri)

    self.assertGreater(count_before_reduction, 0)
    self.assertEqual(count_after_reduction, 0)
    tree_analysis.check_intrinsics_whitelisted_for_reduction(reduced)
    self.assertTrue(modified)
Ejemplo n.º 4
0
 def test_raises_on_none(self):
     with self.assertRaises(TypeError):
         tree_analysis.check_intrinsics_whitelisted_for_reduction(None)
Ejemplo n.º 5
0
def get_canonical_form_for_iterative_process(iterative_process):
    """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process.

  This function transforms computations from the input `iterative_process` into
  an instance of `tff.backends.mapreduce.CanonicalForm`.

  Args:
    iterative_process: An instance of `tff.utils.IterativeProcess`.

  Returns:
    An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this
    process.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.CanonicalFormCompilationError: If the compilation
      process fails.
  """
    py_typecheck.check_type(iterative_process,
                            computation_utils.IterativeProcess)

    initialize_comp = building_blocks.ComputationBuildingBlock.from_proto(
        iterative_process.initialize._computation_proto)  # pylint: disable=protected-access
    next_comp = building_blocks.ComputationBuildingBlock.from_proto(
        iterative_process.next._computation_proto)  # pylint: disable=protected-access
    _check_iterative_process_compatible_with_canonical_form(
        initialize_comp, next_comp)

    if len(next_comp.type_signature.result) == 2:
        next_comp = _create_next_with_fake_client_output(next_comp)

    initialize_comp = _replace_intrinsics_with_bodies(initialize_comp)
    next_comp = _replace_intrinsics_with_bodies(next_comp)
    tree_analysis.check_intrinsics_whitelisted_for_reduction(initialize_comp)
    tree_analysis.check_intrinsics_whitelisted_for_reduction(next_comp)
    tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp)

    if tree_analysis.contains_called_intrinsic(
            next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri):
        before_broadcast, after_broadcast = (
            transformations.force_align_and_split_by_intrinsics(
                next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri]))
    else:
        before_broadcast, after_broadcast = (
            _create_before_and_after_broadcast_for_no_broadcast(next_comp))

    contains_federated_aggregate = tree_analysis.contains_called_intrinsic(
        next_comp, intrinsic_defs.FEDERATED_AGGREGATE.uri)
    contains_federated_secure_sum = tree_analysis.contains_called_intrinsic(
        next_comp, intrinsic_defs.FEDERATED_SECURE_SUM.uri)
    if contains_federated_aggregate and contains_federated_secure_sum:
        before_aggregate, after_aggregate = (
            transformations.force_align_and_split_by_intrinsics(
                after_broadcast, [
                    intrinsic_defs.FEDERATED_AGGREGATE.uri,
                    intrinsic_defs.FEDERATED_SECURE_SUM.uri,
                ]))
    elif not contains_federated_aggregate:
        before_aggregate, after_aggregate = (
            _create_before_and_after_aggregate_for_no_federated_aggregate(
                after_broadcast))
    elif not contains_federated_secure_sum:
        before_aggregate, after_aggregate = (
            _create_before_and_after_aggregate_for_no_federated_secure_sum(
                after_broadcast))
    else:
        raise ValueError(
            'Expected an `tff.utils.IterativeProcess` containing at least one '
            '`federated_aggregate` or `federated_secure_sum`, found none.')

    type_info = _get_type_info(initialize_comp, before_broadcast,
                               after_broadcast, before_aggregate,
                               after_aggregate)

    initialize = transformations.consolidate_and_extract_local_processing(
        initialize_comp)
    _check_type_equal(initialize.type_signature, type_info['initialize_type'])

    prepare = _extract_prepare(before_broadcast)
    _check_type_equal(prepare.type_signature, type_info['prepare_type'])

    work = _extract_work(before_aggregate, after_aggregate)
    _check_type_equal(work.type_signature, type_info['work_type'])

    zero, accumulate, merge, report = _extract_federated_aggregate_functions(
        before_aggregate)
    _check_type_equal(zero.type_signature, type_info['zero_type'])
    _check_type_equal(accumulate.type_signature, type_info['accumulate_type'])
    _check_type_equal(merge.type_signature, type_info['merge_type'])
    _check_type_equal(report.type_signature, type_info['report_type'])

    bitwidth = _extract_federated_secure_sum_functions(before_aggregate)
    _check_type_equal(bitwidth.type_signature, type_info['bitwidth_type'])

    update = _extract_update(after_aggregate)
    _check_type_equal(update.type_signature, type_info['update_type'])

    return canonical_form.CanonicalForm(
        computation_wrapper_instances.building_block_to_computation(
            initialize),
        computation_wrapper_instances.building_block_to_computation(prepare),
        computation_wrapper_instances.building_block_to_computation(work),
        computation_wrapper_instances.building_block_to_computation(zero),
        computation_wrapper_instances.building_block_to_computation(
            accumulate),
        computation_wrapper_instances.building_block_to_computation(merge),
        computation_wrapper_instances.building_block_to_computation(report),
        computation_wrapper_instances.building_block_to_computation(bitwidth),
        computation_wrapper_instances.building_block_to_computation(update))
Ejemplo n.º 6
0
def get_canonical_form_for_iterative_process(iterative_process):
    """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process.

  This function transforms computations from the input `iterative_process` into
  an instance of `tff.backends.mapreduce.CanonicalForm`.

  Args:
    iterative_process: An instance of `tff.utils.IterativeProcess`.

  Returns:
    An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this
    process.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.CanonicalFormCompilationError: If the compilation
      process fails.
  """
    py_typecheck.check_type(iterative_process,
                            computation_utils.IterativeProcess)

    initialize_comp = building_blocks.ComputationBuildingBlock.from_proto(
        iterative_process.initialize._computation_proto)  # pylint: disable=protected-access

    next_comp = building_blocks.ComputationBuildingBlock.from_proto(
        iterative_process.next._computation_proto)  # pylint: disable=protected-access

    if not (isinstance(next_comp.type_signature.parameter,
                       computation_types.NamedTupleType)
            and isinstance(next_comp.type_signature.result,
                           computation_types.NamedTupleType)):
        raise TypeError(
            'Any IterativeProcess compatible with CanonicalForm must '
            'have a `next` function which takes and returns instances '
            'of `tff.NamedTupleType`; your next function takes '
            'parameters of type {} and returns results of type {}'.format(
                next_comp.type_signature.parameter,
                next_comp.type_signature.result))

    if len(next_comp.type_signature.result) == 2:
        next_comp = _create_next_with_fake_client_output(next_comp)

    initialize_comp = replace_intrinsics_with_bodies(initialize_comp)
    next_comp = replace_intrinsics_with_bodies(next_comp)

    tree_analysis.check_intrinsics_whitelisted_for_reduction(initialize_comp)
    tree_analysis.check_intrinsics_whitelisted_for_reduction(next_comp)
    tree_analysis.check_broadcast_not_dependent_on_aggregate(next_comp)

    if tree_analysis.contains_called_intrinsic(
            next_comp, intrinsic_defs.FEDERATED_BROADCAST.uri):
        before_broadcast, after_broadcast = (
            transformations.force_align_and_split_by_intrinsics(
                next_comp, [intrinsic_defs.FEDERATED_BROADCAST.uri]))
    else:
        before_broadcast, after_broadcast = (
            _create_before_and_after_broadcast_for_no_broadcast(next_comp))

    before_aggregate, after_aggregate = (
        transformations.force_align_and_split_by_intrinsics(
            after_broadcast, [intrinsic_defs.FEDERATED_AGGREGATE.uri]))

    init_info_packed = pack_initialize_comp_type_signature(
        initialize_comp.type_signature)

    next_info_packed = pack_next_comp_type_signature(next_comp.type_signature,
                                                     init_info_packed)

    before_broadcast_info_packed = (
        check_and_pack_before_broadcast_type_signature(
            before_broadcast.type_signature, next_info_packed))

    before_aggregate_info_packed = (
        check_and_pack_before_aggregate_type_signature(
            before_aggregate.type_signature, before_broadcast_info_packed))

    canonical_form_types = check_and_pack_after_aggregate_type_signature(
        after_aggregate.type_signature, before_aggregate_info_packed)

    initialize = transformations.consolidate_and_extract_local_processing(
        initialize_comp)

    if initialize.type_signature.result != canonical_form_types[
            'initialize_type'].member:
        raise transformations.CanonicalFormCompilationError(
            'Compilation of initialize has failed. Expected to extract a '
            '`building_blocks.CompiledComputation` of type {}, instead we extracted '
            'a {} of type {}.'.format(next_comp.type_signature.parameter[0],
                                      type(initialize),
                                      initialize.type_signature.result))

    prepare = extract_prepare(before_broadcast, canonical_form_types)

    work = extract_work(before_aggregate, after_aggregate,
                        canonical_form_types)

    zero_noarg_function, accumulate, merge, report = extract_aggregate_functions(
        before_aggregate, canonical_form_types)

    update = extract_update(after_aggregate, canonical_form_types)

    cf = canonical_form.CanonicalForm(
        computation_wrapper_instances.building_block_to_computation(
            initialize),
        computation_wrapper_instances.building_block_to_computation(prepare),
        computation_wrapper_instances.building_block_to_computation(work),
        computation_wrapper_instances.building_block_to_computation(
            zero_noarg_function),
        computation_wrapper_instances.building_block_to_computation(
            accumulate),
        computation_wrapper_instances.building_block_to_computation(merge),
        computation_wrapper_instances.building_block_to_computation(report),
        computation_wrapper_instances.building_block_to_computation(update))
    return cf