Ejemplo n.º 1
0
        def foo1(x):
            val = intrinsics.sequence_reduce(x, 0, add_numbers)
            self.assertIsInstance(val, value_base.Value)
            return val

        self.assert_type(foo1, '(int32* -> int32)')

        @computations.federated_computation(
            computation_types.FederatedType(
                computation_types.SequenceType(tf.int32), placements.SERVER))
        def foo2(x):
            val = intrinsics.sequence_reduce(x, 0, add_numbers)
            self.assertIsInstance(val, value_base.Value)
            return val

        self.assert_type(foo2, '(int32*@SERVER -> int32@SERVER)')

        @computations.federated_computation(
            computation_types.FederatedType(
                computation_types.SequenceType(tf.int32), placements.CLIENTS))
        def foo3(x):
            val = intrinsics.sequence_reduce(x, 0, add_numbers)
            self.assertIsInstance(val, value_base.Value)
            return val

        self.assert_type(foo3, '({int32*}@CLIENTS -> {int32}@CLIENTS)')


if __name__ == '__main__':
    common_test.main()
Ejemplo n.º 2
0
      server_state, metrics = iterative_process.next(server_state, federated_ds)
      self.assertLess(metrics.loss, prev_loss)
      prev_loss = metrics.loss

  def test_execute_empty_data(self):
    iterative_process = federated_averaging.build_federated_averaging_process(
        model_fn=model_examples.TrainableLinearRegression)

    # Results in empty dataset with correct types and shapes.
    ds = tf.data.Dataset.from_tensor_slices({
        'x': [[1., 2.]],
        'y': [[5.]]
    }).batch(
        5, drop_remainder=True)

    federated_ds = [ds] * 2

    server_state = iterative_process.initialize()

    first_state, metric_outputs = iterative_process.next(
        server_state, federated_ds)
    self.assertEqual(
        self.evaluate(tf.reduce_sum(first_state.model.trainable.a)) +
        self.evaluate(tf.reduce_sum(first_state.model.trainable.b)), 0)
    self.assertEqual(metric_outputs.num_examples, 0)
    self.assertTrue(tf.is_nan(metric_outputs.loss))


if __name__ == '__main__':
  test.main()