def test_constructor_with_type_mismatch(self): with self.assertRaises( iterative_process.NextMustAcceptStateFromInitializeError): @computations.federated_computation(tf.float32, tf.float32) def add_float32(current, val): return current + val iterative_process.IterativeProcess( initialize_fn=initialize, next_fn=add_float32) with self.assertRaises(iterative_process.NextMustReturnStateError): @computations.federated_computation(tf.int32) def add_bad_result(_): return 0.0 iterative_process.IterativeProcess( initialize_fn=initialize, next_fn=add_bad_result) with self.assertRaises(iterative_process.NextMustReturnStateError): @computations.federated_computation(tf.int32) def add_bad_multi_result(_): return 0.0, 0 iterative_process.IterativeProcess( initialize_fn=initialize, next_fn=add_bad_multi_result)
def test_constructor_with_type_mismatch(self): with self.assertRaisesRegex( TypeError, r'The return type of initialize_fn should match.*'): @computations.federated_computation([tf.float32, tf.float32]) def add_float32(current, val): return current + val iterative_process.IterativeProcess(initialize_fn=initialize, next_fn=add_float32) with self.assertRaisesRegex( TypeError, 'The return type of next_fn should match the first parameter'): @computations.federated_computation(tf.int32) def add_bad_result(_): return 0.0 iterative_process.IterativeProcess(initialize_fn=initialize, next_fn=add_bad_result) with self.assertRaisesRegex( TypeError, 'The return type of next_fn should match the first parameter'): @computations.federated_computation(tf.int32) def add_bad_multi_result(_): return 0.0, 0 iterative_process.IterativeProcess(initialize_fn=initialize, next_fn=add_bad_multi_result)
def test_constructor_with_initialize_bad_type(self): with self.assertRaisesRegex(TypeError, r'Expected .*\.Computation, .*'): iterative_process.IterativeProcess(initialize_fn=None, next_fn=add_int32) with self.assertRaises(iterative_process.InitializeFnHasArgsError): @computations.federated_computation(tf.int32) def one_arg_initialize(one_arg): del one_arg # Unused. return values.to_value(0) iterative_process.IterativeProcess( initialize_fn=one_arg_initialize, next_fn=add_int32)
def get_iterative_process_for_sum_example_with_no_aggregation(): """Returns an iterative process for a sum example. This iterative process does not have a call to `federated_aggregate` or `federated_secure_sum` and as a result it should fail to compile to `canonical_form.CanonicalForm`. """ @computations.federated_computation def init_fn(): """The `init` function for `tff.templates.IterativeProcess`.""" return intrinsics.federated_value([0, 0], placements.SERVER) @computations.tf_computation([tf.int32, tf.int32], [tf.int32, tf.int32]) def update(server_state, global_update): del server_state # Unused return global_update, [] @computations.federated_computation([ computation_types.FederatedType([tf.int32, tf.int32], placements.SERVER), computation_types.FederatedType(tf.int32, placements.CLIENTS), ]) def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" del client_data # No call to `federated_aggregate`. unsecure_update = intrinsics.federated_value(1, placements.SERVER) # No call to `federated_secure_sum`. secure_update = intrinsics.federated_value(1, placements.SERVER) s6 = intrinsics.federated_zip( [server_state, [unsecure_update, secure_update]]) new_server_state, server_output = intrinsics.federated_map(update, s6) return new_server_state, server_output return iterative_process.IterativeProcess(init_fn, next_fn)
def test_returns_iterproc_accepting_dataset_in_third_index_of_next(self): iterproc = _create_stateless_int_dataset_reduction_iterative_process() old_param_type = iterproc.next.type_signature.parameter new_param_elements = [old_param_type[0], tf.int32, old_param_type[1]] @computations.federated_computation( computation_types.StructType(new_param_elements)) def new_next(param): return iterproc.next([param[0], param[2]]) iterproc_with_dataset_as_third_elem = iterative_process.IterativeProcess( iterproc.initialize, new_next) expected_new_next_type_signature = computation_types.FunctionType([ computation_types.FederatedType(tf.int64, placements.SERVER), tf.int32, computation_types.FederatedType(tf.string, placements.CLIENTS) ], computation_types.FederatedType(tf.int64, placements.SERVER)) new_iterproc = iterative_process_compositions.compose_dataset_computation_with_iterative_process( int_dataset_computation, iterproc_with_dataset_as_third_elem) self.assertTrue( expected_new_next_type_signature.is_equivalent_to( new_iterproc.next.type_signature))
def test_construction_with_empty_state_does_not_raise(self): initialize_fn = computations.tf_computation()(lambda: ()) next_fn = computations.tf_computation(())(lambda x: (x, 1.0)) try: iterative_process.IterativeProcess(initialize_fn, next_fn) except: # pylint: disable=bare-except self.fail('Could not construct an IterativeProcess with empty state.')
def _create_stateless_int_vector_unknown_dim_dataset_reduction_iterative_process( ): # Tests handling client data of unknown shape and summing to fixed shape. @computations.tf_computation() def make_zero(): return tf.reshape(tf.cast(0, tf.int64), shape=[1]) @computations.federated_computation() def init(): return intrinsics.federated_eval(make_zero, placements.SERVER) @computations.tf_computation( computation_types.SequenceType( computation_types.TensorType(tf.int64, shape=[None]))) def reduce_dataset(x): return x.reduce(tf.cast(tf.constant([0]), tf.int64), lambda x, y: x + tf.reduce_sum(y)) @computations.federated_computation( computation_types.FederatedType( computation_types.TensorType(tf.int64, shape=[None]), placements.SERVER), computation_types.FederatedType( computation_types.SequenceType( computation_types.TensorType(tf.int64, shape=[None])), placements.CLIENTS)) def next_fn(server_state, client_data): del server_state # Unused return intrinsics.federated_sum( intrinsics.federated_map(reduce_dataset, client_data)) return iterative_process.IterativeProcess(initialize_fn=init, next_fn=next_fn)
def test_federated_init_state_not_assignable(self): initialize_fn = federated_computation.federated_computation()( lambda: intrinsics.federated_value(0, placements.SERVER)) next_fn = federated_computation.federated_computation( FederatedType(tf.int32, placements.CLIENTS))(lambda state: state) with self.assertRaises(errors.TemplateStateNotAssignableError): iterative_process.IterativeProcess(initialize_fn, next_fn)
def _create_stateless_int_dataset_reduction_iterative_process(): @computations.tf_computation() def make_zero(): return tf.cast(0, tf.int64) @computations.federated_computation() def init(): return intrinsics.federated_eval(make_zero, placement_literals.SERVER) @computations.tf_computation(computation_types.SequenceType(tf.int64)) def reduce_dataset(x): return x.reduce(tf.cast(0, tf.int64), lambda x, y: x + y) @computations.federated_computation( (init.type_signature.result, computation_types.FederatedType( computation_types.SequenceType(tf.int64), placement_literals.CLIENTS))) def next_fn(empty_tup, x): del empty_tup # Unused return intrinsics.federated_sum( intrinsics.federated_map(reduce_dataset, x)) return iterative_process.IterativeProcess(initialize_fn=init, next_fn=next_fn)
def test_disallows_broadcast_dependent_on_aggregate(self): @federated_computation.federated_computation def init_comp(): return intrinsics.federated_value(0, placements.SERVER) @federated_computation.federated_computation( computation_types.at_server(tf.int32), computation_types.at_clients(())) def next_comp(server_state, client_data): del server_state, client_data client_val = intrinsics.federated_value(0, placements.CLIENTS) server_agg = intrinsics.federated_sum(client_val) # This broadcast is dependent on the result of the above aggregation, # which is not supported by MapReduce form. broadcasted = intrinsics.federated_broadcast(server_agg) server_agg_again = intrinsics.federated_sum(broadcasted) # `next` must return two values. return server_agg_again, intrinsics.federated_value( (), placements.SERVER) ip = iterative_process.IterativeProcess(init_comp, next_comp) with self.assertRaises(ValueError): form_utils.check_iterative_process_compatible_with_map_reduce_form( ip)
def test_next_state_not_assignable_tuple_result(self): float_next_fn = computations.tf_computation( tf.float32, tf.float32)(lambda state, x: (tf.cast(state, tf.float32), x)) with self.assertRaises(errors.TemplateStateNotAssignableError): iterative_process.IterativeProcess(test_initialize_fn, float_next_fn)
def get_map_reduce_form_for_client_to_server_fn( self, client_to_server_fn) -> forms.MapReduceForm: """Produces a `MapReduceForm` for the provided `client_to_server_fn`. Creates an `iterative_process.IterativeProcess` which uses `client_to_server_fn` to map from `client_data` to `server_output`, then passes this value through `get_map_reduce_form_for_iterative_process`. Args: client_to_server_fn: A function from client-placed data to server-placed output. Returns: A `forms.MapReduceForm` which uses the embedded `client_to_server_fn`. """ @federated_computation.federated_computation def init_fn(): return intrinsics.federated_value((), placements.SERVER) @federated_computation.federated_computation([ computation_types.at_server(()), computation_types.at_clients(tf.int32), ]) def next_fn(server_state, client_data): server_output = client_to_server_fn(client_data) return server_state, server_output ip = iterative_process.IterativeProcess(init_fn, next_fn) return form_utils.get_map_reduce_form_for_iterative_process(ip)
def test_federated_next_state_not_assignable(self): initialize_fn = computations.federated_computation()( lambda: intrinsics.federated_value(0, placements.SERVER)) next_fn = computations.federated_computation( initialize_fn.type_signature.result)( intrinsics.federated_broadcast) with self.assertRaises(errors.TemplateStateNotAssignableError): iterative_process.IterativeProcess(initialize_fn, next_fn)
def test_constructor_with_state_tuple_arg(self): ip = iterative_process.IterativeProcess(initialize, add_int32) state = ip.initialize() iterations = 10 for val in range(iterations): state = ip.next(state, val) self.assertEqual(state, sum(range(iterations)))
def test_constructor_with_empty_tuple(self): ip = iterative_process.IterativeProcess(initialize_empty_tuple, next_empty_tuple) state = ip.initialize() iterations = 2 for _ in range(iterations): state = ip.next(state) self.assertEqual(state, [])
def test_constructor_with_state_multiple_return_values(self): ip = iterative_process.IterativeProcess(initialize, add_mul_int32) state = ip.initialize() iterations = 10 for val in range(iterations): state, product = ip.next(state, val) self.assertEqual(state, sum(range(iterations))) self.assertEqual(product, sum(range(iterations - 1)) * (iterations - 1))
def get_iterative_process_with_nested_broadcasts(): """Returns an iterative process with nested federated broadcasts. This iterative process contains all the components required to compile to `forms.MapReduceForm`. """ @federated_computation.federated_computation def init_fn(): """The `init` function for `tff.templates.IterativeProcess`.""" return intrinsics.federated_value([0, 0], placements.SERVER) @tensorflow_computation.tf_computation([tf.int32, tf.int32]) def prepare(server_state): return server_state @tensorflow_computation.tf_computation(tf.int32, [tf.int32, tf.int32]) def work(client_data, client_input): del client_data # Unused del client_input # Unused return 1, 1 @tensorflow_computation.tf_computation([tf.int32, tf.int32], [tf.int32, tf.int32]) def update(server_state, global_update): del server_state # Unused return global_update, [] @federated_computation.federated_computation( computation_types.FederatedType([tf.int32, tf.int32], placements.SERVER)) def broadcast_and_return_arg_and_result(x): broadcasted = intrinsics.federated_broadcast(x) return [broadcasted, x] @federated_computation.federated_computation([ computation_types.FederatedType([tf.int32, tf.int32], placements.SERVER), computation_types.FederatedType(tf.int32, placements.CLIENTS), ]) def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" s2 = intrinsics.federated_map(prepare, server_state) unused_client_input, to_broadcast = broadcast_and_return_arg_and_result( s2) client_input = intrinsics.federated_broadcast(to_broadcast) c3 = intrinsics.federated_zip([client_data, client_input]) client_updates = intrinsics.federated_map(work, c3) unsecure_update = intrinsics.federated_sum(client_updates[0]) secure_update = intrinsics.federated_secure_sum_bitwidth( client_updates[1], 8) s6 = intrinsics.federated_zip( [server_state, [unsecure_update, secure_update]]) new_server_state, server_output = intrinsics.federated_map(update, s6) return new_server_state, server_output return iterative_process.IterativeProcess(init_fn, next_fn)
def _create_test_iterative_process(state_type, state_init): @tensorflow_computation.tf_computation(state_type) def next_fn(state): return state return iterative_process.IterativeProcess( initialize_fn=tensorflow_computation.tf_computation( lambda: tf.constant(state_init)), next_fn=next_fn)
def test_constructor_with_state_only(self): ip = iterative_process.IterativeProcess(initialize, count_int32) state = ip.initialize() iterations = 10 for _ in range(iterations): # TODO(b/122321354): remove the .item() call on `state` once numpy.int32 # type is supported. state = ip.next(state.item()) self.assertEqual(state, iterations)
def test_next_computation_returning_tensor_fails_well(self): mrf = mapreduce_test_utils.get_temperature_sensor_example() it = form_utils.get_iterative_process_for_map_reduce_form(mrf) init_result = it.initialize.type_signature.result lam = building_blocks.Lambda( 'x', init_result, building_blocks.Reference('x', init_result)) bad_it = iterative_process.IterativeProcess( it.initialize, computation_impl.ConcreteComputation.from_building_block(lam)) with self.assertRaises(TypeError): form_utils.get_map_reduce_form_for_iterative_process(bad_it)
def test_next_computation_returning_tensor_fails_well(self): cf = test_utils.get_temperature_sensor_example() it = canonical_form_utils.get_iterative_process_for_canonical_form(cf) init_result = it.initialize.type_signature.result lam = building_blocks.Lambda('x', init_result, building_blocks.Reference('x', init_result)) bad_it = iterative_process.IterativeProcess( it.initialize, computation_wrapper_instances.building_block_to_computation(lam)) with self.assertRaises(TypeError): canonical_form_utils.get_canonical_form_for_iterative_process(bad_it)
def _create_whimsy_iterative_process(): @computations.tf_computation() def init(): return [] @computations.tf_computation(init.type_signature.result) def next_fn(x): return x return iterative_process.IterativeProcess(initialize_fn=init, next_fn=next_fn)
def test_raises_on_invalid_distributor(self): model_weights_type = type_conversions.type_from_tensors( model_utils.ModelWeights.from_model( model_examples.LinearRegression())) distributor = distributors.build_broadcast_process(model_weights_type) invalid_distributor = iterative_process.IterativeProcess( distributor.initialize, distributor.next) with self.assertRaises(TypeError): fed_avg.build_weighted_fed_avg( model_fn=model_examples.LinearRegression, client_optimizer_fn=sgdm.build_sgdm(1.0), model_distributor=invalid_distributor)
def get_iterative_process_for_map_reduce_form( mrf: forms.MapReduceForm) -> iterative_process.IterativeProcess: """Creates `tff.templates.IterativeProcess` from a MapReduce form. Args: mrf: An instance of `tff.backends.mapreduce.MapReduceForm`. Returns: An instance of `tff.templates.IterativeProcess` that corresponds to `mrf`. Raises: TypeError: If the arguments are of the wrong types. """ py_typecheck.check_type(mrf, forms.MapReduceForm) @federated_computation.federated_computation def init_computation(): return intrinsics.federated_value(mrf.initialize(), placements.SERVER) next_parameter_type = computation_types.StructType([ (mrf.server_state_label, init_computation.type_signature.result), (mrf.client_data_label, computation_types.FederatedType(mrf.work.type_signature.parameter[0], placements.CLIENTS)), ]) @federated_computation.federated_computation(next_parameter_type) def next_computation(arg): """The logic of a single MapReduce processing round.""" server_state, client_data = arg broadcast_input = intrinsics.federated_map(mrf.prepare, server_state) broadcast_result = intrinsics.federated_broadcast(broadcast_input) work_arg = intrinsics.federated_zip([client_data, broadcast_result]) (aggregate_input, secure_sum_bitwidth_input, secure_sum_input, secure_modular_sum_input) = intrinsics.federated_map( mrf.work, work_arg) aggregate_result = intrinsics.federated_aggregate( aggregate_input, mrf.zero(), mrf.accumulate, mrf.merge, mrf.report) secure_sum_bitwidth_result = intrinsics.federated_secure_sum_bitwidth( secure_sum_bitwidth_input, mrf.secure_sum_bitwidth()) secure_sum_result = intrinsics.federated_secure_sum( secure_sum_input, mrf.secure_sum_max_input()) secure_modular_sum_result = intrinsics.federated_secure_modular_sum( secure_modular_sum_input, mrf.secure_modular_sum_modulus()) update_arg = intrinsics.federated_zip( (server_state, (aggregate_result, secure_sum_bitwidth_result, secure_sum_result, secure_modular_sum_result))) updated_server_state, server_output = intrinsics.federated_map( mrf.update, update_arg) return updated_server_state, server_output return iterative_process.IterativeProcess(init_computation, next_computation)
def test_raises_on_invalid_distributor(self): model_weights_type = type_conversions.type_from_tensors( model_utils.ModelWeights.from_model( model_examples.LinearRegression())) distributor = distributors.build_broadcast_process(model_weights_type) invalid_distributor = iterative_process.IterativeProcess( distributor.initialize, distributor.next) with self.assertRaises(TypeError): mime.build_weighted_mime_lite( model_fn=model_examples.LinearRegression, base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9), model_distributor=invalid_distributor)
def test_construction_with_unknown_dimension_does_not_raise(self): initialize_fn = computations.tf_computation()( lambda: tf.constant([], dtype=tf.string)) @computations.tf_computation( computation_types.TensorType(shape=[None], dtype=tf.string)) def next_fn(strings): return tf.concat([strings, tf.constant(['abc'])], axis=0) try: iterative_process.IterativeProcess(initialize_fn, next_fn) except: # pylint: disable=bare-except self.fail('Could not construct an IterativeProcess with parameter types ' 'with statically unknown shape.')
def get_iterative_process_for_map_reduce_form( mrf: forms.MapReduceForm) -> iterative_process.IterativeProcess: """Creates `tff.templates.IterativeProcess` from a MapReduce form. Args: mrf: An instance of `tff.backends.mapreduce.MapReduceForm`. Returns: An instance of `tff.templates.IterativeProcess` that corresponds to `mrf`. Raises: TypeError: If the arguments are of the wrong types. """ py_typecheck.check_type(mrf, forms.MapReduceForm) @computations.federated_computation def init_computation(): return intrinsics.federated_value(mrf.initialize(), placements.SERVER) next_parameter_type = computation_types.StructType([ (mrf.server_state_label, init_computation.type_signature.result), (mrf.client_data_label, computation_types.FederatedType(mrf.work.type_signature.parameter[0], placements.CLIENTS)), ]) @computations.federated_computation(next_parameter_type) def next_computation(arg): """The logic of a single MapReduce processing round.""" s1 = arg[0] c1 = arg[1] s2 = intrinsics.federated_map(mrf.prepare, s1) c2 = intrinsics.federated_broadcast(s2) c3 = intrinsics.federated_zip([c1, c2]) c4 = intrinsics.federated_map(mrf.work, c3) c5 = c4[0] c6 = c4[1] s3 = intrinsics.federated_aggregate(c5, mrf.zero(), mrf.accumulate, mrf.merge, mrf.report) s4 = intrinsics.federated_secure_sum_bitwidth(c6, mrf.bitwidth()) s5 = intrinsics.federated_zip([s3, s4]) s6 = intrinsics.federated_zip([s1, s5]) s7 = intrinsics.federated_map(mrf.update, s6) s8 = s7[0] s9 = s7[1] return s8, s9 return iterative_process.IterativeProcess(init_computation, next_computation)
def _create_federated_int_dataset_identity_iterative_process(): @computations.tf_computation() def create_dataset(): return tf.data.Dataset.range(5) @computations.federated_computation() def init(): return intrinsics.federated_eval(create_dataset, placements.CLIENTS) @computations.federated_computation(init.type_signature.result) def next_fn(x): return x return iterative_process.IterativeProcess(initialize_fn=init, next_fn=next_fn)
def test_constructor_with_tensors_unknown_dimensions(self): @computations.tf_computation def init(): return tf.constant([], dtype=tf.string) @computations.tf_computation( computation_types.TensorType(shape=[None], dtype=tf.string)) def next_fn(strings): return tf.concat([strings, tf.constant(['abc'])], axis=0) try: iterative_process.IterativeProcess(init, next_fn) except: # pylint: disable=bare-except self.fail( 'Could not construct an IterativeProcess with parameter types ' 'including unknown dimension tennsors.')
def get_iterative_process_for_canonical_form(cf): """Creates `tff.utils.IterativeProcess` from a canonical form. Args: cf: An instance of `tff.backends.mapreduce.CanonicalForm`. Returns: An instance of `tff.utils.IterativeProcess` that corresponds to `cf`. Raises: TypeError: If the arguments are of the wrong types. """ py_typecheck.check_type(cf, canonical_form.CanonicalForm) @computations.federated_computation def init_computation(): return intrinsics.federated_value(cf.initialize(), placements.SERVER) @computations.federated_computation( init_computation.type_signature.result, computation_types.FederatedType(cf.work.type_signature.parameter[0], placements.CLIENTS)) def next_computation(arg): """The logic of a single MapReduce processing round.""" s1 = arg[0] c1 = arg[1] s2 = intrinsics.federated_map(cf.prepare, s1) c2 = intrinsics.federated_broadcast(s2) c3 = intrinsics.federated_zip([c1, c2]) c4 = intrinsics.federated_map(cf.work, c3) c5 = c4[0] c6 = c5[0] c7 = c5[1] c8 = c4[1] s3 = intrinsics.federated_aggregate(c6, cf.zero(), cf.accumulate, cf.merge, cf.report) s4 = intrinsics.federated_secure_sum(c7, cf.bitwidth()) s5 = intrinsics.federated_zip([s3, s4]) s6 = intrinsics.federated_zip([s1, s5]) s7 = intrinsics.federated_map(cf.update, s6) s8 = s7[0] s9 = s7[1] return s8, s9, c8 return iterative_process.IterativeProcess(init_computation, next_computation)