def test_returns_model_weights_for_model_callable(self): weights_type = model_utils.weights_type_from_model(TestModel) self.assertEqual( tff_core.NamedTupleType([('trainable', [ tff_core.TensorType(tf.float32, [3]), tff_core.TensorType(tf.float32, [1]), ]), ('non_trainable', [ tff_core.TensorType(tf.int32), ])]), weights_type)
def test_returns_model_weights_for_model_callable(self): weights_type = model_utils.weights_type_from_model(TestModel) self.assertEqual( tff_core.StructWithPythonType( [('trainable', tff_core.StructWithPythonType([ tff_core.TensorType(tf.float32, [3]), tff_core.TensorType(tf.float32, [1]), ], list)), ('non_trainable', tff_core.StructWithPythonType([ tff_core.TensorType(tf.int32), ], list))], model_utils.ModelWeights), weights_type)
def test_orchestration_type_signature(self): iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.TrainableLinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=lambda: gradient_descent.SGD(learning_rate=1.0 )) expected_model_weights_type = model_utils.ModelWeights( collections.OrderedDict([('a', tff.TensorType(tf.float32, [2, 1])), ('b', tf.float32)]), collections.OrderedDict([('c', tf.float32)])) # ServerState consists of a model and optimizer_state. The optimizer_state # is provided by TensorFlow, TFF doesn't care what the actual value is. expected_federated_server_state_type = tff.FederatedType( optimizer_utils.ServerState(expected_model_weights_type, test.AnyType(), test.AnyType(), test.AnyType()), placement=tff.SERVER, all_equal=True) expected_federated_dataset_type = tff.FederatedType(tff.SequenceType( model_examples.TrainableLinearRegression().input_spec), tff.CLIENTS, all_equal=False) expected_model_output_types = tff.FederatedType( collections.OrderedDict([ ('loss', tff.TensorType(tf.float32, [])), ('num_examples', tff.TensorType(tf.int32, [])), ]), tff.SERVER, all_equal=True) # `initialize` is expected to be a funcion of no arguments to a ServerState. self.assertEqual( tff.FunctionType(parameter=None, result=expected_federated_server_state_type), iterative_process.initialize.type_signature) # `next` is expected be a function of (ServerState, Datasets) to # ServerState. self.assertEqual( tff.FunctionType(parameter=[ expected_federated_server_state_type, expected_federated_dataset_type ], result=(expected_federated_server_state_type, expected_model_output_types)), iterative_process.next.type_signature)
def test_returns_model_weights_for_model(self): model = TestModel() weights_type = model_utils.weights_type_from_model(model) self.assertEqual( tff_core.NamedTupleTypeWithPyContainerType( [('trainable', tff_core.NamedTupleTypeWithPyContainerType([ tff_core.TensorType(tf.float32, [3]), tff_core.TensorType(tf.float32, [1]), ], list)), ('non_trainable', tff_core.NamedTupleTypeWithPyContainerType([ tff_core.TensorType(tf.int32), ], list))], model_utils.ModelWeights), weights_type)
def test_orchestration_typecheck(self): iterative_process = federated_sgd.build_federated_sgd_process( model_fn=model_examples.LinearRegression) expected_model_weights_type = model_utils.ModelWeights( collections.OrderedDict([('a', tff.TensorType(tf.float32, [2, 1])), ('b', tf.float32)]), collections.OrderedDict([('c', tf.float32)])) # ServerState consists of a model and optimizer_state. The optimizer_state # is provided by TensorFlow, TFF doesn't care what the actual value is. expected_federated_server_state_type = tff.FederatedType( optimizer_utils.ServerState(expected_model_weights_type, test.AnyType()), placement=tff.SERVER, all_equal=True) expected_federated_dataset_type = tff.FederatedType( tff.SequenceType( model_examples.LinearRegression.make_batch( tff.TensorType(tf.float32, [None, 2]), tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS, all_equal=False) expected_model_output_types = tff.FederatedType( collections.OrderedDict([ ('loss', tff.TensorType(tf.float32, [])), ('num_examples', tff.TensorType(tf.int32, [])), ]), tff.SERVER, all_equal=True) # `initialize` is expected to be a funcion of no arguments to a ServerState. self.assertEqual( tff.FunctionType( parameter=None, result=expected_federated_server_state_type), iterative_process.initialize.type_signature) # `next` is expected be a function of (ServerState, Datasets) to # ServerState. self.assertEqual( tff.FunctionType( parameter=[ expected_federated_server_state_type, expected_federated_dataset_type ], result=(expected_federated_server_state_type, expected_model_output_types)), iterative_process.next.type_signature)
def build_federated_evaluation(model_fn): """Builds the TFF computation for federated evaluation of the given model. Args: model_fn: A no-argument function that returns a `tff.learning.Model`. Returns: A federated computation (an instance of `tff.Computation`) that accepts model parameters and federated data, and returns the evaluation metrics as aggregated by `tff.learning.Model.federated_output_computation`. """ # Construct the model first just to obtain the metadata and define all the # types needed to define the computations that follow. # TODO(b/124477628): Ideally replace the need for stamping throwaway models # with some other mechanism. with tf.Graph().as_default(): model = model_utils.enhance(model_fn()) model_weights_type = tff.to_type( tf.nest.map_structure( lambda v: tff.TensorType(v.dtype.base_dtype, v.shape), model.weights)) batch_type = tff.to_type(model.input_spec) @tff.tf_computation(model_weights_type, tff.SequenceType(batch_type)) def client_eval(incoming_model_weights, dataset): """Returns local outputs after evaluting `model_weights` on `dataset`.""" model = model_utils.enhance(model_fn()) # TODO(b/124477598): Remove dummy when b/121400757 has been fixed. @tf.function def reduce_fn(dummy, batch): model_output = model.forward_pass(batch, training=False) return dummy + tf.cast(model_output.loss, tf.float64) # TODO(b/123898430): The control dependencies below have been inserted as a # temporary workaround. These control dependencies need to be removed, and # defuns and datasets supported together fully. with tf.control_dependencies( [tff.utils.assign(model.weights, incoming_model_weights)]): dummy = dataset.reduce(tf.constant(0.0, dtype=tf.float64), reduce_fn) with tf.control_dependencies([dummy]): return collections.OrderedDict([ ('local_outputs', model.report_local_outputs()), ('workaround for b/121400757', dummy) ]) @tff.federated_computation( tff.FederatedType(model_weights_type, tff.SERVER), tff.FederatedType(tff.SequenceType(batch_type), tff.CLIENTS)) def server_eval(server_model_weights, federated_dataset): client_outputs = tff.federated_map( client_eval, [tff.federated_broadcast(server_model_weights), federated_dataset]) return model.federated_output_computation(client_outputs.local_outputs) return server_eval
def build_federated_evaluation(model_fn): """Builds the TFF computation for federated evaluation of the given model. Args: model_fn: A no-argument function that returns a `tff.learning.Model`. Returns: A federated computation (an instance of `tff.Computation`) that accepts model parameters and federated data, and returns the evaluation metrics as aggregated by `tff.learning.Model.federated_output_computation`. """ # Construct the model first just to obtain the metadata and define all the # types needed to define the computations that follow. # TODO(b/124477628): Ideally replace the need for stamping throwaway models # with some other mechanism. with tf.Graph().as_default(): model = model_utils.enhance(model_fn()) model_weights_type = tff.to_type( tf.nest.map_structure( lambda v: tff.TensorType(v.dtype.base_dtype, v.shape), model.weights)) batch_type = tff.to_type(model.input_spec) @tff.tf_computation(model_weights_type, tff.SequenceType(batch_type)) def client_eval(incoming_model_weights, dataset): """Returns local outputs after evaluting `model_weights` on `dataset`.""" model = model_utils.enhance(model_fn()) @tf.function def _tf_client_eval(incoming_model_weights, dataset): """Evaluation TF work.""" tff.utils.assign(model.weights, incoming_model_weights) def reduce_fn(prev_loss, batch): model_output = model.forward_pass(batch, training=False) return prev_loss + tf.cast(model_output.loss, tf.float64) dataset.reduce(tf.constant(0.0, dtype=tf.float64), reduce_fn) return collections.OrderedDict([('local_outputs', model.report_local_outputs())]) return _tf_client_eval(incoming_model_weights, dataset) @tff.federated_computation( tff.FederatedType(model_weights_type, tff.SERVER), tff.FederatedType(tff.SequenceType(batch_type), tff.CLIENTS)) def server_eval(server_model_weights, federated_dataset): client_outputs = tff.federated_map( client_eval, [tff.federated_broadcast(server_model_weights), federated_dataset]) return model.federated_output_computation(client_outputs.local_outputs) return server_eval
def test_contruction_with_broadcast_state(self): iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=tf.keras.optimizers.SGD, stateful_model_broadcast_fn=state_incrementing_broadcaster) expected_broadcast_state_type = tff.TensorType(tf.int32) initialize_type = iterative_process.initialize.type_signature self.assertEqual(initialize_type.result.member.model_broadcast_state, expected_broadcast_state_type) next_type = iterative_process.next.type_signature self.assertEqual(next_type.parameter[0].member.model_broadcast_state, expected_broadcast_state_type) self.assertEqual(next_type.result[0].member.model_broadcast_state, expected_broadcast_state_type)
def test_construction(self): iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=tf.keras.optimizers.SGD) server_state_type = tff.FederatedType( optimizer_utils.ServerState( model=model_utils.ModelWeights( trainable=[ tff.TensorType(tf.float32, [2, 1]), tff.TensorType(tf.float32) ], non_trainable=[tff.TensorType(tf.float32)]), optimizer_state=[tf.int64], delta_aggregate_state=(), model_broadcast_state=()), tff.SERVER) self.assertEqual( str(iterative_process.initialize.type_signature), str(tff.FunctionType(parameter=None, result=server_state_type))) dataset_type = tff.FederatedType( tff.SequenceType( collections.OrderedDict( x=tff.TensorType(tf.float32, [None, 2]), y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS) metrics_type = tff.FederatedType( collections.OrderedDict( broadcast=(), aggregation=(), train=collections.OrderedDict( loss=tff.TensorType(tf.float32), num_examples=tff.TensorType(tf.int32))), tff.SERVER) self.assertEqual( str(iterative_process.next.type_signature), str( tff.FunctionType( parameter=(server_state_type, dataset_type), result=(server_state_type, metrics_type))))
def benchmark_fc_api_mnist(self): """Code adapted from FC API tutorial ipynb.""" n_rounds = 10 batch_type = tff.NamedTupleType([ ("x", tff.TensorType(tf.float32, [None, 784])), ("y", tff.TensorType(tf.int32, [None])) ]) model_type = tff.NamedTupleType([ ("weights", tff.TensorType(tf.float32, [784, 10])), ("bias", tff.TensorType(tf.float32, [10])) ]) local_data_type = tff.SequenceType(batch_type) server_model_type = tff.FederatedType(model_type, tff.SERVER, all_equal=True) client_data_type = tff.FederatedType(local_data_type, tff.CLIENTS) server_float_type = tff.FederatedType(tf.float32, tff.SERVER, all_equal=True) computation_building_start = time.time() # pylint: disable=missing-docstring @tff.tf_computation(model_type, batch_type) def batch_loss(model, batch): predicted_y = tf.nn.softmax( tf.matmul(batch.x, model.weights) + model.bias) return -tf.reduce_mean( tf.reduce_sum(tf.one_hot(batch.y, 10) * tf.log(predicted_y), reduction_indices=[1])) initial_model = { "weights": np.zeros([784, 10], dtype=np.float32), "bias": np.zeros([10], dtype=np.float32) } @tff.tf_computation(model_type, batch_type, tf.float32) def batch_train(initial_model, batch, learning_rate): model_vars = tff.utils.get_variables("v", model_type) init_model = tff.utils.assign(model_vars, initial_model) optimizer = tf.train.GradientDescentOptimizer(learning_rate) with tf.control_dependencies([init_model]): train_model = optimizer.minimize(batch_loss(model_vars, batch)) with tf.control_dependencies([train_model]): return tff.utils.identity(model_vars) @tff.federated_computation(model_type, tf.float32, local_data_type) def local_train(initial_model, learning_rate, all_batches): @tff.federated_computation(model_type, batch_type) def batch_fn(model, batch): return batch_train(model, batch, learning_rate) return tff.sequence_reduce(all_batches, initial_model, batch_fn) @tff.federated_computation(server_model_type, server_float_type, client_data_type) def federated_train(model, learning_rate, data): return tff.federated_average( tff.federated_map(local_train, [ tff.federated_broadcast(model), tff.federated_broadcast(learning_rate), data ])) computation_building_stop = time.time() building_time = computation_building_stop - computation_building_start self.report_benchmark(name="computation_building_time, FC API", wall_time=building_time, iters=1) model = initial_model learning_rate = 0.1 federated_data = generate_fake_mnist_data() execution_array = [] for _ in range(n_rounds): execution_start = time.time() model = federated_train(model, learning_rate, federated_data) execution_stop = time.time() execution_array.append(execution_stop - execution_start) self.report_benchmark(name="Average per round execution time, FC API", wall_time=np.mean(execution_array), iters=n_rounds, extras={"std_dev": np.std(execution_array)})