def test_learning_process_can_be_reconstructed(self): process = learning_process.LearningProcess(test_init_fn, test_next_fn, test_report_fn) try: learning_process.LearningProcess(process.initialize, process.next, process.report) except: # pylint: disable=bare-except self.fail('Could not reconstruct the LearningProcess.')
def test_next_not_tff_computation_raises(self): with self.assertRaisesRegex(TypeError, r'Expected .*\.Computation, .*'): learning_process.LearningProcess(initialize_fn=test_init_fn, next_fn=lambda state, client_data: LearningProcessOutput(state, ()), report_fn=test_report_fn)
def test_next_fn_with_client_placed_metrics_result_raises(self): @computations.federated_computation(test_state_type, ClientIntSequenceType) def next_fn(state, metrics): return LearningProcessOutput(state, metrics) with self.assertRaises(learning_process.LearningProcessPlacementError): learning_process.LearningProcess(test_init_fn, next_fn, test_report_fn)
def test_next_fn_with_one_parameter_raises(self): @computations.federated_computation(test_state_type) def next_fn(state): return LearningProcessOutput(state, 0) with self.assertRaises(errors.TemplateNextFnNumArgsError): learning_process.LearningProcess(test_init_fn, next_fn, test_report_fn)
def test_init_param_not_empty_raises(self): @computations.federated_computation( computation_types.FederatedType(tf.int32, placements.SERVER)) def one_arg_initialize_fn(x): return x with self.assertRaises(errors.TemplateInitFnParamNotEmptyError): learning_process.LearningProcess(one_arg_initialize_fn, test_next_fn, test_report_fn)
def test_next_fn_with_non_sequence_second_arg_raises(self): ints_at_clients = computation_types.FederatedType(tf.int32, placements.CLIENTS) @computations.federated_computation(test_state_type, ints_at_clients) def next_fn(state, client_values): metrics = intrinsics.federated_sum(client_values) return LearningProcessOutput(state, metrics) with self.assertRaises(learning_process.LearningProcessSequenceTypeError): learning_process.LearningProcess(test_init_fn, next_fn, test_report_fn)
def test_next_fn_with_three_parameters_raises(self): @computations.federated_computation(test_state_type, ClientIntSequenceType, test_state_type) def next_fn(state, client_values, second_state): # pylint: disable=unused-argument metrics = intrinsics.federated_map(sum_sequence, client_values) metrics = intrinsics.federated_sum(metrics) return LearningProcessOutput(state, metrics) with self.assertRaises(errors.TemplateNextFnNumArgsError): learning_process.LearningProcess(test_init_fn, next_fn, test_report_fn)
def test_next_return_odict_raises(self): @computations.federated_computation(test_state_type, ClientIntSequenceType) def odict_next_fn(state, client_values): metrics = intrinsics.federated_map(sum_sequence, client_values) metrics = intrinsics.federated_sum(metrics) return collections.OrderedDict(state=state, metrics=metrics) with self.assertRaises(learning_process.LearningProcessOutputError): learning_process.LearningProcess(test_init_fn, odict_next_fn, test_report_fn)
def test_init_fn_with_client_placed_state_raises(self): init_fn = computations.federated_computation( lambda: intrinsics.federated_value(0, placements.CLIENTS)) @computations.federated_computation(init_fn.type_signature.result, ClientIntSequenceType) def next_fn(state, client_values): return LearningProcessOutput(state, client_values) with self.assertRaises(learning_process.LearningProcessPlacementError): learning_process.LearningProcess(init_fn, next_fn, test_report_fn)
def test_next_fn_with_non_client_placed_second_arg_raises(self): int_sequence_at_server = computation_types.FederatedType( computation_types.SequenceType(tf.int32), placements.SERVER) @computations.federated_computation(test_state_type, int_sequence_at_server) def next_fn(state, server_values): metrics = intrinsics.federated_map(sum_sequence, server_values) return LearningProcessOutput(state, metrics) with self.assertRaises(learning_process.LearningProcessPlacementError): learning_process.LearningProcess(test_init_fn, next_fn, test_report_fn)
def test_construction_with_empty_state_does_not_raise(self): @computations.federated_computation() def empty_initialize_fn(): return intrinsics.federated_value((), placements.SERVER) next_fn = build_next_fn(empty_initialize_fn) report_fn = build_report_fn(empty_initialize_fn) try: learning_process.LearningProcess(empty_initialize_fn, next_fn, report_fn) except: # pylint: disable=bare-except self.fail('Could not construct a LearningProcess with empty state.')
def test_next_return_namedtuple_raises(self): learning_process_output = collections.namedtuple('LearningProcessOutput', ['state', 'metrics']) @computations.federated_computation(test_state_type, ClientIntSequenceType) def namedtuple_next_fn(state, client_values): metrics = intrinsics.federated_map(sum_sequence, client_values) metrics = intrinsics.federated_sum(metrics) return learning_process_output(state, metrics) with self.assertRaises(learning_process.LearningProcessOutputError): learning_process.LearningProcess(test_init_fn, namedtuple_next_fn, test_report_fn)
def test_construction_with_unknown_dimension_does_not_raise(self): create_empty_string = computations.tf_computation()( lambda: tf.constant([], dtype=tf.string)) initialize_fn = computations.federated_computation()( lambda: intrinsics.federated_value(create_empty_string(), placements. SERVER)) next_fn = build_next_fn(initialize_fn) report_fn = build_report_fn(initialize_fn) try: learning_process.LearningProcess(initialize_fn, next_fn, report_fn) except: # pylint: disable=bare-except self.fail('Could not construct a LearningProcess with state type having ' 'statically unknown shape.')
def test_construction_does_not_raise(self): try: learning_process.LearningProcess(test_init_fn, test_next_fn, test_report_fn) except: # pylint: disable=bare-except self.fail('Could not construct a valid LearningProcess.')
def test_report_param_not_assignable(self): report_fn = computations.tf_computation(tf.float32)(lambda x: x) with self.assertRaises(learning_process.ReportFnTypeSignatureError): learning_process.LearningProcess(test_init_fn, test_next_fn, report_fn)
def test_federated_report_fn_raises(self): report_fn = computations.federated_computation(test_state_type)( lambda x: x) with self.assertRaises(learning_process.ReportFnTypeSignatureError): learning_process.LearningProcess(test_init_fn, test_next_fn, report_fn)
def test_non_tff_computation_report_fn_raises(self): report_fn = lambda x: x with self.assertRaisesRegex(TypeError, r'Expected .*\.Computation, .*'): learning_process.LearningProcess(test_init_fn, test_next_fn, report_fn)
def test_next_state_not_assignable(self): float_initialize_fn = computations.federated_computation()(lambda: 0.0) float_next_fn = build_next_fn(float_initialize_fn) with self.assertRaises(errors.TemplateStateNotAssignableError): learning_process.LearningProcess(test_init_fn, float_next_fn, test_report_fn)
def compose_learning_process( initial_model_weights_fn: computation_base.Computation, model_weights_distributor: distributors.DistributionProcess, client_work: client_works.ClientWorkProcess, model_update_aggregator: aggregation_process.AggregationProcess, model_finalizer: finalizers.FinalizerProcess ) -> learning_process.LearningProcess: """Composes specialized measured processes into a learning process. Given the 4 specialized measured processes that make a learning process as documented in [TODO(b/190334722)], and a computation that returns initial model weights to be used for training, this method validates that the processes fit together, and returns a `LearningProcess`. The main purpose of the 4 measured processes are: * `model_weights_distributor`: Make global model weights at server available as the starting point for learning work to be done at clients. * `client_work`: Produce an update to the model received at clients. * `model_update_aggregator`: Aggregates the model updates from clients to the server. * `model_finalizer`: Updates the global model weights using the aggregated model update at server. The `next` computation of the created learning process is composed from the `next` computations of the 4 measured processes, in order as visualized below. The type signatures of the processes must be such that this chaining is possible. Each process also reports its own metrics. ``` ┌─────────────────────────┐ │model_weights_distributor│ └△─┬─┬────────────────────┘ │ │┌▽──────────┐ │ ││client_work│ │ │└┬─────┬────┘ │┌▽─▽────┐│ ││metrics││ │└△─△────┘│ │ │┌┴─────▽────────────────┐ │ ││model_update_aggregator│ │ │└┬──────────────────────┘ ┌┴─┴─▽──────────┐ │model_finalizer│ └┬──────────────┘ ┌▽─────┐ │result│ └──────┘ ``` Args: initial_model_weights_fn: A `tff.Computation` that returns (unplaced) initial model weights. model_weights_distributor: A `DistributionProcess`. client_work: A `ClientWorkProcess`. model_update_aggregator: A `tff.templates.AggregationProcess`. model_finalizer: A `FinalizerProcess`. Returns: A `LearningProcess`. """ # pyformat: enable _validate_args(initial_model_weights_fn, model_weights_distributor, client_work, model_update_aggregator, model_finalizer) client_data_type = client_work.next.type_signature.parameter[2] @computations.federated_computation() def init_fn(): initial_model_weights = intrinsics.federated_eval( initial_model_weights_fn, placements.SERVER) return intrinsics.federated_zip( LearningAlgorithmState(initial_model_weights, model_weights_distributor.initialize(), client_work.initialize(), model_update_aggregator.initialize(), model_finalizer.initialize())) @computations.federated_computation(init_fn.type_signature.result, client_data_type) def next_fn(state, client_data): # Compose processes. distributor_output = model_weights_distributor.next( state.distributor, state.global_model_weights) client_work_output = client_work.next(state.client_work, distributor_output.result, client_data) aggregator_output = model_update_aggregator.next( state.aggregator, client_work_output.result.update, client_work_output.result.update_weight) finalizer_output = model_finalizer.next(state.finalizer, state.global_model_weights, aggregator_output.result) # Form the learning process output. new_global_model_weights = finalizer_output.result new_state = intrinsics.federated_zip( LearningAlgorithmState(new_global_model_weights, distributor_output.state, client_work_output.state, aggregator_output.state, finalizer_output.state)) metrics = intrinsics.federated_zip( collections.OrderedDict( distributor=distributor_output.measurements, client_work=client_work_output.measurements, aggregator=aggregator_output.measurements, finalizer=finalizer_output.measurements)) return learning_process.LearningProcessOutput(new_state, metrics) @computations.tf_computation(next_fn.type_signature.result.state.member) def report_fn(state): return state.global_model_weights return learning_process.LearningProcess(init_fn, next_fn, report_fn)