Exemple #1
0
 def test_np_values():
   floatv = np.float64(0)
   tff_float = tff.federated_value(floatv, tff.SERVER)
   self.assertEqual(str(tff_float.type_signature), 'float64@SERVER')
   intv = np.int64(0)
   tff_int = tff.federated_value(intv, tff.SERVER)
   self.assertEqual(str(tff_int.type_signature), 'int64@SERVER')
   return (tff_float, tff_int)
Exemple #2
0
        def next_fn(server_state, client_data):
            broadcast_state = tff.federated_broadcast(server_state)

            @tff.tf_computation(tf.int32, tff.SequenceType(tf.float32))
            @tf.function
            def some_transform(x, y):
                del y  # Unused
                return x + 1

            client_update = tff.federated_map(some_transform,
                                              (broadcast_state, client_data))
            aggregate_update = tff.federated_sum(client_update)
            server_output = tff.federated_value(1234, tff.SERVER)
            return aggregate_update, server_output
Exemple #3
0
def get_canonical_form_for_iterative_process(iterative_process):
    """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process.

  This function transforms computations from the input `iterative_process` into
  an instance of `tff.backends.mapreduce.CanonicalForm`.

  Args:
    iterative_process: An instance of `tff.utils.IterativeProcess`.

  Returns:
    An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this
    process.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.CanonicalFormCompilationError: If the compilation
      process fails.
  """
    py_typecheck.check_type(iterative_process,
                            computation_utils.IterativeProcess)

    initialize_comp = tff_framework.ComputationBuildingBlock.from_proto(
        iterative_process.initialize._computation_proto)  # pylint: disable=protected-access

    next_comp = tff_framework.ComputationBuildingBlock.from_proto(
        iterative_process.next._computation_proto)  # pylint: disable=protected-access

    if not (isinstance(next_comp.type_signature.parameter, tff.NamedTupleType)
            and isinstance(next_comp.type_signature.result,
                           tff.NamedTupleType)):
        raise TypeError(
            'Any IterativeProcess compatible with CanonicalForm must '
            'have a `next` function which takes and returns instances '
            'of `tff.NamedTupleType`; your next function takes '
            'parameters of type {} and returns results of type {}'.format(
                next_comp.type_signature.parameter,
                next_comp.type_signature.result))

    if len(next_comp.type_signature.result) == 2:
        next_result = next_comp.result
        dummy_clients_metrics_appended = tff_framework.Tuple([
            next_result[0],
            next_result[1],
            tff.federated_value([], tff.CLIENTS)._comp  # pylint: disable=protected-access
        ])
        next_comp = tff_framework.Lambda(next_comp.parameter_name,
                                         next_comp.parameter_type,
                                         dummy_clients_metrics_appended)

    initialize_comp = tff_framework.replace_intrinsics_with_bodies(
        initialize_comp)
    next_comp = tff_framework.replace_intrinsics_with_bodies(next_comp)

    tff_framework.check_intrinsics_whitelisted_for_reduction(initialize_comp)
    tff_framework.check_intrinsics_whitelisted_for_reduction(next_comp)
    tff_framework.check_broadcast_not_dependent_on_aggregate(next_comp)

    before_broadcast, after_broadcast = (
        transformations.force_align_and_split_by_intrinsic(
            next_comp, tff_framework.FEDERATED_BROADCAST.uri))

    before_aggregate, after_aggregate = (
        transformations.force_align_and_split_by_intrinsic(
            after_broadcast, tff_framework.FEDERATED_AGGREGATE.uri))

    init_info_packed = pack_initialize_comp_type_signature(
        initialize_comp.type_signature)

    next_info_packed = pack_next_comp_type_signature(next_comp.type_signature,
                                                     init_info_packed)

    before_broadcast_info_packed = (
        check_and_pack_before_broadcast_type_signature(
            before_broadcast.type_signature, next_info_packed))

    before_aggregate_info_packed = (
        check_and_pack_before_aggregate_type_signature(
            before_aggregate.type_signature, before_broadcast_info_packed))

    canonical_form_types = check_and_pack_after_aggregate_type_signature(
        after_aggregate.type_signature, before_aggregate_info_packed)

    initialize = transformations.consolidate_and_extract_local_processing(
        initialize_comp)

    if not (isinstance(initialize, tff_framework.CompiledComputation)
            and initialize.type_signature.result
            == canonical_form_types['initialize_type'].member):
        raise transformations.CanonicalFormCompilationError(
            'Compilation of initialize has failed. Expected to extract a '
            '`tff_framework.CompiledComputation` of type {}, instead we extracted '
            'a {} of type {}.'.format(next_comp.type_signature.parameter[0],
                                      type(initialize),
                                      initialize.type_signature.result))

    prepare = extract_prepare(before_broadcast, canonical_form_types)

    work = extract_work(before_aggregate, after_aggregate,
                        canonical_form_types)

    zero_noarg_function, accumulate, merge, report = extract_aggregate_functions(
        before_aggregate, canonical_form_types)

    update = extract_update(after_aggregate, canonical_form_types)

    cf = canonical_form.CanonicalForm(
        tff_framework.building_block_to_computation(initialize),
        tff_framework.building_block_to_computation(prepare),
        tff_framework.building_block_to_computation(work),
        tff_framework.building_block_to_computation(zero_noarg_function),
        tff_framework.building_block_to_computation(accumulate),
        tff_framework.building_block_to_computation(merge),
        tff_framework.building_block_to_computation(report),
        tff_framework.building_block_to_computation(update))
    return cf
Exemple #4
0
 def init_computation():
     return tff.federated_value(cf.initialize(), tff.SERVER)
Exemple #5
0
 def init_fn():
     return tff.federated_value(42, tff.SERVER)
Exemple #6
0
 def run_initialize():
     return tff.federated_value(dp_aggregate_fn.initialize(), tff.SERVER)
 def federated_aggregate_test(values, weights):
   aggregate_fn = computation_utils.StatefulAggregateFn(
       initialize_fn=agg_initialize_fn, next_fn=agg_next_fn)
   state = tff.federated_value(aggregate_fn.initialize(), tff.SERVER)
   return aggregate_fn(state, values, weights)
 def federated_broadcast_test(values):
   broadcast_fn = computation_utils.StatefulBroadcastFn(
       initialize_fn=broadcast_initialize_fn, next_fn=broadcast_next_fn)
   state = tff.federated_value(broadcast_fn.initialize(), tff.SERVER)
   return broadcast_fn(state, values)
Exemple #9
0
 def foo(x):
   return tff.federated_value(x, tff.SERVER)
Exemple #10
0
 def test_federated_value_raw_tf_scalar_variable(self):
   v = tf.Variable(initial_value=0., name='test_var')
   with self.assertRaisesRegex(
       TypeError, 'TensorFlow construct (.*) has been '
       'encountered in a federated context.'):
     _ = tff.federated_value(v, tff.SERVER)
Exemple #11
0
 def foo(x):
   return tff.federated_value(x, tff.CLIENTS)
Exemple #12
0
 def build_federated_zero():
     return tff.federated_value(0, tff.SERVER)