示例#1
0
    def test_iterative_process_with_encoding(self):
        model_fn = model_examples.LinearRegression
        gather_fn = encoding_utils.build_encoded_mean_from_model(
            model_fn, _test_encoder_fn('gather'))
        broadcast_fn = encoding_utils.build_encoded_broadcast_from_model(
            model_fn, _test_encoder_fn('simple'))
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_fn,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=
                                                                1.0),
            stateful_delta_aggregate_fn=gather_fn,
            stateful_model_broadcast_fn=broadcast_fn)

        ds = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict([
                ('x', [[1.0, 2.0], [3.0, 4.0]]),
                ('y', [[5.0], [6.0]]),
            ])).batch(2)
        federated_ds = [ds] * 3

        state = iterative_process.initialize()
        self.assertEqual(state.model_broadcast_state.trainable[0][0], 1)

        state, _ = iterative_process.next(state, federated_ds)
        self.assertEqual(state.model_broadcast_state.trainable[0][0], 2)
示例#2
0
  def test_broadcast_from_model(self):
    model_fn = model_examples.LinearRegression

    with warnings.catch_warnings(record=True) as w:
      warnings.simplefilter('always')
      broadcast_fn = encoding_utils.build_encoded_broadcast_from_model(
          model_fn, _test_encoder_fn('simple'))
      self.assertLen(w, 2)

    self.assertIsInstance(broadcast_fn, computation_utils.StatefulBroadcastFn)
示例#3
0
 def test_iterative_process_with_encoding(self):
   model_fn = model_examples.LinearRegression
   gather_fn = encoding_utils.build_encoded_mean_from_model(
       model_fn, _test_encoder_fn('gather'))
   broadcast_fn = encoding_utils.build_encoded_broadcast_from_model(
       model_fn, _test_encoder_fn('simple'))
   iterative_process = optimizer_utils.build_model_delta_optimizer_process(
       model_fn=model_fn,
       model_to_client_delta_fn=DummyClientDeltaFn,
       server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
       stateful_delta_aggregate_fn=gather_fn,
       stateful_model_broadcast_fn=broadcast_fn)
   self._verify_iterative_process(iterative_process)
示例#4
0
  def test_iterative_process_with_encoding(self):
    model_fn = model_examples.TrainableLinearRegression
    broadcast_fn = encoding_utils.build_encoded_broadcast_from_model(
        model_fn, _test_encoder_fn())
    iterative_process = optimizer_utils.build_model_delta_optimizer_process(
        model_fn=model_fn,
        model_to_client_delta_fn=DummyClientDeltaFn,
        server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
        stateful_model_broadcast_fn=broadcast_fn)

    ds = tf.data.Dataset.from_tensor_slices({
        'x': [[1., 2.], [3., 4.]],
        'y': [[5.], [6.]]
    }).batch(2)
    federated_ds = [ds] * 3

    state = iterative_process.initialize()
    self.assertEqual(state.model_broadcast_state.trainable.a[0], 1)

    state, _ = iterative_process.next(state, federated_ds)
    self.assertEqual(state.model_broadcast_state.trainable.a[0], 2)
示例#5
0
 def test_broadcast_from_model(self):
     model_fn = model_examples.LinearRegression
     broadcast_fn = encoding_utils.build_encoded_broadcast_from_model(
         model_fn, _test_encoder_fn('simple'))
     self.assertIsInstance(broadcast_fn, tff.utils.StatefulBroadcastFn)