Ejemplo n.º 1
0
  def test_client_tf(self):
    model = self.model()
    dataset = self.dataset()
    client_tf = federated_averaging.ClientFedAvg(model)
    init_op = tf.group(
        model_utils.model_initializer(model),
        tf.compat.v1.initializers.variables(client_tf.variables),
        name='fedavg_initializer')
    client_outputs = client_tf(dataset, self.initial_weights())

    tf.compat.v1.get_default_graph().finalize()
    with self.session() as sess:
      sess.run(init_op)
      out = sess.run(client_outputs)

      # Both trainable parameters should have been updated,
      # and we don't return the non-trainable 'c'.
      self.assertCountEqual(['a', 'b'], list(out.weights_delta.keys()))
      self.assertGreater(np.linalg.norm(out.weights_delta['a']), 0.1)
      self.assertGreater(np.linalg.norm(out.weights_delta['b']), 0.1)
      self.assertEqual(out.weights_delta_weight, 8.0)
      self.assertEqual(out.optimizer_output['num_examples'], 8)
      self.assertEqual(out.optimizer_output['has_non_finite_delta'], 0)

      self.assertEqual(out.model_output['num_examples'], 8)
      self.assertEqual(out.model_output['num_batches'], 3)
      self.assertBetween(out.model_output['loss'],
                         np.finfo(np.float32).eps, 10.0)
Ejemplo n.º 2
0
    def test_client_tf(self):
        model = self.model()
        dataset = self.dataset()
        client_tf = federated_sgd.ClientSgd(model)
        init_op = tf.group(model_utils.model_initializer(model),
                           tf.compat.v1.initializers.variables(
                               client_tf.variables),
                           name='fedsgd_initializer')
        client_outputs = client_tf(dataset, self.initial_weights())

        tf.compat.v1.get_default_graph().finalize()
        with self.session() as sess:
            sess.run(init_op)
            out = sess.run(client_outputs)
            # Both trainable parameters should have gradients,
            # and we don't return the non-trainable 'c'.
            self.assertCountEqual(['a', 'b'], list(out.weights_delta.keys()))
            # Model deltas for squared error.
            self.assertAllClose(out.weights_delta['a'], [[1.0], [0.0]])
            self.assertAllClose(out.weights_delta['b'], 1.0)
            self.assertAllClose(out.weights_delta_weight, 8.0)

            self.assertEqual(out.model_output['num_examples'], 8)
            self.assertEqual(out.model_output['num_batches'], 3)
            self.assertAlmostEqual(out.model_output['loss'], 0.5)

            self.assertEqual(out.optimizer_output['client_weight'], 8.0)
            self.assertEqual(out.optimizer_output['has_non_finite_delta'], 0)
Ejemplo n.º 3
0
 def test_model_initializer(self):
   with tf.Graph().as_default() as g:
     model = model_utils.enhance(model_examples.LinearRegression(2))
     init = model_utils.model_initializer(model)
     with self.session(graph=g) as sess:
       sess.run(init)
       # Make sure we can read all the variables
       try:
         sess.run(model.local_variables)
         sess.run(model.weights)
       except tf.errors.FailedPreconditionError:
         self.fail('Expected variables to be initialized, but got '
                   'tf.errors.FailedPreconditionError')