Beispiel #1
0
def extract_work(before_aggregate, after_aggregate, canonical_form_types):
    """Converts `before_aggregate` and `after_aggregate` to `work`.

  Args:
    before_aggregate: The first result of splitting `after_broadcast` on
      `tff_framework.FEDERATED_AGGREGATE`.
    after_aggregate: The second result of splitting `after_broadcast` on
      `tff_framework.FEDERATED_AGGREGATE`.
    canonical_form_types: `dict` holding the `canonical_form.CanonicalForm` type
      signatures specified by the `tff.utils.IterativeProcess` we are compiling.

  Returns:
    `work` as specified by `canonical_form.CanonicalForm`, an instance of
    `tff_framework.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we fail to extract a
    `tff_framework.CompiledComputation`, or we extract one of the wrong type.
  """
    # See `get_iterative_process_for_canonical_form()` above for the meaning of
    # variable names used in the code below.
    c3_elements_in_before_aggregate_parameter = [[0, 1], [1]]
    c3_to_before_aggregate_computation = (
        transformations.zip_selection_as_argument_to_lower_level_lambda(
            before_aggregate,
            c3_elements_in_before_aggregate_parameter).result.function)
    c5_index_in_before_aggregate_result = 0
    c3_to_c5_computation = transformations.select_output_from_lambda(
        c3_to_before_aggregate_computation,
        c5_index_in_before_aggregate_result)
    c6_index_in_after_aggregate_result = 2
    after_aggregate_to_c6_computation = transformations.select_output_from_lambda(
        after_aggregate, c6_index_in_after_aggregate_result)
    c3_elements_in_after_aggregate_parameter = [[0, 0, 1], [0, 1]]
    c3_to_c6_computation = (
        transformations.zip_selection_as_argument_to_lower_level_lambda(
            after_aggregate_to_c6_computation,
            c3_elements_in_after_aggregate_parameter).result.function)
    c3_to_unzipped_c4_computation = transformations.concatenate_function_outputs(
        c3_to_c5_computation, c3_to_c6_computation)
    c3_to_c4_computation = tff_framework.Lambda(
        c3_to_unzipped_c4_computation.parameter_name,
        c3_to_unzipped_c4_computation.parameter_type,
        tff_framework.create_federated_zip(
            c3_to_unzipped_c4_computation.result))

    work = transformations.consolidate_and_extract_local_processing(
        c3_to_c4_computation)
    if not isinstance(work, tff_framework.CompiledComputation):
        raise transformations.CanonicalFormCompilationError(
            'Failed to extract a `tff_framework.CompiledComputation` from '
            'work, instead received a {} (of type {}).'.format(
                type(work), work.type_signature))
    if work.type_signature != canonical_form_types['work_type']:
        raise transformations.CanonicalFormCompilationError(
            'Extracted a TF block of the wrong type. Expected a function with type '
            '{}, but the type signature of the TF block was {}'.format(
                canonical_form_types['work_type'], work.type_signature))
    return work
Beispiel #2
0
def extract_aggregate_functions(before_aggregate, canonical_form_types):
    """Converts `before_aggregate` to aggregation functions.

  Args:
    before_aggregate: The first result of splitting `after_broadcast` on
      `tff_framework.FEDERATED_AGGREGATE`.
    canonical_form_types: `dict` holding the `canonical_form.CanonicalForm` type
      signatures specified by the `tff.utils.IterativeProcess` we are compiling.

  Returns:
    `zero`, `accumulate`, `merge` and `report` as specified by
    `canonical_form.CanonicalForm`. All are instances of
    `tff_framework.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: if we fail to extract
    `tff_framework.CompiledComputation`s, or we extract one of the wrong type.
  """
    # See `get_iterative_process_for_canonical_form()` above for the meaning of
    # variable names used in the code below.
    zero_index_in_before_aggregate_result = 1
    zero_tff = transformations.select_output_from_lambda(
        before_aggregate, zero_index_in_before_aggregate_result).result
    accumulate_index_in_before_aggregate_result = 2
    accumulate_tff = transformations.select_output_from_lambda(
        before_aggregate, accumulate_index_in_before_aggregate_result).result
    merge_index_in_before_aggregate_result = 3
    merge_tff = transformations.select_output_from_lambda(
        before_aggregate, merge_index_in_before_aggregate_result).result
    report_index_in_before_aggregate_result = 4
    report_tff = transformations.select_output_from_lambda(
        before_aggregate, report_index_in_before_aggregate_result).result

    zero = transformations.consolidate_and_extract_local_processing(zero_tff)
    accumulate = transformations.consolidate_and_extract_local_processing(
        accumulate_tff)
    merge = transformations.consolidate_and_extract_local_processing(merge_tff)
    report = transformations.consolidate_and_extract_local_processing(
        report_tff)
    for name, tf_block in (('zero', zero), ('accumulate', accumulate),
                           ('merge', merge), ('report', report)):
        if not isinstance(tf_block, tff_framework.CompiledComputation):
            raise transformations.CanonicalFormCompilationError(
                'Failed to extract a `tff_framework.CompiledComputation` from '
                '{}, instead received a {} (of type {}).'.format(
                    name, type(tf_block), tf_block.type_signature))
        if tf_block.type_signature != canonical_form_types['{}_type'.format(
                name)]:
            raise transformations.CanonicalFormCompilationError(
                'Extracted a TF block of the wrong type. Expected a function with type '
                '{}, but the type signature of the TF block was {}'.format(
                    canonical_form_types['{}_type'.format(name)],
                    tf_block.type_signature))
    return zero, accumulate, merge, report
def extract_prepare(before_broadcast, canonical_form_types):
    """Converts `before_broadcast` into `prepare`.

  Args:
    before_broadcast: The first result of splitting `next_comp` on
      `intrinsic_defs.FEDERATED_BROADCAST`.
    canonical_form_types: `dict` holding the `canonical_form.CanonicalForm` type
      signatures specified by the `tff.utils.IterativeProcess` we are compiling.

  Returns:
    `prepare` as specified by `canonical_form.CanonicalForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we extract an AST of the
      wrong type.
  """
    # See `get_iterative_process_for_canonical_form()` above for the meaning of
    # variable names used in the code below.
    s1_index_in_before_broadcast = 0
    s1_to_s2_computation = (
        transformations.
        bind_single_selection_as_argument_to_lower_level_lambda(
            before_broadcast, s1_index_in_before_broadcast)).result.function
    prepare = transformations.consolidate_and_extract_local_processing(
        s1_to_s2_computation)
    if prepare.type_signature != canonical_form_types['prepare_type']:
        raise transformations.CanonicalFormCompilationError(
            'Extracted a TF block of the wrong type. Expected a function with type '
            '{}, but the type signature of the TF block was {}'.format(
                canonical_form_types['prepare_type'], prepare.type_signature))
    return prepare
def extract_update(after_aggregate, canonical_form_types):
    """Converts `after_aggregate` to `update`.

  Args:
    after_aggregate: The second result of splitting `after_broadcast` on
      `intrinsic_defs.FEDERATED_AGGREGATE`.
    canonical_form_types: `dict` holding the `canonical_form.CanonicalForm` type
      signatures specified by the `tff.utils.IterativeProcess` we are compiling.

  Returns:
    `update` as specified by `canonical_form.CanonicalForm`, an instance of
    `building_blocks.CompiledComputation`.

  Raises:
    transformations.CanonicalFormCompilationError: If we fail to extract a
    `building_blocks.CompiledComputation`, or we extract one of the wrong type.
  """
    # See `get_iterative_process_for_canonical_form()` above for the meaning of
    # variable names used in the code below.
    s5_elements_in_after_aggregate_result = [0, 1]
    s5_output_extracted = transformations.select_output_from_lambda(
        after_aggregate, s5_elements_in_after_aggregate_result)
    s5_output_zipped = building_blocks.Lambda(
        s5_output_extracted.parameter_name, s5_output_extracted.parameter_type,
        building_block_factory.create_federated_zip(
            s5_output_extracted.result))
    s4_elements_in_after_aggregate_parameter = [[0, 0, 0], [1]]
    s4_to_s5_computation = (
        transformations.zip_selection_as_argument_to_lower_level_lambda(
            s5_output_zipped,
            s4_elements_in_after_aggregate_parameter).result.function)

    update = transformations.consolidate_and_extract_local_processing(
        s4_to_s5_computation)
    if not isinstance(update, building_blocks.CompiledComputation):
        raise transformations.CanonicalFormCompilationError(
            'Failed to extract a `building_blocks.CompiledComputation` from '
            'update, instead received a {} (of type {}).'.format(
                type(update), update.type_signature))
    if update.type_signature != canonical_form_types['update_type']:
        raise transformations.CanonicalFormCompilationError(
            'Extracted a TF block of the wrong type. Expected a function with type '
            '{}, but the type signature of the TF block was {}'.format(
                canonical_form_types['update_type'], update.type_signature))
    return update
Beispiel #5
0
def get_canonical_form_for_iterative_process(iterative_process):
    """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process.

  This function transforms computations from the input `iterative_process` into
  an instance of `tff.backends.mapreduce.CanonicalForm`.

  Args:
    iterative_process: An instance of `tff.utils.IterativeProcess`.

  Returns:
    An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this
    process.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.CanonicalFormCompilationError: If the compilation
      process fails.
  """
    py_typecheck.check_type(iterative_process,
                            computation_utils.IterativeProcess)

    initialize_comp = tff_framework.ComputationBuildingBlock.from_proto(
        iterative_process.initialize._computation_proto)  # pylint: disable=protected-access

    next_comp = tff_framework.ComputationBuildingBlock.from_proto(
        iterative_process.next._computation_proto)  # pylint: disable=protected-access

    if not (isinstance(next_comp.type_signature.parameter, tff.NamedTupleType)
            and isinstance(next_comp.type_signature.result,
                           tff.NamedTupleType)):
        raise TypeError(
            'Any IterativeProcess compatible with CanonicalForm must '
            'have a `next` function which takes and returns instances '
            'of `tff.NamedTupleType`; your next function takes '
            'parameters of type {} and returns results of type {}'.format(
                next_comp.type_signature.parameter,
                next_comp.type_signature.result))

    if len(next_comp.type_signature.result) == 2:
        next_result = next_comp.result
        dummy_clients_metrics_appended = tff_framework.Tuple([
            next_result[0],
            next_result[1],
            tff.federated_value([], tff.CLIENTS)._comp  # pylint: disable=protected-access
        ])
        next_comp = tff_framework.Lambda(next_comp.parameter_name,
                                         next_comp.parameter_type,
                                         dummy_clients_metrics_appended)

    initialize_comp = tff_framework.replace_intrinsics_with_bodies(
        initialize_comp)
    next_comp = tff_framework.replace_intrinsics_with_bodies(next_comp)

    tff_framework.check_intrinsics_whitelisted_for_reduction(initialize_comp)
    tff_framework.check_intrinsics_whitelisted_for_reduction(next_comp)
    tff_framework.check_broadcast_not_dependent_on_aggregate(next_comp)

    before_broadcast, after_broadcast = (
        transformations.force_align_and_split_by_intrinsic(
            next_comp, tff_framework.FEDERATED_BROADCAST.uri))

    before_aggregate, after_aggregate = (
        transformations.force_align_and_split_by_intrinsic(
            after_broadcast, tff_framework.FEDERATED_AGGREGATE.uri))

    init_info_packed = pack_initialize_comp_type_signature(
        initialize_comp.type_signature)

    next_info_packed = pack_next_comp_type_signature(next_comp.type_signature,
                                                     init_info_packed)

    before_broadcast_info_packed = (
        check_and_pack_before_broadcast_type_signature(
            before_broadcast.type_signature, next_info_packed))

    before_aggregate_info_packed = (
        check_and_pack_before_aggregate_type_signature(
            before_aggregate.type_signature, before_broadcast_info_packed))

    canonical_form_types = check_and_pack_after_aggregate_type_signature(
        after_aggregate.type_signature, before_aggregate_info_packed)

    initialize = transformations.consolidate_and_extract_local_processing(
        initialize_comp)

    if not (isinstance(initialize, tff_framework.CompiledComputation)
            and initialize.type_signature.result
            == canonical_form_types['initialize_type'].member):
        raise transformations.CanonicalFormCompilationError(
            'Compilation of initialize has failed. Expected to extract a '
            '`tff_framework.CompiledComputation` of type {}, instead we extracted '
            'a {} of type {}.'.format(next_comp.type_signature.parameter[0],
                                      type(initialize),
                                      initialize.type_signature.result))

    prepare = extract_prepare(before_broadcast, canonical_form_types)

    work = extract_work(before_aggregate, after_aggregate,
                        canonical_form_types)

    zero_noarg_function, accumulate, merge, report = extract_aggregate_functions(
        before_aggregate, canonical_form_types)

    update = extract_update(after_aggregate, canonical_form_types)

    cf = canonical_form.CanonicalForm(
        tff_framework.building_block_to_computation(initialize),
        tff_framework.building_block_to_computation(prepare),
        tff_framework.building_block_to_computation(work),
        tff_framework.building_block_to_computation(zero_noarg_function),
        tff_framework.building_block_to_computation(accumulate),
        tff_framework.building_block_to_computation(merge),
        tff_framework.building_block_to_computation(report),
        tff_framework.building_block_to_computation(update))
    return cf
Beispiel #6
0
def _check_type_equal(actual, expected):
    py_typecheck.check_type(actual, computation_types.Type)
    py_typecheck.check_type(expected, computation_types.Type)
    if actual != expected:
        raise transformations.CanonicalFormCompilationError(
            'Expected type of {}, found {}.'.format(expected, actual))
Beispiel #7
0
def _check_placement_equal(actual, expected):
    py_typecheck.check_type(actual, placement_literals.PlacementLiteral)
    py_typecheck.check_type(expected, placement_literals.PlacementLiteral)
    if actual != expected:
        raise transformations.CanonicalFormCompilationError(
            'Expected placement of {}, found {}.'.format(expected, actual))
Beispiel #8
0
def _check_len_equal(target, length):
    py_typecheck.check_type(length, int)
    if len(target) != length:
        raise transformations.CanonicalFormCompilationError(
            'Expected length of {}, found {}.'.format(length, len(target)))
Beispiel #9
0
def _check_type(target, type_spec):
    py_typecheck.check_type(type_spec, type)
    if not isinstance(target, type_spec):
        raise transformations.CanonicalFormCompilationError(
            'Expected type of {}, found {}.'.format(type_spec, type(target)))
Beispiel #10
0
def _check_type_equal(actual, expected, label):
  if actual != expected:
    raise transformations.CanonicalFormCompilationError(
        'Expected \'{}\' to have a type signature of {}, found {}.'.format(
            label, expected, actual))