示例#1
0
  def test_input_specs(self):

    @eager_def_function.function(input_signature=[{
        "a": tensor_spec.TensorSpec([], dtypes.int32),
        "b": tensor_spec.TensorSpec([], dtypes.int32)
    }])
    def test_input_dict(value):
      return math_ops.add(value["a"], value["b"])

    port = portpicker.pick_unused_port()
    address = "localhost:{}".format(port)
    server_resource = rpc_ops.Server(address)

    server_resource.register("test_input_dict", test_input_dict)

    server_resource.start()

    client = rpc_ops.Client(
        address=address, name="test_client", list_registered_methods=True)
    a = variables.Variable(2, dtype=dtypes.int32)
    b = variables.Variable(3, dtype=dtypes.int32)
    result_or = client.test_input_dict({"a": a, "b": b})
    self.assertAllEqual(result_or.is_ok(), True)
    self.assertAllEqual(result_or.get_value(), 5)

    with self.assertRaises(TypeError):
      client.test_input_dict([a, b])
示例#2
0
  def test_captured_inputs(self):
    v = variables.Variable(initial_value=0, dtype=dtypes.int64)

    @eager_def_function.function(
        input_signature=[tensor_spec.TensorSpec([], dtypes.int64)])
    def assign_add(a):
      v.assign_add(a)

    @eager_def_function.function(input_signature=[])
    def read_var():
      return v.value()

    port = portpicker.pick_unused_port()
    address = "localhost:{}".format(port)
    server = rpc_ops.Server(address)
    server.register("assign_add", assign_add)
    server.register("read_var", read_var)

    server.start()

    client = rpc_ops.Client(address)

    result_or = client.call("assign_add",
                            [variables.Variable(2, dtype=dtypes.int64)])
    self.assertAllEqual(result_or.is_ok(), True)
    result_or = client.call("assign_add",
                            [variables.Variable(2, dtype=dtypes.int64)])
    self.assertAllEqual(result_or.is_ok(), True)
    result_or = client.call(
        "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)])

    self.assertAllEqual(result_or.is_ok(), True)
    self.assertAllEqual(result_or.get_value(), [4])
示例#3
0
  def test_rpc_ops_wrapper(self):

    @eager_def_function.function(input_signature=[
        tensor_spec.TensorSpec([], dtypes.int32),
        tensor_spec.TensorSpec([], dtypes.int32)
    ])
    def _remote_fn(a, b):
      return math_ops.multiply(a, b)

    port = portpicker.pick_unused_port()
    address = "localhost:{}".format(port)
    server_resource = rpc_ops.Server(address)

    server_resource.register("multiply", _remote_fn)

    server_resource.start()
    client = rpc_ops.Client(address=address, name="test_client")

    a = variables.Variable(2, dtype=dtypes.int32)
    b = variables.Variable(3, dtype=dtypes.int32)

    mul_or = client.call(
        args=[a, b],
        method_name="multiply",
        output_specs=tensor_spec.TensorSpec((), dtypes.int32))

    self.assertAllEqual(mul_or.is_ok(), True)
    self.assertAllEqual(mul_or.get_value(), 6)

    # Test empty client name
    client1 = rpc_ops.Client(address, list_registered_methods=True)
    mul_or = client1.call(
        args=[a, b],
        method_name="multiply",
        output_specs=tensor_spec.TensorSpec((), dtypes.int32))
    self.assertAllEqual(mul_or.is_ok(), True)
    self.assertAllEqual(mul_or.get_value(), 6)

    # Test without output_spec
    mul_or = client1.multiply(a, b)
    self.assertAllEqual(mul_or.is_ok(), True)
    self.assertAllEqual(mul_or.get_value(), 6)

    self.assertEqual(client1.multiply.__doc__,
                     "RPC Call for multiply method to server " + address)
示例#4
0
  def test_output_specs(self):

    @eager_def_function.function(
        input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
    def test_dict(val):
      return {"key": val}

    @eager_def_function.function(
        input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
    def is_positive(a):
      if a > 0:
        return True
      return False

    @eager_def_function.function(input_signature=[])
    def do_nothing():
      return []

    @eager_def_function.function(
        input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
    def test_nested_structure(v):
      return {"test": (v, [v, v]), "test1": (v,)}

    port = portpicker.pick_unused_port()
    address = "localhost:{}".format(port)
    server_resource = rpc_ops.Server(address)

    server_resource.register("test_dict", test_dict)
    server_resource.register("is_positive", is_positive)
    server_resource.register("test_nested_structure", test_nested_structure)
    server_resource.register("do_nothing", do_nothing)

    server_resource.start()

    client = rpc_ops.Client(
        address=address, name="test_client", list_registered_methods=True)

    a = variables.Variable(2, dtype=dtypes.int32)

    result_or = client.test_dict(a)
    self.assertAllEqual(result_or.is_ok(), True)
    nest.map_structure(self.assertAllEqual, result_or.get_value(), {"key": 2})

    self.assertTrue(client.is_positive(a))

    result_or = client.test_nested_structure(a)
    self.assertAllEqual(result_or.is_ok(), True)
    nest.map_structure(self.assertAllEqual, result_or.get_value(), {
        "test": (2, [2, 2]),
        "test1": (2,)
    })

    result_or = client.do_nothing()
    self.assertAllEqual(result_or.is_ok(), True)
    self.assertAllEqual(result_or.get_value(), [])
示例#5
0
  def test_queue_resource(self):
    elements = np.random.randint(100, size=[200])
    queue = data_flow_ops.FIFOQueue(200, dtypes.int64, shapes=[])

    @eager_def_function.function()
    def populate_queue():
      queue.enqueue_many(elements)
      queue.close()

    port = portpicker.pick_unused_port()
    address = "localhost:{}".format(port)
    server = rpc_ops.Server(address)
    server.register("populate_queue", populate_queue)
    server.start()

    client = rpc_ops.Client(address, list_registered_methods=True)
    client.populate_queue()

    for e in elements:
      self.assertAllEqual(e, queue.dequeue())
示例#6
0
  def test_rpc_error(self):
    v = variables.Variable(initial_value=0, dtype=dtypes.int64)

    @eager_def_function.function(
        input_signature=[tensor_spec.TensorSpec([], dtypes.int64)])
    def assign_add(a):
      v.assign_add(a)

    @eager_def_function.function(input_signature=[])
    def read_var():
      return v.value()

    port = portpicker.pick_unused_port()
    address = "localhost:{}".format(port)
    server = rpc_ops.Server(address)
    server.register("assign_add", assign_add)
    server.register("read_var", read_var)
    server.start()

    client = rpc_ops.Client(address, list_registered_methods=True)

    # confirm it works as expected when arguments are passed.
    result_or = client.call("assign_add",
                            [variables.Variable(2, dtype=dtypes.int64)])
    self.assertAllEqual(result_or.is_ok(), True)
    result_or = client.call(
        "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)])
    self.assertAllEqual(result_or.is_ok(), True)
    self.assertAllEqual(result_or.get_value(), [2])
    result_or = client.assign_add(variables.Variable(2, dtype=dtypes.int64))
    self.assertAllEqual(True, result_or.is_ok())

    result_or = client.read_var()
    self.assertAllEqual(True, result_or.is_ok())
    self.assertAllEqual(result_or.get_value(), 4)

    # Fails with invalid argument error when no arguments are passed.
    result_or = client.call("assign_add")
    self.assertAllEqual(result_or.is_ok(), False)
    error_code, _ = result_or.get_error()
    self.assertAllEqual(error_code, errors.INVALID_ARGUMENT)
示例#7
0
  def test_multi_device_resource_cpu(self):
    with ops.device("/device:cpu:1"):
      v = variables.Variable(initial_value=0, dtype=dtypes.int64)

    @eager_def_function.function(
        input_signature=[tensor_spec.TensorSpec([], dtypes.int64)])
    def assign_add(a):
      v.assign_add(a)

    with ops.device("/device:CPU:0"):
      port = portpicker.pick_unused_port()
      address = "localhost:{}".format(port)
      server = rpc_ops.Server(address)
      server.register("assign_add", assign_add)
      server.start()

      client = rpc_ops.Client(address, list_registered_methods=True)
      result_or = client.assign_add(variables.Variable(2, dtype=dtypes.int64))
      self.assertAllEqual(result_or.is_ok(), True)

    self.assertAllEqual(v, 2)
示例#8
0
  def test_rpc_call_op_in_tf_function(self):

    @eager_def_function.function(input_signature=[
        tensor_spec.TensorSpec([], dtypes.int32),
        tensor_spec.TensorSpec([], dtypes.int32)
    ])
    def _remote_fn(a, b):
      return math_ops.multiply(a, b)

    port = portpicker.pick_unused_port()
    address = "localhost:{}".format(port)
    server_resource = rpc_ops.Server(address)

    server_resource.register("remote_fn", _remote_fn)

    server_resource.start()
    client = rpc_ops.Client(address=address, name="test_client")

    a = variables.Variable(2, dtype=dtypes.int32)
    b = variables.Variable(3, dtype=dtypes.int32)

    @eager_def_function.function
    def call_fn():
      result_or = client.call(
          args=[a, b],
          method_name="remote_fn",
          output_specs=[tensor_spec.TensorSpec([], dtypes.int32)])

      self.assertAllEqual(True, result_or.is_ok())
      result = result_or.get_value()
      self.assertEqual(len(result), 1)  # Call returns a list(tensors)
      # TODO(ishark): Shape for output tensor is unknown currently.
      # Add attribute for capturing TensorSpec for output and enable
      # check below:
      # self.assertIsNotNone(result[0].shape.rank)
      return result

    self.assertAllEqual(call_fn(), [6])
示例#9
0
  def test_resource_deletion(self):
    port = portpicker.pick_unused_port()
    address = "localhost:{}".format(port)
    server = rpc_ops.Server(address)
    server_handle = server._server_handle

    # Test Future resource deletion
    v = variables.Variable(initial_value=0, dtype=dtypes.int64)

    @eager_def_function.function(input_signature=[])
    def read_var():
      return v.value()

    server.register("read_var", read_var)

    server.start()
    client = rpc_ops.Client(address)

    client_handle = client._client_handle

    # Check future resource deletion without calling get_value.
    def _create_and_delete_rpc_future():
      handle = client.call(
          "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)])
      return handle._status_or

    @eager_def_function.function
    def _create_and_delete_rpc_future_fn():
      handle = client.call(
          "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)])
      return handle._status_or

    for _ in range(2):
      handle = _create_and_delete_rpc_future()
      with self.assertRaises(errors.NotFoundError):
        resource_variable_ops.destroy_resource_op(
            handle, ignore_lookup_error=False)

    for _ in range(2):
      handle = _create_and_delete_rpc_future_fn()
      with self.assertRaises(errors.NotFoundError):
        resource_variable_ops.destroy_resource_op(
            handle, ignore_lookup_error=False)

    # Check future resource deletion with calling get_value.
    def _create_and_delete_with_future():
      handle = client.call(
          "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)])
      status_or_handle = handle._status_or
      handle.get_value()
      return status_or_handle

    # Check future resource deletion with calling get_value with tf.function.
    @eager_def_function.function
    def _create_and_delete_with_future_fn():
      handle = client.call(
          "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)])
      status_or_handle = handle._status_or
      handle.get_value()
      return status_or_handle

    for _ in range(2):
      resource_handle = _create_and_delete_with_future()
      with self.assertRaises(errors.NotFoundError):
        resource_variable_ops.destroy_resource_op(
            resource_handle, ignore_lookup_error=False)

    for _ in range(2):
      resource_handle = _create_and_delete_with_future_fn()
      with self.assertRaises(errors.NotFoundError):
        resource_variable_ops.destroy_resource_op(
            resource_handle, ignore_lookup_error=False)

    # Test server client resource gets deleted.
    del client
    with self.assertRaises(errors.NotFoundError):
      resource_variable_ops.destroy_resource_op(
          client_handle, ignore_lookup_error=False)

    # Test server server resource gets deleted.
    del server
    with self.assertRaises(errors.NotFoundError):
      resource_variable_ops.destroy_resource_op(
          server_handle, ignore_lookup_error=False)
示例#10
0
  def test_call_register_ordering(self):
    port = portpicker.pick_unused_port()
    address = "localhost:{}".format(port)

    # Create client succeeds before server start and registration
    client = rpc_ops.Client(address)

    # Create client with list_registered_methods fails before server is started.
    with self.assertRaises(errors.DeadlineExceededError):
      rpc_ops.Client(
          address,
          name="client1",
          list_registered_methods=True,
          timeout_in_ms=1)

    v = variables.Variable(initial_value=0, dtype=dtypes.int64)

    @eager_def_function.function(
        input_signature=[tensor_spec.TensorSpec([], dtypes.int64)])
    def assign_add(a):
      v.assign_add(a)

    @eager_def_function.function(input_signature=[])
    def read_var():
      return v.value()

    server = rpc_ops.Server(address)

    def start_server():
      # Delay server start to test whether client creation also waits
      # till server is up.
      time.sleep(1)
      server.register("assign_add", assign_add)
      server.start()

    t = threading.Thread(target=start_server)
    t.start()

    # Create same "client1" again should succeed.
    client1_with_listed_methods = rpc_ops.Client(
        address, name="client1", list_registered_methods=True)

    result_or = client1_with_listed_methods.assign_add(
        variables.Variable(2, dtype=dtypes.int64))
    self.assertAllEqual(result_or.is_ok(), True)

    result_or = client.call("assign_add",
                            [variables.Variable(2, dtype=dtypes.int64)])
    self.assertAllEqual(result_or.is_ok(), True)

    # Create client with registered methods
    client2_with_listed_methods = rpc_ops.Client(
        address=address, name="client2", list_registered_methods=True)

    result_or = client2_with_listed_methods.assign_add(
        variables.Variable(2, dtype=dtypes.int64))
    self.assertAllEqual(result_or.is_ok(), True)

    self.assertAllEqual(v, 6)

    # Register new method after server started.
    with self.assertRaisesRegex(
        errors.FailedPreconditionError,
        "All methods must be registered before starting the server"):
      server.register("read_var", read_var)