示例#1
0
 def setUp(self):
   super(ClusterPlacementTest, self).setUp()
   context._reset_context()
   config.set_soft_device_placement(enabled=True)
   context.context().log_device_placement = True
   workers, _ = test_util.create_local_cluster(2, 0)
   remote.connect_to_remote_host([workers[0].target, workers[1].target])
示例#2
0
    def testConnectToRemoteServer(self):
        """Basic server connection."""
        remote.connect_to_remote_host(self._cached_server1_target)

        with ops.device("job:worker/replica:0/task:1/device:CPU:0"):
            x1 = array_ops.ones([2, 2])
            x2 = array_ops.ones([2, 2])
            y = math_ops.matmul(x1, x2)
        np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
示例#3
0
  def testConnectToRemoteServer(self):
    """Basic server connection."""
    remote.connect_to_remote_host(self._cached_server1_target)

    with ops.device("job:worker/replica:0/task:1/device:CPU:0"):
      x1 = array_ops.ones([2, 2])
      x2 = array_ops.ones([2, 2])
      y = math_ops.matmul(x1, x2)
    np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
示例#4
0
 def testTPUInitializationMultiHost(self):
     ctx = context.context()
     if not ctx.list_physical_devices('TPU'):
         self.assertEmpty(ctx.tpu_topologies_by_job)
         self.skipTest('A TPU is required to run this test.')
     self.assertEqual(['localhost'], list(ctx.tpu_topologies_by_job.keys()))
     server = server_lib.Server.create_local_server()
     target = server.target[len('grpc://'):]
     remote.connect_to_remote_host([target])
     self.assertIn('localhost', ctx.tpu_topologies_by_job)
     self.assertIn('worker', ctx.tpu_topologies_by_job)
     self.assertLen(ctx.tpu_topologies, 2)
示例#5
0
    def _DISABLED_benchmark_send_mirroring_on(self):
        remote.connect_to_remote_host(self._cached_server_target1)

        x = random_ops.random_uniform((2, 2)).cpu()

        @def_function.function
        def remote_func(m):
            return math_ops.matmul(m, m)

        def func(m):
            with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
                return remote_func(m)

        context.context().mirroring_policy = context.MIRRORING_ALL
        self._run(lambda: func(x))
示例#6
0
    def _DISABLED_benchmark_worker_mirroring_on(self):
        remote.connect_to_remote_host(
            [self._cached_server_target1, self._cached_server_target2])

        with ops.device("job:worker/replica:0/task:1/device:CPU:0"):
            v = variables.Variable(1.0)

        @def_function.function
        def remote_func():
            return 1.0 + v

        def func():
            with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
                return remote_func()

        context.context().mirroring_policy = context.MIRRORING_ALL
        self._run(func)
示例#7
0
    def testOperationTimeout(self):
        context._reset_context()
        context.context().operation_timeout_in_ms = 10
        workers, _ = test_util.create_local_cluster(1, 0)
        remote.connect_to_remote_host(workers[0].target)

        q = data_flow_ops.FIFOQueue(1, dtypes.int32)

        @def_function.function
        def f():
            return q.dequeue()

        with self.assertRaises(errors.DeadlineExceededError):
            with ops.device('/job:worker/replica:0/task:0'):
                f()
            # If streaming RPC is enabled, fetch remote errors before end of execution
            context.async_wait()
    def benchmark_send(self):
        remote.connect_to_remote_host(self._cached_server_target1)

        x = random_ops.random_uniform((2, 2)).cpu()

        @def_function.function
        def remote_func(m):
            return math_ops.matmul(m, m)

        def func(m):
            with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
                return remote_func(m)

        self._run(lambda: func(x))
        # NOTE(b/136184459): Force garbage collecting hanging resources before
        # subsequent calls to set_server_def, to ensure the destroy resource ops are
        # executed when their corresponding device and manager are still available.
        gc.collect()
    def benchmark_create_vars_inside_function(self):
        remote.connect_to_remote_host(self._cached_server_target1)

        def func():
            with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
                layer = Foo(50)

                @def_function.function
                def remote_func():
                    with ops.device(
                            "job:worker/replica:0/task:0/device:CPU:0"):
                        return layer(random_ops.random_uniform([]))

                return remote_func()

        self._run(func, execution_mode=context.ASYNC, num_iters=100)
        # NOTE(b/136184459): Force garbage collecting hanging resources before
        # subsequent calls to set_server_def, to ensure the destroy resource ops are
        # executed when their corresponding device and manager are still available.
        gc.collect()
    def benchmark_worker_recv(self):
        remote.connect_to_remote_host(
            [self._cached_server_target1, self._cached_server_target2])

        with ops.device("job:worker/replica:0/task:1/device:CPU:0"):
            v = variables.Variable(1.0)

        @def_function.function
        def remote_func():
            return 1.0 + v

        def func():
            with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
                return remote_func()

        self._run(func)
        # NOTE(b/136184459): Force garbage collecting hanging resources before
        # subsequent calls to set_server_def, to ensure the destroy resource ops are
        # executed when their corresponding device and manager are still available.
        gc.collect()
示例#11
0
  def testMemoryLeakInLocalCopy(self):
    if memory_profiler is None:
      self.skipTest("memory_profiler required to run this test")

    remote.connect_to_remote_host(self._cached_server_target)

    # Run a function locally with the input on a remote worker and ensure we
    # do not leak a reference to the remote tensor.

    @def_function.function
    def local_func(i):
      return i

    def func():
      with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
        x = array_ops.zeros([1000, 1000], dtypes.int32)

      local_func(x)

    assert_no_leak(func, num_iters=100, increase_threshold_absolute_mb=50)
示例#12
0
    def setUp(self):
        super(SingleWorkerTest, self).setUp()

        workers, _ = test_util.create_local_cluster(1, 0)
        remote.connect_to_remote_host(workers[0].target)
示例#13
0
    def setUp(self):
        super(MultiWorkersTest, self).setUp()

        workers, _ = test_util.create_local_cluster(3, 0)
        remote.connect_to_remote_host(
            [workers[0].target, workers[1].target, workers[2].target])
示例#14
0
 def setUp(self):
     super(SingleWorkerTestBaseEager, self).setUp()
     workers, _ = test_util.create_local_cluster(num_workers=1, num_ps=0)
     remote.connect_to_remote_host(workers[0].target)
 def setUp(self):
     context.context().soft_device_placement = True
     context.context().log_device_placement = True
     workers, _ = test_util.create_local_cluster(2, 0)
     remote.connect_to_remote_host([workers[0].target, workers[1].target])