Exemple #1
0
def test_pubsub_full():
    server_proc, config = perwez.start_server()
    loop = asyncio.get_event_loop()
    trainer_task = loop.create_task(_trainer(config["url"]))
    worker_tasks = []
    for idx in range(WORKER_CNT):
        wt = loop.create_task(_worker(idx, config["url"]))
        worker_tasks.append(wt)
    loop.run_until_complete(asyncio.gather(*worker_tasks))
    trainer_task.cancel()
    server_proc.terminate()
    server_proc.join()
Exemple #2
0
def main():
    pwz_proc, _ = perwez.start_server(name="default",
                                      parent_url=f"http://{PARENT_IP}:12333")
    processes = [pwz_proc]
    for idx in range(WORKER_CNT):
        proc = mp.Process(target=_worker, args=(idx, ))
        proc.start()
        processes.append(proc)

    atexit.register(_exit, processes)

    for proc in processes:
        proc.join()
Exemple #3
0
def test_zmq_push_block():
    proc, _ = perwez.start_server("default")
    topic = "asd"
    c0 = perwez.connect("default")
    c1 = perwez.connect("default")
    c2 = perwez.connect("default")
    data = secrets.token_bytes(size)

    print("EAGAIN when no client")
    try:
        c0.push(topic, data, ipc=False, noblock=True)
    except zmq.ZMQError:
        pass
    else:
        assert False

    print("c1 join")
    w1 = c1.subscribe(topic, False)
    for i in range(DEFAULT_HWM):
        c0.push(topic, data, ipc=False)
    print(f"{DEFAULT_HWM} items pushed")
    print("hwm block: push more items")
    for i in range(DEFAULT_HWM * 3):
        try:
            c0.push(topic, data, ipc=False, noblock=True)
        except zmq.ZMQError:
            break
    print(f"{i}/{DEFAULT_HWM * 3} more items in queue")

    res = w1.get(timeout=3)
    assert res == data

    print("c2 join")
    w2 = c2.subscribe(topic, False)
    time.sleep(1)
    print("hwm block: push more items")
    for i in range(DEFAULT_HWM * 3):
        try:
            c0.push(topic, data, ipc=False, noblock=True)
        except zmq.ZMQError:
            break
    print(f"{i}/{DEFAULT_HWM * 3} more items in queue")
    res = w2.get(timeout=15)
    assert res == data
    print("finished")
    c0.close()
    c1.close()
    c2.close()
    proc.terminate()
    proc.join()
Exemple #4
0
def main():
    mp.set_start_method("spawn", force=True)
    config_path = path.join(path.dirname(__file__), "config.yaml")
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)

    # init perwez
    pwz_proc, pwz_config = perwez.start_server()
    # init reverb_server
    reverb.Server(
        tables=[
            reverb.Table(
                name=TABLE_NAME,
                sampler=reverb.selectors.Prioritized(0.6),
                remover=reverb.selectors.Fifo(),
                max_size=config["replay_buffer"]["capacity"],
                rate_limiter=reverb.rate_limiters.MinSize(1000),
            )
        ],
        port=PORT,
    )

    # worker subprocesses
    worker_processes = []
    num_workers = config["common"]["num_workers"]
    for idx in range(num_workers):
        p = mp.Process(
            name=f"apex-worker-{idx}",
            target=worker_main,
            args=(pwz_config["url"], config, idx),
            daemon=True,
        )
        p.start()
        worker_processes.append(p)

    # trainer process should be the main process
    try:
        trainer_main_tf_dataset(pwz_config["url"], config)
    finally:
        print("exiting...")
        for p in worker_processes:
            p.terminate()
            p.join()
        pwz_proc.terminate()
        pwz_proc.join()
Exemple #5
0
def test_pubsub_full():
    try:
        server_proc, config = perwez.start_server()
        procs = []
        trainer_proc = mp.Process(target=_trainer, args=(config["url"], ))
        trainer_proc.start()
        for idx in range(WORKER_CNT):
            proc = mp.Process(target=_worker, args=(idx, config["url"]))
            proc.start()
            procs.append(proc)
        for p in procs:
            p.join()
    finally:
        trainer_proc.terminate()
        trainer_proc.join()
        for p in procs:
            p.terminate()
            p.join()
        server_proc.terminate()
        server_proc.join()
Exemple #6
0
def main(perwez_port, rb_ports):
    mp.set_start_method("spawn", force=True)
    config_path = path.join(path.dirname(__file__), "config.yaml")
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)
    perwez_proc, perwez_config = perwez.start_server(port=perwez_port)
    rb_procs = []
    rb_addrs = []
    for port in rb_ports:
        proc, addr = reth_buffer.start_per(
            capacity=config["replay_buffer"]["capacity"] // len(rb_ports),
            alpha=config["replay_buffer"]["alpha"],
            beta=config["replay_buffer"]["beta"],
            batch_size=config["common"]["batch_size"],
            port=port,
        )
        rb_procs.append(proc)
        rb_addrs.append(addr)

    trainer_proc = mp.Process(target=trainer_main,
                              args=(config, perwez_config["url"], rb_addrs))

    def graceful_exit(*_):
        global EXITED
        if not EXITED:
            EXITED = True
            print("exiting...")
            for p in rb_procs:
                p.terminate()
                p.join()
            trainer_proc.terminate()
            trainer_proc.join()
            perwez_proc.terminate()
            perwez_proc.join()

    signal.signal(signal.SIGINT, graceful_exit)
    signal.signal(signal.SIGTERM, graceful_exit)

    trainer_proc.start()
    trainer_proc.join()
    graceful_exit()
Exemple #7
0
def main():
    mp.set_start_method("spawn", force=True)
    config_path = path.join(path.dirname(__file__), "config.yaml")
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)

    # init perwez
    pwz_proc, pwz_config = perwez.start_server()
    # init reth_buffer
    buffer_proc, rb_addr = reth_buffer.start_per(
        num_sampler_procs=1,
        capacity=config["replay_buffer"]["capacity"],
        batch_size=config["common"]["batch_size"],
        alpha=config["replay_buffer"]["alpha"],
        beta=config["replay_buffer"]["beta"],
    )

    # worker subprocesses
    worker_processes = []
    num_workers = config["common"]["num_workers"]
    for idx in range(num_workers):
        p = mp.Process(
            name=f"apex-worker-{idx}",
            target=worker_main,
            args=(config, pwz_config["url"], rb_addr, idx),
        )
        p.start()
        worker_processes.append(p)

    # trainer process should be the main process
    try:
        trainer_main(config, pwz_config["url"], rb_addr)
    finally:
        print("exiting...")
        for p in worker_processes:
            p.terminate()
            p.join()
        pwz_proc.terminate()
        pwz_proc.join()
        buffer_proc.terminate()
        buffer_proc.join()
Exemple #8
0
def _weight_producer():
    pwz = perwez.connect("default")

    while True:
        weight = secrets.token_bytes(50 * 1024 * 1024)
        pwz.publish("weight", weight, ipc=False)
        time.sleep(0.333)


def _exit(proc_list):
    for proc in proc_list:
        proc.terminate()
        proc.join()


if __name__ == "__main__":
    pwz_proc, _ = perwez.start_server(port=12333)
    processes = [pwz_proc]

    p1 = mp.Process(target=_data_fetcher)
    p1.start()
    processes.append(p1)
    p2 = mp.Process(target=_weight_producer)
    p2.start()
    processes.append(p2)

    atexit.register(_exit, processes)

    for p in processes:
        p.join()
Exemple #9
0
def test_zmq_push_mp():
    mp.set_start_method("spawn", force=True)
    time.sleep(1)
    proc, _ = perwez.start_server("default")
    topic = "asd"

    # consumer first
    print("consumer first")
    processes = []
    sem = mp.Semaphore(0)
    for i in range(3):
        t = mp.Process(target=_consumer, args=(topic, i, sem))
        t.start()
        processes.append(t)
    for i in range(1):
        t = mp.Process(target=_producer, args=(topic, i))
        t.start()
        processes.append(t)
    for _ in range(20):
        sem.acquire()
    for t in processes:
        t.terminate()
        t.join()

    # producer first
    print("producer first")
    processes = []
    sem = mp.Semaphore(0)
    for i in range(1):
        t = mp.Process(target=_producer, args=(topic, i))
        t.start()
        processes.append(t)
    for i in range(3):
        t = mp.Process(target=_consumer, args=(topic, i, sem))
        t.start()
        processes.append(t)
    for _ in range(20):
        sem.acquire()
    for t in processes:
        t.terminate()
        t.join()

    # N to N
    print("N to N")
    processes = []
    sem = mp.Semaphore(0)
    for i in range(4):
        t = mp.Process(target=_producer, args=(topic, i))
        t.start()
        processes.append(t)
    for i in range(5):
        t = mp.Process(target=_consumer, args=(topic, i, sem))
        t.start()
        processes.append(t)
    for _ in range(4 * 20):
        sem.acquire()
    for t in processes:
        t.terminate()
        t.join()

    proc.terminate()
    proc.join()