Example #1
0
 def test_state_structure(self):
   optimizer = adagrad.build_adagrad(0.01)
   state = optimizer.initialize(_SCALAR_SPEC)
   self.assertLen(state, 3)
   self.assertIn(optimizer_base.LEARNING_RATE_KEY, state)
   self.assertIn(adagrad._EPSILON_KEY, state)
   self.assertIn(adagrad._PRECONDITIONER_KEY, state)
Example #2
0
  def test_executes_with(self, spec):
    weights = tf.nest.map_structure(lambda s: tf.ones(s.shape, s.dtype), spec)
    gradients = tf.nest.map_structure(lambda s: tf.ones(s.shape, s.dtype), spec)
    optimizer = adagrad.build_adagrad(0.01)

    state = optimizer.initialize(spec)
    for _ in range(10):
      state, weights = optimizer.next(state, weights, gradients)

    tf.nest.map_structure(lambda w: self.assertTrue(all(tf.math.is_finite(w))),
                          weights)
Example #3
0
  def test_convergence(self):
    init_w, fn, grad_fn = optimizer_test_utils.test_quadratic_problem()
    weights = init_w()
    self.assertGreater(fn(weights), 5.0)

    optimizer = adagrad.build_adagrad(0.5)
    state = optimizer.initialize(tf.TensorSpec(weights.shape, weights.dtype))

    for _ in range(100):
      gradients = grad_fn(weights)
      state, weights = optimizer.next(state, weights, gradients)
    self.assertLess(fn(weights), 0.005)
Example #4
0
  def test_executes_with_indexed_slices(self):
    # TF can represent gradients as tf.IndexedSlices. This test makes sure this
    # case is supported by the optimizer.
    weights = tf.ones([4, 2])
    gradients = tf.IndexedSlices(
        values=tf.constant([[1.0, 1.0], [1.0, 1.0]]),
        indices=tf.constant([0, 2]),
        dense_shape=tf.constant([4, 2]))
    optimizer = adagrad.build_adagrad(0.5, initial_preconditioner_value=0.0)

    state = optimizer.initialize(tf.TensorSpec([4, 2]))
    _, weights = optimizer.next(state, weights, gradients)
    self.assertAllClose([[0.5, 0.5], [1.0, 1.0], [0.5, 0.5], [1.0, 1.0]],
                        weights)
Example #5
0
class IntegrationTest(tf.test.TestCase, parameterized.TestCase):
    @parameterized.named_parameters(
        ('adagrad_scalar', adagrad.build_adagrad(0.1), _SCALAR_SPEC),
        ('adagrad_struct', adagrad.build_adagrad(0.1), _STRUCT_SPEC),
        ('adagrad_nested', adagrad.build_adagrad(0.1), _NESTED_SPEC),
        ('adam_scalar', adam.build_adam(0.1), _SCALAR_SPEC),
        ('adam_struct', adam.build_adam(0.1), _STRUCT_SPEC),
        ('adam_nested', adam.build_adam(0.1), _NESTED_SPEC),
        ('rmsprop_scalar', rmsprop.build_rmsprop(0.1), _SCALAR_SPEC),
        ('rmsprop_struct', rmsprop.build_rmsprop(0.1), _STRUCT_SPEC),
        ('rmsprop_nested', rmsprop.build_rmsprop(0.1), _NESTED_SPEC),
        ('scheduled_sgd_scalar', _scheduled_sgd(), _SCALAR_SPEC),
        ('scheduled_sgd_struct', _scheduled_sgd(), _STRUCT_SPEC),
        ('scheduled_sgd_nested', _scheduled_sgd(), _NESTED_SPEC),
        ('sgd_scalar', sgdm.build_sgdm(0.1), _SCALAR_SPEC),
        ('sgd_struct', sgdm.build_sgdm(0.1), _STRUCT_SPEC),
        ('sgd_nested', sgdm.build_sgdm(0.1), _NESTED_SPEC),
        ('sgdm_scalar', sgdm.build_sgdm(0.1, 0.9), _SCALAR_SPEC),
        ('sgdm_struct', sgdm.build_sgdm(0.1, 0.9), _STRUCT_SPEC),
        ('sgdm_nested', sgdm.build_sgdm(0.1, 0.9), _NESTED_SPEC),
        ('yogi_scalar', yogi.build_yogi(0.1), _SCALAR_SPEC),
        ('yogi_struct', yogi.build_yogi(0.1), _STRUCT_SPEC),
        ('yogi_nested', yogi.build_yogi(0.1), _NESTED_SPEC),
    )
    def test_integration_produces_identical_results(self, optimizer, spec):
        eager_history = _run_in_eager_mode(optimizer, spec)
        tf_comp_history = _run_in_tf_computation(optimizer, spec)
        federated_comp_history = _run_in_federated_computation(optimizer, spec)

        self.assertAllClose(eager_history,
                            tf_comp_history,
                            rtol=1e-5,
                            atol=1e-5)
        self.assertAllClose(eager_history,
                            federated_comp_history,
                            rtol=1e-5,
                            atol=1e-5)
Example #6
0
class ScheduledLROptimizerTest(parameterized.TestCase, tf.test.TestCase):

  def test_scheduled_sgd_computes_correctly(self):
    scheduled_sgd = scheduling.schedule_learning_rate(
        sgdm.build_sgdm(1.0), _example_schedule_fn)

    weight = tf.constant(1.0)
    gradient = tf.constant(1.0)
    state = scheduled_sgd.initialize(tf.TensorSpec((), tf.float32))
    state, weight = scheduled_sgd.next(state, weight, gradient)
    self.assertAllClose(0.9, weight)  # Learning rate initially 0.1.
    state, weight = scheduled_sgd.next(state, weight, gradient)
    self.assertAllClose(0.8, weight)
    state, weight = scheduled_sgd.next(state, weight, gradient)
    self.assertAllClose(0.79, weight)  # Learning rate has decreased to 0.01.
    state, weight = scheduled_sgd.next(state, weight, gradient)
    self.assertAllClose(0.78, weight)

  @parameterized.named_parameters(
      ('adagrad', adagrad.build_adagrad(1.0)),
      ('adam', adam.build_adam(1.0)),
      ('rmsprop', rmsprop.build_rmsprop(1.0)),
      ('sgd', sgdm.build_sgdm(1.0)),
      ('sgdm', sgdm.build_sgdm(1.0, momentum=0.9)),
      ('yogi', yogi.build_yogi(1.0)),
  )
  def test_schedule_learning_rate_integrates_with(self, optimizer):
    scheduled_optimizer = scheduling.schedule_learning_rate(
        optimizer, _example_schedule_fn)
    self.assertIsInstance(scheduled_optimizer, optimizer_base.Optimizer)

  def test_keras_optimizer_raises(self):
    keras_optimizer = tf.keras.optimizers.SGD(1.0)
    with self.assertRaises(TypeError):
      scheduling.schedule_learning_rate(keras_optimizer, _example_schedule_fn)

  def test_scheduling_scheduled_optimizer_raises(self):
    scheduled_optimizer = scheduling.schedule_learning_rate(
        sgdm.build_sgdm(1.0), _example_schedule_fn)
    twice_scheduled_optimizer = scheduling.schedule_learning_rate(
        scheduled_optimizer, _example_schedule_fn)
    with self.assertRaisesRegex(KeyError, 'must have learning rate'):
      twice_scheduled_optimizer.initialize(tf.TensorSpec((), tf.float32))
Example #7
0
  def test_math_no_momentum(self):
    weights = tf.constant([1.0], tf.float32)
    gradients = tf.constant([2.0], tf.float32)
    optimizer = adagrad.build_adagrad(
        learning_rate=0.01, initial_preconditioner_value=0.0, epsilon=0.0)
    history = [weights]

    state = optimizer.initialize(_SCALAR_SPEC)

    for _ in range(4):
      state, weights = optimizer.next(state, weights, gradients)
      history.append(weights)
    self.assertAllClose(
        [
            [1.0],  # w0
            [0.99],  # w1 = w0 - 0.01 * 2.0 / sqrt(4)
            [0.9829289],  # w2 = w1 - 0.01 * 2.0 / sqrt(8)
            [0.9771554],  # w3 = w2 - 0.01 * 2.0 / sqrt(12)
            [0.9721554],  # w4 = w3 - 0.01 * 2.0 / sqrt(16)
        ],
        history)
Example #8
0
  def test_match_keras(self):
    weight_spec = [
        tf.TensorSpec([10, 2], tf.float32),
        tf.TensorSpec([2], tf.float32)
    ]
    steps = 10
    genarator = tf.random.Generator.from_seed(2021)

    def random_vector():
      return [
          genarator.normal(shape=s.shape, dtype=s.dtype) for s in weight_spec
      ]

    intial_weight = random_vector()
    model_variables_fn = lambda: [tf.Variable(v) for v in intial_weight]
    gradients = [random_vector() for _ in range(steps)]
    tff_optimizer_fn = lambda: adagrad.build_adagrad(0.01)
    keras_optimizer_fn = lambda: tf.keras.optimizers.Adagrad(0.01)

    self.assert_optimizers_numerically_close(model_variables_fn, gradients,
                                             tff_optimizer_fn,
                                             keras_optimizer_fn)
Example #9
0
class MimeLiteTest(tf.test.TestCase, parameterized.TestCase):
    """Tests construction of the Mime Lite training process."""
    def test_construction_calls_model_fn(self):
        # Assert that the process building does not call `model_fn` too many times.
        # `model_fn` can potentially be expensive (loading weights, processing, etc
        # ).
        mock_model_fn = mock.Mock(side_effect=model_examples.LinearRegression)
        mime.build_weighted_mime_lite(model_fn=mock_model_fn,
                                      base_optimizer=sgdm.build_sgdm(
                                          learning_rate=0.01, momentum=0.9))
        self.assertEqual(mock_model_fn.call_count, 3)

    @parameterized.named_parameters(
        ('non-simulation_tff_optimizer', False),
        ('simulation_tff_optimizer', True),
    )
    @mock.patch.object(dataset_reduce,
                       '_dataset_reduce_fn',
                       wraps=dataset_reduce._dataset_reduce_fn)
    def test_client_tf_dataset_reduce_fn(self, simulation, mock_method):
        mime.build_weighted_mime_lite(
            model_fn=model_examples.LinearRegression,
            base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9),
            use_experimental_simulation_loop=simulation)
        if simulation:
            mock_method.assert_not_called()
        else:
            mock_method.assert_called()

    @mock.patch.object(mime, 'build_weighted_mime_lite')
    def test_build_weighted_mime_lite_called_by_unweighted_mime_lite(
            self, mock_mime_lite):
        mime.build_unweighted_mime_lite(
            model_fn=model_examples.LinearRegression,
            base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9))
        self.assertEqual(mock_mime_lite.call_count, 1)

    @mock.patch.object(mime, 'build_weighted_mime_lite')
    @mock.patch.object(factory_utils, 'as_weighted_aggregator')
    def test_aggregation_wrapper_called_by_unweighted(self, _,
                                                      mock_as_weighted):
        mime.build_unweighted_mime_lite(
            model_fn=model_examples.LinearRegression,
            base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9))
        self.assertEqual(mock_as_weighted.call_count, 1)

    def test_raises_on_non_callable_model_fn(self):
        with self.assertRaises(TypeError):
            mime.build_weighted_mime_lite(
                model_fn=model_examples.LinearRegression(),
                base_optimizer=sgdm.build_sgdm(learning_rate=0.01,
                                               momentum=0.9))

    def test_raises_on_invalid_client_weighting(self):
        with self.assertRaises(TypeError):
            mime.build_weighted_mime_lite(
                model_fn=model_examples.LinearRegression,
                base_optimizer=sgdm.build_sgdm(learning_rate=0.01,
                                               momentum=0.9),
                client_weighting='uniform')

    def test_raises_on_invalid_distributor(self):
        model_weights_type = type_conversions.type_from_tensors(
            model_utils.ModelWeights.from_model(
                model_examples.LinearRegression()))
        distributor = distributors.build_broadcast_process(model_weights_type)
        invalid_distributor = iterative_process.IterativeProcess(
            distributor.initialize, distributor.next)
        with self.assertRaises(TypeError):
            mime.build_weighted_mime_lite(
                model_fn=model_examples.LinearRegression,
                base_optimizer=sgdm.build_sgdm(learning_rate=0.01,
                                               momentum=0.9),
                model_distributor=invalid_distributor)

    def test_weighted_mime_lite_raises_on_unweighted_aggregator(self):
        aggregator = model_update_aggregator.robust_aggregator(weighted=False)
        with self.assertRaisesRegex(TypeError, 'WeightedAggregationFactory'):
            mime.build_weighted_mime_lite(
                model_fn=model_examples.LinearRegression,
                base_optimizer=sgdm.build_sgdm(learning_rate=0.01,
                                               momentum=0.9),
                model_aggregator=aggregator)
        with self.assertRaisesRegex(TypeError, 'WeightedAggregationFactory'):
            mime.build_weighted_mime_lite(
                model_fn=model_examples.LinearRegression,
                base_optimizer=sgdm.build_sgdm(learning_rate=0.01,
                                               momentum=0.9),
                full_gradient_aggregator=aggregator)

    def test_unweighted_mime_lite_raises_on_weighted_aggregator(self):
        aggregator = model_update_aggregator.robust_aggregator(weighted=True)
        with self.assertRaisesRegex(TypeError, 'UnweightedAggregationFactory'):
            mime.build_unweighted_mime_lite(
                model_fn=model_examples.LinearRegression,
                base_optimizer=sgdm.build_sgdm(learning_rate=0.01,
                                               momentum=0.9),
                model_aggregator=aggregator)
        with self.assertRaisesRegex(TypeError, 'UnweightedAggregationFactory'):
            mime.build_unweighted_mime_lite(
                model_fn=model_examples.LinearRegression,
                base_optimizer=sgdm.build_sgdm(learning_rate=0.01,
                                               momentum=0.9),
                full_gradient_aggregator=aggregator)

    def test_weighted_mime_lite_with_only_secure_aggregation(self):
        aggregator = model_update_aggregator.secure_aggregator(weighted=True)
        learning_process = mime.build_weighted_mime_lite(
            model_examples.LinearRegression,
            base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9),
            model_aggregator=aggregator,
            full_gradient_aggregator=aggregator,
            metrics_aggregator=metrics_aggregator.secure_sum_then_finalize)
        static_assert.assert_not_contains_unsecure_aggregation(
            learning_process.next)

    def test_unweighted_mime_lite_with_only_secure_aggregation(self):
        aggregator = model_update_aggregator.secure_aggregator(weighted=False)
        learning_process = mime.build_unweighted_mime_lite(
            model_examples.LinearRegression,
            base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9),
            model_aggregator=aggregator,
            full_gradient_aggregator=aggregator,
            metrics_aggregator=metrics_aggregator.secure_sum_then_finalize)
        static_assert.assert_not_contains_unsecure_aggregation(
            learning_process.next)

    @tensorflow_test_utils.skip_test_for_multi_gpu
    def test_equivalent_to_vanilla_fed_avg(self):
        # Mime Lite with no-momentum SGD should reduce to FedAvg.
        mime_process = mime.build_weighted_mime_lite(
            model_fn=_create_model, base_optimizer=sgdm.build_sgdm(0.1))
        fed_avg_process = fed_avg.build_weighted_fed_avg(
            model_fn=_create_model, client_optimizer_fn=sgdm.build_sgdm(0.1))

        client_data = [_create_dataset()]
        mime_state = mime_process.initialize()
        fed_avg_state = fed_avg_process.initialize()

        for _ in range(3):
            mime_output = mime_process.next(mime_state, client_data)
            mime_state = mime_output.state
            mime_metrics = mime_output.metrics
            fed_avg_output = fed_avg_process.next(fed_avg_state, client_data)
            fed_avg_state = fed_avg_output.state
            fed_avg_metrics = fed_avg_output.metrics
            self.assertAllClose(
                tf.nest.flatten(mime_process.get_model_weights(mime_state)),
                tf.nest.flatten(
                    fed_avg_process.get_model_weights(fed_avg_state)))
            self.assertAllClose(
                mime_metrics['client_work']['train']['loss'],
                fed_avg_metrics['client_work']['train']['loss'])
            self.assertAllClose(
                mime_metrics['client_work']['train']['num_examples'],
                fed_avg_metrics['client_work']['train']['num_examples'])

    @parameterized.named_parameters(
        ('sgdm_sgd', sgdm.build_sgdm(0.1, 0.9), sgdm.build_sgdm(1.0)),
        ('sgdm_sgdm', sgdm.build_sgdm(0.1, 0.9), sgdm.build_sgdm(1.0, 0.9)),
        ('sgdm_adam', sgdm.build_sgdm(0.1, 0.9), adam.build_adam(1.0)),
        ('adagrad_sgdm', adagrad.build_adagrad(0.1), sgdm.build_sgdm(1.0,
                                                                     0.9)),
    )
    @tensorflow_test_utils.skip_test_for_multi_gpu
    def test_execution_with_optimizers(self, base_optimizer, server_optimizer):
        learning_process = mime.build_weighted_mime_lite(
            model_fn=_create_model,
            base_optimizer=base_optimizer,
            server_optimizer=server_optimizer)

        client_data = [_create_dataset()]
        state = learning_process.initialize()

        for _ in range(3):
            output = learning_process.next(state, client_data)
            state = output.state
            metrics = output.metrics
            self.assertEqual(8,
                             metrics['client_work']['train']['num_examples'])
Example #10
0
class MimeLiteClientWorkExecutionTest(tf.test.TestCase,
                                      parameterized.TestCase):
    @parameterized.named_parameters(('non-simulation', False),
                                    ('simulation', True))
    @mock.patch.object(dataset_reduce,
                       '_dataset_reduce_fn',
                       wraps=dataset_reduce._dataset_reduce_fn)
    @tensorflow_test_utils.skip_test_for_multi_gpu
    def test_client_tf_dataset_reduce_fn(self, simulation, mock_method):
        process = mime._build_mime_lite_client_work(
            model_fn=_create_model,
            optimizer=sgdm.build_sgdm(learning_rate=0.1, momentum=0.9),
            client_weighting=client_weight_lib.ClientWeighting.NUM_EXAMPLES,
            use_experimental_simulation_loop=simulation)
        client_data = [_create_dataset()]
        client_model_weights = [_initial_weights()]
        process.next(process.initialize(), client_model_weights, client_data)
        if simulation:
            mock_method.assert_not_called()
        else:
            mock_method.assert_called()

    @parameterized.named_parameters(
        ('adagrad', adagrad.build_adagrad(0.1)),
        ('adam', adam.build_adam(0.1)),
        ('rmsprop', rmsprop.build_rmsprop(0.1)), ('sgd', sgdm.build_sgdm(0.1)),
        ('sgdm', sgdm.build_sgdm(0.1, momentum=0.9)),
        ('yogi', yogi.build_yogi(0.1)))
    @tensorflow_test_utils.skip_test_for_multi_gpu
    def test_execution_with_optimizer(self, optimizer):
        process = mime._build_mime_lite_client_work(
            _create_model,
            optimizer,
            client_weighting=client_weight_lib.ClientWeighting.NUM_EXAMPLES)
        client_data = [_create_dataset()]
        client_model_weights = [_initial_weights()]
        state = process.initialize()
        output = process.next(state, client_model_weights, client_data)
        self.assertEqual(8, output.measurements['train']['num_examples'])

    @tensorflow_test_utils.skip_test_for_multi_gpu
    def test_custom_metrics_aggregator(self):
        def sum_then_finalize_then_times_two(metric_finalizers,
                                             local_unfinalized_metrics_type):
            @federated_computation.federated_computation(
                computation_types.at_clients(local_unfinalized_metrics_type))
            def aggregation_computation(client_local_unfinalized_metrics):
                unfinalized_metrics_sum = intrinsics.federated_sum(
                    client_local_unfinalized_metrics)

                @tensorflow_computation.tf_computation(
                    local_unfinalized_metrics_type)
                def finalizer_computation(unfinalized_metrics):
                    finalized_metrics = collections.OrderedDict()
                    for metric_name, metric_finalizer in metric_finalizers.items(
                    ):
                        finalized_metrics[metric_name] = metric_finalizer(
                            unfinalized_metrics[metric_name]) * 2
                    return finalized_metrics

                return intrinsics.federated_map(finalizer_computation,
                                                unfinalized_metrics_sum)

            return aggregation_computation

        process = mime._build_mime_lite_client_work(
            model_fn=_create_model,
            optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9),
            client_weighting=client_weight_lib.ClientWeighting.NUM_EXAMPLES,
            metrics_aggregator=sum_then_finalize_then_times_two)
        client_model_weights = [_initial_weights()]
        client_data = [_create_dataset()]
        output = process.next(process.initialize(), client_model_weights,
                              client_data)
        # Train metrics should be multiplied by two by the custom aggregator.
        self.assertEqual(output.measurements['train']['num_examples'], 16)
Example #11
0
 def test_initialize_next_weights_mismatch_raises(self):
   optimizer = adagrad.build_adagrad(0.1)
   state = optimizer.initialize(_SCALAR_SPEC)
   with self.assertRaises(ValueError):
     optimizer.next(state, tf.zeros([2]), tf.zeros([2]))
Example #12
0
 def test_invalid_args_raises(self, lr, preconditioner, epsilon, regex):
   with self.assertRaisesRegex(ValueError, regex):
     adagrad.build_adagrad(lr, preconditioner, epsilon)
Example #13
0
 def test_build_adagrad(self):
   optimizer = adagrad.build_adagrad(0.01)
   self.assertIsInstance(optimizer, optimizer_base.Optimizer)