def test_execution_stateful_optimizer(self): client_work_process = client_works.build_model_delta_client_work( model_examples.LinearRegression, sgdm.build_sgdm(0.1, momentum=0.9)) data = tf.data.Dataset.from_tensor_slices( collections.OrderedDict( x=[[1.0, 2.0], [3.0, 4.0]], y=[[5.0], [6.0]], )).batch(2) data = [data, data.repeat(2)] # 1st client has 2 examples, 2nd has 4. model_weights = model_utils.ModelWeights(trainable=[[[0.0], [0.0]], 0.0], non_trainable=[0.0]) client_model_weights = [model_weights] * 2 state = client_work_process.initialize() output = client_work_process.next(state, client_model_weights, data) expected_result = ( client_works.ClientResult([[[-1.15], [-1.7]], -0.55], 2.0), client_works.ClientResult([[[-1.46], [-2.26]], -0.8], 4.0), ) self.assertEqual((), output.state) for i in range(len(expected_result)): self.assertAllClose(expected_result[i].update, output.result[i].update) self.assertAllClose(expected_result[i].update_weight, output.result[i].update_weight) self.assertEqual((), output.measurements)
def test_type_properties(self): mw_type = computation_types.to_type( model_utils.ModelWeights(trainable=(tf.float32, tf.float32), non_trainable=tf.float32)) finalizer = finalizers.build_apply_optimizer_finalizer( sgdm.build_sgdm(1.0), mw_type) self.assertIsInstance(finalizer, finalizers.FinalizerProcess) expected_param_weights_type = computation_types.at_server(mw_type) expected_param_update_type = computation_types.at_server( mw_type.trainable) expected_result_type = computation_types.at_server(mw_type) expected_state_type = computation_types.at_server(()) expected_measurements_type = computation_types.at_server(()) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) expected_initialize_type.check_equivalent_to( finalizer.initialize.type_signature) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, weights=expected_param_weights_type, update=expected_param_update_type), result=MeasuredProcessOutput(expected_state_type, expected_result_type, expected_measurements_type)) expected_next_type.check_equivalent_to(finalizer.next.type_signature)
def test_state_structure_momentum(self): optimizer = sgdm.build_sgdm(0.01, momentum=0.9) state = optimizer.initialize(_SCALAR_SPEC) self.assertLen(state, 3) self.assertIn(optimizer_base.LEARNING_RATE_KEY, state) self.assertIn(sgdm._MOMENTUM_KEY, state) self.assertIn(sgdm._ACCUMULATOR_KEY, state)
def test_delta_regularizer_yields_smaller_model_delta(self, optimizer): simple_process = model_delta_client_work.build_model_delta_client_work( self.create_model, optimizer, client_weighting=client_weight_lib.ClientWeighting.NUM_EXAMPLES, delta_l2_regularizer=0.0) proximal_process = model_delta_client_work.build_model_delta_client_work( self.create_model, sgdm.build_sgdm(1.0), client_weighting=client_weight_lib.ClientWeighting.NUM_EXAMPLES, delta_l2_regularizer=1.0) client_data = [create_test_dataset()] client_model_weights = [create_test_initial_weights()] simple_output = simple_process.next(simple_process.initialize(), client_model_weights, client_data) proximal_output = proximal_process.next(proximal_process.initialize(), client_model_weights, client_data) simple_update_norm = tf.linalg.global_norm( tf.nest.flatten(simple_output.result[0].update)) proximal_update_norm = tf.linalg.global_norm( tf.nest.flatten(proximal_output.result[0].update)) self.assertGreater(simple_update_norm, proximal_update_norm) self.assertEqual(simple_output.measurements['train']['num_examples'], proximal_output.measurements['train']['num_examples'])
def test_scheduling_scheduled_optimizer_raises(self): scheduled_optimizer = scheduling.schedule_learning_rate( sgdm.build_sgdm(1.0), _example_schedule_fn) twice_scheduled_optimizer = scheduling.schedule_learning_rate( scheduled_optimizer, _example_schedule_fn) with self.assertRaisesRegex(KeyError, 'must have learning rate'): twice_scheduled_optimizer.initialize(tf.TensorSpec((), tf.float32))
def test_aggregation_wrapper_called_by_unweighted(self, _, mock_as_weighted): fed_prox.build_unweighted_fed_prox( model_fn=model_examples.LinearRegression, proximal_strength=1.0, client_optimizer_fn=sgdm.build_sgdm(1.0)) self.assertEqual(mock_as_weighted.call_count, 1)
def test_build_weighted_fed_prox_called_by_unweighted_fed_prox( self, mock_fed_avg): fed_prox.build_unweighted_fed_prox( model_fn=model_examples.LinearRegression, proximal_strength=1.0, client_optimizer_fn=sgdm.build_sgdm(1.0)) self.assertEqual(mock_fed_avg.call_count, 1)
def test_negative_proximal_strength_raises(self): with self.assertRaises(ValueError): model_delta_client_work.build_model_delta_client_work( model_examples.LinearRegression, sgdm.build_sgdm(1.0), client_weighting=client_weight_lib.ClientWeighting. NUM_EXAMPLES, delta_l2_regularizer=-1.0)
def test_weighted_fed_avg_raises_on_unweighted_aggregator(self): model_aggregator = model_update_aggregator.robust_aggregator( weighted=False) with self.assertRaisesRegex(TypeError, 'WeightedAggregationFactory'): fed_prox.build_weighted_fed_prox( model_fn=model_examples.LinearRegression, proximal_strength=1.0, client_optimizer_fn=sgdm.build_sgdm(1.0), model_aggregator=model_aggregator)
def test_execution(self): finalizer = finalizers.build_apply_optimizer_finalizer( sgdm.build_sgdm(1.0), MODEL_WEIGHTS_TYPE.member) weights = model_utils.ModelWeights(1.0, ()) update = 0.1 output = finalizer.next(finalizer.initialize(), weights, update) self.assertEqual((), output.state) self.assertAllClose(0.9, output.result.trainable) self.assertEqual((), output.measurements)
class ScheduledLROptimizerTest(parameterized.TestCase, tf.test.TestCase): def test_scheduled_sgd_computes_correctly(self): scheduled_sgd = scheduling.schedule_learning_rate( sgdm.build_sgdm(1.0), _example_schedule_fn) weight = tf.constant(1.0) gradient = tf.constant(1.0) state = scheduled_sgd.initialize(tf.TensorSpec((), tf.float32)) state, weight = scheduled_sgd.next(state, weight, gradient) self.assertAllClose(0.9, weight) # Learning rate initially 0.1. state, weight = scheduled_sgd.next(state, weight, gradient) self.assertAllClose(0.8, weight) state, weight = scheduled_sgd.next(state, weight, gradient) self.assertAllClose(0.79, weight) # Learning rate has decreased to 0.01. state, weight = scheduled_sgd.next(state, weight, gradient) self.assertAllClose(0.78, weight) @parameterized.named_parameters( ('adagrad', adagrad.build_adagrad(1.0)), ('adam', adam.build_adam(1.0)), ('rmsprop', rmsprop.build_rmsprop(1.0)), ('sgd', sgdm.build_sgdm(1.0)), ('sgdm', sgdm.build_sgdm(1.0, momentum=0.9)), ('yogi', yogi.build_yogi(1.0)), ) def test_schedule_learning_rate_integrates_with(self, optimizer): scheduled_optimizer = scheduling.schedule_learning_rate( optimizer, _example_schedule_fn) self.assertIsInstance(scheduled_optimizer, optimizer_base.Optimizer) def test_keras_optimizer_raises(self): keras_optimizer = tf.keras.optimizers.SGD(1.0) with self.assertRaises(TypeError): scheduling.schedule_learning_rate(keras_optimizer, _example_schedule_fn) def test_scheduling_scheduled_optimizer_raises(self): scheduled_optimizer = scheduling.schedule_learning_rate( sgdm.build_sgdm(1.0), _example_schedule_fn) twice_scheduled_optimizer = scheduling.schedule_learning_rate( scheduled_optimizer, _example_schedule_fn) with self.assertRaisesRegex(KeyError, 'must have learning rate'): twice_scheduled_optimizer.initialize(tf.TensorSpec((), tf.float32))
def test_executes_with(self, spec, momentum): weights = tf.nest.map_structure(lambda s: tf.ones(s.shape, s.dtype), spec) gradients = tf.nest.map_structure(lambda s: tf.ones(s.shape, s.dtype), spec) optimizer = sgdm.build_sgdm(0.01, momentum=momentum) state = optimizer.initialize(spec) for _ in range(10): state, weights = optimizer.next(state, weights, gradients) tf.nest.map_structure(lambda w: self.assertTrue(all(tf.math.is_finite(w))), weights)
def test_raises_on_invalid_distributor(self): model_weights_type = type_conversions.type_from_tensors( model_utils.ModelWeights.from_model( model_examples.LinearRegression())) distributor = distributors.build_broadcast_process(model_weights_type) invalid_distributor = iterative_process.IterativeProcess( distributor.initialize, distributor.next) with self.assertRaises(TypeError): fed_avg.build_weighted_fed_avg( model_fn=model_examples.LinearRegression, client_optimizer_fn=sgdm.build_sgdm(1.0), model_distributor=invalid_distributor)
def test_math_momentum_0_5(self): weights = tf.constant([1.0], tf.float32) gradients = tf.constant([2.0], tf.float32) optimizer = sgdm.build_sgdm(0.01, momentum=0.5) history = [weights] state = optimizer.initialize(_SCALAR_SPEC) for _ in range(4): state, weights = optimizer.next(state, weights, gradients) history.append(weights) self.assertAllClose([[1.0], [0.98], [0.95], [0.915], [0.8775]], history)
def test_convergence(self, momentum): init_w, fn, grad_fn = optimizer_test_utils.test_quadratic_problem() weights = init_w() self.assertGreater(fn(weights), 5.0) optimizer = sgdm.build_sgdm(0.1, momentum=momentum) state = optimizer.initialize(tf.TensorSpec(weights.shape, weights.dtype)) for _ in range(100): gradients = grad_fn(weights) state, weights = optimizer.next(state, weights, gradients) self.assertLess(fn(weights), 0.005)
class VanillaFedAvgTest(test_case.TestCase, parameterized.TestCase): def _test_data(self): return tf.data.Dataset.from_tensor_slices( collections.OrderedDict( x=[[1.0, 2.0], [3.0, 4.0]], y=[[5.0], [6.0]], )).batch(2) def _test_batch_loss(self, model, weights): tf.nest.map_structure(lambda w, v: w.assign(v), model_utils.ModelWeights.from_model(model), weights) for batch in self._test_data().take(1): batch_output = model.forward_pass(batch, training=False) return batch_output.loss def test_loss_decreases(self): model_fn = model_examples.LinearRegression test_model = model_fn() fedavg = composers.build_basic_fedavg_process(model_fn=model_fn, client_learning_rate=0.1) client_data = [self._test_data()] * 3 # 3 clients with identical data. state = fedavg.initialize() last_loss = self._test_batch_loss(test_model, state.global_model_weights) for _ in range(5): fedavg_result = fedavg.next(state, client_data) state = fedavg_result.state metrics = fedavg_result.metrics loss = self._test_batch_loss(test_model, state.global_model_weights) self.assertLess(loss, last_loss) last_loss = loss self.assertIsInstance(state, composers.LearningAlgorithmState) self.assertLen(metrics, 4) for key in ['distributor', 'client_work', 'aggregator', 'finalizer']: self.assertIn(key, metrics) def test_created_model_raises(self): with self.assertRaises(TypeError): composers.build_basic_fedavg_process( model_examples.LinearRegression(), 0.1) @parameterized.named_parameters(('int', 1), ('optimizer', sgdm.build_sgdm(0.1))) def test_wrong_client_learning_rate_raises(self, bad_client_lr): with self.assertRaises(TypeError): composers.build_basic_fedavg_process( model_examples.LinearRegression(), bad_client_lr)