Exemplo n.º 1
0
    def test_query_model(self, num_clients=2, address="localhost:23457"):
        def run_client():
            client = postman.Client(address)
            client.connect(10)
            model = client.query_state_dict()
            self.assertEqual(model["fc.weight"].size, 100)
            self.assertEqual(model["fc.bias"].size, 10)

        model = Model()
        model_queue = buffer.ModelQueue(model)

        def query_state_dict():
            # print(agent.state_dict())
            model_id, cur_model = model_queue.get_model()
            model_nest = nest.map(lambda t: t.unsqueeze(0), cur_model.state_dict())
            model_queue.release_model(model_id)
            return model_nest

        server = postman.Server(address)
        server.bind("query_state_dict", query_state_dict, batch_size=1)
        server.run()

        client_processes = [mp.Process(target=run_client) for _ in range(num_clients)]

        for p in client_processes:
            p.start()

        for p in client_processes:
            p.join()

        server.stop()
Exemplo n.º 2
0
    def test_set_batch_size(self):
        address = "127.0.0.1"

        init_batch_size = 3
        final_batch_size = 2

        def run_client(port):
            client = postman.Client("%s:%i" % (address, port))
            client.connect(10)
            client.foo(torch.Tensor(init_batch_size, 2, 2))
            client.foo(torch.Tensor(final_batch_size, 2, 2))

        try:
            server = postman.Server("%s:0" % address)
            q = postman.ComputationQueue(batch_size=init_batch_size)
            server.bind_queue_batched("foo", q)
            server.run()

            client_proc = mp.Process(target=run_client, args=(server.port(),))
            client_proc.start()

            with q.get(wait_till_full=True) as batch:
                batch.set_outputs(batch.get_inputs()[0])

            q.set_batch_size(final_batch_size)

            with q.get(wait_till_full=True) as batch:
                batch.set_outputs(batch.get_inputs()[0])
        finally:
            q.close()
            server.stop()
            client_proc.join()
Exemplo n.º 3
0
    def test_abba(self, address="127.0.0.1:12346"):
        event = threading.Event()

        def a(x):
            event.wait()
            return x + 1

        def b(x):
            return x + 2

        server = postman.Server(address)
        server.bind("a", a, batch_size=1)
        server.bind("b", b, batch_size=1)
        server.run()

        try:
            client = postman.AsyncClient(address)
            streams = client.connect(10)

            a_future = streams.a(torch.zeros(()))
            b_future = streams.b(torch.zeros(()))

            event.set()

            a_result = a_future.get()
            b_result = b_future.get()

            np.testing.assert_array_equal(a_result, torch.full((), 1))
            np.testing.assert_array_equal(b_result, torch.full((), 2))
        finally:
            streams.close()
            server.stop()
Exemplo n.º 4
0
    def test_rpc_jit(self, num_clients=2, address="127.0.0.1:12346"):
        def run_client(client_id):
            client = postman.Client(address)
            client.connect(10)
            arg = np.full((1, 2), client_id, dtype=np.float32)
            batched_arg = np.full((2,), client_id, dtype=np.float32)

            function_result = client.function(arg)
            batched_function_result = client.batched_function(batched_arg)

            np.testing.assert_array_equal(function_result, np.full((1, 2), client_id))
            np.testing.assert_array_equal(batched_function_result, np.full((2,), client_id))

        clients = [mp.Process(target=run_client, args=(i,)) for i in range(num_clients)]

        linear = torch.nn.Linear(2, 2, bias=False)
        linear.weight.data = torch.diagflat(torch.ones(2))
        module = torch.jit.script(linear)
        server = postman.Server("127.0.0.1:12346")

        server.bind("function", module)
        server.bind("batched_function", module, batch_size=num_clients)

        server.run()

        for p in clients:
            p.start()

        for p in clients:
            p.join()

        server.stop()
Exemplo n.º 5
0
 def test_bind_unix_domain_socket(self):
     server = postman.Server("unix:/tmp/test.sock")
     server.run()
     try:
         self.assertNotEqual(server.port(), 0)
     finally:
         server.stop()
Exemplo n.º 6
0
    def test_none_return(self):
        def get_nothing():
            # TODO(heiner): Add check on return shape.
            return torch.arange(2).reshape(1, 2)

        def return_nothing(t):
            return None

        def nothing():
            return

        server = postman.Server("127.0.0.1:0")
        server.bind("get_nothing", get_nothing, batch_size=1)
        server.bind("return_nothing", return_nothing, batch_size=1)
        server.bind("nothing", nothing, batch_size=1)
        server.run()

        client = postman.Client("127.0.0.1:%i" % server.port())
        client.connect(10)
        try:
            value = client.get_nothing()
            np.testing.assert_array_equal(value, np.arange(2))
            value = client.return_nothing(torch.tensor(10))

            # For now, "None" responses are empty tuples.
            self.assertEqual(value, ())
            self.assertEqual(client.nothing(), ())

        finally:
            server.stop()
Exemplo n.º 7
0
def main():
    server = postman.Server("%s:%d" % ("localhost", 12345))
    model = Model()
    replay_buffer = buffer.NestPrioritizedReplay(1000, 0, 0.6, 0.4, True)

    model_queue = buffer.ModelQueue(model)

    def add_replay(content, priority):
        print(content)
        print(priority)
        replay_buffer.add_batch_async(content, priority[0])

    def query_state_dict():
        # print(agent.state_dict())
        model_id, cur_model = model_queue.get_model()
        model_nest = nest.map(lambda t: t.unsqueeze(0), cur_model.state_dict())
        model_queue.release_model(model_id)
        return model_nest

    server.bind("query_state_dict", query_state_dict, batch_size=1)
    server.bind("add_replay", add_replay, batch_size=1)
    server.run()

    try:
        while True:
            time.sleep(1)
            print("current replay buffer size is %d" % replay_buffer.size())
    except KeyboardInterrupt:
        server.stop()
        server.wait()
Exemplo n.º 8
0
 def test_bind_port_zero(self):
     server = postman.Server("127.0.0.1:0")
     server.run()
     try:
         # ephemeral port should be assigned
         self.assertNotEqual(server.port(), 0)
     finally:
         server.stop()
Exemplo n.º 9
0
    def test_rpc_python(self, num_clients=2, address="127.0.0.1:12346"):
        def run_client():
            client = postman.Client(address)
            client.connect(10)
            client.py_function(
                torch.zeros((1, 2)), torch.arange(10), (torch.empty(2, 3), torch.ones((1, 2)))
            )
            client.batched_function(torch.zeros((1, 2)))

        client_processes = [mp.Process(target=run_client) for _ in range(num_clients)]

        calls = collections.defaultdict(int)

        def py_function(a, b, c):
            calls["py_function"] += 1
            np.testing.assert_array_equal(a.numpy(), np.zeros((1, 1, 2)))
            np.testing.assert_array_equal(b.numpy(), np.arange(10).reshape((1, 10)))

            c0, c1 = c
            self.assertSequenceEqual(list(c0.shape), (1, 2, 3))
            np.testing.assert_array_equal(c1.numpy(), np.ones((1, 1, 2)))

            return torch.ones(1, 1)

        def batched_function(a):
            calls["batched_function"] += 1
            self.assertEqual(a.shape[0], 2)
            return torch.ones(a.shape)

        server = postman.Server(address)
        server.bind("py_function", py_function, batch_size=1)
        server.bind(
            "batched_function", batched_function, batch_size=num_clients, wait_till_full=True
        )
        server.run()

        for p in client_processes:
            p.start()

        for p in client_processes:
            p.join()

        server.stop()

        self.assertEqual(calls["py_function"], num_clients)
        self.assertEqual(calls["batched_function"], 1)
Exemplo n.º 10
0
def main():
    server = postman.Server("localhost:12345")

    server.bind("pyfunc", pyfunc, batch_size=1)
    server.bind("identity", identity, batch_size=1)
    server.bind("batched_identity",
                identity,
                batch_size=2,
                wait_till_full=True)

    server.run()

    try:
        while True:
            time.sleep(1)
    except KeyboardInterrupt:
        server.stop()
        server.wait()
def run_server(port, batch_size, port_q=None, **kwargs):
    def set_seed(seed):
        seed = seed.item()
        logging.info(f"Set server seed to {seed}")
        torch.manual_seed(seed)
        np.random.seed(seed)

    logging.basicConfig(format="%(asctime)s [%(levelname)s]: %(message)s",
                        level=logging.INFO)
    max_port = port + 10
    try:
        logging.info(f"Starting server port={port} batch={batch_size}")
        eval_queue = postman.ComputationQueue(batch_size)
        for p in range(port, max_port):
            server = postman.Server(f"127.0.0.1:{p}")
            server.bind("set_batch_size",
                        lambda x: eval_queue.set_batch_size(x.item()),
                        batch_size=1)
            server.bind("set_seed", set_seed, batch_size=1)
            server.bind_queue_batched("evaluate", eval_queue)
            try:
                server.run()
                break  # port is good
            except RuntimeError:
                continue  # try a different port
        else:
            raise RuntimeError(
                f"Couldn't start server on ports {port}:{max_port}")

        bound_port = server.port()
        assert bound_port != 0

        logging.info(f"Started server on port={bound_port} pid={os.getpid()}")
        if port_q is not None:
            port_q.put(bound_port)  # send port to parent proc

        server_handler(eval_queue, **kwargs)  # FIXME: try multiple threads?
    except Exception as e:
        logging.exception("Caught exception in the server (%s)", e)
        raise
    finally:
        eval_queue.close()
        server.stop()
Exemplo n.º 12
0
def main():
    # TODO: Re-add TorchScript modules. Example code:
    # https://github.com/fairinternal/torchbeast/blob/4e34d2b6493ea2f2d364e8cd7c5eb9596b9dcb6d/torchbeast/server.cc#L185

    # module = torch.jit.script(torch.nn.Linear(2, 3))

    server = postman.Server("localhost:12345")

    # s.bind("mymodule", module)
    server.bind("pyfunc", pyfunc, batch_size=1)
    server.bind("identity", identity, batch_size=1)
    server.bind("batched_identity",
                identity,
                batch_size=2,
                wait_till_full=True)

    # server.bind("batched_myfunc", module, batch_size=2)

    # Alternative: Binding "ComputationQueue"s instead of functions directly:
    queue = postman.ComputationQueue(batch_size=2)
    server.bind_queue("batched_identity2", queue)

    def read_queue():
        try:
            while True:
                with queue.get(wait_till_full=False) as batch:
                    batch.set_outputs(identity(*batch.get_inputs()))
        except StopIteration:
            return

    thread = threading.Thread(target=read_queue)
    thread.start()

    server.run()

    try:
        while True:
            time.sleep(1)  # Could also deal with signals. I guess.
    except KeyboardInterrupt:
        queue.close()
        server.stop()
        server.wait()
        thread.join()
Exemplo n.º 13
0
    def test_simple(self, address="127.0.0.1:12346"):
        def function(a, b, c):
            return a + 1, b + 2, c + 3

        server = postman.Server(address)
        server.bind("function", function, batch_size=1)
        server.run()

        try:
            client = postman.AsyncClient(address)
            streams = client.connect(10)

            inputs = (torch.zeros(1), torch.ones(2), torch.arange(10))

            future = streams.function(*inputs)

            result = future.get()
            for x, y in zip(result, function(*inputs)):
                np.testing.assert_array_equal(x, y)

        finally:
            streams.close()
            server.stop()
Exemplo n.º 14
0
    def test_add_replay(self, num_clients=2, address="localhost:23456"):
        def run_client():
            client = postman.Client(address)
            client.connect(10)
            local_replay_buffer = buffer.NestPrioritizedReplay(1000, 0, 0.6, 0.4, True)

            data = {}
            data["a"] = torch.Tensor(10)

            # testing, this could be a long-running c++ replay buffer adding
            local_replay_buffer.add_one(data, 1)
            local_replay_buffer.add_one(data, 2)

            size, batch, priority = local_replay_buffer.get_new_content()
            client.add_replay(batch, priority)

        client_processes = [mp.Process(target=run_client) for _ in range(num_clients)]

        replay_buffer = buffer.NestPrioritizedReplay(1000, 0, 0.6, 0.4, True)

        def add_replay(content, priority):
            replay_buffer.add_batch_async(content, priority[0])

        server = postman.Server(address)
        server.bind("add_replay", add_replay, batch_size=1)
        server.run()

        for p in client_processes:
            p.start()

        for p in client_processes:
            p.join()

        server.stop()

        self.assertEqual(replay_buffer.size(), 2 * num_clients)