def _assert_server_update_with_all_ones(self, model_fn): model = model_fn() state = _server_init(model, optimizer_utils.SGDServerOptimizer(0.1)) weights_delta = tf.nest.map_structure(tf.ones_like, model.trainable_variables) example_optimizer = optimizer_utils.SGDServerOptimizer(0.10) for _ in range(2): state = dp_fedavg.server_update(model, example_optimizer, state, weights_delta) model_vars = self.evaluate(state.model) train_vars = model_vars.trainable self.assertLen(train_vars, 2) self.assertEqual(state.round_num, 2) # weights are initialized with all-zeros, weights_delta is all ones, # SGD learning rate is 0.1. Updating server for 2 steps. self.assertAllClose(train_vars, [np.ones_like(v) * 0.2 for v in train_vars])
def test_deterministic_sgd(self): model_variables = _create_model_variables() grad = tf.nest.map_structure(tf.ones_like, model_variables) optimizer = optimizer_utils.SGDServerOptimizer(learning_rate=0.1) state = optimizer.init_state() for i in range(2): state = optimizer.model_update(state, model_variables, grad, i) self.assertLen(model_variables, 2) # variables initialize with all zeros and update with all ones and learning # rate 0.1 for several steps. flatten_variables = tf.nest.flatten(model_variables) self.assertAllClose(flatten_variables, [-0.2 * np.ones_like(v) for v in flatten_variables])
def _server_optimizer_fn(model_weights, name, learning_rate, noise_std): """Returns server optimizer.""" model_weight_specs = tf.nest.map_structure( lambda v: tf.TensorSpec(v.shape, v.dtype), model_weights) if name == 'sgd': return optimizer_utils.SGDServerOptimizer(learning_rate) elif name == 'sgdm': return optimizer_utils.DPSGDMServerOptimizer( learning_rate, momentum=FLAGS.server_momentum, noise_std=0, model_weight_specs=model_weight_specs) elif name == 'dpftrl': return optimizer_utils.DPFTRLMServerOptimizer( learning_rate, momentum=0, noise_std=noise_std, model_weight_specs=model_weight_specs) elif name == 'dpsgd': return optimizer_utils.DPSGDMServerOptimizer( learning_rate, momentum=0, noise_std=noise_std, model_weight_specs=model_weight_specs) elif name == 'dpsgdm': return optimizer_utils.DPSGDMServerOptimizer( learning_rate, momentum=FLAGS.server_momentum, noise_std=noise_std, model_weight_specs=model_weight_specs) elif name == 'dpftrlm': return optimizer_utils.DPFTRLMServerOptimizer( learning_rate, momentum=FLAGS.server_momentum, noise_std=noise_std, model_weight_specs=model_weight_specs) else: raise ValueError('Unknown server optimizer name {}'.format(name))
def build_federated_averaging_process( model_fn, dp_clip_norm=1.0, server_optimizer_fn=lambda w: optimizer_utils.SGDServerOptimizer( # pylint: disable=g-long-lambda learning_rate=1.0), client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1)): """Builds the TFF computations for optimization using federated averaging. Args: model_fn: A no-arg function that returns a `dp_fedavg_tf.KerasModelWrapper`. dp_clip_norm: if < 0, no clipping server_optimizer_fn: . client_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer` for client update. Returns: A `tff.templates.IterativeProcess`. """ example_model = model_fn() @tff.tf_computation def server_init_tf(): model = model_fn() optimizer = server_optimizer_fn(model.weights.trainable) return ServerState(model_weights=model.weights, optimizer_state=optimizer.init_state(), round_num=0, dp_clip_norm=dp_clip_norm) server_state_type = server_init_tf.type_signature.result model_weights_type = server_state_type.model_weights @tff.tf_computation(server_state_type, model_weights_type.trainable) def server_update_fn(server_state, model_delta): model = model_fn() optimizer = server_optimizer_fn(model.weights.trainable) return server_update(model, optimizer, server_state, model_delta) @tff.tf_computation(server_state_type) def server_message_fn(server_state): return build_server_broadcast_message(server_state) server_message_type = server_message_fn.type_signature.result tf_dataset_type = tff.SequenceType(example_model.input_spec) @tff.tf_computation(tf_dataset_type, server_message_type) def client_update_fn(tf_dataset, server_message): model = model_fn() client_optimizer = client_optimizer_fn() return client_update(model, tf_dataset, server_message, client_optimizer) federated_server_state_type = tff.type_at_server(server_state_type) federated_dataset_type = tff.type_at_clients(tf_dataset_type) @tff.federated_computation(federated_server_state_type, federated_dataset_type) def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.data.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and `tf.Tensor` of average loss. """ server_message = tff.federated_map(server_message_fn, server_state) server_message_at_client = tff.federated_broadcast(server_message) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, server_message_at_client)) # Model deltas are equally weighted in DP. round_model_delta = tff.federated_mean(client_outputs.weights_delta) server_state = tff.federated_map(server_update_fn, (server_state, round_model_delta)) round_loss_metric = tff.federated_mean(client_outputs.model_output) return server_state, round_loss_metric @tff.federated_computation def server_init_tff(): """Orchestration logic for server model initialization.""" return tff.federated_value(server_init_tf(), tff.SERVER) return tff.templates.IterativeProcess(initialize_fn=server_init_tff, next_fn=run_one_round)