def _worker(idx, server_url): weight_recv = perwez.RecvSocket(server_url, "weight", broadcast=True) data_send = perwez.SendSocket(server_url, "data", broadcast=False) env = reth.env.make(ENV_NAME) buffer = reth.buffer.NumpyBuffer(WORKER_BATCH_SIZE, circular=False) w_cnt = 0 d_cnt = 0 s0 = env.reset().astype("f4") while w_cnt < 10: if not weight_recv.empty(): weight_recv.recv() w_cnt += 1 s1, r, done, _ = env.step(env.action_space.sample()) s1 = s1.astype("f4") buffer.append((s0, r, done, s1)) if buffer.size == buffer.capacity: data_send.send(buffer.data) d_cnt += 1 buffer.clear() print(f"worker{idx}: recv weights {w_cnt}, send data {d_cnt}") if done: s0 = env.reset().astype("f4") else: s0 = s1
def weights_proxy(perwez_url): weight_recv = perwez.RecvSocket(perwez_url, "weights", broadcast=True) weight_send = perwez.SendSocket(perwez_url, "local-weights", broadcast=True) while True: res = weight_recv.recv() weight_send.send(res)
def _trainer(server_url): data_recv = perwez.RecvSocket(server_url, "data", broadcast=False) weight_send = perwez.SendSocket(server_url, "weight", broadcast=True) cnt = 0 while True: latest_res = data_recv.recv() shapes = [x.shape for x in latest_res] cnt += 1 if cnt % 5 == 0: print(f"recv data: {cnt}, shape: {shapes}") weight = secrets.token_bytes(50 * 1024 * 1024) weight_send.send(weight) time.sleep(0.333)
def worker_main(config, perwez_url, rb_addr, idx): rb_client = reth_buffer.Client(rb_addr) weight_recv = perwez.RecvSocket(perwez_url, "weight", broadcast=True) batch_size = config["common"]["batch_size"] num_workers = config["common"]["num_workers"] eps = 0.4**(1 + (idx / (num_workers - 1)) * 7) solver = get_solver(config, device="cpu") log_flag = idx >= num_workers + (-num_workers // 3) # aligned with ray worker = get_worker( config, exploration=eps, solver=solver, logger=getLogger(f"worker{idx}") if log_flag else None, ) recv_weights_interval = config["common"]["recv_weights_interval"] step = 0 prev_recv = 0 try: while True: step += 1 # load weights interval_flag = (step - prev_recv) * batch_size >= recv_weights_interval if interval_flag and not weight_recv.empty(): worker.load_weights(io.BytesIO(weight_recv.recv())) # step data = worker.step_batch(batch_size) loss = worker.solver.calc_loss(data) # format s0, a, r, s1, done = data s0 = np.asarray(s0, dtype="f4") a = np.asarray(a, dtype="i8") r = np.asarray(r, dtype="f4") s1 = np.asarray(s1, dtype="f4") done = np.asarray(done, dtype="f4") loss = np.asarray(loss, dtype="f4") # upload rb_client.append([s0, a, r, s1, done], loss) except KeyboardInterrupt: # silent exit pass
def worker_main(config, idx, size, perwez_url, rb_addr): rb_client = reth_buffer.Client(rb_addr) weight_recv = perwez.RecvSocket(perwez_url, "local-weights", broadcast=True) batch_size = config["common"]["rollout_batch_size"] eps = 0.4**(1 + (idx / (size - 1)) * 7) solver = get_solver(config, device="cpu") worker = get_worker(config, exploration=eps, solver=solver, logger=getLogger(f"worker{idx}")) recv_weights_interval = config["common"]["recv_weights_interval"] prev_load = 0 adder = NStepAdder(config["solver"]["gamma"], config["solver"]["n_step"]) buffer = NumpyBuffer(batch_size, circular=False) while True: # load weights if (worker.cur_step - prev_load) > recv_weights_interval and not weight_recv.empty(): worker.load_weights(io.BytesIO(weight_recv.recv())) prev_load = worker.cur_step # step s0, a, r, s1, done = worker.step() s0 = np.asarray(s0, dtype="f4") a = np.asarray(a, dtype="i8") r = np.asarray(r, dtype="f4") s1 = np.asarray(s1, dtype="f4") done = np.asarray(done, dtype="f4") # adder row = adder.push(s0, a, r, s1, done) if row is None: continue buffer.append(row) if buffer.size == buffer.capacity: loss = worker.solver.calc_loss(buffer.data) loss = np.asarray(loss, dtype="f4") rb_client.append(buffer.data, loss, compress=True) buffer.clear()
def worker_main(perwez_url, config, idx): reverb_client = reverb.Client(f"localhost:{PORT}") reverb_writer = reverb_client.writer(1) weight_recv = perwez.RecvSocket(perwez_url, "weight", broadcast=True) batch_size = config["common"]["batch_size"] num_workers = config["common"]["num_workers"] eps = 0.4 ** (1 + (idx / (num_workers - 1)) * 7) solver = get_solver(config, device="cpu") log_flag = idx >= num_workers + (-num_workers // 3) # aligned with ray worker = get_worker( config, exploration=eps, solver=solver, logger=getLogger(f"worker{idx}") if log_flag else None, ) while True: # load weights if not weight_recv.empty(): worker.load_weights(io.BytesIO(weight_recv.recv())) # step data = worker.step_batch(batch_size) loss = worker.solver.calc_loss(data) # format s0, a, r, s1, done = data s0 = np.asarray(s0, dtype="f4") a = np.asarray(a, dtype="i8") r = np.asarray(r, dtype="f4") s1 = np.asarray(s1, dtype="f4") done = np.asarray(done, dtype="f4") loss = np.asarray(loss, dtype="f4") # upload for i, _ in enumerate(s0): reverb_writer.append([s0[i], a[i], r[i], s1[i], done[i]]) reverb_writer.create_item( table=TABLE_NAME, num_timesteps=1, priority=loss[i] )