def test_execution_stateful_optimizer(self): client_work_process = client_works.build_model_delta_client_work( model_examples.LinearRegression, sgdm.build_sgdm(0.1, momentum=0.9)) data = tf.data.Dataset.from_tensor_slices( collections.OrderedDict( x=[[1.0, 2.0], [3.0, 4.0]], y=[[5.0], [6.0]], )).batch(2) data = [data, data.repeat(2)] # 1st client has 2 examples, 2nd has 4. model_weights = model_utils.ModelWeights(trainable=[[[0.0], [0.0]], 0.0], non_trainable=[0.0]) client_model_weights = [model_weights] * 2 state = client_work_process.initialize() output = client_work_process.next(state, client_model_weights, data) expected_result = ( client_works.ClientResult([[[-1.15], [-1.7]], -0.55], 2.0), client_works.ClientResult([[[-1.46], [-2.26]], -0.8], 4.0), ) self.assertEqual((), output.state) for i in range(len(expected_result)): self.assertAllClose(expected_result[i].update, output.result[i].update) self.assertAllClose(expected_result[i].update_weight, output.result[i].update_weight) self.assertEqual((), output.measurements)
def initial_weights(self): return model_utils.ModelWeights( trainable={ 'a': tf.constant([[0.0], [0.0]]), 'b': tf.constant(0.0) }, non_trainable={'c': 0.0})
def initial_weights(self): return model_utils.ModelWeights( trainable=[ tf.constant([[0.0], [0.0]]), tf.constant(0.0), ], non_trainable=[0.0])
def next_fn(state, weights, update): return MeasuredProcessOutput( state, intrinsics.federated_zip( model_utils.ModelWeights( federated_add(weights['trainable'], update), ())), server_zero())
def test_type_properties(self): mw_type = computation_types.to_type( model_utils.ModelWeights(trainable=(tf.float32, tf.float32), non_trainable=tf.float32)) finalizer = finalizers.build_apply_optimizer_finalizer( sgdm.build_sgdm(1.0), mw_type) self.assertIsInstance(finalizer, finalizers.FinalizerProcess) expected_param_weights_type = computation_types.at_server(mw_type) expected_param_update_type = computation_types.at_server( mw_type.trainable) expected_result_type = computation_types.at_server(mw_type) expected_state_type = computation_types.at_server(()) expected_measurements_type = computation_types.at_server(()) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) expected_initialize_type.check_equivalent_to( finalizer.initialize.type_signature) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, weights=expected_param_weights_type, update=expected_param_update_type), result=MeasuredProcessOutput(expected_state_type, expected_result_type, expected_measurements_type)) expected_next_type.check_equivalent_to(finalizer.next.type_signature)
def state_with_new_model_weights( server_state: ServerState, trainable_weights: List[np.ndarray], non_trainable_weights: List[np.ndarray], ) -> ServerState: """Returns a `ServerState` with updated model weights. Args: server_state: A server state object returned by an iterative training process like `tff.learning.build_federated_averaging_process`. trainable_weights: A list of `numpy` arrays in the order of the original model's `trainable_variables`. non_trainable_weights: A list of `numpy` arrays in the order of the original model's `non_trainable_variables`. Returns: A new server `ServerState` object which can be passed to the `next` method of the iterative process. """ py_typecheck.check_type(server_state, ServerState) leaf_types = (int, float, np.ndarray, tf.Tensor) def assert_weight_lists_match(old_value, new_value): """Assert two flat lists of ndarrays or tensors match.""" if isinstance(new_value, leaf_types) and isinstance( old_value, leaf_types): if (old_value.dtype != new_value.dtype or old_value.shape != new_value.shape): raise TypeError('Element is not the same tensor type. old ' f'({old_value.dtype}, {old_value.shape}) != ' f'new ({new_value.dtype}, {new_value.shape})') elif (isinstance(new_value, collections.abc.Sequence) and isinstance(old_value, collections.abc.Sequence)): if len(old_value) != len(new_value): raise TypeError( 'Model weights have different lengths: ' f'(old) {len(old_value)} != (new) {len(new_value)})\n' f'Old values: {old_value}\nNew values: {new_value}') for old, new in zip(old_value, new_value): assert_weight_lists_match(old, new) else: raise TypeError( 'Model weights structures contains types that cannot be ' 'handled.\nOld weights structure: {old}\n' 'New weights structure: {new}\n' 'Must be one of (int, float, np.ndarray, tf.Tensor, ' 'collections.abc.Sequence)'.format( old=tf.nest.map_structure(type, old_value), new=tf.nest.map_structure(type, new_value))) assert_weight_lists_match(server_state.model.trainable, trainable_weights) assert_weight_lists_match(server_state.model.non_trainable, non_trainable_weights) new_server_state = ServerState( model=model_utils.ModelWeights(trainable=trainable_weights, non_trainable=non_trainable_weights), optimizer_state=server_state.optimizer_state, delta_aggregate_state=server_state.delta_aggregate_state, model_broadcast_state=server_state.model_broadcast_state) return new_server_state
def initial_weights(self): return model_utils.ModelWeights( trainable=collections.OrderedDict([ ('a', tf.constant([[0.0], [0.0]])), ('b', tf.constant(0.0)), ]), non_trainable=collections.OrderedDict([('c', 0.0)]), )
def next_fn(state, weights, updates): new_weights = intrinsics.federated_map( tensorflow_computation.tf_computation(lambda x, y: x + y), (weights.trainable, updates)) new_weights = intrinsics.federated_zip( model_utils.ModelWeights(new_weights, ())) return measured_process.MeasuredProcessOutput(state, new_weights, empty_at_server())
def test_construction(self, weighted): aggregation_factory = (mean.MeanFactory() if weighted else sum_factory.SumFactory()) iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=tf.keras.optimizers.SGD, model_update_aggregation_factory=aggregation_factory) if weighted: aggregate_state = collections.OrderedDict(value_sum_process=(), weight_sum_process=()) aggregate_metrics = collections.OrderedDict(mean_value=(), mean_weight=()) else: aggregate_state = () aggregate_metrics = () server_state_type = computation_types.FederatedType( optimizer_utils.ServerState(model=model_utils.ModelWeights( trainable=[ computation_types.TensorType(tf.float32, [2, 1]), computation_types.TensorType(tf.float32) ], non_trainable=[computation_types.TensorType(tf.float32)]), optimizer_state=[tf.int64], delta_aggregate_state=aggregate_state, model_broadcast_state=()), placements.SERVER) self.assert_types_equivalent( computation_types.FunctionType(parameter=None, result=server_state_type), iterative_process.initialize.type_signature) dataset_type = computation_types.FederatedType( computation_types.SequenceType( collections.OrderedDict( x=computation_types.TensorType(tf.float32, [None, 2]), y=computation_types.TensorType(tf.float32, [None, 1]))), placements.CLIENTS) metrics_type = computation_types.FederatedType( collections.OrderedDict( broadcast=(), aggregation=aggregate_metrics, train=collections.OrderedDict( loss=computation_types.TensorType(tf.float32), num_examples=computation_types.TensorType(tf.int32)), stat=collections.OrderedDict( num_examples=computation_types.TensorType(tf.float32))), placements.SERVER) self.assert_types_equivalent( computation_types.FunctionType(parameter=collections.OrderedDict( server_state=server_state_type, federated_dataset=dataset_type, ), result=(server_state_type, metrics_type)), iterative_process.next.type_signature)
class ApplyOptimizerFinalizerComputationTest(tf.test.TestCase, parameterized.TestCase): def test_type_properties(self): mw_type = computation_types.to_type( model_utils.ModelWeights( trainable=(tf.float32, tf.float32), non_trainable=tf.float32)) finalizer = finalizers.build_apply_optimizer_finalizer( sgdm.build_sgdm(1.0), mw_type) self.assertIsInstance(finalizer, finalizers.FinalizerProcess) expected_param_weights_type = computation_types.at_server(mw_type) expected_param_update_type = computation_types.at_server(mw_type.trainable) expected_result_type = computation_types.at_server(mw_type) expected_state_type = computation_types.at_server( computation_types.to_type( collections.OrderedDict([(optimizer_base.LEARNING_RATE_KEY, tf.float32)]))) expected_measurements_type = computation_types.at_server(()) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) expected_initialize_type.check_equivalent_to( finalizer.initialize.type_signature) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, weights=expected_param_weights_type, update=expected_param_update_type), result=MeasuredProcessOutput(expected_state_type, expected_result_type, expected_measurements_type)) expected_next_type.check_equivalent_to(finalizer.next.type_signature) @parameterized.named_parameters( ('not_struct', computation_types.TensorType(tf.float32)), ('federated_type', MODEL_WEIGHTS_TYPE), ('model_weights_of_federated_types', computation_types.to_type( model_utils.ModelWeights(SERVER_FLOAT, SERVER_FLOAT))), ('not_model_weights', computation_types.to_type( (tf.float32, tf.float32))), ('function_type', computation_types.FunctionType(None, MODEL_WEIGHTS_TYPE)), ('sequence_type', computation_types.SequenceType( MODEL_WEIGHTS_TYPE.member))) def test_incorrect_value_type_raises(self, bad_type): with self.assertRaises(TypeError): finalizers.build_apply_optimizer_finalizer(sgdm.build_sgdm(1.0), bad_type) def test_unexpected_optimizer_fn_raises(self): optimizer = tf.keras.optimizers.SGD(1.0) with self.assertRaises(TypeError): finalizers.build_apply_optimizer_finalizer(optimizer, MODEL_WEIGHTS_TYPE.member)
def test_execution(self): finalizer = finalizers.build_apply_optimizer_finalizer( sgdm.build_sgdm(1.0), MODEL_WEIGHTS_TYPE.member) weights = model_utils.ModelWeights(1.0, ()) update = 0.1 output = finalizer.next(finalizer.initialize(), weights, update) self.assertEqual((), output.state) self.assertAllClose(0.9, output.result.trainable) self.assertEqual((), output.measurements)
def next_fn(state, weights, update): optimizer_state, new_trainable_weights = intrinsics.federated_map( next_tf, (state, weights.trainable, update)) new_weights = intrinsics.federated_zip( model_utils.ModelWeights(new_trainable_weights, weights.non_trainable)) empty_measurements = intrinsics.federated_value((), placements.SERVER) return measured_process.MeasuredProcessOutput(optimizer_state, new_weights, empty_measurements)
def _model_fn_with_zero_weights(): linear_regression_model = model_examples.LinearRegression() weights = model_utils.ModelWeights.from_model( linear_regression_model) zero_trainable = [tf.zeros_like(x) for x in weights.trainable] zero_non_trainable = [ tf.zeros_like(x) for x in weights.non_trainable ] zero_weights = model_utils.ModelWeights( trainable=zero_trainable, non_trainable=zero_non_trainable) zero_weights.assign_weights_to(linear_regression_model) return linear_regression_model
def _model_fn_with_one_weights(): linear_regression_model = model_examples.LinearRegression weights = model_utils.ModelWeights.from_model( linear_regression_model) ones_trainable = [tf.ones_like(x) for x in weights.trainable] ones_non_trainable = [ tf.ones_like(x) for x in weights.non_trainable ] ones_weights = model_utils.ModelWeights( trainable=ones_trainable, non_trainable=ones_non_trainable) ones_weights.assign_weights_to(linear_regression_model) return linear_regression_model
def state_with_new_model_weights(server_state, trainable_weights, non_trainable_weights): """Returns a `ServerState` with updated model weights. Args: server_state: A server state object returned by an iterative training process like `tff.learning.build_federated_averaging_process`. trainable_weights: A list of `numpy` arrays in the order of the original model's `trainable_variables`. non_trainable_weights: A list of `numpy` arrays in the order of the original model's `non_trainable_variables`. Returns: A new server `ServerState` object which can be passed to the `next` method of the iterative process. """ # TODO(b/123092620): Simplify this. py_typecheck.check_type(server_state, anonymous_tuple.AnonymousTuple) def pack_values(old, new_values, name): """Packs new_values in an OrderedDict matching old.""" if len(old) != len(new_values): raise ValueError('Lengths differ for {} weights: {} vs {}'.format( name, len(old), len(new_values))) tuples = [] for (key, old_value), new_value in zip(anonymous_tuple.to_elements(old), new_values): if (old_value.dtype != new_value.dtype or old_value.shape != new_value.shape): raise ValueError( 'The shapes or dtypes do not match for {} weight {}:\n' 'current weights: shape {} dtype {}\n' ' new weights: shape {} dtype {}'.format( name, key, old_value.shape, old_value.dtype, new_value.shape, new_value.dtype)) tuples.append((key, new_value)) return collections.OrderedDict(tuples) renamed_new_weights = model_utils.ModelWeights( trainable=pack_values(server_state.model.trainable, trainable_weights, 'trainable'), non_trainable=pack_values(server_state.model.non_trainable, non_trainable_weights, 'non_trainable')) # TODO(b/123092620): We can't use tff.utils.update_state because this # is an AnonymousTuple, not a ServerState. We should do something # that doesn't mention every entry in the state. return ServerState( model=renamed_new_weights, optimizer_state=server_state.optimizer_state, delta_aggregate_state=server_state.delta_aggregate_state, model_broadcast_state=server_state.model_broadcast_state)
def test_non_federated_init_next_raises(self): initialize_fn = tensorflow_computation.tf_computation(lambda: 0) @tensorflow_computation.tf_computation( tf.int32, computation_types.to_type(model_utils.ModelWeights(tf.float32, ())), tf.float32) def next_fn(state, weights, update): new_weigths = model_utils.ModelWeights(weights.trainable + update, ()) return MeasuredProcessOutput(state, new_weigths, 0) with self.assertRaises(errors.TemplateNotFederatedError): finalizers.FinalizerProcess(initialize_fn, next_fn)
def test_model_weights_from_python_structure(self): trainable_weights = [tf.constant([1., 1.])] non_trainable_weights = [tf.constant(1)] model_weights = model_utils.ModelWeights( trainable=trainable_weights, non_trainable=non_trainable_weights) python_weights_structure = collections.OrderedDict( trainable=trainable_weights, non_trainable=non_trainable_weights) model_weights_from_python_structure = model_utils.ModelWeights.from_python_structure( python_weights_structure) self.assertEqual(model_weights.trainable, model_weights_from_python_structure.trainable) self.assertEqual(model_weights.non_trainable, model_weights_from_python_structure.non_trainable)
def test_execution_with_stateless_tff_optimizer(self): finalizer = finalizers.build_apply_optimizer_finalizer( sgdm.build_sgdm(1.0), MODEL_WEIGHTS_TYPE.member) weights = model_utils.ModelWeights(1.0, ()) update = 0.1 optimizer_state = finalizer.initialize() for i in range(5): output = finalizer.next(optimizer_state, weights, update) optimizer_state = output.state weights = output.result self.assertEqual(1.0, optimizer_state[optimizer_base.LEARNING_RATE_KEY]) self.assertAllClose(1.0 - 0.1 * (i + 1), weights.trainable) self.assertEqual((), output.measurements)
def test_orchestration_typecheck(self): iterative_process = federated_sgd.build_federated_sgd_process( model_fn=model_examples.LinearRegression) expected_model_weights_type = model_utils.ModelWeights( collections.OrderedDict([('a', tff.TensorType(tf.float32, [2, 1])), ('b', tf.float32)]), collections.OrderedDict([('c', tf.float32)])) # ServerState consists of a model and optimizer_state. The optimizer_state # is provided by TensorFlow, TFF doesn't care what the actual value is. expected_federated_server_state_type = tff.FederatedType( optimizer_utils.ServerState(expected_model_weights_type, test.AnyType()), placement=tff.SERVER, all_equal=True) expected_federated_dataset_type = tff.FederatedType( tff.SequenceType( model_examples.LinearRegression.make_batch( tff.TensorType(tf.float32, [None, 2]), tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS, all_equal=False) expected_model_output_types = tff.FederatedType( collections.OrderedDict([ ('loss', tff.TensorType(tf.float32, [])), ('num_examples', tff.TensorType(tf.int32, [])), ]), tff.SERVER, all_equal=True) # `initialize` is expected to be a funcion of no arguments to a ServerState. self.assertEqual( tff.FunctionType( parameter=None, result=expected_federated_server_state_type), iterative_process.initialize.type_signature) # `next` is expected be a function of (ServerState, Datasets) to # ServerState. self.assertEqual( tff.FunctionType( parameter=[ expected_federated_server_state_type, expected_federated_dataset_type ], result=(expected_federated_server_state_type, expected_model_output_types)), iterative_process.next.type_signature)
def test_orchestration_type_signature(self): iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.TrainableLinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=lambda: gradient_descent.SGD(learning_rate=1.0 )) expected_model_weights_type = model_utils.ModelWeights( collections.OrderedDict([('a', tff.TensorType(tf.float32, [2, 1])), ('b', tf.float32)]), collections.OrderedDict([('c', tf.float32)])) # ServerState consists of a model and optimizer_state. The optimizer_state # is provided by TensorFlow, TFF doesn't care what the actual value is. expected_federated_server_state_type = tff.FederatedType( optimizer_utils.ServerState(expected_model_weights_type, test.AnyType(), test.AnyType(), test.AnyType()), placement=tff.SERVER, all_equal=True) expected_federated_dataset_type = tff.FederatedType(tff.SequenceType( model_examples.TrainableLinearRegression().input_spec), tff.CLIENTS, all_equal=False) expected_model_output_types = tff.FederatedType( collections.OrderedDict([ ('loss', tff.TensorType(tf.float32, [])), ('num_examples', tff.TensorType(tf.int32, [])), ]), tff.SERVER, all_equal=True) # `initialize` is expected to be a funcion of no arguments to a ServerState. self.assertEqual( tff.FunctionType(parameter=None, result=expected_federated_server_state_type), iterative_process.initialize.type_signature) # `next` is expected be a function of (ServerState, Datasets) to # ServerState. self.assertEqual( tff.FunctionType(parameter=[ expected_federated_server_state_type, expected_federated_dataset_type ], result=(expected_federated_server_state_type, expected_model_output_types)), iterative_process.next.type_signature)
def test_construction(self): iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=tf.keras.optimizers.SGD) server_state_type = computation_types.FederatedType( optimizer_utils.ServerState( model=model_utils.ModelWeights( trainable=[ computation_types.TensorType(tf.float32, [2, 1]), computation_types.TensorType(tf.float32) ], non_trainable=[computation_types.TensorType(tf.float32)]), optimizer_state=[tf.int64], delta_aggregate_state=(), model_broadcast_state=()), placements.SERVER) self.assertEqual( str(iterative_process.initialize.type_signature), str( computation_types.FunctionType( parameter=None, result=server_state_type))) dataset_type = computation_types.FederatedType( computation_types.SequenceType( collections.OrderedDict( x=computation_types.TensorType(tf.float32, [None, 2]), y=computation_types.TensorType(tf.float32, [None, 1]))), placements.CLIENTS) metrics_type = computation_types.FederatedType( collections.OrderedDict( broadcast=(), aggregation=(), train=collections.OrderedDict( loss=computation_types.TensorType(tf.float32), num_examples=computation_types.TensorType(tf.int32))), placements.SERVER) self.assertEqual( str(iterative_process.next.type_signature), str( computation_types.FunctionType( parameter=collections.OrderedDict( server_state=server_state_type, federated_dataset=dataset_type, ), result=(server_state_type, metrics_type))))
def test_model_weights_from_tff_struct(self): trainable_weights = [tf.constant([1., 1.])] non_trainable_weights = [tf.constant(1)] model_weights = model_utils.ModelWeights( trainable=trainable_weights, non_trainable=non_trainable_weights) tff_struct = structure.Struct([ ('trainable', structure.from_container(trainable_weights)), ('non_trainable', structure.from_container(non_trainable_weights)) ]) model_weights_from_tff_struct = model_utils.ModelWeights.from_tff_result( tff_struct) self.assertEqual(model_weights.trainable, model_weights_from_tff_struct.trainable) self.assertEqual(model_weights.non_trainable, model_weights_from_tff_struct.non_trainable)
def test_type_properties(self, weighting): model_fn = model_examples.LinearRegression optimizer = sgdm.build_sgdm(learning_rate=0.1, momentum=0.9) client_work_process = mime._build_mime_lite_client_work( model_fn, optimizer, weighting) self.assertIsInstance(client_work_process, client_works.ClientWorkProcess) mw_type = model_utils.ModelWeights( trainable=computation_types.to_type([(tf.float32, (2, 1)), tf.float32]), non_trainable=computation_types.to_type([tf.float32])) expected_param_model_weights_type = computation_types.at_clients( mw_type) expected_param_data_type = computation_types.at_clients( computation_types.SequenceType( computation_types.to_type(model_fn().input_spec))) expected_result_type = computation_types.at_clients( client_works.ClientResult( update=mw_type.trainable, update_weight=computation_types.TensorType(tf.float32))) expected_optimizer_state_type = type_conversions.type_from_tensors( optimizer.initialize( type_conversions.type_to_tf_tensor_specs(mw_type.trainable))) expected_aggregator_type = computation_types.to_type( collections.OrderedDict(value_sum_process=(), weight_sum_process=())) expected_state_type = computation_types.at_server( (expected_optimizer_state_type, expected_aggregator_type)) expected_measurements_type = computation_types.at_server( collections.OrderedDict(train=collections.OrderedDict( loss=tf.float32, num_examples=tf.int32))) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) expected_initialize_type.check_equivalent_to( client_work_process.initialize.type_signature) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, weights=expected_param_model_weights_type, client_data=expected_param_data_type), result=measured_process.MeasuredProcessOutput( expected_state_type, expected_result_type, expected_measurements_type)) expected_next_type.check_equivalent_to( client_work_process.next.type_signature)
def test_state_with_new_model_weights(self): trainable = [np.array([1.0, 2.0]), np.array([[1.0]])] non_trainable = [np.array(1)] state = anonymous_tuple.from_container( optimizer_utils.ServerState( model=model_utils.ModelWeights( trainable=trainable, non_trainable=non_trainable), optimizer_state=[], delta_aggregate_state=tf.constant(0), model_broadcast_state=tf.constant(0)), recursive=True) new_state = optimizer_utils.state_with_new_model_weights( state, trainable_weights=[np.array([3.0, 3.0]), np.array([[3.0]])], non_trainable_weights=[np.array(3)]) self.assertAllClose( new_state.model.trainable, [np.array([3.0, 3.0]), np.array([[3.0]])]) self.assertAllClose(new_state.model.non_trainable, [3]) with self.assertRaisesRegex(TypeError, 'tensor type'): optimizer_utils.state_with_new_model_weights( state, trainable_weights=[np.array([3.0, 3.0]), np.array([[3]])], non_trainable_weights=[np.array(3.0)]) with self.assertRaisesRegex(TypeError, 'tensor type'): optimizer_utils.state_with_new_model_weights( state, trainable_weights=[np.array([3.0, 3.0]), np.array([3.0])], non_trainable_weights=[np.array(3)]) with self.assertRaisesRegex(TypeError, 'different lengths'): optimizer_utils.state_with_new_model_weights( state, trainable_weights=[np.array([3.0, 3.0])], non_trainable_weights=[np.array(3)]) with self.assertRaisesRegex(TypeError, 'cannot be handled'): optimizer_utils.state_with_new_model_weights( state, trainable_weights={'a': np.array([3.0, 3.0])}, non_trainable_weights=[np.array(3)])
def test_execution_with_nearly_stateless_keras_optimizer(self): server_optimizer_fn = lambda: tf.keras.optimizers.SGD(learning_rate=1.0) # Note that SGD only maintains a counter of how many times it has been # called. No other state is used. finalizer = finalizers.build_apply_optimizer_finalizer( server_optimizer_fn, MODEL_WEIGHTS_TYPE.member) weights = model_utils.ModelWeights(1.0, ()) update = 0.1 optimizer_state = finalizer.initialize() for i in range(5): output = finalizer.next(optimizer_state, weights, update) optimizer_state = output.state weights = output.result # We check that the optimizer state is the number of calls. self.assertEqual([i + 1], optimizer_state) self.assertAllClose(1.0 - 0.1 * (i + 1), weights.trainable) self.assertEqual((), output.measurements)
def test_execution_with_stateful_tff_optimizer(self): momentum = 0.5 finalizer = finalizers.build_apply_optimizer_finalizer( sgdm.build_sgdm(1.0, momentum=momentum), MODEL_WEIGHTS_TYPE.member) weights = model_utils.ModelWeights(1.0, ()) update = 0.1 expected_velocity = 0.0 optimizer_state = finalizer.initialize() for _ in range(5): output = finalizer.next(optimizer_state, weights, update) optimizer_state = output.state expected_velocity = expected_velocity * momentum + update self.assertNear(expected_velocity, optimizer_state['accumulator'], 1e-6) self.assertAllClose(weights.trainable - expected_velocity, output.result.trainable) self.assertEqual((), output.measurements) weights = output.result
def test_state_with_new_model_weights(self): trainable = [('b', np.array([1.0, 2.0])), ('a', np.array([[1.0]]))] non_trainable = [('c', np.array(1))] state = anonymous_tuple.from_container( optimizer_utils.ServerState( model=model_utils.ModelWeights( trainable=collections.OrderedDict(trainable), non_trainable=collections.OrderedDict(non_trainable)), optimizer_state=[], delta_aggregate_state=tf.constant(0), model_broadcast_state=tf.constant(0)), recursive=True) new_state = optimizer_utils.state_with_new_model_weights( state, trainable_weights=[np.array([3.0, 3.0]), np.array([[3.0]])], non_trainable_weights=[np.array(3)]) self.assertEqual(list(new_state.model.trainable.keys()), ['b', 'a']) self.assertEqual(list(new_state.model.non_trainable.keys()), ['c']) self.assertAllClose(new_state.model.trainable['b'], [3.0, 3.0]) self.assertAllClose(new_state.model.trainable['a'], [[3.0]]) self.assertAllClose(new_state.model.non_trainable['c'], 3) with self.assertRaisesRegexp(ValueError, 'dtype'): optimizer_utils.state_with_new_model_weights( state, trainable_weights=[np.array([3.0, 3.0]), np.array([[3]])], non_trainable_weights=[np.array(3.0)]) with self.assertRaisesRegexp(ValueError, 'shape'): optimizer_utils.state_with_new_model_weights( state, trainable_weights=[np.array([3.0, 3.0]), np.array([3.0])], non_trainable_weights=[np.array(3)]) with self.assertRaisesRegexp(ValueError, 'Lengths differ'): optimizer_utils.state_with_new_model_weights( state, trainable_weights=[np.array([3.0, 3.0])], non_trainable_weights=[np.array(3)])
def test_state_with_new_model_weights_failure(self, new_trainable, new_non_trainable, expected_err_msg): trainable = [np.array([1.0, 2.0]), np.array([[1.0]]), np.int64(3)] non_trainable = [np.array(1), b'bytes type', 5, 2.0] state = optimizer_utils.ServerState( model=model_utils.ModelWeights(trainable=trainable, non_trainable=non_trainable), optimizer_state=[], delta_aggregate_state=tf.constant(0), model_broadcast_state=tf.constant(0)) new_trainable = trainable if new_trainable is None else new_trainable non_trainable = non_trainable if new_non_trainable is None else non_trainable with self.assertRaisesRegex(TypeError, expected_err_msg): optimizer_utils.state_with_new_model_weights( state, trainable_weights=new_trainable, non_trainable_weights=new_non_trainable)
def test_state_with_model_weights_success(self): trainable = [np.array([1.0, 2.0]), np.array([[1.0]]), np.int64(3)] non_trainable = [np.array(1), b'bytes type', 5, 2.0] new_trainable = [np.array([3.0, 3.0]), np.array([[3.0]]), np.int64(4)] new_non_trainable = [np.array(3), b'bytes check', 6, 3.0] state = optimizer_utils.ServerState( model=model_utils.ModelWeights(trainable=trainable, non_trainable=non_trainable), optimizer_state=[], delta_aggregate_state=tf.constant(0), model_broadcast_state=tf.constant(0)) new_state = optimizer_utils.state_with_new_model_weights( state, trainable_weights=new_trainable, non_trainable_weights=new_non_trainable) self.assertAllClose(new_state.model.trainable, new_trainable) self.assertEqual(new_state.model.non_trainable, new_non_trainable)
def test_type_properties(self): model_fn = model_examples.LinearRegression client_work_process = client_works.build_model_delta_client_work( model_fn, sgdm.build_sgdm(1.0)) self.assertIsInstance(client_work_process, client_works.ClientWorkProcess) mw_type = model_utils.ModelWeights( trainable=computation_types.to_type([(tf.float32, (2, 1)), tf.float32]), non_trainable=computation_types.to_type([tf.float32])) expected_param_model_weights_type = computation_types.at_clients( mw_type) expected_param_data_type = computation_types.at_clients( computation_types.SequenceType( computation_types.to_type(model_fn().input_spec))) expected_result_type = computation_types.at_clients( client_works.ClientResult( update=mw_type.trainable, update_weight=computation_types.TensorType(tf.float32))) expected_state_type = computation_types.at_server(()) expected_measurements_type = computation_types.at_server(()) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) expected_initialize_type.check_equivalent_to( client_work_process.initialize.type_signature) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, weights=expected_param_model_weights_type, client_data=expected_param_data_type), result=MeasuredProcessOutput(expected_state_type, expected_result_type, expected_measurements_type)) expected_next_type.check_equivalent_to( client_work_process.next.type_signature)