def test_query_model(self, num_clients=2, address="localhost:23457"): def run_client(): client = postman.Client(address) client.connect(10) model = client.query_state_dict() self.assertEqual(model["fc.weight"].size, 100) self.assertEqual(model["fc.bias"].size, 10) model = Model() model_queue = buffer.ModelQueue(model) def query_state_dict(): # print(agent.state_dict()) model_id, cur_model = model_queue.get_model() model_nest = nest.map(lambda t: t.unsqueeze(0), cur_model.state_dict()) model_queue.release_model(model_id) return model_nest server = postman.Server(address) server.bind("query_state_dict", query_state_dict, batch_size=1) server.run() client_processes = [mp.Process(target=run_client) for _ in range(num_clients)] for p in client_processes: p.start() for p in client_processes: p.join() server.stop()
def test_set_batch_size(self): address = "127.0.0.1" init_batch_size = 3 final_batch_size = 2 def run_client(port): client = postman.Client("%s:%i" % (address, port)) client.connect(10) client.foo(torch.Tensor(init_batch_size, 2, 2)) client.foo(torch.Tensor(final_batch_size, 2, 2)) try: server = postman.Server("%s:0" % address) q = postman.ComputationQueue(batch_size=init_batch_size) server.bind_queue_batched("foo", q) server.run() client_proc = mp.Process(target=run_client, args=(server.port(),)) client_proc.start() with q.get(wait_till_full=True) as batch: batch.set_outputs(batch.get_inputs()[0]) q.set_batch_size(final_batch_size) with q.get(wait_till_full=True) as batch: batch.set_outputs(batch.get_inputs()[0]) finally: q.close() server.stop() client_proc.join()
def test_abba(self, address="127.0.0.1:12346"): event = threading.Event() def a(x): event.wait() return x + 1 def b(x): return x + 2 server = postman.Server(address) server.bind("a", a, batch_size=1) server.bind("b", b, batch_size=1) server.run() try: client = postman.AsyncClient(address) streams = client.connect(10) a_future = streams.a(torch.zeros(())) b_future = streams.b(torch.zeros(())) event.set() a_result = a_future.get() b_result = b_future.get() np.testing.assert_array_equal(a_result, torch.full((), 1)) np.testing.assert_array_equal(b_result, torch.full((), 2)) finally: streams.close() server.stop()
def test_rpc_jit(self, num_clients=2, address="127.0.0.1:12346"): def run_client(client_id): client = postman.Client(address) client.connect(10) arg = np.full((1, 2), client_id, dtype=np.float32) batched_arg = np.full((2,), client_id, dtype=np.float32) function_result = client.function(arg) batched_function_result = client.batched_function(batched_arg) np.testing.assert_array_equal(function_result, np.full((1, 2), client_id)) np.testing.assert_array_equal(batched_function_result, np.full((2,), client_id)) clients = [mp.Process(target=run_client, args=(i,)) for i in range(num_clients)] linear = torch.nn.Linear(2, 2, bias=False) linear.weight.data = torch.diagflat(torch.ones(2)) module = torch.jit.script(linear) server = postman.Server("127.0.0.1:12346") server.bind("function", module) server.bind("batched_function", module, batch_size=num_clients) server.run() for p in clients: p.start() for p in clients: p.join() server.stop()
def test_bind_unix_domain_socket(self): server = postman.Server("unix:/tmp/test.sock") server.run() try: self.assertNotEqual(server.port(), 0) finally: server.stop()
def test_none_return(self): def get_nothing(): # TODO(heiner): Add check on return shape. return torch.arange(2).reshape(1, 2) def return_nothing(t): return None def nothing(): return server = postman.Server("127.0.0.1:0") server.bind("get_nothing", get_nothing, batch_size=1) server.bind("return_nothing", return_nothing, batch_size=1) server.bind("nothing", nothing, batch_size=1) server.run() client = postman.Client("127.0.0.1:%i" % server.port()) client.connect(10) try: value = client.get_nothing() np.testing.assert_array_equal(value, np.arange(2)) value = client.return_nothing(torch.tensor(10)) # For now, "None" responses are empty tuples. self.assertEqual(value, ()) self.assertEqual(client.nothing(), ()) finally: server.stop()
def main(): server = postman.Server("%s:%d" % ("localhost", 12345)) model = Model() replay_buffer = buffer.NestPrioritizedReplay(1000, 0, 0.6, 0.4, True) model_queue = buffer.ModelQueue(model) def add_replay(content, priority): print(content) print(priority) replay_buffer.add_batch_async(content, priority[0]) def query_state_dict(): # print(agent.state_dict()) model_id, cur_model = model_queue.get_model() model_nest = nest.map(lambda t: t.unsqueeze(0), cur_model.state_dict()) model_queue.release_model(model_id) return model_nest server.bind("query_state_dict", query_state_dict, batch_size=1) server.bind("add_replay", add_replay, batch_size=1) server.run() try: while True: time.sleep(1) print("current replay buffer size is %d" % replay_buffer.size()) except KeyboardInterrupt: server.stop() server.wait()
def test_bind_port_zero(self): server = postman.Server("127.0.0.1:0") server.run() try: # ephemeral port should be assigned self.assertNotEqual(server.port(), 0) finally: server.stop()
def test_rpc_python(self, num_clients=2, address="127.0.0.1:12346"): def run_client(): client = postman.Client(address) client.connect(10) client.py_function( torch.zeros((1, 2)), torch.arange(10), (torch.empty(2, 3), torch.ones((1, 2))) ) client.batched_function(torch.zeros((1, 2))) client_processes = [mp.Process(target=run_client) for _ in range(num_clients)] calls = collections.defaultdict(int) def py_function(a, b, c): calls["py_function"] += 1 np.testing.assert_array_equal(a.numpy(), np.zeros((1, 1, 2))) np.testing.assert_array_equal(b.numpy(), np.arange(10).reshape((1, 10))) c0, c1 = c self.assertSequenceEqual(list(c0.shape), (1, 2, 3)) np.testing.assert_array_equal(c1.numpy(), np.ones((1, 1, 2))) return torch.ones(1, 1) def batched_function(a): calls["batched_function"] += 1 self.assertEqual(a.shape[0], 2) return torch.ones(a.shape) server = postman.Server(address) server.bind("py_function", py_function, batch_size=1) server.bind( "batched_function", batched_function, batch_size=num_clients, wait_till_full=True ) server.run() for p in client_processes: p.start() for p in client_processes: p.join() server.stop() self.assertEqual(calls["py_function"], num_clients) self.assertEqual(calls["batched_function"], 1)
def main(): server = postman.Server("localhost:12345") server.bind("pyfunc", pyfunc, batch_size=1) server.bind("identity", identity, batch_size=1) server.bind("batched_identity", identity, batch_size=2, wait_till_full=True) server.run() try: while True: time.sleep(1) except KeyboardInterrupt: server.stop() server.wait()
def run_server(port, batch_size, port_q=None, **kwargs): def set_seed(seed): seed = seed.item() logging.info(f"Set server seed to {seed}") torch.manual_seed(seed) np.random.seed(seed) logging.basicConfig(format="%(asctime)s [%(levelname)s]: %(message)s", level=logging.INFO) max_port = port + 10 try: logging.info(f"Starting server port={port} batch={batch_size}") eval_queue = postman.ComputationQueue(batch_size) for p in range(port, max_port): server = postman.Server(f"127.0.0.1:{p}") server.bind("set_batch_size", lambda x: eval_queue.set_batch_size(x.item()), batch_size=1) server.bind("set_seed", set_seed, batch_size=1) server.bind_queue_batched("evaluate", eval_queue) try: server.run() break # port is good except RuntimeError: continue # try a different port else: raise RuntimeError( f"Couldn't start server on ports {port}:{max_port}") bound_port = server.port() assert bound_port != 0 logging.info(f"Started server on port={bound_port} pid={os.getpid()}") if port_q is not None: port_q.put(bound_port) # send port to parent proc server_handler(eval_queue, **kwargs) # FIXME: try multiple threads? except Exception as e: logging.exception("Caught exception in the server (%s)", e) raise finally: eval_queue.close() server.stop()
def main(): # TODO: Re-add TorchScript modules. Example code: # https://github.com/fairinternal/torchbeast/blob/4e34d2b6493ea2f2d364e8cd7c5eb9596b9dcb6d/torchbeast/server.cc#L185 # module = torch.jit.script(torch.nn.Linear(2, 3)) server = postman.Server("localhost:12345") # s.bind("mymodule", module) server.bind("pyfunc", pyfunc, batch_size=1) server.bind("identity", identity, batch_size=1) server.bind("batched_identity", identity, batch_size=2, wait_till_full=True) # server.bind("batched_myfunc", module, batch_size=2) # Alternative: Binding "ComputationQueue"s instead of functions directly: queue = postman.ComputationQueue(batch_size=2) server.bind_queue("batched_identity2", queue) def read_queue(): try: while True: with queue.get(wait_till_full=False) as batch: batch.set_outputs(identity(*batch.get_inputs())) except StopIteration: return thread = threading.Thread(target=read_queue) thread.start() server.run() try: while True: time.sleep(1) # Could also deal with signals. I guess. except KeyboardInterrupt: queue.close() server.stop() server.wait() thread.join()
def test_simple(self, address="127.0.0.1:12346"): def function(a, b, c): return a + 1, b + 2, c + 3 server = postman.Server(address) server.bind("function", function, batch_size=1) server.run() try: client = postman.AsyncClient(address) streams = client.connect(10) inputs = (torch.zeros(1), torch.ones(2), torch.arange(10)) future = streams.function(*inputs) result = future.get() for x, y in zip(result, function(*inputs)): np.testing.assert_array_equal(x, y) finally: streams.close() server.stop()
def test_add_replay(self, num_clients=2, address="localhost:23456"): def run_client(): client = postman.Client(address) client.connect(10) local_replay_buffer = buffer.NestPrioritizedReplay(1000, 0, 0.6, 0.4, True) data = {} data["a"] = torch.Tensor(10) # testing, this could be a long-running c++ replay buffer adding local_replay_buffer.add_one(data, 1) local_replay_buffer.add_one(data, 2) size, batch, priority = local_replay_buffer.get_new_content() client.add_replay(batch, priority) client_processes = [mp.Process(target=run_client) for _ in range(num_clients)] replay_buffer = buffer.NestPrioritizedReplay(1000, 0, 0.6, 0.4, True) def add_replay(content, priority): replay_buffer.add_batch_async(content, priority[0]) server = postman.Server(address) server.bind("add_replay", add_replay, batch_size=1) server.run() for p in client_processes: p.start() for p in client_processes: p.join() server.stop() self.assertEqual(replay_buffer.size(), 2 * num_clients)