def test_execution_stateful_optimizer(self):
        client_work_process = client_works.build_model_delta_client_work(
            model_examples.LinearRegression, sgdm.build_sgdm(0.1,
                                                             momentum=0.9))
        data = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict(
                x=[[1.0, 2.0], [3.0, 4.0]],
                y=[[5.0], [6.0]],
            )).batch(2)
        data = [data, data.repeat(2)]  # 1st client has 2 examples, 2nd has 4.
        model_weights = model_utils.ModelWeights(trainable=[[[0.0], [0.0]],
                                                            0.0],
                                                 non_trainable=[0.0])
        client_model_weights = [model_weights] * 2

        state = client_work_process.initialize()
        output = client_work_process.next(state, client_model_weights, data)

        expected_result = (
            client_works.ClientResult([[[-1.15], [-1.7]], -0.55], 2.0),
            client_works.ClientResult([[[-1.46], [-2.26]], -0.8], 4.0),
        )

        self.assertEqual((), output.state)
        for i in range(len(expected_result)):
            self.assertAllClose(expected_result[i].update,
                                output.result[i].update)
            self.assertAllClose(expected_result[i].update_weight,
                                output.result[i].update_weight)
        self.assertEqual((), output.measurements)
    def test_type_properties(self):
        mw_type = computation_types.to_type(
            model_utils.ModelWeights(trainable=(tf.float32, tf.float32),
                                     non_trainable=tf.float32))

        finalizer = finalizers.build_apply_optimizer_finalizer(
            sgdm.build_sgdm(1.0), mw_type)
        self.assertIsInstance(finalizer, finalizers.FinalizerProcess)

        expected_param_weights_type = computation_types.at_server(mw_type)
        expected_param_update_type = computation_types.at_server(
            mw_type.trainable)
        expected_result_type = computation_types.at_server(mw_type)
        expected_state_type = computation_types.at_server(())
        expected_measurements_type = computation_types.at_server(())

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        expected_initialize_type.check_equivalent_to(
            finalizer.initialize.type_signature)

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                weights=expected_param_weights_type,
                update=expected_param_update_type),
            result=MeasuredProcessOutput(expected_state_type,
                                         expected_result_type,
                                         expected_measurements_type))
        expected_next_type.check_equivalent_to(finalizer.next.type_signature)
Exemple #3
0
 def test_state_structure_momentum(self):
   optimizer = sgdm.build_sgdm(0.01, momentum=0.9)
   state = optimizer.initialize(_SCALAR_SPEC)
   self.assertLen(state, 3)
   self.assertIn(optimizer_base.LEARNING_RATE_KEY, state)
   self.assertIn(sgdm._MOMENTUM_KEY, state)
   self.assertIn(sgdm._ACCUMULATOR_KEY, state)
Exemple #4
0
    def test_delta_regularizer_yields_smaller_model_delta(self, optimizer):
        simple_process = model_delta_client_work.build_model_delta_client_work(
            self.create_model,
            optimizer,
            client_weighting=client_weight_lib.ClientWeighting.NUM_EXAMPLES,
            delta_l2_regularizer=0.0)
        proximal_process = model_delta_client_work.build_model_delta_client_work(
            self.create_model,
            sgdm.build_sgdm(1.0),
            client_weighting=client_weight_lib.ClientWeighting.NUM_EXAMPLES,
            delta_l2_regularizer=1.0)
        client_data = [create_test_dataset()]
        client_model_weights = [create_test_initial_weights()]

        simple_output = simple_process.next(simple_process.initialize(),
                                            client_model_weights, client_data)
        proximal_output = proximal_process.next(proximal_process.initialize(),
                                                client_model_weights,
                                                client_data)

        simple_update_norm = tf.linalg.global_norm(
            tf.nest.flatten(simple_output.result[0].update))
        proximal_update_norm = tf.linalg.global_norm(
            tf.nest.flatten(proximal_output.result[0].update))
        self.assertGreater(simple_update_norm, proximal_update_norm)

        self.assertEqual(simple_output.measurements['train']['num_examples'],
                         proximal_output.measurements['train']['num_examples'])
 def test_scheduling_scheduled_optimizer_raises(self):
   scheduled_optimizer = scheduling.schedule_learning_rate(
       sgdm.build_sgdm(1.0), _example_schedule_fn)
   twice_scheduled_optimizer = scheduling.schedule_learning_rate(
       scheduled_optimizer, _example_schedule_fn)
   with self.assertRaisesRegex(KeyError, 'must have learning rate'):
     twice_scheduled_optimizer.initialize(tf.TensorSpec((), tf.float32))
 def test_aggregation_wrapper_called_by_unweighted(self, _,
                                                   mock_as_weighted):
     fed_prox.build_unweighted_fed_prox(
         model_fn=model_examples.LinearRegression,
         proximal_strength=1.0,
         client_optimizer_fn=sgdm.build_sgdm(1.0))
     self.assertEqual(mock_as_weighted.call_count, 1)
 def test_build_weighted_fed_prox_called_by_unweighted_fed_prox(
         self, mock_fed_avg):
     fed_prox.build_unweighted_fed_prox(
         model_fn=model_examples.LinearRegression,
         proximal_strength=1.0,
         client_optimizer_fn=sgdm.build_sgdm(1.0))
     self.assertEqual(mock_fed_avg.call_count, 1)
Exemple #8
0
 def test_negative_proximal_strength_raises(self):
     with self.assertRaises(ValueError):
         model_delta_client_work.build_model_delta_client_work(
             model_examples.LinearRegression,
             sgdm.build_sgdm(1.0),
             client_weighting=client_weight_lib.ClientWeighting.
             NUM_EXAMPLES,
             delta_l2_regularizer=-1.0)
 def test_weighted_fed_avg_raises_on_unweighted_aggregator(self):
     model_aggregator = model_update_aggregator.robust_aggregator(
         weighted=False)
     with self.assertRaisesRegex(TypeError, 'WeightedAggregationFactory'):
         fed_prox.build_weighted_fed_prox(
             model_fn=model_examples.LinearRegression,
             proximal_strength=1.0,
             client_optimizer_fn=sgdm.build_sgdm(1.0),
             model_aggregator=model_aggregator)
    def test_execution(self):
        finalizer = finalizers.build_apply_optimizer_finalizer(
            sgdm.build_sgdm(1.0), MODEL_WEIGHTS_TYPE.member)

        weights = model_utils.ModelWeights(1.0, ())
        update = 0.1
        output = finalizer.next(finalizer.initialize(), weights, update)
        self.assertEqual((), output.state)
        self.assertAllClose(0.9, output.result.trainable)
        self.assertEqual((), output.measurements)
class ScheduledLROptimizerTest(parameterized.TestCase, tf.test.TestCase):

  def test_scheduled_sgd_computes_correctly(self):
    scheduled_sgd = scheduling.schedule_learning_rate(
        sgdm.build_sgdm(1.0), _example_schedule_fn)

    weight = tf.constant(1.0)
    gradient = tf.constant(1.0)
    state = scheduled_sgd.initialize(tf.TensorSpec((), tf.float32))
    state, weight = scheduled_sgd.next(state, weight, gradient)
    self.assertAllClose(0.9, weight)  # Learning rate initially 0.1.
    state, weight = scheduled_sgd.next(state, weight, gradient)
    self.assertAllClose(0.8, weight)
    state, weight = scheduled_sgd.next(state, weight, gradient)
    self.assertAllClose(0.79, weight)  # Learning rate has decreased to 0.01.
    state, weight = scheduled_sgd.next(state, weight, gradient)
    self.assertAllClose(0.78, weight)

  @parameterized.named_parameters(
      ('adagrad', adagrad.build_adagrad(1.0)),
      ('adam', adam.build_adam(1.0)),
      ('rmsprop', rmsprop.build_rmsprop(1.0)),
      ('sgd', sgdm.build_sgdm(1.0)),
      ('sgdm', sgdm.build_sgdm(1.0, momentum=0.9)),
      ('yogi', yogi.build_yogi(1.0)),
  )
  def test_schedule_learning_rate_integrates_with(self, optimizer):
    scheduled_optimizer = scheduling.schedule_learning_rate(
        optimizer, _example_schedule_fn)
    self.assertIsInstance(scheduled_optimizer, optimizer_base.Optimizer)

  def test_keras_optimizer_raises(self):
    keras_optimizer = tf.keras.optimizers.SGD(1.0)
    with self.assertRaises(TypeError):
      scheduling.schedule_learning_rate(keras_optimizer, _example_schedule_fn)

  def test_scheduling_scheduled_optimizer_raises(self):
    scheduled_optimizer = scheduling.schedule_learning_rate(
        sgdm.build_sgdm(1.0), _example_schedule_fn)
    twice_scheduled_optimizer = scheduling.schedule_learning_rate(
        scheduled_optimizer, _example_schedule_fn)
    with self.assertRaisesRegex(KeyError, 'must have learning rate'):
      twice_scheduled_optimizer.initialize(tf.TensorSpec((), tf.float32))
Exemple #12
0
  def test_executes_with(self, spec, momentum):
    weights = tf.nest.map_structure(lambda s: tf.ones(s.shape, s.dtype), spec)
    gradients = tf.nest.map_structure(lambda s: tf.ones(s.shape, s.dtype), spec)
    optimizer = sgdm.build_sgdm(0.01, momentum=momentum)

    state = optimizer.initialize(spec)
    for _ in range(10):
      state, weights = optimizer.next(state, weights, gradients)

    tf.nest.map_structure(lambda w: self.assertTrue(all(tf.math.is_finite(w))),
                          weights)
Exemple #13
0
 def test_raises_on_invalid_distributor(self):
     model_weights_type = type_conversions.type_from_tensors(
         model_utils.ModelWeights.from_model(
             model_examples.LinearRegression()))
     distributor = distributors.build_broadcast_process(model_weights_type)
     invalid_distributor = iterative_process.IterativeProcess(
         distributor.initialize, distributor.next)
     with self.assertRaises(TypeError):
         fed_avg.build_weighted_fed_avg(
             model_fn=model_examples.LinearRegression,
             client_optimizer_fn=sgdm.build_sgdm(1.0),
             model_distributor=invalid_distributor)
Exemple #14
0
  def test_math_momentum_0_5(self):
    weights = tf.constant([1.0], tf.float32)
    gradients = tf.constant([2.0], tf.float32)
    optimizer = sgdm.build_sgdm(0.01, momentum=0.5)
    history = [weights]

    state = optimizer.initialize(_SCALAR_SPEC)

    for _ in range(4):
      state, weights = optimizer.next(state, weights, gradients)
      history.append(weights)
    self.assertAllClose([[1.0], [0.98], [0.95], [0.915], [0.8775]], history)
Exemple #15
0
  def test_convergence(self, momentum):
    init_w, fn, grad_fn = optimizer_test_utils.test_quadratic_problem()
    weights = init_w()
    self.assertGreater(fn(weights), 5.0)

    optimizer = sgdm.build_sgdm(0.1, momentum=momentum)
    state = optimizer.initialize(tf.TensorSpec(weights.shape, weights.dtype))

    for _ in range(100):
      gradients = grad_fn(weights)
      state, weights = optimizer.next(state, weights, gradients)
    self.assertLess(fn(weights), 0.005)
Exemple #16
0
class VanillaFedAvgTest(test_case.TestCase, parameterized.TestCase):
    def _test_data(self):
        return tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict(
                x=[[1.0, 2.0], [3.0, 4.0]],
                y=[[5.0], [6.0]],
            )).batch(2)

    def _test_batch_loss(self, model, weights):
        tf.nest.map_structure(lambda w, v: w.assign(v),
                              model_utils.ModelWeights.from_model(model),
                              weights)
        for batch in self._test_data().take(1):
            batch_output = model.forward_pass(batch, training=False)
        return batch_output.loss

    def test_loss_decreases(self):
        model_fn = model_examples.LinearRegression
        test_model = model_fn()
        fedavg = composers.build_basic_fedavg_process(model_fn=model_fn,
                                                      client_learning_rate=0.1)
        client_data = [self._test_data()] * 3  # 3 clients with identical data.

        state = fedavg.initialize()
        last_loss = self._test_batch_loss(test_model,
                                          state.global_model_weights)
        for _ in range(5):
            fedavg_result = fedavg.next(state, client_data)
            state = fedavg_result.state
            metrics = fedavg_result.metrics
            loss = self._test_batch_loss(test_model,
                                         state.global_model_weights)
            self.assertLess(loss, last_loss)
            last_loss = loss

        self.assertIsInstance(state, composers.LearningAlgorithmState)
        self.assertLen(metrics, 4)
        for key in ['distributor', 'client_work', 'aggregator', 'finalizer']:
            self.assertIn(key, metrics)

    def test_created_model_raises(self):
        with self.assertRaises(TypeError):
            composers.build_basic_fedavg_process(
                model_examples.LinearRegression(), 0.1)

    @parameterized.named_parameters(('int', 1),
                                    ('optimizer', sgdm.build_sgdm(0.1)))
    def test_wrong_client_learning_rate_raises(self, bad_client_lr):
        with self.assertRaises(TypeError):
            composers.build_basic_fedavg_process(
                model_examples.LinearRegression(), bad_client_lr)