def test_np_values(): floatv = np.float64(0) tff_float = tff.federated_value(floatv, tff.SERVER) self.assertEqual(str(tff_float.type_signature), 'float64@SERVER') intv = np.int64(0) tff_int = tff.federated_value(intv, tff.SERVER) self.assertEqual(str(tff_int.type_signature), 'int64@SERVER') return (tff_float, tff_int)
def next_fn(server_state, client_data): broadcast_state = tff.federated_broadcast(server_state) @tff.tf_computation(tf.int32, tff.SequenceType(tf.float32)) @tf.function def some_transform(x, y): del y # Unused return x + 1 client_update = tff.federated_map(some_transform, (broadcast_state, client_data)) aggregate_update = tff.federated_sum(client_update) server_output = tff.federated_value(1234, tff.SERVER) return aggregate_update, server_output
def get_canonical_form_for_iterative_process(iterative_process): """Constructs `tff.backends.mapreduce.CanonicalForm` given iterative process. This function transforms computations from the input `iterative_process` into an instance of `tff.backends.mapreduce.CanonicalForm`. Args: iterative_process: An instance of `tff.utils.IterativeProcess`. Returns: An instance of `tff.backends.mapreduce.CanonicalForm` equivalent to this process. Raises: TypeError: If the arguments are of the wrong types. transformations.CanonicalFormCompilationError: If the compilation process fails. """ py_typecheck.check_type(iterative_process, computation_utils.IterativeProcess) initialize_comp = tff_framework.ComputationBuildingBlock.from_proto( iterative_process.initialize._computation_proto) # pylint: disable=protected-access next_comp = tff_framework.ComputationBuildingBlock.from_proto( iterative_process.next._computation_proto) # pylint: disable=protected-access if not (isinstance(next_comp.type_signature.parameter, tff.NamedTupleType) and isinstance(next_comp.type_signature.result, tff.NamedTupleType)): raise TypeError( 'Any IterativeProcess compatible with CanonicalForm must ' 'have a `next` function which takes and returns instances ' 'of `tff.NamedTupleType`; your next function takes ' 'parameters of type {} and returns results of type {}'.format( next_comp.type_signature.parameter, next_comp.type_signature.result)) if len(next_comp.type_signature.result) == 2: next_result = next_comp.result dummy_clients_metrics_appended = tff_framework.Tuple([ next_result[0], next_result[1], tff.federated_value([], tff.CLIENTS)._comp # pylint: disable=protected-access ]) next_comp = tff_framework.Lambda(next_comp.parameter_name, next_comp.parameter_type, dummy_clients_metrics_appended) initialize_comp = tff_framework.replace_intrinsics_with_bodies( initialize_comp) next_comp = tff_framework.replace_intrinsics_with_bodies(next_comp) tff_framework.check_intrinsics_whitelisted_for_reduction(initialize_comp) tff_framework.check_intrinsics_whitelisted_for_reduction(next_comp) tff_framework.check_broadcast_not_dependent_on_aggregate(next_comp) before_broadcast, after_broadcast = ( transformations.force_align_and_split_by_intrinsic( next_comp, tff_framework.FEDERATED_BROADCAST.uri)) before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsic( after_broadcast, tff_framework.FEDERATED_AGGREGATE.uri)) init_info_packed = pack_initialize_comp_type_signature( initialize_comp.type_signature) next_info_packed = pack_next_comp_type_signature(next_comp.type_signature, init_info_packed) before_broadcast_info_packed = ( check_and_pack_before_broadcast_type_signature( before_broadcast.type_signature, next_info_packed)) before_aggregate_info_packed = ( check_and_pack_before_aggregate_type_signature( before_aggregate.type_signature, before_broadcast_info_packed)) canonical_form_types = check_and_pack_after_aggregate_type_signature( after_aggregate.type_signature, before_aggregate_info_packed) initialize = transformations.consolidate_and_extract_local_processing( initialize_comp) if not (isinstance(initialize, tff_framework.CompiledComputation) and initialize.type_signature.result == canonical_form_types['initialize_type'].member): raise transformations.CanonicalFormCompilationError( 'Compilation of initialize has failed. Expected to extract a ' '`tff_framework.CompiledComputation` of type {}, instead we extracted ' 'a {} of type {}.'.format(next_comp.type_signature.parameter[0], type(initialize), initialize.type_signature.result)) prepare = extract_prepare(before_broadcast, canonical_form_types) work = extract_work(before_aggregate, after_aggregate, canonical_form_types) zero_noarg_function, accumulate, merge, report = extract_aggregate_functions( before_aggregate, canonical_form_types) update = extract_update(after_aggregate, canonical_form_types) cf = canonical_form.CanonicalForm( tff_framework.building_block_to_computation(initialize), tff_framework.building_block_to_computation(prepare), tff_framework.building_block_to_computation(work), tff_framework.building_block_to_computation(zero_noarg_function), tff_framework.building_block_to_computation(accumulate), tff_framework.building_block_to_computation(merge), tff_framework.building_block_to_computation(report), tff_framework.building_block_to_computation(update)) return cf
def init_computation(): return tff.federated_value(cf.initialize(), tff.SERVER)
def init_fn(): return tff.federated_value(42, tff.SERVER)
def run_initialize(): return tff.federated_value(dp_aggregate_fn.initialize(), tff.SERVER)
def federated_aggregate_test(values, weights): aggregate_fn = computation_utils.StatefulAggregateFn( initialize_fn=agg_initialize_fn, next_fn=agg_next_fn) state = tff.federated_value(aggregate_fn.initialize(), tff.SERVER) return aggregate_fn(state, values, weights)
def federated_broadcast_test(values): broadcast_fn = computation_utils.StatefulBroadcastFn( initialize_fn=broadcast_initialize_fn, next_fn=broadcast_next_fn) state = tff.federated_value(broadcast_fn.initialize(), tff.SERVER) return broadcast_fn(state, values)
def foo(x): return tff.federated_value(x, tff.SERVER)
def test_federated_value_raw_tf_scalar_variable(self): v = tf.Variable(initial_value=0., name='test_var') with self.assertRaisesRegex( TypeError, 'TensorFlow construct (.*) has been ' 'encountered in a federated context.'): _ = tff.federated_value(v, tff.SERVER)
def foo(x): return tff.federated_value(x, tff.CLIENTS)
def build_federated_zero(): return tff.federated_value(0, tff.SERVER)