예제 #1
0
class WorkerFailureTest(parameterized.TestCase):
    @parameterized.named_parameters(
        ('native_remote_request_reply',
         remote_runtime_test_utils.create_localhost_remote_context(
             _WORKER_PORTS),
         remote_runtime_test_utils.create_localhost_worker_contexts(
             _WORKER_PORTS),
         remote_runtime_test_utils.create_localhost_worker_contexts(
             _WORKER_PORTS)),
        ('native_remote_streaming',
         remote_runtime_test_utils.create_localhost_remote_context(
             _WORKER_PORTS, rpc_mode='STREAMING'),
         remote_runtime_test_utils.create_localhost_worker_contexts(
             _WORKER_PORTS),
         remote_runtime_test_utils.create_localhost_worker_contexts(
             _WORKER_PORTS)),
        ('native_remote_intermediate_aggregator',
         remote_runtime_test_utils.create_localhost_remote_context(
             _AGGREGATOR_PORTS),
         remote_runtime_test_utils.create_localhost_aggregator_contexts(
             _WORKER_PORTS, _AGGREGATOR_PORTS),
         remote_runtime_test_utils.create_localhost_aggregator_contexts(
             _WORKER_PORTS, _AGGREGATOR_PORTS)),
    )
    def test_computations_run_with_worker_restarts(self, context,
                                                   first_contexts,
                                                   second_contexts):
        @tff.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        @tff.federated_computation(tff.type_at_clients(tf.int32))
        def map_add_one(federated_arg):
            return tff.federated_map(add_one, federated_arg)

        context_stack = tff.framework.get_context_stack()
        with context_stack.install(context):

            with contextlib.ExitStack() as stack:
                for server_context in first_contexts:
                    stack.enter_context(server_context)
                result = map_add_one([0, 1])
                self.assertEqual(result, [1, 2])

            # Closing and re-entering the server contexts serves to simulate failures
            # and restarts at the workers. Restarts leave the workers in a state that
            # needs initialization again; entering the second context ensures that the
            # servers need to be reinitialized by the controller.
            with contextlib.ExitStack() as stack:
                for server_context in second_contexts:
                    stack.enter_context(server_context)
                result = map_add_one([0, 1])
                self.assertEqual(result, [1, 2])
예제 #2
0
def _get_all_contexts():
  # pyformat: disable
  return [
      ('native_local', tff.backends.native.create_local_execution_context()),
      ('native_remote',
       remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS),
       remote_runtime_test_utils.create_localhost_worker_contexts(_WORKER_PORTS)),
      ('native_remote_intermediate_aggregator',
       remote_runtime_test_utils.create_localhost_remote_context(_AGGREGATOR_PORTS),
       remote_runtime_test_utils.create_localhost_aggregator_contexts(_WORKER_PORTS, _AGGREGATOR_PORTS)),
      ('native_sizing', tff.backends.native.create_sizing_execution_context()),
      ('native_thread_debug',
       tff.backends.native.create_thread_debugging_execution_context()),
      ('reference', tff.backends.reference.create_reference_context()),
  ]
예제 #3
0
  def test_computations_run_with_worker_restarts(self):

    context = remote_runtime_test_utils.create_localhost_remote_context(
        _WORKER_PORTS)
    first_contexts = remote_runtime_test_utils.create_inprocess_worker_contexts(
        _WORKER_PORTS)
    second_contexts = remote_runtime_test_utils.create_inprocess_worker_contexts(
        _WORKER_PORTS)

    @tff.tf_computation(tf.int32)
    def add_one(x):
      return x + 1

    @tff.federated_computation(tff.type_at_clients(tf.int32))
    def map_add_one(federated_arg):
      return tff.federated_map(add_one, federated_arg)

    context_stack = tff.framework.get_context_stack()
    with context_stack.install(context):

      with contextlib.ExitStack() as stack:
        for server_context in first_contexts:
          stack.enter_context(server_context)
        result = map_add_one([0, 1])
        self.assertEqual(result, [1, 2])

      # Closing and re-entering the server contexts serves to simulate failures
      # and restarts at the workers. Restarts leave the workers in a state that
      # needs initialization again; entering the second context ensures that the
      # servers need to be reinitialized by the controller.
      with contextlib.ExitStack() as stack:
        for server_context in second_contexts:
          stack.enter_context(server_context)
        result = map_add_one([0, 1])
        self.assertEqual(result, [1, 2])
예제 #4
0
def _get_all_contexts():
  # pyformat: disable
  return [
      ('native_local', tff.backends.native.create_local_execution_context()),
      ('native_local_caching', _create_native_local_caching_context()),
      ('native_remote_request_reply',
       remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS, rpc_mode='REQUEST_REPLY'),
       remote_runtime_test_utils.create_localhost_worker_contexts(_WORKER_PORTS)),
      ('native_remote_streaming',
       remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS, rpc_mode='STREAMING'),
       remote_runtime_test_utils.create_localhost_worker_contexts(_WORKER_PORTS)),
      ('native_remote_intermediate_aggregator',
       remote_runtime_test_utils.create_localhost_remote_context(_AGGREGATOR_PORTS),
       remote_runtime_test_utils.create_localhost_aggregator_contexts(_WORKER_PORTS, _AGGREGATOR_PORTS)),
      ('native_sizing', tff.backends.native.create_sizing_execution_context()),
      ('native_thread_debug',
       tff.backends.native.create_thread_debugging_execution_context()),
      ('reference', tff.backends.reference.create_reference_context()),
      ('test', tff.backends.test.create_test_execution_context()),
  ]
예제 #5
0
  def test_worker_going_down_with_fixed_clients_per_round(self):
    tff_context = remote_runtime_test_utils.create_localhost_remote_context(
        _WORKER_PORTS, default_num_clients=10)
    worker_contexts = remote_runtime_test_utils.create_inprocess_worker_contexts(
        _WORKER_PORTS)

    @tff.federated_computation(tff.type_at_server(tf.int32))
    def sum_arg(x):
      return tff.federated_sum(tff.federated_broadcast(x))

    context_stack = tff.framework.get_context_stack()
    with context_stack.install(tff_context):

      with worker_contexts[0]:
        with worker_contexts[1]:
          # With both workers live, we should get 10 back.
          self.assertEqual(sum_arg(1), 10)
        # Leaving the inner context kills the second worker, but should leave
        # the result untouched.
        self.assertEqual(sum_arg(1), 10)
예제 #6
0
  def test_computations_run_with_worker_restarts_and_aggregation(self):

    context = remote_runtime_test_utils.create_localhost_remote_context(
        _AGGREGATOR_PORTS)
    # TODO(b/180524229): Swap for inprocess aggregator when mutex
    # corruption on shutdown is understood.
    aggregation_contexts = remote_runtime_test_utils.create_standalone_subprocess_aggregator_contexts(
        _WORKER_PORTS, _AGGREGATOR_PORTS)
    first_worker_contexts = remote_runtime_test_utils.create_inprocess_worker_contexts(
        _WORKER_PORTS)
    second_worker_contexts = remote_runtime_test_utils.create_inprocess_worker_contexts(
        _WORKER_PORTS)

    @tff.tf_computation(tf.int32)
    def add_one(x):
      return x + 1

    @tff.federated_computation(tff.type_at_clients(tf.int32))
    def map_add_one(federated_arg):
      return tff.federated_map(add_one, federated_arg)

    context_stack = tff.framework.get_context_stack()
    with context_stack.install(context):

      with contextlib.ExitStack() as aggregation_stack:
        for server_context in aggregation_contexts:
          aggregation_stack.enter_context(server_context)
        with contextlib.ExitStack() as first_worker_stack:
          for server_context in first_worker_contexts:
            first_worker_stack.enter_context(server_context)

          result = map_add_one([0, 1])
          self.assertEqual(result, [1, 2])

        # Reinitializing the workers without leaving the aggregation context
        # simulates a worker failure, while the aggregator keeps running.
        with contextlib.ExitStack() as second_worker_stack:
          for server_context in second_worker_contexts:
            second_worker_stack.enter_context(server_context)
          result = map_add_one([0, 1])
          self.assertEqual(result, [1, 2])
예제 #7
0
    def test_runs_computation_streaming_with_intermediate_agg(self):
        @tff.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        @tff.federated_computation(tff.type_at_clients(tf.int32))
        def map_add_one_and_sum(federated_arg):
            return tff.federated_sum(tff.federated_map(add_one, federated_arg))

        execution_context = remote_runtime_test_utils.create_localhost_remote_context(
            _AGGREGATOR_PORTS, rpc_mode='STREAMING')
        worker_contexts = remote_runtime_test_utils.create_localhost_aggregator_contexts(
            _WORKER_PORTS, _AGGREGATOR_PORTS, rpc_mode='STREAMING')

        context_stack = tff.framework.get_context_stack()
        with context_stack.install(execution_context):

            with contextlib.ExitStack() as stack:
                for server_context in worker_contexts:
                    stack.enter_context(server_context)
                result = map_add_one_and_sum([0, 1])
                self.assertEqual(result, 3)
예제 #8
0
  def test_computations_run_with_partially_available_workers(self):

    tff_context = remote_runtime_test_utils.create_localhost_remote_context(
        _WORKER_PORTS)
    server_contexts = remote_runtime_test_utils.create_inprocess_worker_contexts(
        [_WORKER_PORTS[0]])

    @tff.tf_computation(tf.int32)
    def add_one(x):
      return x + 1

    @tff.federated_computation(tff.type_at_clients(tf.int32))
    def map_add_one(federated_arg):
      return tff.federated_map(add_one, federated_arg)

    context_stack = tff.framework.get_context_stack()
    with context_stack.install(tff_context):

      with contextlib.ExitStack() as stack:
        for server_context in server_contexts:
          stack.enter_context(server_context)
        result = map_add_one([0, 1])
        self.assertEqual(result, [1, 2])
예제 #9
0
def _get_all_contexts():
    """Returns a list containing a (name, context_fn) tuple for each context."""
    # pylint: disable=unnecessary-lambda
    # pyformat: disable
    return [
        ('native_local_python',
         lambda: tff.backends.native.create_local_python_execution_context()),
        ('native_mergeable', lambda: _create_local_mergeable_comp_context()),
        ('native_remote', lambda: remote_runtime_test_utils.
         create_localhost_remote_context(WORKER_PORTS),
         lambda: remote_runtime_test_utils.create_inprocess_worker_contexts(
             WORKER_PORTS)),
        ('native_remote_intermediate_aggregator',
         lambda: remote_runtime_test_utils.create_localhost_remote_context(
             AGGREGATOR_PORTS), lambda: remote_runtime_test_utils.
         create_inprocess_aggregator_contexts(WORKER_PORTS, AGGREGATOR_PORTS)),
        ('native_sizing',
         lambda: tff.backends.native.create_sizing_execution_context()),
        ('native_thread_debug', lambda: tff.backends.native.
         create_thread_debugging_execution_context()),
        ('test_python',
         lambda: tff.backends.test.create_test_python_execution_context()),
    ]
예제 #10
0
class TensorFlowComputationTest(parameterized.TestCase):

  @with_contexts
  def test_returns_constant(self):

    @tff.tf_computation
    def foo():
      return 10

    result = foo()
    self.assertEqual(result, 10)

  @with_contexts
  def test_returns_empty_tuple(self):

    @tff.tf_computation
    def foo():
      return ()

    result = foo()
    self.assertEqual(result, ())

  @with_contexts
  def test_returns_variable(self):

    @tff.tf_computation
    def foo():
      return tf.Variable(10, name='var')

    result = foo()
    self.assertEqual(result, 10)

  # pyformat: disable
  @with_contexts(
      ('native_local', tff.backends.native.create_local_execution_context()),
      ('native_local_caching', _create_native_local_caching_context()),
      ('native_remote',
       remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS),
       remote_runtime_test_utils.create_localhost_worker_contexts(_WORKER_PORTS)),
      ('native_remote_intermediate_aggregator',
       remote_runtime_test_utils.create_localhost_remote_context(_AGGREGATOR_PORTS),
       remote_runtime_test_utils.create_localhost_aggregator_contexts(_WORKER_PORTS, _AGGREGATOR_PORTS)),
      ('native_sizing', tff.backends.native.create_sizing_execution_context()),
      ('native_thread_debug',
       tff.backends.native.create_thread_debugging_execution_context()),
  )
  # pyformat: enable
  def test_takes_infinite_dataset(self):

    @tff.tf_computation
    def foo(ds):
      return ds.take(10).reduce(np.int64(0), lambda x, y: x + y)

    ds = tf.data.Dataset.range(10).repeat()
    actual_result = foo(ds)

    expected_result = ds.take(10).reduce(np.int64(0), lambda x, y: x + y)
    self.assertEqual(actual_result, expected_result)

  # pyformat: disable
  @with_contexts(
      ('native_local', tff.backends.native.create_local_execution_context()),
      ('native_local_caching', _create_native_local_caching_context()),
      ('native_remote',
       remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS),
       remote_runtime_test_utils.create_localhost_worker_contexts(_WORKER_PORTS)),
      ('native_remote_intermediate_aggregator',
       remote_runtime_test_utils.create_localhost_remote_context(_AGGREGATOR_PORTS),
       remote_runtime_test_utils.create_localhost_aggregator_contexts(_WORKER_PORTS, _AGGREGATOR_PORTS)),
      ('native_sizing', tff.backends.native.create_sizing_execution_context()),
      ('native_thread_debug',
       tff.backends.native.create_thread_debugging_execution_context()),
  )
  # pyformat: enable
  def test_returns_infinite_dataset(self):

    @tff.tf_computation
    def foo():
      return tf.data.Dataset.range(10).repeat()

    actual_result = foo()

    expected_result = tf.data.Dataset.range(10).repeat()
    self.assertEqual(
        actual_result.take(100).reduce(np.int64(0), lambda x, y: x + y),
        expected_result.take(100).reduce(np.int64(0), lambda x, y: x + y))

  @with_contexts
  def test_returns_result_with_typed_fn(self):

    @tff.tf_computation(tf.int32, tf.int32)
    def foo(x, y):
      return x + y

    result = foo(1, 2)
    self.assertEqual(result, 3)

  @with_contexts
  def test_raises_type_error_with_typed_fn(self):

    @tff.tf_computation(tf.int32, tf.int32)
    def foo(x, y):
      return x + y

    with self.assertRaises(TypeError):
      foo(1.0, 2.0)

  @with_contexts
  def test_returns_result_with_polymorphic_fn(self):

    @tff.tf_computation
    def foo(x, y):
      return x + y

    result = foo(1, 2)
    self.assertEqual(result, 3)
    result = foo(1.0, 2.0)
    self.assertEqual(result, 3.0)
예제 #11
0
        worker_contexts = remote_runtime_test_utils.create_localhost_aggregator_contexts(
            _WORKER_PORTS, _AGGREGATOR_PORTS, rpc_mode='STREAMING')

        context_stack = tff.framework.get_context_stack()
        with context_stack.install(execution_context):

            with contextlib.ExitStack() as stack:
                for server_context in worker_contexts:
                    stack.enter_context(server_context)
                result = map_add_one_and_sum([0, 1])
                self.assertEqual(result, 3)


@parameterized.named_parameters((
    'native_remote_request_reply',
    remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS),
    remote_runtime_test_utils.create_localhost_worker_contexts(_WORKER_PORTS),
), (
    'native_remote_streaming',
    remote_runtime_test_utils.create_localhost_remote_context(
        _WORKER_PORTS, rpc_mode='STREAMING'),
    remote_runtime_test_utils.create_localhost_worker_contexts(_WORKER_PORTS),
), (
    'native_remote_intermediate_aggregator',
    remote_runtime_test_utils.create_localhost_remote_context(
        _AGGREGATOR_PORTS),
    remote_runtime_test_utils.create_localhost_aggregator_contexts(
        _WORKER_PORTS, _AGGREGATOR_PORTS),
))
class RemoteRuntimeConfigurationChangeTest(absltest.TestCase):
    def test_computations_run_with_changing_clients(self, context,
예제 #12
0
    context_stack = tff.framework.get_context_stack()
    with context_stack.install(tff_context):

      with worker_contexts[0]:
        with worker_contexts[1]:
          # With both workers live, we should get 10 back.
          self.assertEqual(sum_arg(1), 10)
        # Leaving the inner context kills the second worker, but should leave
        # the result untouched.
        self.assertEqual(sum_arg(1), 10)


@parameterized.named_parameters((
    'native_remote',
    remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS),
    remote_runtime_test_utils.create_inprocess_worker_contexts(_WORKER_PORTS),
), (
    'native_remote_intermediate_aggregator',
    remote_runtime_test_utils.create_localhost_remote_context(
        _AGGREGATOR_PORTS),
    remote_runtime_test_utils.create_inprocess_aggregator_contexts(
        _WORKER_PORTS, _AGGREGATOR_PORTS),
))
class RemoteRuntimeConfigurationChangeTest(parameterized.TestCase):

  def test_computations_run_with_changing_clients(self, context,
                                                  server_contexts):

    @tff.tf_computation(tf.int32)
    @tf.function
예제 #13
0
class RemoteRuntimeIntegrationTest(parameterized.TestCase):
    @parameterized.named_parameters(
        ('native_remote',
         remote_runtime_test_utils.create_localhost_remote_context(
             _WORKER_PORTS),
         remote_runtime_test_utils.create_localhost_worker_contexts(
             _WORKER_PORTS),
         remote_runtime_test_utils.create_localhost_worker_contexts(
             _WORKER_PORTS)),
        ('native_remote_intermediate_aggregator',
         remote_runtime_test_utils.create_localhost_remote_context(
             _AGGREGATOR_PORTS),
         remote_runtime_test_utils.create_localhost_aggregator_contexts(
             _WORKER_PORTS, _AGGREGATOR_PORTS),
         remote_runtime_test_utils.create_localhost_aggregator_contexts(
             _WORKER_PORTS, _AGGREGATOR_PORTS)),
    )
    def test_computations_run_with_worker_restarts(self, context,
                                                   first_contexts,
                                                   second_contexts):
        @tff.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        @tff.federated_computation(tff.type_at_clients(tf.int32))
        def map_add_one(federated_arg):
            return tff.federated_map(add_one, federated_arg)

        context_stack = tff.framework.get_context_stack()
        with context_stack.install(context):

            with contextlib.ExitStack() as stack:
                for server_context in first_contexts:
                    stack.enter_context(server_context)
                result = map_add_one([0, 1])
                self.assertEqual(result, [1, 2])

            # Closing and re-entering the server contexts serves to simulate failures
            # and restarts at the workers. Restarts leave the workers in a state that
            # needs initialization again; entering the second context ensures that the
            # servers need to be reinitialized by the controller.
            with contextlib.ExitStack() as stack:
                for server_context in second_contexts:
                    stack.enter_context(server_context)
                result = map_add_one([0, 1])
                self.assertEqual(result, [1, 2])

    @parameterized.named_parameters((
        'native_remote',
        remote_runtime_test_utils.create_localhost_remote_context(
            _WORKER_PORTS),
        remote_runtime_test_utils.create_localhost_worker_contexts(
            _WORKER_PORTS),
        list(range(len(_WORKER_PORTS) - 1)),
    ), (
        'native_remote_intermediate_aggregator',
        remote_runtime_test_utils.create_localhost_remote_context(
            _AGGREGATOR_PORTS),
        remote_runtime_test_utils.create_localhost_aggregator_contexts(
            _WORKER_PORTS, _AGGREGATOR_PORTS),
        list(range(len(_AGGREGATOR_PORTS) - 1)),
    ))
    def test_computations_run_with_fewer_clients_than_remote_connections(
            self, context, serving_contexts, arg):
        self.skipTest('b/170315887')

        @tff.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        @tff.federated_computation(tff.type_at_clients(tf.int32))
        def map_add_one(federated_arg):
            return tff.federated_map(add_one, federated_arg)

        context_stack = tff.framework.get_context_stack()
        with context_stack.install(context):

            with contextlib.ExitStack() as stack:
                for server_context in serving_contexts:
                    stack.enter_context(server_context)
                result = map_add_one(arg)
                self.assertEqual(result, [x + 1 for x in arg])
예제 #14
0
class TensorFlowComputationTest(tf.test.TestCase, parameterized.TestCase):
    @test_contexts.with_contexts
    def test_create_call_take_two_from_stateful_dataset(self):

        vocab = ['a', 'b', 'c', 'd', 'e', 'f']

        @tff.tf_computation(tff.SequenceType(tf.string))
        def take_two(ds):
            table = tf.lookup.StaticVocabularyTable(
                tf.lookup.KeyValueTensorInitializer(
                    vocab, tf.range(len(vocab), dtype=tf.int64)),
                num_oov_buckets=1)
            ds = ds.map(table.lookup)
            return ds.take(2)

        ds = tf.data.Dataset.from_tensor_slices(vocab)
        result = take_two(ds)
        self.assertCountEqual([x.numpy() for x in result], [0, 1])

    @test_contexts.with_contexts
    def test_twice_used_variable_keeps_separate_state(self):
        def count_one_body():
            variable = tf.Variable(initial_value=0, name='var_of_interest')
            with tf.control_dependencies([variable.assign_add(1)]):
                return variable.read_value()

        count_one_1 = tff.tf_computation(count_one_body)
        count_one_2 = tff.tf_computation(count_one_body)

        @tff.tf_computation
        def count_one_twice():
            return count_one_1(), count_one_1(), count_one_2()

        self.assertEqual((1, 1, 1), count_one_twice())

    @test_contexts.with_contexts
    def test_dynamic_lookup_table(self):
        @tff.tf_computation(tff.TensorType(shape=[None], dtype=tf.string),
                            tff.TensorType(shape=[None], dtype=tf.string))
        def comp(table_args, to_lookup):
            values = tf.range(tf.shape(table_args)[0])
            initializer = tf.lookup.KeyValueTensorInitializer(
                table_args, values)
            table = tf.lookup.StaticHashTable(initializer, default_value=101)
            return table.lookup(to_lookup)

        result = comp(tf.constant(['a', 'b', 'c']),
                      tf.constant(['a', 'z', 'c']))
        self.assertAllEqual(result, [0, 101, 2])

    @test_contexts.with_contexts
    def test_reinitialize_dynamic_lookup_table(self):
        @tff.tf_computation(tff.TensorType(shape=[None], dtype=tf.string),
                            tff.TensorType(shape=[], dtype=tf.string))
        def comp(table_args, to_lookup):
            values = tf.range(tf.shape(table_args)[0])
            initializer = tf.lookup.KeyValueTensorInitializer(
                table_args, values)
            table = tf.lookup.StaticHashTable(initializer, default_value=101)
            return table.lookup(to_lookup)

        expected_zero = comp(tf.constant(['a', 'b', 'c']), tf.constant('a'))
        expected_three = comp(tf.constant(['a', 'b', 'c', 'd']),
                              tf.constant('d'))

        self.assertEqual(expected_zero, 0)
        self.assertEqual(expected_three, 3)

    @test_contexts.with_contexts
    def test_returns_constant(self):
        @tff.tf_computation
        def foo():
            return 10

        result = foo()
        self.assertEqual(result, 10)

    @test_contexts.with_contexts
    def test_returns_empty_tuple(self):
        @tff.tf_computation
        def foo():
            return ()

        result = foo()
        self.assertEqual(result, ())

    @test_contexts.with_contexts
    def test_returns_variable(self):
        @tff.tf_computation
        def foo():
            return tf.Variable(10, name='var')

        result = foo()
        self.assertEqual(result, 10)

    # pyformat: disable
    @test_contexts.with_contexts(
        # pylint: disable=unnecessary-lambda
        ('native_local',
         lambda: tff.backends.native.create_local_python_execution_context()),
        ('native_remote', lambda: remote_runtime_test_utils.
         create_localhost_remote_context(test_contexts.WORKER_PORTS),
         lambda: remote_runtime_test_utils.create_inprocess_worker_contexts(
             test_contexts.WORKER_PORTS)),
        ('native_remote_intermediate_aggregator',
         lambda: remote_runtime_test_utils.create_localhost_remote_context(
             test_contexts.AGGREGATOR_PORTS), lambda:
         remote_runtime_test_utils.create_inprocess_aggregator_contexts(
             test_contexts.WORKER_PORTS, test_contexts.AGGREGATOR_PORTS)),
        ('native_sizing',
         lambda: tff.backends.native.create_sizing_execution_context()),
        ('native_thread_debug', lambda: tff.backends.native.
         create_thread_debugging_execution_context()),
    )
    # pyformat: enable
    def test_takes_infinite_dataset(self):
        @tff.tf_computation
        def foo(ds):
            return ds.take(10).reduce(np.int64(0), lambda x, y: x + y)

        ds = tf.data.Dataset.range(10).repeat()
        actual_result = foo(ds)

        expected_result = ds.take(10).reduce(np.int64(0), lambda x, y: x + y)
        self.assertEqual(actual_result, expected_result)

    # pyformat: disable
    @test_contexts.with_contexts(
        # pylint: disable=unnecessary-lambda
        ('native_local',
         lambda: tff.backends.native.create_local_python_execution_context()),
        ('native_remote', lambda: remote_runtime_test_utils.
         create_localhost_remote_context(test_contexts.WORKER_PORTS),
         lambda: remote_runtime_test_utils.create_inprocess_worker_contexts(
             test_contexts.WORKER_PORTS)),
        ('native_remote_intermediate_aggregator',
         lambda: remote_runtime_test_utils.create_localhost_remote_context(
             test_contexts.AGGREGATOR_PORTS), lambda:
         remote_runtime_test_utils.create_inprocess_aggregator_contexts(
             test_contexts.WORKER_PORTS, test_contexts.AGGREGATOR_PORTS)),
        ('native_sizing',
         lambda: tff.backends.native.create_sizing_execution_context()),
        ('native_thread_debug', lambda: tff.backends.native.
         create_thread_debugging_execution_context()),
    )
    # pyformat: enable
    def test_returns_infinite_dataset(self):
        @tff.tf_computation
        def foo():
            return tf.data.Dataset.range(10).repeat()

        actual_result = foo()

        expected_result = tf.data.Dataset.range(10).repeat()
        self.assertEqual(
            actual_result.take(100).reduce(np.int64(0), lambda x, y: x + y),
            expected_result.take(100).reduce(np.int64(0), lambda x, y: x + y))

    @test_contexts.with_contexts
    def test_returns_result_with_typed_fn(self):
        @tff.tf_computation(tf.int32, tf.int32)
        def foo(x, y):
            return x + y

        result = foo(1, 2)
        self.assertEqual(result, 3)

    @test_contexts.with_contexts
    def test_raises_type_error_with_typed_fn(self):
        @tff.tf_computation(tf.int32, tf.int32)
        def foo(x, y):
            return x + y

        with self.assertRaises(TypeError):
            foo(1.0, 2.0)

    @test_contexts.with_contexts
    def test_returns_result_with_polymorphic_fn(self):
        @tff.tf_computation
        def foo(x, y):
            return x + y

        result = foo(1, 2)
        self.assertEqual(result, 3)
        result = foo(1.0, 2.0)
        self.assertEqual(result, 3.0)