示例#1
0
    def test_client_tf(self, simulation):
        model = self.model()
        dataset = self.dataset()
        client_tf = federated_sgd.ClientSgd(
            model, use_experimental_simulation_loop=simulation)
        client_outputs = self.evaluate(
            client_tf(dataset, self.initial_weights()))

        # Both trainable parameters should have gradients, and we don't return the
        # non-trainable 'c'. Model deltas for squared error:
        self.assertAllClose(client_outputs.weights_delta,
                            [[[1.0], [0.0]], 1.0])
        self.assertAllClose(client_outputs.weights_delta_weight, 8.0)

        self.assertEqual(
            client_outputs.model_output, {
                'num_examples': 8,
                'num_examples_float': 8.0,
                'num_batches': 3,
                'loss': 0.5,
            })
        self.assertEqual(client_outputs.optimizer_output, {
            'client_weight': 8.0,
            'has_non_finite_delta': 0,
        })
示例#2
0
    def test_client_tf(self):
        model = self.model()
        dataset = self.dataset()
        client_tf = federated_sgd.ClientSgd(model)
        init_op = tf.group(model_utils.model_initializer(model),
                           tf.compat.v1.initializers.variables(
                               client_tf.variables),
                           name='fedsgd_initializer')
        client_outputs = client_tf(dataset, self.initial_weights())

        tf.compat.v1.get_default_graph().finalize()
        with self.session() as sess:
            sess.run(init_op)
            out = sess.run(client_outputs)
            # Both trainable parameters should have gradients,
            # and we don't return the non-trainable 'c'.
            self.assertCountEqual(['a', 'b'], list(out.weights_delta.keys()))
            # Model deltas for squared error.
            self.assertAllClose(out.weights_delta['a'], [[1.0], [0.0]])
            self.assertAllClose(out.weights_delta['b'], 1.0)
            self.assertAllClose(out.weights_delta_weight, 8.0)

            self.assertEqual(out.model_output['num_examples'], 8)
            self.assertEqual(out.model_output['num_batches'], 3)
            self.assertAlmostEqual(out.model_output['loss'], 0.5)

            self.assertEqual(out.optimizer_output['client_weight'], 8.0)
            self.assertEqual(out.optimizer_output['has_non_finite_delta'], 0)
  def test_client_tf(self, simulation, weighted):
    model = self.model()
    dataset = self.dataset()
    if weighted:
      client_weighting = client_weight_lib.ClientWeighting.NUM_EXAMPLES
    else:
      client_weighting = client_weight_lib.ClientWeighting.UNIFORM
    client_tf = federated_sgd.ClientSgd(
        model,
        client_weighting=client_weighting,
        use_experimental_simulation_loop=simulation)
    client_outputs = self.evaluate(client_tf(dataset, self.initial_weights()))

    # Both trainable parameters should have gradients, and we don't return the
    # non-trainable 'c'. Model deltas for squared error:
    self.assertAllClose(client_outputs.weights_delta, [[[1.0], [0.0]], 1.0])
    if weighted:
      self.assertAllClose(client_outputs.weights_delta_weight, 8.0)
    else:
      self.assertAllClose(client_outputs.weights_delta_weight, 1.0)

    self.assertDictContainsSubset(
        client_outputs.model_output, {
            'num_examples': 8,
            'num_examples_float': 8.0,
            'num_batches': 3,
            'loss': 0.5,
        })
    self.assertEqual(client_outputs.optimizer_output['has_non_finite_delta'], 0)
示例#4
0
 def test_client_tf_custom_batch_weight(self):
     model = self.model()
     dataset = self.dataset()
     client_tf = federated_sgd.ClientSgd(
         model, batch_weight_fn=lambda batch: 2.0 * tf.reduce_sum(batch.x))
     out = client_tf(dataset, self.initial_weights())
     self.assertEqual(out.weights_delta_weight.numpy(), 16.0)  # 2 * 8
示例#5
0
    def test_client_tf(self):
        model = self.model()
        dataset = self.dataset()
        client_tf = federated_sgd.ClientSgd(model)
        client_outputs = self.evaluate(
            client_tf(dataset, self.initial_weights()))

        # Both trainable parameters should have gradients,
        # and we don't return the non-trainable 'c'.
        self.assertCountEqual(['a', 'b'], client_outputs.weights_delta.keys())
        # Model deltas for squared error.
        self.assertAllClose(client_outputs.weights_delta['a'], [[1.0], [0.0]])
        self.assertAllClose(client_outputs.weights_delta['b'], 1.0)
        self.assertAllClose(client_outputs.weights_delta_weight, 8.0)

        self.assertEqual(
            client_outputs.model_output, {
                'num_examples': 8,
                'num_examples_float': 8.0,
                'num_batches': 3,
                'loss': 0.5,
            })
        self.assertEqual(client_outputs.optimizer_output, {
            'client_weight': 8.0,
            'has_non_finite_delta': 0,
        })
示例#6
0
 def test_client_tf_dataset_reduce_fn(self, simulation, mock_method):
     model = self.model()
     dataset = self.dataset()
     client_tf = federated_sgd.ClientSgd(
         model, use_experimental_simulation_loop=simulation)
     client_tf(dataset, self.initial_weights())
     if simulation:
         mock_method.assert_not_called()
     else:
         mock_method.assert_called()
示例#7
0
 def test_non_finite_aggregation(self, bad_value):
   model = self.model()
   dataset = self.dataset()
   client_tf = federated_sgd.ClientSgd(model)
   init_weights = self.initial_weights()
   init_weights.trainable[1] = bad_value
   client_outputs = client_tf(dataset, init_weights)
   self.assertEqual(self.evaluate(client_outputs.weights_delta_weight), 0.0)
   self.assertAllClose(
       self.evaluate(client_outputs.weights_delta), [[[0.0], [0.0]], 0.0])
 def test_non_finite_aggregation(self, bad_value):
   model = self.model()
   dataset = self.dataset()
   client_tf = federated_sgd.ClientSgd(model)
   init_weights = self.initial_weights()
   init_weights.trainable['b'] = bad_value
   out = client_tf(dataset, init_weights)
   self.assertEqual(out.weights_delta_weight.numpy(), 0.0)
   self.assertAllClose(out.weights_delta['a'].numpy(), np.array([[0.0],
                                                                 [0.0]]))
   self.assertAllClose(out.weights_delta['b'].numpy(), 0.0)
   self.assertEqual(out.optimizer_output['has_non_finite_delta'].numpy(), 1)
示例#9
0
    def test_client_tf(self):
        model = self.model()
        dataset = self.dataset()
        client_tf = federated_sgd.ClientSgd(model)
        out = client_tf(dataset, self.initial_weights())
        out = nest.map_structure(lambda t: t.numpy(), out)

        # Both trainable parameters should have gradients,
        # and we don't return the non-trainable 'c'.
        self.assertCountEqual(['a', 'b'], out.weights_delta.keys())
        # Model deltas for squared error.
        self.assertAllClose(out.weights_delta['a'], [[1.0], [0.0]])
        self.assertAllClose(out.weights_delta['b'], 1.0)
        self.assertAllClose(out.weights_delta_weight, 8.0)

        self.assertEqual(out.model_output['num_examples'], 8)
        self.assertEqual(out.model_output['num_batches'], 3)
        self.assertAlmostEqual(out.model_output['loss'], 0.5)

        self.assertEqual(out.optimizer_output['client_weight'], 8.0)
        self.assertEqual(out.optimizer_output['has_non_finite_delta'], 0)
 def test_clietsgd_fails_for_non_tff_model(self):
   keras_model = tf.keras.Sequential([tf.keras.layers.Dense(1)])
   with self.assertRaisesRegex(TypeError, 'Model'):
     federated_sgd.ClientSgd(keras_model)