def test_learning_process_can_be_reconstructed(self): process = learning_process.LearningProcess(test_init_fn, test_next_fn, test_get_model_weights_fn, test_set_model_weights_fn) try: learning_process.LearningProcess(process.initialize, process.next, process.get_model_weights, process.set_model_weights) except: # pylint: disable=bare-except self.fail('Could not reconstruct the LearningProcess.')
def test_federated_get_model_weights_raises(self): bad_get_model_weights = create_pass_through_get_model_weights( at_server(tf.float32)) with self.assertRaises(learning_process.GetModelWeightsTypeSignatureError): learning_process.LearningProcess(test_init_fn, test_next_fn, bad_get_model_weights, test_set_model_weights_fn)
def test_next_not_tff_computation_raises(self): with self.assertRaisesRegex(TypeError, r'Expected .*\.Computation, .*'): learning_process.LearningProcess( initialize_fn=test_init_fn, next_fn=lambda state, client_data: LearningProcessOutput(state, ()), get_model_weights=test_get_model_weights_fn, set_model_weights=test_set_model_weights_fn)
def test_set_model_weights_result_not_assignable(self): bad_set_model_weights_fn = create_take_arg_set_model_weights( tf.int32, tf.float32) with self.assertRaises(learning_process.SetModelWeightsTypeSignatureError): learning_process.LearningProcess(test_init_fn, test_next_fn, test_get_model_weights_fn, bad_set_model_weights_fn)
def test_construction_with_nested_datasets_does_not_raise(self): @federated_computation def initialize_fn(): return intrinsics.federated_eval( tf_computation(lambda: tf.constant(0.0, tf.float32)), placements.SERVER) # Test that clients can receive multiple datasets. datasets_type = (SequenceType(tf.string), (SequenceType(tf.string), SequenceType(tf.string))) @federated_computation(at_server(tf.float32), at_clients(datasets_type)) def next_fn(state, datasets): del datasets # Unused. return LearningProcessOutput( state, intrinsics.federated_value((), placements.SERVER)) try: learning_process.LearningProcess( initialize_fn, next_fn, create_pass_through_get_model_weights(tf.float32), create_take_arg_set_model_weights(tf.float32, tf.float32)) except learning_process.LearningProcessSequenceTypeError: self.fail('Could not construct a LearningProcess with second parameter ' 'type having nested sequences.')
def test_construction_with_unknown_dimension_does_not_raise(self): @federated_computation def initialize_fn(): return intrinsics.federated_eval( tf_computation(lambda: tf.constant([], tf.string)), placements.SERVER) # This replicates a tensor that can grow in string length. The # `initialize_fn` will concretely start with shape `[0]`, but `next_fn` will # grow this, hence the need to define the shape as `[None]`. none_dimension_string_type = TensorType(shape=[None], dtype=tf.string) @federated_computation( at_server(none_dimension_string_type), at_clients(SequenceType(tf.string))) def next_fn(state, datasets): del datasets # Unused. return LearningProcessOutput( state, intrinsics.federated_value((), placements.SERVER)) try: learning_process.LearningProcess( initialize_fn, next_fn, create_pass_through_get_model_weights(none_dimension_string_type), create_take_arg_set_model_weights(none_dimension_string_type, none_dimension_string_type)) except: # pylint: disable=bare-except self.fail('Could not construct a LearningProcess with state type having ' 'statically unknown shape.')
def test_construction_does_not_raise(self): try: learning_process.LearningProcess(test_init_fn, test_next_fn, test_get_model_weights_fn, test_set_model_weights_fn) except: # pylint: disable=bare-except self.fail('Could not construct a valid LearningProcess.')
def test_next_fn_with_one_parameter_raises(self): @federated_computation(at_server(tf.int32)) def next_fn(state): return LearningProcessOutput(state, 0) with self.assertRaises(errors.TemplateNextFnNumArgsError): learning_process.LearningProcess(test_init_fn, next_fn, test_get_model_weights_fn, test_set_model_weights_fn)
def test_init_state_not_federated(self): @federated_computation def float_initialize_fn(): return 0.0 with self.assertRaises(errors.TemplateStateNotAssignableError): learning_process.LearningProcess(float_initialize_fn, test_next_fn, test_get_model_weights_fn, test_set_model_weights_fn)
def test_init_param_not_empty_raises(self): @federated_computation(at_server(tf.int32)) def one_arg_initialize_fn(x): return x with self.assertRaises(errors.TemplateInitFnParamNotEmptyError): learning_process.LearningProcess(one_arg_initialize_fn, test_next_fn, test_get_model_weights_fn, test_set_model_weights_fn)
def test_next_fn_with_nonsequence_second_arg_raises(self): @federated_computation(at_server(tf.int32), at_clients(tf.int32)) def next_fn(state, client_values): metrics = intrinsics.federated_sum(client_values) return LearningProcessOutput(state, metrics) with self.assertRaises(learning_process.LearningProcessSequenceTypeError): learning_process.LearningProcess(test_init_fn, next_fn, test_get_model_weights_fn, test_set_model_weights_fn)
def test_next_fn_with_client_placed_metrics_result_raises(self): @federated_computation( at_server(tf.int32), at_clients(SequenceType(tf.int32))) def next_fn(state, metrics): return LearningProcessOutput(state, metrics) with self.assertRaises(learning_process.LearningProcessPlacementError): learning_process.LearningProcess(test_init_fn, next_fn, test_get_model_weights_fn, test_set_model_weights_fn)
def test_next_state_not_federated(self): @federated_computation(tf.float32, at_clients(SequenceType(tf.float32))) def float_next_fn(state, datasets): del datasets # Unused. return state float_next_fn = create_pass_through_get_model_weights(tf.float32) with self.assertRaises(errors.TemplateStateNotAssignableError): learning_process.LearningProcess(test_init_fn, float_next_fn, test_get_model_weights_fn, test_set_model_weights_fn)
def test_next_fn_with_server_placed_second_arg_raises(self): @federated_computation( at_server(tf.int32), at_server(SequenceType(tf.int32))) def next_fn(state, server_values): metrics = intrinsics.federated_map(sum_dataset, server_values) return LearningProcessOutput(state, metrics) with self.assertRaises(learning_process.LearningProcessPlacementError): learning_process.LearningProcess(test_init_fn, next_fn, test_get_model_weights_fn, test_set_model_weights_fn)
def test_next_return_odict_raises(self): @federated_computation( at_server(tf.int32), at_clients(SequenceType(tf.int32))) def odict_next_fn(state, client_values): metrics = intrinsics.federated_map(sum_dataset, client_values) metrics = intrinsics.federated_sum(metrics) return collections.OrderedDict(state=state, metrics=metrics) with self.assertRaises(learning_process.LearningProcessOutputError): learning_process.LearningProcess(test_init_fn, odict_next_fn, test_get_model_weights_fn, test_set_model_weights_fn)
def test_init_fn_with_client_placed_state_raises(self): @federated_computation def init_fn(): return intrinsics.federated_value(0, placements.CLIENTS) @federated_computation( at_clients(tf.int32), at_clients(SequenceType(tf.int32))) def next_fn(state, client_values): return LearningProcessOutput(state, client_values) with self.assertRaises(learning_process.LearningProcessPlacementError): learning_process.LearningProcess(init_fn, next_fn, test_get_model_weights_fn, test_set_model_weights_fn)
def test_next_fn_with_three_parameters_raises(self): @federated_computation( at_server(tf.int32), at_clients(SequenceType(tf.int32)), at_server(tf.int32)) def next_fn(state, client_values, second_state): del second_state # Unused. metrics = intrinsics.federated_map(sum_dataset, client_values) metrics = intrinsics.federated_sum(metrics) return LearningProcessOutput(state, metrics) with self.assertRaises(errors.TemplateNextFnNumArgsError): learning_process.LearningProcess(test_init_fn, next_fn, test_get_model_weights_fn, test_set_model_weights_fn)
def test_next_return_namedtuple_raises(self): learning_process_output = collections.namedtuple('LearningProcessOutput', ['state', 'metrics']) @federated_computation( at_server(tf.int32), at_clients(SequenceType(tf.int32))) def namedtuple_next_fn(state, client_values): metrics = intrinsics.federated_map(sum_dataset, client_values) metrics = intrinsics.federated_sum(metrics) return learning_process_output(state, metrics) with self.assertRaises(learning_process.LearningProcessOutputError): learning_process.LearningProcess(test_init_fn, namedtuple_next_fn, test_get_model_weights_fn, test_set_model_weights_fn)
def test_returns_iterative_process_with_same_non_next_type_signatures(self): @tensorflow_computation.tf_computation() def make_zero(): return tf.cast(0, tf.int64) @federated_computation.federated_computation() def initialize_fn(): return intrinsics.federated_eval(make_zero, placements.SERVER) @federated_computation.federated_computation( (initialize_fn.type_signature.result, computation_types.FederatedType( computation_types.SequenceType(tf.int64), placements.CLIENTS))) def next_fn(server_state, client_data): del client_data return learning_process.LearningProcessOutput( state=server_state, metrics=intrinsics.federated_value((), placements.SERVER)) @tensorflow_computation.tf_computation(tf.int64) def get_model_weights(server_state): return server_state + 1 @tensorflow_computation.tf_computation(tf.int64, tf.int64) def set_model_weights(state, state_update): return state + state_update process = learning_process.LearningProcess( initialize_fn=initialize_fn, next_fn=next_fn, get_model_weights=get_model_weights, set_model_weights=set_model_weights) new_process = iterative_process_compositions.compose_dataset_computation_with_learning_process( int_dataset_computation, process) self.assertIsInstance(new_process, iterative_process.IterativeProcess) self.assertTrue(hasattr(new_process, 'get_model_weights')) self.assertTrue( new_process.get_model_weights.type_signature.is_equivalent_to( process.get_model_weights.type_signature)) self.assertTrue(hasattr(new_process, 'set_model_weights')) self.assertTrue( new_process.set_model_weights.type_signature.is_equivalent_to( process.set_model_weights.type_signature))
def test_construction_with_empty_state_does_not_raise(self): empty_tuple = () @federated_computation def empty_initialize_fn(): return intrinsics.federated_value(empty_tuple, placements.SERVER) @federated_computation( at_server(empty_tuple), at_clients(SequenceType(tf.int32))) def next_fn(state, value): del value # Unused. return LearningProcessOutput( state=state, metrics=intrinsics.federated_value(empty_tuple, placements.SERVER)) try: learning_process.LearningProcess( empty_initialize_fn, next_fn, create_pass_through_get_model_weights(empty_tuple), create_take_arg_set_model_weights(empty_tuple, empty_tuple)) except: # pylint: disable=bare-except self.fail('Could not construct a LearningProcess with empty state.')
def compose_learning_process( initial_model_weights_fn: computation_base.Computation, model_weights_distributor: distributors.DistributionProcess, client_work: client_works.ClientWorkProcess, model_update_aggregator: aggregation_process.AggregationProcess, model_finalizer: finalizers.FinalizerProcess ) -> learning_process.LearningProcess: """Composes specialized measured processes into a learning process. Given 4 specialized measured processes (described below) that make a learning process, and a computation that returns initial model weights to be used for training, this method validates that the processes fit together, and returns a `LearningProcess`. Please see the tutorial at https://www.tensorflow.org/federated/tutorials/composing_learning_algorithms for more details on composing learning processes. The main purpose of the 4 measured processes are: * `model_weights_distributor`: Make global model weights at server available as the starting point for learning work to be done at clients. * `client_work`: Produce an update to the model received at clients. * `model_update_aggregator`: Aggregates the model updates from clients to the server. * `model_finalizer`: Updates the global model weights using the aggregated model update at server. The `next` computation of the created learning process is composed from the `next` computations of the 4 measured processes, in order as visualized below. The type signatures of the processes must be such that this chaining is possible. Each process also reports its own metrics. ``` ┌─────────────────────────┐ │model_weights_distributor│ └△─┬─┬────────────────────┘ │ │┌▽──────────┐ │ ││client_work│ │ │└┬─────┬────┘ │┌▽─▽────┐│ ││metrics││ │└△─△────┘│ │ │┌┴─────▽────────────────┐ │ ││model_update_aggregator│ │ │└┬──────────────────────┘ ┌┴─┴─▽──────────┐ │model_finalizer│ └┬──────────────┘ ┌▽─────┐ │result│ └──────┘ ``` Args: initial_model_weights_fn: A `tff.Computation` that returns (unplaced) initial model weights. model_weights_distributor: A `DistributionProcess`. client_work: A `ClientWorkProcess`. model_update_aggregator: A `tff.templates.AggregationProcess`. model_finalizer: A `FinalizerProcess`. Returns: A `LearningProcess`. """ # pyformat: enable _validate_args(initial_model_weights_fn, model_weights_distributor, client_work, model_update_aggregator, model_finalizer) client_data_type = client_work.next.type_signature.parameter[2] @federated_computation.federated_computation() def init_fn(): initial_model_weights = intrinsics.federated_eval( initial_model_weights_fn, placements.SERVER) return intrinsics.federated_zip( LearningAlgorithmState(initial_model_weights, model_weights_distributor.initialize(), client_work.initialize(), model_update_aggregator.initialize(), model_finalizer.initialize())) @federated_computation.federated_computation(init_fn.type_signature.result, client_data_type) def next_fn(state, client_data): # Compose processes. distributor_output = model_weights_distributor.next( state.distributor, state.global_model_weights) client_work_output = client_work.next(state.client_work, distributor_output.result, client_data) aggregator_output = model_update_aggregator.next( state.aggregator, client_work_output.result.update, client_work_output.result.update_weight) finalizer_output = model_finalizer.next(state.finalizer, state.global_model_weights, aggregator_output.result) # Form the learning process output. new_global_model_weights = finalizer_output.result new_state = intrinsics.federated_zip( LearningAlgorithmState(new_global_model_weights, distributor_output.state, client_work_output.state, aggregator_output.state, finalizer_output.state)) metrics = intrinsics.federated_zip( collections.OrderedDict( distributor=distributor_output.measurements, client_work=client_work_output.measurements, aggregator=aggregator_output.measurements, finalizer=finalizer_output.measurements)) return learning_process.LearningProcessOutput(new_state, metrics) state_parameter_type = next_fn.type_signature.parameter[0].member @tensorflow_computation.tf_computation(state_parameter_type) def get_model_weights_fn(state): return state.global_model_weights @tensorflow_computation.tf_computation( state_parameter_type, state_parameter_type.global_model_weights) def set_model_weights_fn(state, model_weights): return attr.evolve(state, global_model_weights=model_weights) return learning_process.LearningProcess(init_fn, next_fn, get_model_weights_fn, set_model_weights_fn)
def test_init_not_tff_computation_raises(self): with self.assertRaisesRegex(TypeError, r'Expected .*\.Computation, .*'): init_fn = lambda: 0 learning_process.LearningProcess(init_fn, test_next_fn, test_get_model_weights_fn, test_set_model_weights_fn)
def test_get_model_weights_param_not_equivalent_to_next_fn(self): bad_get_model_weights = create_pass_through_get_model_weights(tf.float32) with self.assertRaises(learning_process.GetModelWeightsTypeSignatureError): learning_process.LearningProcess(test_init_fn, test_next_fn, bad_get_model_weights, test_set_model_weights_fn)
def test_non_tff_computation_get_model_weights_raises(self): get_model_weights = lambda x: x with self.assertRaisesRegex(TypeError, r'Expected .*\.Computation, .*'): learning_process.LearningProcess(test_init_fn, test_next_fn, get_model_weights, test_set_model_weights_fn)
def test_non_functional_get_model_weights_raises(self): get_model_weights = at_server(tf.int32) with self.assertRaises(TypeError): learning_process.LearningProcess(test_init_fn, test_next_fn, get_model_weights, test_set_model_weights_fn)