示例#1
0
    def test_iterative_process_with_encoding(self):
        model_fn = model_examples.LinearRegression
        gather_fn = encoding_utils.build_encoded_mean_from_model(
            model_fn, _test_encoder_fn('gather'))
        broadcast_fn = encoding_utils.build_encoded_broadcast_from_model(
            model_fn, _test_encoder_fn('simple'))
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_fn,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=
                                                                1.0),
            stateful_delta_aggregate_fn=gather_fn,
            stateful_model_broadcast_fn=broadcast_fn)

        ds = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict([
                ('x', [[1.0, 2.0], [3.0, 4.0]]),
                ('y', [[5.0], [6.0]]),
            ])).batch(2)
        federated_ds = [ds] * 3

        state = iterative_process.initialize()
        self.assertEqual(state.model_broadcast_state.trainable[0][0], 1)

        state, _ = iterative_process.next(state, federated_ds)
        self.assertEqual(state.model_broadcast_state.trainable[0][0], 2)
示例#2
0
    def test_mean_from_model_raise_warning(self):
        model_fn = model_examples.LinearRegression

        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter('always')
            gather_fn = encoding_utils.build_encoded_mean_from_model(
                model_fn, _test_encoder_fn('gather'))
            self.assertLen(w, 2)

        self.assertIsInstance(gather_fn, computation_utils.StatefulAggregateFn)
示例#3
0
 def test_iterative_process_with_encoding(self):
   model_fn = model_examples.LinearRegression
   gather_fn = encoding_utils.build_encoded_mean_from_model(
       model_fn, _test_encoder_fn('gather'))
   broadcast_fn = encoding_utils.build_encoded_broadcast_from_model(
       model_fn, _test_encoder_fn('simple'))
   iterative_process = optimizer_utils.build_model_delta_optimizer_process(
       model_fn=model_fn,
       model_to_client_delta_fn=DummyClientDeltaFn,
       server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
       stateful_delta_aggregate_fn=gather_fn,
       stateful_model_broadcast_fn=broadcast_fn)
   self._verify_iterative_process(iterative_process)
示例#4
0
 def test_mean_from_model(self):
     model_fn = model_examples.LinearRegression
     gather_fn = encoding_utils.build_encoded_mean_from_model(
         model_fn, _test_encoder_fn('gather'))
     self.assertIsInstance(gather_fn, tff.utils.StatefulAggregateFn)