def setUp(self): super(ClusterPlacementTest, self).setUp() context._reset_context() config.set_soft_device_placement(enabled=True) context.context().log_device_placement = True workers, _ = test_util.create_local_cluster(2, 0) remote.connect_to_remote_host([workers[0].target, workers[1].target])
def testConnectToRemoteServer(self): """Basic server connection.""" remote.connect_to_remote_host(self._cached_server1_target) with ops.device("job:worker/replica:0/task:1/device:CPU:0"): x1 = array_ops.ones([2, 2]) x2 = array_ops.ones([2, 2]) y = math_ops.matmul(x1, x2) np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
def testConnectToRemoteServer(self): """Basic server connection.""" remote.connect_to_remote_host(self._cached_server1_target) with ops.device("job:worker/replica:0/task:1/device:CPU:0"): x1 = array_ops.ones([2, 2]) x2 = array_ops.ones([2, 2]) y = math_ops.matmul(x1, x2) np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
def testTPUInitializationMultiHost(self): ctx = context.context() if not ctx.list_physical_devices('TPU'): self.assertEmpty(ctx.tpu_topologies_by_job) self.skipTest('A TPU is required to run this test.') self.assertEqual(['localhost'], list(ctx.tpu_topologies_by_job.keys())) server = server_lib.Server.create_local_server() target = server.target[len('grpc://'):] remote.connect_to_remote_host([target]) self.assertIn('localhost', ctx.tpu_topologies_by_job) self.assertIn('worker', ctx.tpu_topologies_by_job) self.assertLen(ctx.tpu_topologies, 2)
def _DISABLED_benchmark_send_mirroring_on(self): remote.connect_to_remote_host(self._cached_server_target1) x = random_ops.random_uniform((2, 2)).cpu() @def_function.function def remote_func(m): return math_ops.matmul(m, m) def func(m): with ops.device("job:worker/replica:0/task:0/device:CPU:0"): return remote_func(m) context.context().mirroring_policy = context.MIRRORING_ALL self._run(lambda: func(x))
def _DISABLED_benchmark_worker_mirroring_on(self): remote.connect_to_remote_host( [self._cached_server_target1, self._cached_server_target2]) with ops.device("job:worker/replica:0/task:1/device:CPU:0"): v = variables.Variable(1.0) @def_function.function def remote_func(): return 1.0 + v def func(): with ops.device("job:worker/replica:0/task:0/device:CPU:0"): return remote_func() context.context().mirroring_policy = context.MIRRORING_ALL self._run(func)
def testOperationTimeout(self): context._reset_context() context.context().operation_timeout_in_ms = 10 workers, _ = test_util.create_local_cluster(1, 0) remote.connect_to_remote_host(workers[0].target) q = data_flow_ops.FIFOQueue(1, dtypes.int32) @def_function.function def f(): return q.dequeue() with self.assertRaises(errors.DeadlineExceededError): with ops.device('/job:worker/replica:0/task:0'): f() # If streaming RPC is enabled, fetch remote errors before end of execution context.async_wait()
def benchmark_send(self): remote.connect_to_remote_host(self._cached_server_target1) x = random_ops.random_uniform((2, 2)).cpu() @def_function.function def remote_func(m): return math_ops.matmul(m, m) def func(m): with ops.device("job:worker/replica:0/task:0/device:CPU:0"): return remote_func(m) self._run(lambda: func(x)) # NOTE(b/136184459): Force garbage collecting hanging resources before # subsequent calls to set_server_def, to ensure the destroy resource ops are # executed when their corresponding device and manager are still available. gc.collect()
def benchmark_create_vars_inside_function(self): remote.connect_to_remote_host(self._cached_server_target1) def func(): with ops.device("job:worker/replica:0/task:0/device:CPU:0"): layer = Foo(50) @def_function.function def remote_func(): with ops.device( "job:worker/replica:0/task:0/device:CPU:0"): return layer(random_ops.random_uniform([])) return remote_func() self._run(func, execution_mode=context.ASYNC, num_iters=100) # NOTE(b/136184459): Force garbage collecting hanging resources before # subsequent calls to set_server_def, to ensure the destroy resource ops are # executed when their corresponding device and manager are still available. gc.collect()
def benchmark_worker_recv(self): remote.connect_to_remote_host( [self._cached_server_target1, self._cached_server_target2]) with ops.device("job:worker/replica:0/task:1/device:CPU:0"): v = variables.Variable(1.0) @def_function.function def remote_func(): return 1.0 + v def func(): with ops.device("job:worker/replica:0/task:0/device:CPU:0"): return remote_func() self._run(func) # NOTE(b/136184459): Force garbage collecting hanging resources before # subsequent calls to set_server_def, to ensure the destroy resource ops are # executed when their corresponding device and manager are still available. gc.collect()
def testMemoryLeakInLocalCopy(self): if memory_profiler is None: self.skipTest("memory_profiler required to run this test") remote.connect_to_remote_host(self._cached_server_target) # Run a function locally with the input on a remote worker and ensure we # do not leak a reference to the remote tensor. @def_function.function def local_func(i): return i def func(): with ops.device("job:worker/replica:0/task:0/device:CPU:0"): x = array_ops.zeros([1000, 1000], dtypes.int32) local_func(x) assert_no_leak(func, num_iters=100, increase_threshold_absolute_mb=50)
def setUp(self): super(SingleWorkerTest, self).setUp() workers, _ = test_util.create_local_cluster(1, 0) remote.connect_to_remote_host(workers[0].target)
def setUp(self): super(MultiWorkersTest, self).setUp() workers, _ = test_util.create_local_cluster(3, 0) remote.connect_to_remote_host( [workers[0].target, workers[1].target, workers[2].target])
def setUp(self): super(SingleWorkerTestBaseEager, self).setUp() workers, _ = test_util.create_local_cluster(num_workers=1, num_ps=0) remote.connect_to_remote_host(workers[0].target)
def setUp(self): context.context().soft_device_placement = True context.context().log_device_placement = True workers, _ = test_util.create_local_cluster(2, 0) remote.connect_to_remote_host([workers[0].target, workers[1].target])