def iterator(coefficient_fn: COEFFICIENT_FN, model_fn: MODEL_FN, client_state_fn: CLIENT_STATE_FN, server_optimizer_fn: OPTIMIZER_FN, client_optimizer_fn: OPTIMIZER_FN): model = model_fn() client_state = client_state_fn() init_tf = tff.tf_computation( lambda: __initialize_server(model_fn, server_optimizer_fn)) server_state_type = init_tf.type_signature.result client_state_type = tff.framework.type_from_tensors(client_state) update_server_tf = tff.tf_computation( lambda state, weights_delta: __update_server( state, weights_delta, model_fn, server_optimizer_fn, tf.function(server.update)), (server_state_type, server_state_type.model.trainable)) state_to_message_tf = tff.tf_computation( lambda state: __state_to_message(state, tf.function(server.to_message) ), server_state_type) dataset_type = tff.SequenceType(model.input_spec) server_message_type = state_to_message_tf.type_signature.result update_client_tf = tff.tf_computation( lambda dataset, state, message: __update_client( dataset, state, message, coefficient_fn, model_fn, client_optimizer_fn, tf.function(client.update)), (dataset_type, client_state_type, server_message_type)) federated_server_state_type = tff.type_at_server(server_state_type) federated_dataset_type = tff.type_at_clients(dataset_type) federated_client_state_type = tff.type_at_clients(client_state_type) def init_tff(): return tff.federated_value(init_tf(), tff.SERVER) def next_tff(server_state, datasets, client_states): message = tff.federated_map(state_to_message_tf, server_state) broadcast = tff.federated_broadcast(message) outputs = tff.federated_map(update_client_tf, (datasets, client_states, broadcast)) weights_delta = tff.federated_mean(outputs.weights_delta, weight=outputs.client_weight) metrics = model.federated_output_computation(outputs.metrics) next_state = tff.federated_map(update_server_tf, (server_state, weights_delta)) return next_state, metrics, outputs.client_state return tff.templates.IterativeProcess( initialize_fn=tff.federated_computation(init_tff), next_fn=tff.federated_computation( next_tff, (federated_server_state_type, federated_dataset_type, federated_client_state_type)))
def test_twice_used_variable_keeps_separate_state(self): def count_one_body(): variable = tf.Variable(initial_value=0, name='var_of_interest') with tf.control_dependencies([variable.assign_add(1)]): return variable.read_value() count_one_1 = tff.tf_computation(count_one_body) count_one_2 = tff.tf_computation(count_one_body) @tff.tf_computation def count_one_twice(): return count_one_1(), count_one_1(), count_one_2() self.assertEqual((1, 1, 1), count_one_twice())
def validator( model_fn: MODEL_FN, client_state_fn: CLIENT_STATE_FN ): model = model_fn() client_state = client_state_fn() dataset_type = tff.SequenceType(model.input_spec) client_state_type = tff.framework.type_from_tensors(client_state) validate_client_tf = tff.tf_computation( lambda dataset, state: __validate_client( dataset, state, model_fn, tf.function(client.validate) ), (dataset_type, client_state_type) ) federated_dataset_type = tff.type_at_clients(dataset_type) federated_client_state_type = tff.type_at_clients(client_state_type) def validate(datasets, client_states): outputs = tff.federated_map(validate_client_tf, (datasets, client_states)) metrics = model.federated_output_computation(outputs.metrics) return metrics return tff.federated_computation( validate, (federated_dataset_type, federated_client_state_type) )
def test_dp_momentum_training(self, model_fn, optimzer_fn, total_rounds=3): def server_optimzier_fn(model_weights): model_weight_specs = tf.nest.map_structure( lambda v: tf.TensorSpec(v.shape, v.dtype), model_weights) return optimzer_fn( learning_rate=1.0, momentum=0.9, noise_std=1e-5, model_weight_specs=model_weight_specs) it_process = dp_fedavg.build_federated_averaging_process( model_fn, server_optimizer_fn=server_optimzier_fn) server_state = it_process.initialize() def deterministic_batch(): return collections.OrderedDict( x=np.ones([1, 28, 28, 1], dtype=np.float32), y=np.ones([1], dtype=np.int32)) batch = tff.tf_computation(deterministic_batch)() federated_data = [[batch]] loss_list = [] for i in range(total_rounds): server_state, loss = it_process.next(server_state, federated_data) loss_list.append(loss) self.assertEqual(i + 1, server_state.round_num) if server_state.optimizer_state is optimizer_utils.FTRLState: self.assertEqual( i + 1, tree_aggregation.get_step_idx( server_state.optimizer_state.dp_tree_state)) self.assertLess(np.mean(loss_list[1:]), loss_list[0])
def evaluator(coefficient_fn: COEFFICIENT_FN, model_fn: MODEL_FN, client_state_fn: CLIENT_STATE_FN): model = model_fn() client_state = client_state_fn() dataset_type = tff.SequenceType(model.input_spec) client_state_type = tff.framework.type_from_tensors(client_state) weights_type = tff.framework.type_from_tensors( tff.learning.ModelWeights.from_model(model)) evaluate_client_tf = tff.tf_computation( lambda dataset, state, weights: __evaluate_client( dataset, state, weights, coefficient_fn, model_fn, tf.function(client.evaluate)), (dataset_type, client_state_type, weights_type)) federated_weights_type = tff.type_at_server(weights_type) federated_dataset_type = tff.type_at_clients(dataset_type) federated_client_state_type = tff.type_at_clients(client_state_type) def evaluate(weights, datasets, client_states): broadcast = tff.federated_broadcast(weights) outputs = tff.federated_map(evaluate_client_tf, (datasets, client_states, broadcast)) confusion_matrix = tff.federated_sum(outputs.confusion_matrix) aggregated_metrics = model.federated_output_computation( outputs.metrics) collected_metrics = tff.federated_collect(outputs.metrics) return confusion_matrix, aggregated_metrics, collected_metrics return tff.federated_computation( evaluate, (federated_weights_type, federated_dataset_type, federated_client_state_type))
def validator(coefficient_fn: COEFFICIENT_FN, model_fn: MODEL_FN, client_state_fn: CLIENT_STATE_FN): model = model_fn() client_state = client_state_fn() dataset_type = tff.SequenceType(model.input_spec) client_state_type = tff.framework.type_from_tensors(client_state) weights_type = tff.learning.framework.weights_type_from_model(model) validate_client_tf = tff.tf_computation( lambda dataset, state, weights: __validate_client( dataset, state, weights, coefficient_fn, model_fn, tf.function(client.validate)), (dataset_type, client_state_type, weights_type)) federated_weights_type = tff.type_at_server(weights_type) federated_dataset_type = tff.type_at_clients(dataset_type) federated_client_state_type = tff.type_at_clients(client_state_type) def validate(weights, datasets, client_states): broadcast = tff.federated_broadcast(weights) outputs = tff.federated_map(validate_client_tf, (datasets, client_states, broadcast)) metrics = model.federated_output_computation(outputs.metrics) return metrics return tff.federated_computation( validate, (federated_weights_type, federated_dataset_type, federated_client_state_type))
def test_dpftal_training(self, total_rounds=5): def server_optimzier_fn(model_weights): model_weight_shape = tf.nest.map_structure(tf.shape, model_weights) return optimizer_utils.DPFTRLMServerOptimizer( learning_rate=0.1, momentum=0.9, noise_std=1e-5, model_weight_shape=model_weight_shape) it_process = dp_fedavg.build_federated_averaging_process( _rnn_model_fn, server_optimizer_fn=server_optimzier_fn) server_state = it_process.initialize() def deterministic_batch(): return collections.OrderedDict(x=np.array([[0, 1, 2, 3, 4]], dtype=np.int32), y=np.array([[1, 2, 3, 4, 0]], dtype=np.int32)) batch = tff.tf_computation(deterministic_batch)() federated_data = [[batch]] loss_list = [] for i in range(total_rounds): server_state, loss = it_process.next(server_state, federated_data) loss_list.append(loss) self.assertEqual(i + 1, server_state.round_num) self.assertEqual( i + 1, tree_aggregation.get_step_idx( server_state.optimizer_state['dp_tree_state'].level_state)) self.assertLess(np.mean(loss_list[1:]), loss_list[0])
def test_simple_training(self): it_process = simple_fedavg_tff.build_federated_averaging_process( _model_fn) server_state = it_process.initialize() Batch = collections.namedtuple('Batch', ['x', 'y']) # pylint: disable=invalid-name # Test out manually setting weights: keras_model = _create_test_cnn_model(only_digits=True) def deterministic_batch(): return Batch(x=np.ones([1, 28, 28, 1], dtype=np.float32), y=np.ones([1], dtype=np.int32)) batch = tff.tf_computation(deterministic_batch)() federated_data = [[batch]] def keras_evaluate(state): tff.learning.assign_weights_to_keras_model(keras_model, state.model_weights) keras_model.predict(batch.x) loss_list = [] for _ in range(3): keras_evaluate(server_state) server_state, loss = it_process.next(server_state, federated_data) loss_list.append(loss) keras_evaluate(server_state) self.assertLess(np.mean(loss_list[1:]), loss_list[0])
def test_dp_momentum_training(self, model_fn, optimzer_fn, total_rounds=3): def server_optimzier_fn(model_weights): model_weight_shape = tf.nest.map_structure(tf.shape, model_weights) return optimzer_fn(learning_rate=1.0, momentum=0.9, noise_std=1e-5, model_weight_shape=model_weight_shape) print('defining it process') it_process = dp_fedavg.build_federated_averaging_process( model_fn, server_optimizer_fn=server_optimzier_fn) print('next type', it_process.next.type_signature.parameter[0]) server_state = it_process.initialize() def deterministic_batch(): return collections.OrderedDict(x=np.ones([1, 28, 28, 1], dtype=np.float32), y=np.ones([1], dtype=np.int32)) batch = tff.tf_computation(deterministic_batch)() federated_data = [[batch]] loss_list = [] for i in range(total_rounds): print('round', i) server_state, loss = it_process.next(server_state, federated_data) loss_list.append(loss) self.assertEqual(i + 1, server_state.round_num) if 'server_state_type' in server_state.optimizer_state: self.assertEqual( i + 1, tree_aggregation.get_step_idx( server_state.optimizer_state['dp_tree_state'])) self.assertLess(np.mean(loss_list[1:]), loss_list[0])
def test_simple_training(self): it_process = build_federated_averaging_process(models.model_fn) server_state = it_process.initialize() Batch = collections.namedtuple('Batch', ['x', 'y']) # pylint: disable=invalid-name # Test out manually setting weights: keras_model = models.create_keras_model(compile_model=True) def deterministic_batch(): return Batch(x=np.ones([1, 784], dtype=np.float32), y=np.ones([1, 1], dtype=np.int64)) batch = tff.tf_computation(deterministic_batch)() federated_data = [[batch]] def keras_evaluate(state): tff.learning.assign_weights_to_keras_model(keras_model, state.model) # N.B. The loss computed here won't match the # loss computed by TFF because of the Dropout layer. keras_model.test_on_batch(batch.x, batch.y) loss_list = [] for _ in range(3): keras_evaluate(server_state) server_state, loss = it_process.next(server_state, federated_data) loss_list.append(loss) keras_evaluate(server_state) self.assertLess(np.mean(loss_list[1:]), loss_list[0])
def _create_next_fn(self, inner_agg_next, state_type): value_type = inner_agg_next.type_signature.parameter[1] modular_clip_by_value_tff = tff.tf_computation(modular_clip_by_value) @tff.federated_computation(state_type, value_type) def next_fn(state, value): clip_range_lower, clip_range_upper = self._get_clip_range() # Modular clip values before aggregation. clipped_value = tff.federated_map( modular_clip_by_value_tff, (value, tff.federated_broadcast(clip_range_lower), tff.federated_broadcast(clip_range_upper))) (agg_output_state, agg_output_result, agg_output_measurements) = inner_agg_next(state, clipped_value) # Clip the aggregate to the same range again (not considering summands). clipped_agg_output_result = tff.federated_map( modular_clip_by_value_tff, (agg_output_result, clip_range_lower, clip_range_upper)) measurements = collections.OrderedDict( agg_process=agg_output_measurements) return tff.templates.MeasuredProcessOutput( state=agg_output_state, result=clipped_agg_output_result, measurements=tff.federated_zip(measurements)) return next_fn
def iterator( model_fn: MODEL_FN, client_state_fn: CLIENT_STATE_FN, client_optimizer_fn: OPTIMIZER_FN ): model = model_fn() client_state = client_state_fn() init_tf = tff.tf_computation( lambda: () ) server_state_type = init_tf.type_signature.result client_state_type = tff.framework.type_from_tensors(client_state) dataset_type = tff.SequenceType(model.input_spec) update_client_tf = tff.tf_computation( lambda dataset, state: __update_client( dataset, state, model_fn, client_optimizer_fn, tf.function(client.update) ), (dataset_type, client_state_type) ) federated_server_state_type = tff.type_at_server(server_state_type) federated_dataset_type = tff.type_at_clients(dataset_type) federated_client_state_type = tff.type_at_clients(client_state_type) def init_tff(): return tff.federated_value(init_tf(), tff.SERVER) def next_tff(server_state, datasets, client_states): outputs = tff.federated_map(update_client_tf, (datasets, client_states)) metrics = model.federated_output_computation(outputs.metrics) return server_state, metrics, outputs.client_state return tff.templates.IterativeProcess( initialize_fn=tff.federated_computation(init_tff), next_fn=tff.federated_computation( next_tff, (federated_server_state_type, federated_dataset_type, federated_client_state_type) ) )
def test_inferred_type_assignable_to_type_spec(self): tf_comp = tff.tf_computation(create_sparse) type_from_return = tf_comp.type_signature.result sparse_tensor_spec = tf.SparseTensorSpec.from_value(create_sparse()) type_from_spec = tff.to_type(sparse_tensor_spec) type_from_spec.check_assignable_from(type_from_return)
def test_inferred_type_assignable_to_type_spec(self): tf_comp = tff.tf_computation(create_ragged) type_from_return = tf_comp.type_signature.result ragged_tensor_spec = tf.RaggedTensorSpec.from_value(create_ragged()) type_from_spec = tff.to_type(ragged_tensor_spec) type_from_spec.check_assignable_from(type_from_return)
def test_dpftal_restart(self, total_rounds=3): def server_optimizer_fn(model_weights): model_weight_specs = tf.nest.map_structure( lambda v: tf.TensorSpec(v.shape, v.dtype), model_weights) return optimizer_utils.DPFTRLMServerOptimizer( learning_rate=0.1, momentum=0.9, noise_std=1e-5, model_weight_specs=model_weight_specs, efficient_tree=True, use_nesterov=True) it_process = dp_fedavg.build_federated_averaging_process( _rnn_model_fn, server_optimizer_fn=server_optimizer_fn, use_simulation_loop=True) server_state = it_process.initialize() model = _rnn_model_fn() optimizer = server_optimizer_fn(model.weights.trainable) def server_state_update(state): return tff.structure.update_struct( state, model=state.model, optimizer_state=optimizer.restart_dp_tree(state.model.trainable), round_num=state.round_num) def deterministic_batch(): return collections.OrderedDict( x=np.array([[0, 1, 2, 3, 4]], dtype=np.int32), y=np.array([[1, 2, 3, 4, 0]], dtype=np.int32)) batch = tff.tf_computation(deterministic_batch)() federated_data = [[batch]] loss_list = [] for i in range(total_rounds): server_state, loss = it_process.next(server_state, federated_data) server_state = server_state_update(server_state) loss_list.append(loss) self.assertEqual(i + 1, server_state.round_num) self.assertEqual( 0, tree_aggregation.get_step_idx( server_state.optimizer_state.dp_tree_state)) self.assertLess(np.mean(loss_list[1:]), loss_list[0])
def test_simple_training(self, model_fn): it_process = dp_fedavg.build_federated_averaging_process(model_fn) server_state = it_process.initialize() def deterministic_batch(): return collections.OrderedDict(x=np.ones([1, 28, 28, 1], dtype=np.float32), y=np.ones([1], dtype=np.int32)) batch = tff.tf_computation(deterministic_batch)() federated_data = [[batch]] loss_list = [] for _ in range(3): server_state, loss = it_process.next(server_state, federated_data) loss_list.append(loss) self.assertLess(np.mean(loss_list[1:]), loss_list[0])
def test_simple_training(self, model_fn): it_process = simple_fedavg_tff.build_federated_averaging_process( model_fn) server_state = it_process.initialize() Batch = collections.namedtuple('Batch', ['x', 'y']) # pylint: disable=invalid-name def deterministic_batch(): return Batch(x=np.ones([1, 28, 28, 1], dtype=np.float32), y=np.ones([1], dtype=np.int32)) batch = tff.tf_computation(deterministic_batch)() federated_data = [[batch]] loss_list = [] for _ in range(3): server_state, loss = it_process.next(server_state, federated_data) loss_list.append(loss) self.assertLess(np.mean(loss_list[1:]), loss_list[0])
def test_client_adagrad_train(self): it_process = simple_fedavg_tff.build_federated_averaging_process( _rnn_model_fn, client_optimizer_fn=functools.partial( tf.keras.optimizers.Adagrad, learning_rate=0.01)) server_state = it_process.initialize() def deterministic_batch(): return collections.OrderedDict( x=np.array([[0, 1, 2, 3, 4]], dtype=np.int32), y=np.array([[1, 2, 3, 4, 0]], dtype=np.int32)) batch = tff.tf_computation(deterministic_batch)() federated_data = [[batch]] loss_list = [] for _ in range(3): server_state, loss = it_process.next(server_state, federated_data) loss_list.append(loss) self.assertLess(np.mean(loss_list[1:]), loss_list[0])
def run_one_round(server_state, federated_dataset, client_weight): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_round_num = tff.federated_broadcast(server_state.round_num) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_round_num)) #client_weight = client_outputs.client_weight # model_delta = tff.federated_mean( # client_outputs.weights_delta, weight=client_weight) participant_client_weight = tff.federated_map( tff.tf_computation(lambda x, y: x * y), (client_weight, client_outputs.client_weight)) aggregation_output = aggregation_process.next( server_state.aggregation_state, client_outputs.weights_delta, participant_client_weight) server_state = tff.federated_map( server_update_fn, (server_state, aggregation_output.result)) aggregated_outputs = dummy_model.federated_output_computation( client_outputs.model_output) if aggregated_outputs.type_signature.is_struct(): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def test_training_keras_model_converges(self): it_process = simple_fedavg_tff.build_federated_averaging_process( _tff_learning_model_fn) server_state = it_process.initialize() def deterministic_batch(): return collections.OrderedDict(x=np.ones([1, 28, 28, 1], dtype=np.float32), y=np.ones([1], dtype=np.int32)) batch = tff.tf_computation(deterministic_batch)() federated_data = [tf.data.Dataset.from_tensor_slices(batch).batch(1)] previous_loss = None for _ in range(10): server_state, outputs = it_process.next(server_state, federated_data) loss = outputs['loss'] if previous_loss is not None: self.assertLessEqual(loss, previous_loss) previous_loss = loss self.assertLess(loss, 0.1)
def test_simple_training(self, model_fn): it_process = stateful_fedavg_tff.build_federated_averaging_process( model_fn, _create_one_client_state) server_state = it_process.initialize() def deterministic_batch(): return collections.OrderedDict( x=np.ones([1, 28, 28, 1], dtype=np.float32), y=np.ones([1], dtype=np.int32)) batch = tff.tf_computation(deterministic_batch)() federated_data = [[batch]] client_states = [_create_one_client_state()] loss_list = [] for _ in range(3): server_state, loss, client_states = it_process.next( server_state, federated_data, client_states) loss_list.append(loss) self.assertEqual(server_state.total_iters_count, client_states[0].iters_count) self.assertLess(np.mean(loss_list[1:]), loss_list[0])
def evaluator( model_fn: MODEL_FN, client_state_fn: CLIENT_STATE_FN ): model = model_fn() client_state = client_state_fn() dataset_type = tff.SequenceType(model.input_spec) client_state_type = tff.framework.type_from_tensors(client_state) evaluate_client_tf = tff.tf_computation( lambda dataset, state: __evaluate_client( dataset, state, model_fn, tf.function(client.evaluate) ), (dataset_type, client_state_type) ) federated_dataset_type = tff.type_at_clients(dataset_type) federated_client_state_type = tff.type_at_clients(client_state_type) def evaluate(datasets, client_states): outputs = tff.federated_map(evaluate_client_tf, (datasets, client_states)) confusion_matrix = tff.federated_sum(outputs.confusion_matrix) aggregated_metrics = model.federated_output_computation(outputs.metrics) collected_metrics = tff.federated_collect(outputs.metrics) return confusion_matrix, aggregated_metrics, collected_metrics return tff.federated_computation( evaluate, (federated_dataset_type, federated_client_state_type) )