def test_zmq_push_block(): proc, _ = perwez.start_server("default") topic = "asd" c0 = perwez.connect("default") c1 = perwez.connect("default") c2 = perwez.connect("default") data = secrets.token_bytes(size) print("EAGAIN when no client") try: c0.push(topic, data, ipc=False, noblock=True) except zmq.ZMQError: pass else: assert False print("c1 join") w1 = c1.subscribe(topic, False) for i in range(DEFAULT_HWM): c0.push(topic, data, ipc=False) print(f"{DEFAULT_HWM} items pushed") print("hwm block: push more items") for i in range(DEFAULT_HWM * 3): try: c0.push(topic, data, ipc=False, noblock=True) except zmq.ZMQError: break print(f"{i}/{DEFAULT_HWM * 3} more items in queue") res = w1.get(timeout=3) assert res == data print("c2 join") w2 = c2.subscribe(topic, False) time.sleep(1) print("hwm block: push more items") for i in range(DEFAULT_HWM * 3): try: c0.push(topic, data, ipc=False, noblock=True) except zmq.ZMQError: break print(f"{i}/{DEFAULT_HWM * 3} more items in queue") res = w2.get(timeout=15) assert res == data print("finished") c0.close() c1.close() c2.close() proc.terminate() proc.join()
def _weight_producer(): pwz = perwez.connect("default") while True: weight = secrets.token_bytes(50 * 1024 * 1024) pwz.publish("weight", weight, ipc=False) time.sleep(0.333)
def _worker(worker_idx): pwz = perwez.connect("default") weight_watcher = pwz.subscribe("weight", True) env = reth.env.make(ENV_NAME) buffer = reth.buffer.NumpyBuffer(WORKER_BATCH_SIZE, circular=False) w_cnt = 0 d_cnt = 0 s0 = env.reset().astype("f4") while True: if not weight_watcher.empty(): weight_watcher.get() w_cnt += 1 s1, r, done, _ = env.step(env.action_space.sample()) s1 = s1.astype("f4") buffer.append((s0, r, done, s1)) if buffer.size == buffer.capacity: pwz.push("data", buffer.data, ipc=False, compression=DATA_COMPRESSION) d_cnt += 1 buffer.clear() if w_cnt % 1 == 0: print( f"worker{worker_idx}: recv weights {w_cnt}, send data {d_cnt}" ) if done: s0 = env.reset().astype("f4") else: s0 = s1
def _consumer(topic, idx, sem): c = perwez.connect("default") w = c.subscribe(topic, False) while True: res = w.get() print(idx, res[0]) sem.release() c.close(linger=0)
def _data_fetcher(): pwz = perwez.connect("default") data_watcher = pwz.subscribe("data", False) cnt = 0 while True: latest_res = data_watcher.get() shapes = [x.shape for x in latest_res] cnt += 1 if cnt % 1 == 0: print(f"recv data: {cnt}, shape: {shapes}")
def _producer(topic, idx): data = secrets.token_bytes(size) c = perwez.connect("default") for i in range(20): c.push(topic, [f"{idx}-{i}".encode(), data], ipc=False) c.close(linger=-1)