Example #1
0
def _get_all_contexts():
  # pyformat: disable
  return [
      ('native_local', tff.backends.native.create_local_execution_context()),
      ('native_local_caching', _create_native_local_caching_context()),
      ('native_remote',
       remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS),
       remote_runtime_test_utils.create_inprocess_worker_contexts(_WORKER_PORTS)),
      ('native_remote_intermediate_aggregator',
       remote_runtime_test_utils.create_localhost_remote_context(_AGGREGATOR_PORTS),
       remote_runtime_test_utils.create_inprocess_aggregator_contexts(_WORKER_PORTS, _AGGREGATOR_PORTS)),
      ('native_sizing', tff.backends.native.create_sizing_execution_context()),
      ('native_thread_debug',
       tff.backends.native.create_thread_debugging_execution_context()),
      ('reference', tff.backends.reference.create_reference_context()),
      ('test', tff.backends.test.create_test_execution_context()),
  ]
Example #2
0
          # With both workers live, we should get 10 back.
          self.assertEqual(sum_arg(1), 10)
        # Leaving the inner context kills the second worker, but should leave
        # the result untouched.
        self.assertEqual(sum_arg(1), 10)


@parameterized.named_parameters((
    'native_remote',
    remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS),
    remote_runtime_test_utils.create_inprocess_worker_contexts(_WORKER_PORTS),
), (
    'native_remote_intermediate_aggregator',
    remote_runtime_test_utils.create_localhost_remote_context(
        _AGGREGATOR_PORTS),
    remote_runtime_test_utils.create_inprocess_aggregator_contexts(
        _WORKER_PORTS, _AGGREGATOR_PORTS),
))
class RemoteRuntimeConfigurationChangeTest(parameterized.TestCase):

  def test_computations_run_with_changing_clients(self, context,
                                                  server_contexts):

    @tff.tf_computation(tf.int32)
    @tf.function
    def add_one(x):
      return x + 1

    @tff.federated_computation(tff.type_at_clients(tf.int32))
    def map_add_one(federated_arg):
      return tff.federated_map(add_one, federated_arg)
Example #3
0
class TensorFlowComputationTest(parameterized.TestCase):
    @test_contexts.with_contexts
    def test_returns_constant(self):
        @tff.tf_computation
        def foo():
            return 10

        result = foo()
        self.assertEqual(result, 10)

    @test_contexts.with_contexts
    def test_returns_empty_tuple(self):
        @tff.tf_computation
        def foo():
            return ()

        result = foo()
        self.assertEqual(result, ())

    @test_contexts.with_contexts
    def test_returns_variable(self):
        @tff.tf_computation
        def foo():
            return tf.Variable(10, name='var')

        result = foo()
        self.assertEqual(result, 10)

    # pyformat: disable
    @test_contexts.with_contexts(
        ('native_local', tff.backends.native.create_local_execution_context()),
        ('native_local_caching',
         test_contexts.create_native_local_caching_context()),
        ('native_remote',
         remote_runtime_test_utils.create_localhost_remote_context(
             test_contexts.WORKER_PORTS),
         remote_runtime_test_utils.create_inprocess_worker_contexts(
             test_contexts.WORKER_PORTS)),
        ('native_remote_intermediate_aggregator',
         remote_runtime_test_utils.create_localhost_remote_context(
             test_contexts.AGGREGATOR_PORTS),
         remote_runtime_test_utils.create_inprocess_aggregator_contexts(
             test_contexts.WORKER_PORTS, test_contexts.AGGREGATOR_PORTS)),
        ('native_sizing',
         tff.backends.native.create_sizing_execution_context()),
        ('native_thread_debug',
         tff.backends.native.create_thread_debugging_execution_context()),
    )
    # pyformat: enable
    def test_takes_infinite_dataset(self):
        @tff.tf_computation
        def foo(ds):
            return ds.take(10).reduce(np.int64(0), lambda x, y: x + y)

        ds = tf.data.Dataset.range(10).repeat()
        actual_result = foo(ds)

        expected_result = ds.take(10).reduce(np.int64(0), lambda x, y: x + y)
        self.assertEqual(actual_result, expected_result)

    # pyformat: disable
    @test_contexts.with_contexts(
        ('native_local', tff.backends.native.create_local_execution_context()),
        ('native_local_caching',
         test_contexts.create_native_local_caching_context()),
        ('native_remote',
         remote_runtime_test_utils.create_localhost_remote_context(
             test_contexts.WORKER_PORTS),
         remote_runtime_test_utils.create_inprocess_worker_contexts(
             test_contexts.WORKER_PORTS)),
        ('native_remote_intermediate_aggregator',
         remote_runtime_test_utils.create_localhost_remote_context(
             test_contexts.AGGREGATOR_PORTS),
         remote_runtime_test_utils.create_inprocess_aggregator_contexts(
             test_contexts.WORKER_PORTS, test_contexts.AGGREGATOR_PORTS)),
        ('native_sizing',
         tff.backends.native.create_sizing_execution_context()),
        ('native_thread_debug',
         tff.backends.native.create_thread_debugging_execution_context()),
    )
    # pyformat: enable
    def test_returns_infinite_dataset(self):
        @tff.tf_computation
        def foo():
            return tf.data.Dataset.range(10).repeat()

        actual_result = foo()

        expected_result = tf.data.Dataset.range(10).repeat()
        self.assertEqual(
            actual_result.take(100).reduce(np.int64(0), lambda x, y: x + y),
            expected_result.take(100).reduce(np.int64(0), lambda x, y: x + y))

    @test_contexts.with_contexts
    def test_returns_result_with_typed_fn(self):
        @tff.tf_computation(tf.int32, tf.int32)
        def foo(x, y):
            return x + y

        result = foo(1, 2)
        self.assertEqual(result, 3)

    @test_contexts.with_contexts
    def test_raises_type_error_with_typed_fn(self):
        @tff.tf_computation(tf.int32, tf.int32)
        def foo(x, y):
            return x + y

        with self.assertRaises(TypeError):
            foo(1.0, 2.0)

    @test_contexts.with_contexts
    def test_returns_result_with_polymorphic_fn(self):
        @tff.tf_computation
        def foo(x, y):
            return x + y

        result = foo(1, 2)
        self.assertEqual(result, 3)
        result = foo(1.0, 2.0)
        self.assertEqual(result, 3.0)
Example #4
0
class TensorFlowComputationTest(tf.test.TestCase, parameterized.TestCase):
    @test_contexts.with_contexts
    def test_create_call_take_two_from_stateful_dataset(self):

        vocab = ['a', 'b', 'c', 'd', 'e', 'f']

        @tff.tf_computation(tff.SequenceType(tf.string))
        def take_two(ds):
            table = tf.lookup.StaticVocabularyTable(
                tf.lookup.KeyValueTensorInitializer(
                    vocab, tf.range(len(vocab), dtype=tf.int64)),
                num_oov_buckets=1)
            ds = ds.map(table.lookup)
            return ds.take(2)

        ds = tf.data.Dataset.from_tensor_slices(vocab)
        result = take_two(ds)
        self.assertCountEqual([x.numpy() for x in result], [0, 1])

    @test_contexts.with_contexts
    def test_twice_used_variable_keeps_separate_state(self):
        def count_one_body():
            variable = tf.Variable(initial_value=0, name='var_of_interest')
            with tf.control_dependencies([variable.assign_add(1)]):
                return variable.read_value()

        count_one_1 = tff.tf_computation(count_one_body)
        count_one_2 = tff.tf_computation(count_one_body)

        @tff.tf_computation
        def count_one_twice():
            return count_one_1(), count_one_1(), count_one_2()

        self.assertEqual((1, 1, 1), count_one_twice())

    @test_contexts.with_contexts
    def test_dynamic_lookup_table(self):
        @tff.tf_computation(tff.TensorType(shape=[None], dtype=tf.string),
                            tff.TensorType(shape=[None], dtype=tf.string))
        def comp(table_args, to_lookup):
            values = tf.range(tf.shape(table_args)[0])
            initializer = tf.lookup.KeyValueTensorInitializer(
                table_args, values)
            table = tf.lookup.StaticHashTable(initializer, default_value=101)
            return table.lookup(to_lookup)

        result = comp(tf.constant(['a', 'b', 'c']),
                      tf.constant(['a', 'z', 'c']))
        self.assertAllEqual(result, [0, 101, 2])

    @test_contexts.with_contexts
    def test_reinitialize_dynamic_lookup_table(self):
        @tff.tf_computation(tff.TensorType(shape=[None], dtype=tf.string),
                            tff.TensorType(shape=[], dtype=tf.string))
        def comp(table_args, to_lookup):
            values = tf.range(tf.shape(table_args)[0])
            initializer = tf.lookup.KeyValueTensorInitializer(
                table_args, values)
            table = tf.lookup.StaticHashTable(initializer, default_value=101)
            return table.lookup(to_lookup)

        expected_zero = comp(tf.constant(['a', 'b', 'c']), tf.constant('a'))
        expected_three = comp(tf.constant(['a', 'b', 'c', 'd']),
                              tf.constant('d'))

        self.assertEqual(expected_zero, 0)
        self.assertEqual(expected_three, 3)

    @test_contexts.with_contexts
    def test_returns_constant(self):
        @tff.tf_computation
        def foo():
            return 10

        result = foo()
        self.assertEqual(result, 10)

    @test_contexts.with_contexts
    def test_returns_empty_tuple(self):
        @tff.tf_computation
        def foo():
            return ()

        result = foo()
        self.assertEqual(result, ())

    @test_contexts.with_contexts
    def test_returns_variable(self):
        @tff.tf_computation
        def foo():
            return tf.Variable(10, name='var')

        result = foo()
        self.assertEqual(result, 10)

    # pyformat: disable
    @test_contexts.with_contexts(
        # pylint: disable=unnecessary-lambda
        ('native_local',
         lambda: tff.backends.native.create_local_python_execution_context()),
        ('native_remote', lambda: remote_runtime_test_utils.
         create_localhost_remote_context(test_contexts.WORKER_PORTS),
         lambda: remote_runtime_test_utils.create_inprocess_worker_contexts(
             test_contexts.WORKER_PORTS)),
        ('native_remote_intermediate_aggregator',
         lambda: remote_runtime_test_utils.create_localhost_remote_context(
             test_contexts.AGGREGATOR_PORTS), lambda:
         remote_runtime_test_utils.create_inprocess_aggregator_contexts(
             test_contexts.WORKER_PORTS, test_contexts.AGGREGATOR_PORTS)),
        ('native_sizing',
         lambda: tff.backends.native.create_sizing_execution_context()),
        ('native_thread_debug', lambda: tff.backends.native.
         create_thread_debugging_execution_context()),
    )
    # pyformat: enable
    def test_takes_infinite_dataset(self):
        @tff.tf_computation
        def foo(ds):
            return ds.take(10).reduce(np.int64(0), lambda x, y: x + y)

        ds = tf.data.Dataset.range(10).repeat()
        actual_result = foo(ds)

        expected_result = ds.take(10).reduce(np.int64(0), lambda x, y: x + y)
        self.assertEqual(actual_result, expected_result)

    # pyformat: disable
    @test_contexts.with_contexts(
        # pylint: disable=unnecessary-lambda
        ('native_local',
         lambda: tff.backends.native.create_local_python_execution_context()),
        ('native_remote', lambda: remote_runtime_test_utils.
         create_localhost_remote_context(test_contexts.WORKER_PORTS),
         lambda: remote_runtime_test_utils.create_inprocess_worker_contexts(
             test_contexts.WORKER_PORTS)),
        ('native_remote_intermediate_aggregator',
         lambda: remote_runtime_test_utils.create_localhost_remote_context(
             test_contexts.AGGREGATOR_PORTS), lambda:
         remote_runtime_test_utils.create_inprocess_aggregator_contexts(
             test_contexts.WORKER_PORTS, test_contexts.AGGREGATOR_PORTS)),
        ('native_sizing',
         lambda: tff.backends.native.create_sizing_execution_context()),
        ('native_thread_debug', lambda: tff.backends.native.
         create_thread_debugging_execution_context()),
    )
    # pyformat: enable
    def test_returns_infinite_dataset(self):
        @tff.tf_computation
        def foo():
            return tf.data.Dataset.range(10).repeat()

        actual_result = foo()

        expected_result = tf.data.Dataset.range(10).repeat()
        self.assertEqual(
            actual_result.take(100).reduce(np.int64(0), lambda x, y: x + y),
            expected_result.take(100).reduce(np.int64(0), lambda x, y: x + y))

    @test_contexts.with_contexts
    def test_returns_result_with_typed_fn(self):
        @tff.tf_computation(tf.int32, tf.int32)
        def foo(x, y):
            return x + y

        result = foo(1, 2)
        self.assertEqual(result, 3)

    @test_contexts.with_contexts
    def test_raises_type_error_with_typed_fn(self):
        @tff.tf_computation(tf.int32, tf.int32)
        def foo(x, y):
            return x + y

        with self.assertRaises(TypeError):
            foo(1.0, 2.0)

    @test_contexts.with_contexts
    def test_returns_result_with_polymorphic_fn(self):
        @tff.tf_computation
        def foo(x, y):
            return x + y

        result = foo(1, 2)
        self.assertEqual(result, 3)
        result = foo(1.0, 2.0)
        self.assertEqual(result, 3.0)