def test_executes_with(self, spec): weights = tf.nest.map_structure(lambda s: tf.ones(s.shape, s.dtype), spec) gradients = tf.nest.map_structure(lambda s: tf.ones(s.shape, s.dtype), spec) optimizer = adam.build_adam(0.01) state = optimizer.initialize(spec) for _ in range(10): state, weights = optimizer.next(state, weights, gradients) tf.nest.map_structure(lambda w: self.assertTrue(all(tf.math.is_finite(w))), weights)
def test_state_structure(self): optimizer = adam.build_adam(0.01) state = optimizer.initialize(_SCALAR_SPEC) self.assertLen(state, 7) self.assertIn(optimizer_base.LEARNING_RATE_KEY, state) self.assertIn(adam._BETA_1_KEY, state) self.assertIn(adam._BETA_2_KEY, state) self.assertIn(adam._EPSILON_KEY, state) self.assertIn(adam._STEP_KEY, state) self.assertIn(adam._PRECONDITIONER_KEY, state) self.assertIn(adam._ACCUMULATOR_KEY, state)
def test_convergence(self): init_w, fn, grad_fn = optimizer_test_utils.test_quadratic_problem() weights = init_w() self.assertGreater(fn(weights), 5.0) optimizer = adam.build_adam(0.5) state = optimizer.initialize(tf.TensorSpec(weights.shape, weights.dtype)) for _ in range(100): gradients = grad_fn(weights) state, weights = optimizer.next(state, weights, gradients) self.assertLess(fn(weights), 0.005)
def test_math(self): weights = tf.constant([1.0], tf.float32) gradients = tf.constant([2.0], tf.float32) optimizer = adam.build_adam(0.1, beta_1=0.9, beta_2=0.999, epsilon=0.0) history = [weights] state = optimizer.initialize(_SCALAR_SPEC) for _ in range(4): state, weights = optimizer.next(state, weights, gradients) history.append(weights) self.assertAllClose( [[1.0], [0.9000007], [0.8000017], [0.700002], [0.600003]], history)
class IntegrationTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters( ('adagrad_scalar', adagrad.build_adagrad(0.1), _SCALAR_SPEC), ('adagrad_struct', adagrad.build_adagrad(0.1), _STRUCT_SPEC), ('adagrad_nested', adagrad.build_adagrad(0.1), _NESTED_SPEC), ('adam_scalar', adam.build_adam(0.1), _SCALAR_SPEC), ('adam_struct', adam.build_adam(0.1), _STRUCT_SPEC), ('adam_nested', adam.build_adam(0.1), _NESTED_SPEC), ('rmsprop_scalar', rmsprop.build_rmsprop(0.1), _SCALAR_SPEC), ('rmsprop_struct', rmsprop.build_rmsprop(0.1), _STRUCT_SPEC), ('rmsprop_nested', rmsprop.build_rmsprop(0.1), _NESTED_SPEC), ('scheduled_sgd_scalar', _scheduled_sgd(), _SCALAR_SPEC), ('scheduled_sgd_struct', _scheduled_sgd(), _STRUCT_SPEC), ('scheduled_sgd_nested', _scheduled_sgd(), _NESTED_SPEC), ('sgd_scalar', sgdm.build_sgdm(0.1), _SCALAR_SPEC), ('sgd_struct', sgdm.build_sgdm(0.1), _STRUCT_SPEC), ('sgd_nested', sgdm.build_sgdm(0.1), _NESTED_SPEC), ('sgdm_scalar', sgdm.build_sgdm(0.1, 0.9), _SCALAR_SPEC), ('sgdm_struct', sgdm.build_sgdm(0.1, 0.9), _STRUCT_SPEC), ('sgdm_nested', sgdm.build_sgdm(0.1, 0.9), _NESTED_SPEC), ('yogi_scalar', yogi.build_yogi(0.1), _SCALAR_SPEC), ('yogi_struct', yogi.build_yogi(0.1), _STRUCT_SPEC), ('yogi_nested', yogi.build_yogi(0.1), _NESTED_SPEC), ) def test_integration_produces_identical_results(self, optimizer, spec): eager_history = _run_in_eager_mode(optimizer, spec) tf_comp_history = _run_in_tf_computation(optimizer, spec) federated_comp_history = _run_in_federated_computation(optimizer, spec) self.assertAllClose(eager_history, tf_comp_history, rtol=1e-5, atol=1e-5) self.assertAllClose(eager_history, federated_comp_history, rtol=1e-5, atol=1e-5)
def test_executes_with_indexed_slices(self): # TF can represent gradients as tf.IndexedSlices. This test makes sure this # case is supported by the optimizer. weights = tf.ones([4, 2]) gradients = tf.IndexedSlices( values=tf.constant([[1.0, 1.0], [1.0, 1.0]]), indices=tf.constant([0, 2]), dense_shape=tf.constant([4, 2])) # Always-zero preconditioner and accumulator, for simplicity of this test. optimizer = adam.build_adam(0.5, beta_1=0.0, beta_2=0.0) state = optimizer.initialize(tf.TensorSpec([4, 2])) _, weights = optimizer.next(state, weights, gradients) self.assertAllClose([[0.5, 0.5], [1.0, 1.0], [0.5, 0.5], [1.0, 1.0]], weights)
class ScheduledLROptimizerTest(parameterized.TestCase, tf.test.TestCase): def test_scheduled_sgd_computes_correctly(self): scheduled_sgd = scheduling.schedule_learning_rate( sgdm.build_sgdm(1.0), _example_schedule_fn) weight = tf.constant(1.0) gradient = tf.constant(1.0) state = scheduled_sgd.initialize(tf.TensorSpec((), tf.float32)) state, weight = scheduled_sgd.next(state, weight, gradient) self.assertAllClose(0.9, weight) # Learning rate initially 0.1. state, weight = scheduled_sgd.next(state, weight, gradient) self.assertAllClose(0.8, weight) state, weight = scheduled_sgd.next(state, weight, gradient) self.assertAllClose(0.79, weight) # Learning rate has decreased to 0.01. state, weight = scheduled_sgd.next(state, weight, gradient) self.assertAllClose(0.78, weight) @parameterized.named_parameters( ('adagrad', adagrad.build_adagrad(1.0)), ('adam', adam.build_adam(1.0)), ('rmsprop', rmsprop.build_rmsprop(1.0)), ('sgd', sgdm.build_sgdm(1.0)), ('sgdm', sgdm.build_sgdm(1.0, momentum=0.9)), ('yogi', yogi.build_yogi(1.0)), ) def test_schedule_learning_rate_integrates_with(self, optimizer): scheduled_optimizer = scheduling.schedule_learning_rate( optimizer, _example_schedule_fn) self.assertIsInstance(scheduled_optimizer, optimizer_base.Optimizer) def test_keras_optimizer_raises(self): keras_optimizer = tf.keras.optimizers.SGD(1.0) with self.assertRaises(TypeError): scheduling.schedule_learning_rate(keras_optimizer, _example_schedule_fn) def test_scheduling_scheduled_optimizer_raises(self): scheduled_optimizer = scheduling.schedule_learning_rate( sgdm.build_sgdm(1.0), _example_schedule_fn) twice_scheduled_optimizer = scheduling.schedule_learning_rate( scheduled_optimizer, _example_schedule_fn) with self.assertRaisesRegex(KeyError, 'must have learning rate'): twice_scheduled_optimizer.initialize(tf.TensorSpec((), tf.float32))
def test_match_keras(self): weight_spec = [ tf.TensorSpec([10, 2], tf.float32), tf.TensorSpec([2], tf.float32) ] steps = 10 genarator = tf.random.Generator.from_seed(2021) def random_vector(): return [ genarator.normal(shape=s.shape, dtype=s.dtype) for s in weight_spec ] intial_weight = random_vector() model_variables_fn = lambda: [tf.Variable(v) for v in intial_weight] gradients = [random_vector() for _ in range(steps)] tff_optimizer_fn = lambda: adam.build_adam(0.01, 0.9, 0.999) keras_optimizer_fn = lambda: tf.keras.optimizers.Adam(0.01, 0.9, 0.999) self.assert_optimizers_numerically_close(model_variables_fn, gradients, tff_optimizer_fn, keras_optimizer_fn)
class MimeLiteTest(tf.test.TestCase, parameterized.TestCase): """Tests construction of the Mime Lite training process.""" def test_construction_calls_model_fn(self): # Assert that the process building does not call `model_fn` too many times. # `model_fn` can potentially be expensive (loading weights, processing, etc # ). mock_model_fn = mock.Mock(side_effect=model_examples.LinearRegression) mime.build_weighted_mime_lite(model_fn=mock_model_fn, base_optimizer=sgdm.build_sgdm( learning_rate=0.01, momentum=0.9)) self.assertEqual(mock_model_fn.call_count, 3) @parameterized.named_parameters( ('non-simulation_tff_optimizer', False), ('simulation_tff_optimizer', True), ) @mock.patch.object(dataset_reduce, '_dataset_reduce_fn', wraps=dataset_reduce._dataset_reduce_fn) def test_client_tf_dataset_reduce_fn(self, simulation, mock_method): mime.build_weighted_mime_lite( model_fn=model_examples.LinearRegression, base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9), use_experimental_simulation_loop=simulation) if simulation: mock_method.assert_not_called() else: mock_method.assert_called() @mock.patch.object(mime, 'build_weighted_mime_lite') def test_build_weighted_mime_lite_called_by_unweighted_mime_lite( self, mock_mime_lite): mime.build_unweighted_mime_lite( model_fn=model_examples.LinearRegression, base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9)) self.assertEqual(mock_mime_lite.call_count, 1) @mock.patch.object(mime, 'build_weighted_mime_lite') @mock.patch.object(factory_utils, 'as_weighted_aggregator') def test_aggregation_wrapper_called_by_unweighted(self, _, mock_as_weighted): mime.build_unweighted_mime_lite( model_fn=model_examples.LinearRegression, base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9)) self.assertEqual(mock_as_weighted.call_count, 1) def test_raises_on_non_callable_model_fn(self): with self.assertRaises(TypeError): mime.build_weighted_mime_lite( model_fn=model_examples.LinearRegression(), base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9)) def test_raises_on_invalid_client_weighting(self): with self.assertRaises(TypeError): mime.build_weighted_mime_lite( model_fn=model_examples.LinearRegression, base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9), client_weighting='uniform') def test_raises_on_invalid_distributor(self): model_weights_type = type_conversions.type_from_tensors( model_utils.ModelWeights.from_model( model_examples.LinearRegression())) distributor = distributors.build_broadcast_process(model_weights_type) invalid_distributor = iterative_process.IterativeProcess( distributor.initialize, distributor.next) with self.assertRaises(TypeError): mime.build_weighted_mime_lite( model_fn=model_examples.LinearRegression, base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9), model_distributor=invalid_distributor) def test_weighted_mime_lite_raises_on_unweighted_aggregator(self): aggregator = model_update_aggregator.robust_aggregator(weighted=False) with self.assertRaisesRegex(TypeError, 'WeightedAggregationFactory'): mime.build_weighted_mime_lite( model_fn=model_examples.LinearRegression, base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9), model_aggregator=aggregator) with self.assertRaisesRegex(TypeError, 'WeightedAggregationFactory'): mime.build_weighted_mime_lite( model_fn=model_examples.LinearRegression, base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9), full_gradient_aggregator=aggregator) def test_unweighted_mime_lite_raises_on_weighted_aggregator(self): aggregator = model_update_aggregator.robust_aggregator(weighted=True) with self.assertRaisesRegex(TypeError, 'UnweightedAggregationFactory'): mime.build_unweighted_mime_lite( model_fn=model_examples.LinearRegression, base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9), model_aggregator=aggregator) with self.assertRaisesRegex(TypeError, 'UnweightedAggregationFactory'): mime.build_unweighted_mime_lite( model_fn=model_examples.LinearRegression, base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9), full_gradient_aggregator=aggregator) def test_weighted_mime_lite_with_only_secure_aggregation(self): aggregator = model_update_aggregator.secure_aggregator(weighted=True) learning_process = mime.build_weighted_mime_lite( model_examples.LinearRegression, base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9), model_aggregator=aggregator, full_gradient_aggregator=aggregator, metrics_aggregator=metrics_aggregator.secure_sum_then_finalize) static_assert.assert_not_contains_unsecure_aggregation( learning_process.next) def test_unweighted_mime_lite_with_only_secure_aggregation(self): aggregator = model_update_aggregator.secure_aggregator(weighted=False) learning_process = mime.build_unweighted_mime_lite( model_examples.LinearRegression, base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9), model_aggregator=aggregator, full_gradient_aggregator=aggregator, metrics_aggregator=metrics_aggregator.secure_sum_then_finalize) static_assert.assert_not_contains_unsecure_aggregation( learning_process.next) @tensorflow_test_utils.skip_test_for_multi_gpu def test_equivalent_to_vanilla_fed_avg(self): # Mime Lite with no-momentum SGD should reduce to FedAvg. mime_process = mime.build_weighted_mime_lite( model_fn=_create_model, base_optimizer=sgdm.build_sgdm(0.1)) fed_avg_process = fed_avg.build_weighted_fed_avg( model_fn=_create_model, client_optimizer_fn=sgdm.build_sgdm(0.1)) client_data = [_create_dataset()] mime_state = mime_process.initialize() fed_avg_state = fed_avg_process.initialize() for _ in range(3): mime_output = mime_process.next(mime_state, client_data) mime_state = mime_output.state mime_metrics = mime_output.metrics fed_avg_output = fed_avg_process.next(fed_avg_state, client_data) fed_avg_state = fed_avg_output.state fed_avg_metrics = fed_avg_output.metrics self.assertAllClose( tf.nest.flatten(mime_process.get_model_weights(mime_state)), tf.nest.flatten( fed_avg_process.get_model_weights(fed_avg_state))) self.assertAllClose( mime_metrics['client_work']['train']['loss'], fed_avg_metrics['client_work']['train']['loss']) self.assertAllClose( mime_metrics['client_work']['train']['num_examples'], fed_avg_metrics['client_work']['train']['num_examples']) @parameterized.named_parameters( ('sgdm_sgd', sgdm.build_sgdm(0.1, 0.9), sgdm.build_sgdm(1.0)), ('sgdm_sgdm', sgdm.build_sgdm(0.1, 0.9), sgdm.build_sgdm(1.0, 0.9)), ('sgdm_adam', sgdm.build_sgdm(0.1, 0.9), adam.build_adam(1.0)), ('adagrad_sgdm', adagrad.build_adagrad(0.1), sgdm.build_sgdm(1.0, 0.9)), ) @tensorflow_test_utils.skip_test_for_multi_gpu def test_execution_with_optimizers(self, base_optimizer, server_optimizer): learning_process = mime.build_weighted_mime_lite( model_fn=_create_model, base_optimizer=base_optimizer, server_optimizer=server_optimizer) client_data = [_create_dataset()] state = learning_process.initialize() for _ in range(3): output = learning_process.next(state, client_data) state = output.state metrics = output.metrics self.assertEqual(8, metrics['client_work']['train']['num_examples'])
class MimeLiteClientWorkExecutionTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters(('non-simulation', False), ('simulation', True)) @mock.patch.object(dataset_reduce, '_dataset_reduce_fn', wraps=dataset_reduce._dataset_reduce_fn) @tensorflow_test_utils.skip_test_for_multi_gpu def test_client_tf_dataset_reduce_fn(self, simulation, mock_method): process = mime._build_mime_lite_client_work( model_fn=_create_model, optimizer=sgdm.build_sgdm(learning_rate=0.1, momentum=0.9), client_weighting=client_weight_lib.ClientWeighting.NUM_EXAMPLES, use_experimental_simulation_loop=simulation) client_data = [_create_dataset()] client_model_weights = [_initial_weights()] process.next(process.initialize(), client_model_weights, client_data) if simulation: mock_method.assert_not_called() else: mock_method.assert_called() @parameterized.named_parameters( ('adagrad', adagrad.build_adagrad(0.1)), ('adam', adam.build_adam(0.1)), ('rmsprop', rmsprop.build_rmsprop(0.1)), ('sgd', sgdm.build_sgdm(0.1)), ('sgdm', sgdm.build_sgdm(0.1, momentum=0.9)), ('yogi', yogi.build_yogi(0.1))) @tensorflow_test_utils.skip_test_for_multi_gpu def test_execution_with_optimizer(self, optimizer): process = mime._build_mime_lite_client_work( _create_model, optimizer, client_weighting=client_weight_lib.ClientWeighting.NUM_EXAMPLES) client_data = [_create_dataset()] client_model_weights = [_initial_weights()] state = process.initialize() output = process.next(state, client_model_weights, client_data) self.assertEqual(8, output.measurements['train']['num_examples']) @tensorflow_test_utils.skip_test_for_multi_gpu def test_custom_metrics_aggregator(self): def sum_then_finalize_then_times_two(metric_finalizers, local_unfinalized_metrics_type): @federated_computation.federated_computation( computation_types.at_clients(local_unfinalized_metrics_type)) def aggregation_computation(client_local_unfinalized_metrics): unfinalized_metrics_sum = intrinsics.federated_sum( client_local_unfinalized_metrics) @tensorflow_computation.tf_computation( local_unfinalized_metrics_type) def finalizer_computation(unfinalized_metrics): finalized_metrics = collections.OrderedDict() for metric_name, metric_finalizer in metric_finalizers.items( ): finalized_metrics[metric_name] = metric_finalizer( unfinalized_metrics[metric_name]) * 2 return finalized_metrics return intrinsics.federated_map(finalizer_computation, unfinalized_metrics_sum) return aggregation_computation process = mime._build_mime_lite_client_work( model_fn=_create_model, optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9), client_weighting=client_weight_lib.ClientWeighting.NUM_EXAMPLES, metrics_aggregator=sum_then_finalize_then_times_two) client_model_weights = [_initial_weights()] client_data = [_create_dataset()] output = process.next(process.initialize(), client_model_weights, client_data) # Train metrics should be multiplied by two by the custom aggregator. self.assertEqual(output.measurements['train']['num_examples'], 16)
def test_initialize_next_weights_mismatch_raises(self): optimizer = adam.build_adam(0.1) state = optimizer.initialize(_SCALAR_SPEC) with self.assertRaises(ValueError): optimizer.next(state, tf.zeros([2]), tf.zeros([2]))
def test_invalid_args_raises(self, lr, beta_1, beta_2, epsilon, regex): with self.assertRaisesRegex(ValueError, regex): adam.build_adam(lr, beta_1, beta_2, epsilon)
def test_build_adam(self): optimizer = adam.build_adam(0.01) self.assertIsInstance(optimizer, optimizer_base.Optimizer)