示例#1
0
文件: test_zmq.py 项目: sosp2021/reth
def test_zmq_push_block():
    proc, _ = perwez.start_server("default")
    topic = "asd"
    c0 = perwez.connect("default")
    c1 = perwez.connect("default")
    c2 = perwez.connect("default")
    data = secrets.token_bytes(size)

    print("EAGAIN when no client")
    try:
        c0.push(topic, data, ipc=False, noblock=True)
    except zmq.ZMQError:
        pass
    else:
        assert False

    print("c1 join")
    w1 = c1.subscribe(topic, False)
    for i in range(DEFAULT_HWM):
        c0.push(topic, data, ipc=False)
    print(f"{DEFAULT_HWM} items pushed")
    print("hwm block: push more items")
    for i in range(DEFAULT_HWM * 3):
        try:
            c0.push(topic, data, ipc=False, noblock=True)
        except zmq.ZMQError:
            break
    print(f"{i}/{DEFAULT_HWM * 3} more items in queue")

    res = w1.get(timeout=3)
    assert res == data

    print("c2 join")
    w2 = c2.subscribe(topic, False)
    time.sleep(1)
    print("hwm block: push more items")
    for i in range(DEFAULT_HWM * 3):
        try:
            c0.push(topic, data, ipc=False, noblock=True)
        except zmq.ZMQError:
            break
    print(f"{i}/{DEFAULT_HWM * 3} more items in queue")
    res = w2.get(timeout=15)
    assert res == data
    print("finished")
    c0.close()
    c1.close()
    c2.close()
    proc.terminate()
    proc.join()
示例#2
0
文件: trainer.py 项目: sosp2021/reth
def _weight_producer():
    pwz = perwez.connect("default")

    while True:
        weight = secrets.token_bytes(50 * 1024 * 1024)
        pwz.publish("weight", weight, ipc=False)
        time.sleep(0.333)
示例#3
0
文件: worker.py 项目: sosp2021/reth
def _worker(worker_idx):
    pwz = perwez.connect("default")
    weight_watcher = pwz.subscribe("weight", True)

    env = reth.env.make(ENV_NAME)
    buffer = reth.buffer.NumpyBuffer(WORKER_BATCH_SIZE, circular=False)

    w_cnt = 0
    d_cnt = 0
    s0 = env.reset().astype("f4")

    while True:
        if not weight_watcher.empty():
            weight_watcher.get()
            w_cnt += 1
        s1, r, done, _ = env.step(env.action_space.sample())
        s1 = s1.astype("f4")
        buffer.append((s0, r, done, s1))

        if buffer.size == buffer.capacity:
            pwz.push("data",
                     buffer.data,
                     ipc=False,
                     compression=DATA_COMPRESSION)
            d_cnt += 1
            buffer.clear()
            if w_cnt % 1 == 0:
                print(
                    f"worker{worker_idx}: recv weights {w_cnt}, send data {d_cnt}"
                )

        if done:
            s0 = env.reset().astype("f4")
        else:
            s0 = s1
示例#4
0
文件: test_zmq.py 项目: sosp2021/reth
def _consumer(topic, idx, sem):
    c = perwez.connect("default")
    w = c.subscribe(topic, False)
    while True:
        res = w.get()
        print(idx, res[0])
        sem.release()
    c.close(linger=0)
示例#5
0
文件: trainer.py 项目: sosp2021/reth
def _data_fetcher():
    pwz = perwez.connect("default")
    data_watcher = pwz.subscribe("data", False)
    cnt = 0
    while True:
        latest_res = data_watcher.get()
        shapes = [x.shape for x in latest_res]
        cnt += 1
        if cnt % 1 == 0:
            print(f"recv data: {cnt}, shape: {shapes}")
示例#6
0
文件: test_zmq.py 项目: sosp2021/reth
def _producer(topic, idx):
    data = secrets.token_bytes(size)
    c = perwez.connect("default")
    for i in range(20):
        c.push(topic, [f"{idx}-{i}".encode(), data], ipc=False)
    c.close(linger=-1)