def _get_all_contexts(): # pyformat: disable return [ ('native_local', tff.backends.native.create_local_execution_context()), ('native_local_caching', _create_native_local_caching_context()), ('native_remote', remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS), remote_runtime_test_utils.create_inprocess_worker_contexts(_WORKER_PORTS)), ('native_remote_intermediate_aggregator', remote_runtime_test_utils.create_localhost_remote_context(_AGGREGATOR_PORTS), remote_runtime_test_utils.create_inprocess_aggregator_contexts(_WORKER_PORTS, _AGGREGATOR_PORTS)), ('native_sizing', tff.backends.native.create_sizing_execution_context()), ('native_thread_debug', tff.backends.native.create_thread_debugging_execution_context()), ('reference', tff.backends.reference.create_reference_context()), ('test', tff.backends.test.create_test_execution_context()), ]
# With both workers live, we should get 10 back. self.assertEqual(sum_arg(1), 10) # Leaving the inner context kills the second worker, but should leave # the result untouched. self.assertEqual(sum_arg(1), 10) @parameterized.named_parameters(( 'native_remote', remote_runtime_test_utils.create_localhost_remote_context(_WORKER_PORTS), remote_runtime_test_utils.create_inprocess_worker_contexts(_WORKER_PORTS), ), ( 'native_remote_intermediate_aggregator', remote_runtime_test_utils.create_localhost_remote_context( _AGGREGATOR_PORTS), remote_runtime_test_utils.create_inprocess_aggregator_contexts( _WORKER_PORTS, _AGGREGATOR_PORTS), )) class RemoteRuntimeConfigurationChangeTest(parameterized.TestCase): def test_computations_run_with_changing_clients(self, context, server_contexts): @tff.tf_computation(tf.int32) @tf.function def add_one(x): return x + 1 @tff.federated_computation(tff.type_at_clients(tf.int32)) def map_add_one(federated_arg): return tff.federated_map(add_one, federated_arg)
class TensorFlowComputationTest(parameterized.TestCase): @test_contexts.with_contexts def test_returns_constant(self): @tff.tf_computation def foo(): return 10 result = foo() self.assertEqual(result, 10) @test_contexts.with_contexts def test_returns_empty_tuple(self): @tff.tf_computation def foo(): return () result = foo() self.assertEqual(result, ()) @test_contexts.with_contexts def test_returns_variable(self): @tff.tf_computation def foo(): return tf.Variable(10, name='var') result = foo() self.assertEqual(result, 10) # pyformat: disable @test_contexts.with_contexts( ('native_local', tff.backends.native.create_local_execution_context()), ('native_local_caching', test_contexts.create_native_local_caching_context()), ('native_remote', remote_runtime_test_utils.create_localhost_remote_context( test_contexts.WORKER_PORTS), remote_runtime_test_utils.create_inprocess_worker_contexts( test_contexts.WORKER_PORTS)), ('native_remote_intermediate_aggregator', remote_runtime_test_utils.create_localhost_remote_context( test_contexts.AGGREGATOR_PORTS), remote_runtime_test_utils.create_inprocess_aggregator_contexts( test_contexts.WORKER_PORTS, test_contexts.AGGREGATOR_PORTS)), ('native_sizing', tff.backends.native.create_sizing_execution_context()), ('native_thread_debug', tff.backends.native.create_thread_debugging_execution_context()), ) # pyformat: enable def test_takes_infinite_dataset(self): @tff.tf_computation def foo(ds): return ds.take(10).reduce(np.int64(0), lambda x, y: x + y) ds = tf.data.Dataset.range(10).repeat() actual_result = foo(ds) expected_result = ds.take(10).reduce(np.int64(0), lambda x, y: x + y) self.assertEqual(actual_result, expected_result) # pyformat: disable @test_contexts.with_contexts( ('native_local', tff.backends.native.create_local_execution_context()), ('native_local_caching', test_contexts.create_native_local_caching_context()), ('native_remote', remote_runtime_test_utils.create_localhost_remote_context( test_contexts.WORKER_PORTS), remote_runtime_test_utils.create_inprocess_worker_contexts( test_contexts.WORKER_PORTS)), ('native_remote_intermediate_aggregator', remote_runtime_test_utils.create_localhost_remote_context( test_contexts.AGGREGATOR_PORTS), remote_runtime_test_utils.create_inprocess_aggregator_contexts( test_contexts.WORKER_PORTS, test_contexts.AGGREGATOR_PORTS)), ('native_sizing', tff.backends.native.create_sizing_execution_context()), ('native_thread_debug', tff.backends.native.create_thread_debugging_execution_context()), ) # pyformat: enable def test_returns_infinite_dataset(self): @tff.tf_computation def foo(): return tf.data.Dataset.range(10).repeat() actual_result = foo() expected_result = tf.data.Dataset.range(10).repeat() self.assertEqual( actual_result.take(100).reduce(np.int64(0), lambda x, y: x + y), expected_result.take(100).reduce(np.int64(0), lambda x, y: x + y)) @test_contexts.with_contexts def test_returns_result_with_typed_fn(self): @tff.tf_computation(tf.int32, tf.int32) def foo(x, y): return x + y result = foo(1, 2) self.assertEqual(result, 3) @test_contexts.with_contexts def test_raises_type_error_with_typed_fn(self): @tff.tf_computation(tf.int32, tf.int32) def foo(x, y): return x + y with self.assertRaises(TypeError): foo(1.0, 2.0) @test_contexts.with_contexts def test_returns_result_with_polymorphic_fn(self): @tff.tf_computation def foo(x, y): return x + y result = foo(1, 2) self.assertEqual(result, 3) result = foo(1.0, 2.0) self.assertEqual(result, 3.0)
class TensorFlowComputationTest(tf.test.TestCase, parameterized.TestCase): @test_contexts.with_contexts def test_create_call_take_two_from_stateful_dataset(self): vocab = ['a', 'b', 'c', 'd', 'e', 'f'] @tff.tf_computation(tff.SequenceType(tf.string)) def take_two(ds): table = tf.lookup.StaticVocabularyTable( tf.lookup.KeyValueTensorInitializer( vocab, tf.range(len(vocab), dtype=tf.int64)), num_oov_buckets=1) ds = ds.map(table.lookup) return ds.take(2) ds = tf.data.Dataset.from_tensor_slices(vocab) result = take_two(ds) self.assertCountEqual([x.numpy() for x in result], [0, 1]) @test_contexts.with_contexts def test_twice_used_variable_keeps_separate_state(self): def count_one_body(): variable = tf.Variable(initial_value=0, name='var_of_interest') with tf.control_dependencies([variable.assign_add(1)]): return variable.read_value() count_one_1 = tff.tf_computation(count_one_body) count_one_2 = tff.tf_computation(count_one_body) @tff.tf_computation def count_one_twice(): return count_one_1(), count_one_1(), count_one_2() self.assertEqual((1, 1, 1), count_one_twice()) @test_contexts.with_contexts def test_dynamic_lookup_table(self): @tff.tf_computation(tff.TensorType(shape=[None], dtype=tf.string), tff.TensorType(shape=[None], dtype=tf.string)) def comp(table_args, to_lookup): values = tf.range(tf.shape(table_args)[0]) initializer = tf.lookup.KeyValueTensorInitializer( table_args, values) table = tf.lookup.StaticHashTable(initializer, default_value=101) return table.lookup(to_lookup) result = comp(tf.constant(['a', 'b', 'c']), tf.constant(['a', 'z', 'c'])) self.assertAllEqual(result, [0, 101, 2]) @test_contexts.with_contexts def test_reinitialize_dynamic_lookup_table(self): @tff.tf_computation(tff.TensorType(shape=[None], dtype=tf.string), tff.TensorType(shape=[], dtype=tf.string)) def comp(table_args, to_lookup): values = tf.range(tf.shape(table_args)[0]) initializer = tf.lookup.KeyValueTensorInitializer( table_args, values) table = tf.lookup.StaticHashTable(initializer, default_value=101) return table.lookup(to_lookup) expected_zero = comp(tf.constant(['a', 'b', 'c']), tf.constant('a')) expected_three = comp(tf.constant(['a', 'b', 'c', 'd']), tf.constant('d')) self.assertEqual(expected_zero, 0) self.assertEqual(expected_three, 3) @test_contexts.with_contexts def test_returns_constant(self): @tff.tf_computation def foo(): return 10 result = foo() self.assertEqual(result, 10) @test_contexts.with_contexts def test_returns_empty_tuple(self): @tff.tf_computation def foo(): return () result = foo() self.assertEqual(result, ()) @test_contexts.with_contexts def test_returns_variable(self): @tff.tf_computation def foo(): return tf.Variable(10, name='var') result = foo() self.assertEqual(result, 10) # pyformat: disable @test_contexts.with_contexts( # pylint: disable=unnecessary-lambda ('native_local', lambda: tff.backends.native.create_local_python_execution_context()), ('native_remote', lambda: remote_runtime_test_utils. create_localhost_remote_context(test_contexts.WORKER_PORTS), lambda: remote_runtime_test_utils.create_inprocess_worker_contexts( test_contexts.WORKER_PORTS)), ('native_remote_intermediate_aggregator', lambda: remote_runtime_test_utils.create_localhost_remote_context( test_contexts.AGGREGATOR_PORTS), lambda: remote_runtime_test_utils.create_inprocess_aggregator_contexts( test_contexts.WORKER_PORTS, test_contexts.AGGREGATOR_PORTS)), ('native_sizing', lambda: tff.backends.native.create_sizing_execution_context()), ('native_thread_debug', lambda: tff.backends.native. create_thread_debugging_execution_context()), ) # pyformat: enable def test_takes_infinite_dataset(self): @tff.tf_computation def foo(ds): return ds.take(10).reduce(np.int64(0), lambda x, y: x + y) ds = tf.data.Dataset.range(10).repeat() actual_result = foo(ds) expected_result = ds.take(10).reduce(np.int64(0), lambda x, y: x + y) self.assertEqual(actual_result, expected_result) # pyformat: disable @test_contexts.with_contexts( # pylint: disable=unnecessary-lambda ('native_local', lambda: tff.backends.native.create_local_python_execution_context()), ('native_remote', lambda: remote_runtime_test_utils. create_localhost_remote_context(test_contexts.WORKER_PORTS), lambda: remote_runtime_test_utils.create_inprocess_worker_contexts( test_contexts.WORKER_PORTS)), ('native_remote_intermediate_aggregator', lambda: remote_runtime_test_utils.create_localhost_remote_context( test_contexts.AGGREGATOR_PORTS), lambda: remote_runtime_test_utils.create_inprocess_aggregator_contexts( test_contexts.WORKER_PORTS, test_contexts.AGGREGATOR_PORTS)), ('native_sizing', lambda: tff.backends.native.create_sizing_execution_context()), ('native_thread_debug', lambda: tff.backends.native. create_thread_debugging_execution_context()), ) # pyformat: enable def test_returns_infinite_dataset(self): @tff.tf_computation def foo(): return tf.data.Dataset.range(10).repeat() actual_result = foo() expected_result = tf.data.Dataset.range(10).repeat() self.assertEqual( actual_result.take(100).reduce(np.int64(0), lambda x, y: x + y), expected_result.take(100).reduce(np.int64(0), lambda x, y: x + y)) @test_contexts.with_contexts def test_returns_result_with_typed_fn(self): @tff.tf_computation(tf.int32, tf.int32) def foo(x, y): return x + y result = foo(1, 2) self.assertEqual(result, 3) @test_contexts.with_contexts def test_raises_type_error_with_typed_fn(self): @tff.tf_computation(tf.int32, tf.int32) def foo(x, y): return x + y with self.assertRaises(TypeError): foo(1.0, 2.0) @test_contexts.with_contexts def test_returns_result_with_polymorphic_fn(self): @tff.tf_computation def foo(x, y): return x + y result = foo(1, 2) self.assertEqual(result, 3) result = foo(1.0, 2.0) self.assertEqual(result, 3.0)