def test_input_specs(self): @eager_def_function.function(input_signature=[{ "a": tensor_spec.TensorSpec([], dtypes.int32), "b": tensor_spec.TensorSpec([], dtypes.int32) }]) def test_input_dict(value): return math_ops.add(value["a"], value["b"]) port = portpicker.pick_unused_port() address = "localhost:{}".format(port) server_resource = rpc_ops.Server(address) server_resource.register("test_input_dict", test_input_dict) server_resource.start() client = rpc_ops.Client( address=address, name="test_client", list_registered_methods=True) a = variables.Variable(2, dtype=dtypes.int32) b = variables.Variable(3, dtype=dtypes.int32) result_or = client.test_input_dict({"a": a, "b": b}) self.assertAllEqual(result_or.is_ok(), True) self.assertAllEqual(result_or.get_value(), 5) with self.assertRaises(TypeError): client.test_input_dict([a, b])
def test_captured_inputs(self): v = variables.Variable(initial_value=0, dtype=dtypes.int64) @eager_def_function.function( input_signature=[tensor_spec.TensorSpec([], dtypes.int64)]) def assign_add(a): v.assign_add(a) @eager_def_function.function(input_signature=[]) def read_var(): return v.value() port = portpicker.pick_unused_port() address = "localhost:{}".format(port) server = rpc_ops.Server(address) server.register("assign_add", assign_add) server.register("read_var", read_var) server.start() client = rpc_ops.Client(address) result_or = client.call("assign_add", [variables.Variable(2, dtype=dtypes.int64)]) self.assertAllEqual(result_or.is_ok(), True) result_or = client.call("assign_add", [variables.Variable(2, dtype=dtypes.int64)]) self.assertAllEqual(result_or.is_ok(), True) result_or = client.call( "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)]) self.assertAllEqual(result_or.is_ok(), True) self.assertAllEqual(result_or.get_value(), [4])
def test_rpc_ops_wrapper(self): @eager_def_function.function(input_signature=[ tensor_spec.TensorSpec([], dtypes.int32), tensor_spec.TensorSpec([], dtypes.int32) ]) def _remote_fn(a, b): return math_ops.multiply(a, b) port = portpicker.pick_unused_port() address = "localhost:{}".format(port) server_resource = rpc_ops.Server(address) server_resource.register("multiply", _remote_fn) server_resource.start() client = rpc_ops.Client(address=address, name="test_client") a = variables.Variable(2, dtype=dtypes.int32) b = variables.Variable(3, dtype=dtypes.int32) mul_or = client.call( args=[a, b], method_name="multiply", output_specs=tensor_spec.TensorSpec((), dtypes.int32)) self.assertAllEqual(mul_or.is_ok(), True) self.assertAllEqual(mul_or.get_value(), 6) # Test empty client name client1 = rpc_ops.Client(address, list_registered_methods=True) mul_or = client1.call( args=[a, b], method_name="multiply", output_specs=tensor_spec.TensorSpec((), dtypes.int32)) self.assertAllEqual(mul_or.is_ok(), True) self.assertAllEqual(mul_or.get_value(), 6) # Test without output_spec mul_or = client1.multiply(a, b) self.assertAllEqual(mul_or.is_ok(), True) self.assertAllEqual(mul_or.get_value(), 6) self.assertEqual(client1.multiply.__doc__, "RPC Call for multiply method to server " + address)
def test_output_specs(self): @eager_def_function.function( input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) def test_dict(val): return {"key": val} @eager_def_function.function( input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) def is_positive(a): if a > 0: return True return False @eager_def_function.function(input_signature=[]) def do_nothing(): return [] @eager_def_function.function( input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) def test_nested_structure(v): return {"test": (v, [v, v]), "test1": (v,)} port = portpicker.pick_unused_port() address = "localhost:{}".format(port) server_resource = rpc_ops.Server(address) server_resource.register("test_dict", test_dict) server_resource.register("is_positive", is_positive) server_resource.register("test_nested_structure", test_nested_structure) server_resource.register("do_nothing", do_nothing) server_resource.start() client = rpc_ops.Client( address=address, name="test_client", list_registered_methods=True) a = variables.Variable(2, dtype=dtypes.int32) result_or = client.test_dict(a) self.assertAllEqual(result_or.is_ok(), True) nest.map_structure(self.assertAllEqual, result_or.get_value(), {"key": 2}) self.assertTrue(client.is_positive(a)) result_or = client.test_nested_structure(a) self.assertAllEqual(result_or.is_ok(), True) nest.map_structure(self.assertAllEqual, result_or.get_value(), { "test": (2, [2, 2]), "test1": (2,) }) result_or = client.do_nothing() self.assertAllEqual(result_or.is_ok(), True) self.assertAllEqual(result_or.get_value(), [])
def test_queue_resource(self): elements = np.random.randint(100, size=[200]) queue = data_flow_ops.FIFOQueue(200, dtypes.int64, shapes=[]) @eager_def_function.function() def populate_queue(): queue.enqueue_many(elements) queue.close() port = portpicker.pick_unused_port() address = "localhost:{}".format(port) server = rpc_ops.Server(address) server.register("populate_queue", populate_queue) server.start() client = rpc_ops.Client(address, list_registered_methods=True) client.populate_queue() for e in elements: self.assertAllEqual(e, queue.dequeue())
def test_rpc_error(self): v = variables.Variable(initial_value=0, dtype=dtypes.int64) @eager_def_function.function( input_signature=[tensor_spec.TensorSpec([], dtypes.int64)]) def assign_add(a): v.assign_add(a) @eager_def_function.function(input_signature=[]) def read_var(): return v.value() port = portpicker.pick_unused_port() address = "localhost:{}".format(port) server = rpc_ops.Server(address) server.register("assign_add", assign_add) server.register("read_var", read_var) server.start() client = rpc_ops.Client(address, list_registered_methods=True) # confirm it works as expected when arguments are passed. result_or = client.call("assign_add", [variables.Variable(2, dtype=dtypes.int64)]) self.assertAllEqual(result_or.is_ok(), True) result_or = client.call( "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)]) self.assertAllEqual(result_or.is_ok(), True) self.assertAllEqual(result_or.get_value(), [2]) result_or = client.assign_add(variables.Variable(2, dtype=dtypes.int64)) self.assertAllEqual(True, result_or.is_ok()) result_or = client.read_var() self.assertAllEqual(True, result_or.is_ok()) self.assertAllEqual(result_or.get_value(), 4) # Fails with invalid argument error when no arguments are passed. result_or = client.call("assign_add") self.assertAllEqual(result_or.is_ok(), False) error_code, _ = result_or.get_error() self.assertAllEqual(error_code, errors.INVALID_ARGUMENT)
def test_multi_device_resource_cpu(self): with ops.device("/device:cpu:1"): v = variables.Variable(initial_value=0, dtype=dtypes.int64) @eager_def_function.function( input_signature=[tensor_spec.TensorSpec([], dtypes.int64)]) def assign_add(a): v.assign_add(a) with ops.device("/device:CPU:0"): port = portpicker.pick_unused_port() address = "localhost:{}".format(port) server = rpc_ops.Server(address) server.register("assign_add", assign_add) server.start() client = rpc_ops.Client(address, list_registered_methods=True) result_or = client.assign_add(variables.Variable(2, dtype=dtypes.int64)) self.assertAllEqual(result_or.is_ok(), True) self.assertAllEqual(v, 2)
def test_rpc_call_op_in_tf_function(self): @eager_def_function.function(input_signature=[ tensor_spec.TensorSpec([], dtypes.int32), tensor_spec.TensorSpec([], dtypes.int32) ]) def _remote_fn(a, b): return math_ops.multiply(a, b) port = portpicker.pick_unused_port() address = "localhost:{}".format(port) server_resource = rpc_ops.Server(address) server_resource.register("remote_fn", _remote_fn) server_resource.start() client = rpc_ops.Client(address=address, name="test_client") a = variables.Variable(2, dtype=dtypes.int32) b = variables.Variable(3, dtype=dtypes.int32) @eager_def_function.function def call_fn(): result_or = client.call( args=[a, b], method_name="remote_fn", output_specs=[tensor_spec.TensorSpec([], dtypes.int32)]) self.assertAllEqual(True, result_or.is_ok()) result = result_or.get_value() self.assertEqual(len(result), 1) # Call returns a list(tensors) # TODO(ishark): Shape for output tensor is unknown currently. # Add attribute for capturing TensorSpec for output and enable # check below: # self.assertIsNotNone(result[0].shape.rank) return result self.assertAllEqual(call_fn(), [6])
def test_resource_deletion(self): port = portpicker.pick_unused_port() address = "localhost:{}".format(port) server = rpc_ops.Server(address) server_handle = server._server_handle # Test Future resource deletion v = variables.Variable(initial_value=0, dtype=dtypes.int64) @eager_def_function.function(input_signature=[]) def read_var(): return v.value() server.register("read_var", read_var) server.start() client = rpc_ops.Client(address) client_handle = client._client_handle # Check future resource deletion without calling get_value. def _create_and_delete_rpc_future(): handle = client.call( "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)]) return handle._status_or @eager_def_function.function def _create_and_delete_rpc_future_fn(): handle = client.call( "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)]) return handle._status_or for _ in range(2): handle = _create_and_delete_rpc_future() with self.assertRaises(errors.NotFoundError): resource_variable_ops.destroy_resource_op( handle, ignore_lookup_error=False) for _ in range(2): handle = _create_and_delete_rpc_future_fn() with self.assertRaises(errors.NotFoundError): resource_variable_ops.destroy_resource_op( handle, ignore_lookup_error=False) # Check future resource deletion with calling get_value. def _create_and_delete_with_future(): handle = client.call( "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)]) status_or_handle = handle._status_or handle.get_value() return status_or_handle # Check future resource deletion with calling get_value with tf.function. @eager_def_function.function def _create_and_delete_with_future_fn(): handle = client.call( "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)]) status_or_handle = handle._status_or handle.get_value() return status_or_handle for _ in range(2): resource_handle = _create_and_delete_with_future() with self.assertRaises(errors.NotFoundError): resource_variable_ops.destroy_resource_op( resource_handle, ignore_lookup_error=False) for _ in range(2): resource_handle = _create_and_delete_with_future_fn() with self.assertRaises(errors.NotFoundError): resource_variable_ops.destroy_resource_op( resource_handle, ignore_lookup_error=False) # Test server client resource gets deleted. del client with self.assertRaises(errors.NotFoundError): resource_variable_ops.destroy_resource_op( client_handle, ignore_lookup_error=False) # Test server server resource gets deleted. del server with self.assertRaises(errors.NotFoundError): resource_variable_ops.destroy_resource_op( server_handle, ignore_lookup_error=False)
def test_call_register_ordering(self): port = portpicker.pick_unused_port() address = "localhost:{}".format(port) # Create client succeeds before server start and registration client = rpc_ops.Client(address) # Create client with list_registered_methods fails before server is started. with self.assertRaises(errors.DeadlineExceededError): rpc_ops.Client( address, name="client1", list_registered_methods=True, timeout_in_ms=1) v = variables.Variable(initial_value=0, dtype=dtypes.int64) @eager_def_function.function( input_signature=[tensor_spec.TensorSpec([], dtypes.int64)]) def assign_add(a): v.assign_add(a) @eager_def_function.function(input_signature=[]) def read_var(): return v.value() server = rpc_ops.Server(address) def start_server(): # Delay server start to test whether client creation also waits # till server is up. time.sleep(1) server.register("assign_add", assign_add) server.start() t = threading.Thread(target=start_server) t.start() # Create same "client1" again should succeed. client1_with_listed_methods = rpc_ops.Client( address, name="client1", list_registered_methods=True) result_or = client1_with_listed_methods.assign_add( variables.Variable(2, dtype=dtypes.int64)) self.assertAllEqual(result_or.is_ok(), True) result_or = client.call("assign_add", [variables.Variable(2, dtype=dtypes.int64)]) self.assertAllEqual(result_or.is_ok(), True) # Create client with registered methods client2_with_listed_methods = rpc_ops.Client( address=address, name="client2", list_registered_methods=True) result_or = client2_with_listed_methods.assign_add( variables.Variable(2, dtype=dtypes.int64)) self.assertAllEqual(result_or.is_ok(), True) self.assertAllEqual(v, 6) # Register new method after server started. with self.assertRaisesRegex( errors.FailedPreconditionError, "All methods must be registered before starting the server"): server.register("read_var", read_var)