示例#1
0
def test_context(rpc_mode='REQUEST_REPLY'):
  port = portpicker.pick_unused_port()
  server_pool = logging_pool.pool(max_workers=1)
  server = grpc.server(server_pool)
  server.add_insecure_port('[::]:{}'.format(port))
  target_executor = executor_stacks.local_executor_factory(
      num_clients=3).create_executor({})
  tracer = executor_test_utils.TracingExecutor(target_executor)
  service = executor_service.ExecutorService(tracer)
  executor_pb2_grpc.add_ExecutorServicer_to_server(service, server)
  server.start()
  channel = grpc.insecure_channel('localhost:{}'.format(port))
  remote_exec = remote_executor.RemoteExecutor(channel, rpc_mode)
  executor = lambda_executor.LambdaExecutor(remote_exec)
  set_default_executor.set_default_executor(
      executor_factory.ExecutorFactoryImpl(lambda _: executor))
  try:
    yield collections.namedtuple('_', 'executor tracer')(executor, tracer)
  finally:
    set_default_executor.set_default_executor()
    try:
      channel.close()
    except AttributeError:
      pass  # Public gRPC channel doesn't support close()
    finally:
      server.stop(None)
示例#2
0
 def _get_losses_before_and_after_training_single_batch(ex):
     set_default_executor.set_default_executor(ex)
     model = _mnist_initial_model
     losses = [_mnist_batch_loss(model, _mnist_sample_batch)]
     for _ in range(20):
         model = _mnist_batch_train(model, _mnist_sample_batch)
         losses.append(_mnist_batch_loss(model, _mnist_sample_batch))
     return losses
 def setUp(self):
     super(CompositeExecutorTest, self).setUp()
     set_default_executor.set_default_executor(
         _create_middle_stack([
             _create_middle_stack(
                 [_create_worker_stack() for _ in range(3)]),
             _create_middle_stack(
                 [_create_worker_stack() for _ in range(3)])
         ]))
 def setUp(self):
   super().setUp()
   # 2 clients per worker stack * 3 worker stacks * 2 middle stacks
   self._num_clients = 12
   set_default_executor.set_default_executor(
       _create_middle_stack([
           _create_middle_stack([_create_worker_stack() for _ in range(3)]),
           _create_middle_stack([_create_worker_stack() for _ in range(3)])
       ]))
示例#5
0
    def test_with_num_clients_larger_than_fanout(self):
        set_default_executor.set_default_executor(
            executor_stacks.create_local_executor(max_fanout=3))

        @computations.federated_computation(type_factory.at_clients(tf.int32))
        def foo(x):
            return intrinsics.federated_sum(x)

        self.assertEqual(foo([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 55)
示例#6
0
    def test_with_no_args(self):
        set_default_executor.set_default_executor(
            executor_stacks.create_local_executor())

        @computations.tf_computation
        def foo():
            return tf.constant(10)

        self.assertEqual(foo(), 10)

        set_default_executor.set_default_executor()
  def test_basic_functionality(self):

    @computations.tf_computation(computation_types.SequenceType(tf.int32))
    def comp(ds):
      return ds.take(5).reduce(np.int32(0), lambda x, y: x + y)

    set_default_executor.set_default_executor(eager_executor.EagerExecutor())

    ds = tf.data.Dataset.range(1).map(lambda x: tf.constant(5)).repeat()
    v = comp(ds)
    self.assertEqual(v, 25)

    set_default_executor.set_default_executor()
    self.assertIn('ExecutionContext',
                  str(type(context_stack_impl.context_stack.current).__name__))
    def setUp(self):
        super().setUp()
        # 2 clients per worker stack * 3 worker stacks * 2 middle stacks
        self._num_clients = 12

        def _stack_fn(x):
            del x  # Unused
            return _create_middle_stack([
                _create_middle_stack(
                    [_create_worker_stack() for _ in range(3)]),
                _create_middle_stack(
                    [_create_worker_stack() for _ in range(3)])
            ])

        set_default_executor.set_default_executor(
            executor_factory.ExecutorFactoryImpl(_stack_fn))
示例#9
0
    def test_end_to_end(self):
        @computations.tf_computation(tf.int32)
        def add_one(x):
            return tf.add(x, 1)

        ex = concurrent_executor.ConcurrentExecutor(
            eager_executor.EagerExecutor())

        set_default_executor.set_default_executor(ex)

        self.assertEqual(add_one(7), 8)

        # After this invocation, the ConcurrentExecutor has been closed, and needs
        # to be re-initialized.

        self.assertEqual(add_one(8), 9)

        set_default_executor.set_default_executor()
示例#10
0
    def test_executor_factory_raises_with_wrong_cardinalities(self):
        ex_factory = executor_stacks.local_executor_factory(num_clients=5)
        cardinalities = {
            placement_literals.SERVER: 1,
            None: 1,
            placement_literals.CLIENTS: 1
        }
        with self.assertRaisesRegex(ValueError,
                                    'construct an executor with 1 clients'):
            ex_factory.create_executor(cardinalities)

        @computations.tf_computation
        def foo():
            return tf.constant(10)

        self.assertEqual(foo(), 10)

        set_default_executor.set_default_executor()
示例#11
0
  def test_with_temperature_sensor_example(self):

    @computations.tf_computation(
        computation_types.SequenceType(tf.float32), tf.float32)
    def count_over(ds, t):
      return ds.reduce(
          np.float32(0), lambda n, x: n + tf.cast(tf.greater(x, t), tf.float32))

    @computations.tf_computation(computation_types.SequenceType(tf.float32))
    def count_total(ds):
      return ds.reduce(np.float32(0.0), lambda n, _: n + 1.0)

    @computations.federated_computation(
        type_factory.at_clients(computation_types.SequenceType(tf.float32)),
        type_factory.at_server(tf.float32))
    def comp(temperatures, threshold):
      return intrinsics.federated_mean(
          intrinsics.federated_map(
              count_over,
              intrinsics.federated_zip(
                  [temperatures,
                   intrinsics.federated_broadcast(threshold)])),
          intrinsics.federated_map(count_total, temperatures))

    set_default_executor.set_default_executor(
        executor_stacks.create_local_executor(3))
    to_float = lambda x: tf.cast(x, tf.float32)
    temperatures = [
        tf.data.Dataset.range(10).map(to_float),
        tf.data.Dataset.range(20).map(to_float),
        tf.data.Dataset.range(30).map(to_float)
    ]
    threshold = 15.0
    result = comp(temperatures, threshold)
    self.assertAlmostEqual(result, 8.333, places=3)

    set_default_executor.set_default_executor(
        executor_stacks.create_local_executor())
    to_float = lambda x: tf.cast(x, tf.float32)
    temperatures = [
        tf.data.Dataset.range(10).map(to_float),
        tf.data.Dataset.range(20).map(to_float),
        tf.data.Dataset.range(30).map(to_float)
    ]
    threshold = 15.0
    result = comp(temperatures, threshold)
    self.assertAlmostEqual(result, 8.333, places=3)
    set_default_executor.set_default_executor()
示例#12
0
 def tearDown(self):
     set_default_executor.set_default_executor(None)
     super().tearDown()
 def tearDown(self):
     set_default_executor.set_default_executor(None)
     super(CompositeExecutorTest, self).tearDown()
示例#14
0
 def wrapped_fn(self, executor):
     set_default_executor.set_default_executor(executor)
     fn(self)
示例#15
0
        state, mean = federated_aggregate_test([1.0, 2.0, 3.0])
        self.assertAlmostEqual(mean, 2.0)  # (1 + 2 + 3) / (1 + 1 + 1)
        self.assertDictEqual(state._asdict(), {'call_count': 1})

    def test_execute_with_explicit_weights(self):
        aggregate_fn = computation_utils.StatefulAggregateFn(
            initialize_fn=agg_initialize_fn, next_fn=agg_next_fn)

        @computations.federated_computation(
            computation_types.FederatedType(tf.float32, placements.CLIENTS),
            computation_types.FederatedType(tf.float32, placements.CLIENTS))
        def federated_aggregate_test(args, weights):
            state = intrinsics.federated_value(aggregate_fn.initialize(),
                                               placements.SERVER)
            return aggregate_fn(state, args, weights)

        state, mean = federated_aggregate_test([1.0, 2.0, 3.0],
                                               [4.0, 1.0, 1.0])
        self.assertAlmostEqual(mean, 1.5)  # (1*4 + 2*1 + 3*1) / (4 + 1 + 1)
        self.assertDictEqual(state._asdict(), {'call_count': 1})


if __name__ == '__main__':
    tf.compat.v1.enable_v2_behavior()
    # NOTE: num_clients must be explicit here to correctly test the broadcast
    # behavior. Otherwise TFF will infer there are zero clients, which is an
    # error.
    set_default_executor.set_default_executor(
        executor_stacks.local_executor_factory(num_clients=3))
    test.main()
示例#16
0
def test_runs_tf(test_obj, executor):
    """Tests `executor` can run a minimal TF computation."""
    py_typecheck.check_type(executor, executor_base.Executor)
    set_default_executor.set_default_executor(
        executor_factory.ExecutorFactoryImpl(lambda _: executor))
    test_obj.assertEqual(_dummy_tf_computation(), 10)
示例#17
0
def test_runs_tf(test_obj, executor):
  """Tests `executor` can run a minimal TF computation."""
  set_default_executor.set_default_executor(executor)
  test_obj.assertEqual(_dummy_tf_computation(), 10)
示例#18
0
            _ = next(iter(x[1]))

    def test_fetch_value_with_empty_structured_dataset_and_tensors(self):
        def return_dataset():
            ds1 = tf.data.Dataset.from_tensor_slices(
                collections.OrderedDict([('a', [1, 1]), ('b', [1, 1])]))
            return [tf.constant([0., 0.]), ds1.batch(5).take(0)]

        executable_return_dataset = computation_impl.ComputationImpl(
            tensorflow_serialization.serialize_py_fn_as_tf_computation(
                return_dataset, None, context_stack_impl.context_stack)[0],
            context_stack_impl.context_stack)

        x = executable_return_dataset()
        self.assertAllEqual(x[0], [0., 0.])
        self.assertEqual(
            tf.data.experimental.get_structure(x[1]),
            collections.OrderedDict([
                ('a', tf.TensorSpec(shape=(None, ), dtype=tf.int32)),
                ('b', tf.TensorSpec(shape=(None, ), dtype=tf.int32)),
            ]))
        with self.assertRaises(StopIteration):
            _ = next(iter(x[1]))


if __name__ == '__main__':
    # Use the local executor.
    set_default_executor.set_default_executor(
        executor_stacks.create_local_executor())
    test.main()
        state, mean = federated_aggregate_test([1.0, 2.0, 3.0])
        self.assertAlmostEqual(mean, 2.0)  # (1 + 2 + 3) / (1 + 1 + 1)
        self.assertDictEqual(state._asdict(), {'call_count': 1})

    def test_execute_with_explicit_weights(self):
        aggregate_fn = computation_utils.StatefulAggregateFn(
            initialize_fn=agg_initialize_fn, next_fn=agg_next_fn)

        @computations.federated_computation(
            computation_types.FederatedType(tf.float32, placements.CLIENTS),
            computation_types.FederatedType(tf.float32, placements.CLIENTS))
        def federated_aggregate_test(args, weights):
            state = intrinsics.federated_value(aggregate_fn.initialize(),
                                               placements.SERVER)
            return aggregate_fn(state, args, weights)

        state, mean = federated_aggregate_test([1.0, 2.0, 3.0],
                                               [4.0, 1.0, 1.0])
        self.assertAlmostEqual(mean, 1.5)  # (1*4 + 2*1 + 3*1) / (4 + 1 + 1)
        self.assertDictEqual(state._asdict(), {'call_count': 1})


if __name__ == '__main__':
    tf.compat.v1.enable_v2_behavior()
    # NOTE: num_clients must be explicit here to correctly test the broadcast
    # behavior. Otherwise TFF will infer there are zero clients, which is an
    # error.
    set_default_executor.set_default_executor(
        executor_stacks.create_local_executor(num_clients=3))
    absltest.main()
示例#20
0
 def test_with_no_args(self):
     set_default_executor.set_default_executor(
         executor_stacks.local_executor_factory())