def test_basic(): rb_proc, addr = reth_buffer.start_server(CAPACITY, BATCH_SIZE) try: client = reth_buffer.Client(addr) loader = reth_buffer.NumpyLoader(addr) # 1. append data = [ np.random.rand(1000, 4, 84), np.random.rand(1000), np.arange(1000) ] buf_weights = np.random.rand(1000) + 1 for start in range(0, 1000, 100): batch = [x[start:start + 100] for x in data] client.append(batch, buf_weights[start:start + 100]) # 2. sample for _ in range(10): batch, indices, weights = loader.sample() for i, idx in enumerate(indices): assert batch[2][i] == idx assert batch[1][i] == data[1][idx] # 3. update priorities buf_weights = np.random.rand(1000) + 10 client.update_priorities(np.arange(1000), buf_weights) finally: rb_proc.terminate() rb_proc.join()
def _trainer(addr): client = reth_buffer.Client(addr) loader = reth_buffer.NumpyLoader(addr) for _ in range(10): print(_, flush=True) data, indices, weights = loader.sample() assert len(data) == 5 time.sleep(0.1) client.update_priorities(indices, np.random.rand(len(indices)))
def _worker(addr): client = reth_buffer.Client(addr) for _ in range(10): print("worker", _, flush=True) s0 = np.random.rand(BATCH_SIZE, 4, 84, 84) s1 = np.random.rand(BATCH_SIZE, 4, 84, 84) a = np.random.randint(0, 8, BATCH_SIZE) r = np.random.rand(BATCH_SIZE) done = np.random.rand(BATCH_SIZE) weights = np.random.rand(BATCH_SIZE) + 1 client.append([s0, a, r, s1, done], weights) time.sleep(0.1)
def worker_main(config, perwez_url, rb_addr, idx): rb_client = reth_buffer.Client(rb_addr) weight_recv = perwez.RecvSocket(perwez_url, "weight", broadcast=True) batch_size = config["common"]["batch_size"] num_workers = config["common"]["num_workers"] eps = 0.4**(1 + (idx / (num_workers - 1)) * 7) solver = get_solver(config, device="cpu") log_flag = idx >= num_workers + (-num_workers // 3) # aligned with ray worker = get_worker( config, exploration=eps, solver=solver, logger=getLogger(f"worker{idx}") if log_flag else None, ) recv_weights_interval = config["common"]["recv_weights_interval"] step = 0 prev_recv = 0 try: while True: step += 1 # load weights interval_flag = (step - prev_recv) * batch_size >= recv_weights_interval if interval_flag and not weight_recv.empty(): worker.load_weights(io.BytesIO(weight_recv.recv())) # step data = worker.step_batch(batch_size) loss = worker.solver.calc_loss(data) # format s0, a, r, s1, done = data s0 = np.asarray(s0, dtype="f4") a = np.asarray(a, dtype="i8") r = np.asarray(r, dtype="f4") s1 = np.asarray(s1, dtype="f4") done = np.asarray(done, dtype="f4") loss = np.asarray(loss, dtype="f4") # upload rb_client.append([s0, a, r, s1, done], loss) except KeyboardInterrupt: # silent exit pass
def worker_main(config, idx, size, perwez_url, rb_addr): rb_client = reth_buffer.Client(rb_addr) weight_recv = perwez.RecvSocket(perwez_url, "local-weights", broadcast=True) batch_size = config["common"]["rollout_batch_size"] eps = 0.4**(1 + (idx / (size - 1)) * 7) solver = get_solver(config, device="cpu") worker = get_worker(config, exploration=eps, solver=solver, logger=getLogger(f"worker{idx}")) recv_weights_interval = config["common"]["recv_weights_interval"] prev_load = 0 adder = NStepAdder(config["solver"]["gamma"], config["solver"]["n_step"]) buffer = NumpyBuffer(batch_size, circular=False) while True: # load weights if (worker.cur_step - prev_load) > recv_weights_interval and not weight_recv.empty(): worker.load_weights(io.BytesIO(weight_recv.recv())) prev_load = worker.cur_step # step s0, a, r, s1, done = worker.step() s0 = np.asarray(s0, dtype="f4") a = np.asarray(a, dtype="i8") r = np.asarray(r, dtype="f4") s1 = np.asarray(s1, dtype="f4") done = np.asarray(done, dtype="f4") # adder row = adder.push(s0, a, r, s1, done) if row is None: continue buffer.append(row) if buffer.size == buffer.capacity: loss = worker.solver.calc_loss(buffer.data) loss = np.asarray(loss, dtype="f4") rb_client.append(buffer.data, loss, compress=True) buffer.clear()
def test_reth_buffer(data): import reth_buffer print("TEST RETH_BUFFER") print("initializing...") buffer_process, addr = reth_buffer.start_server(CAPACITY, BATCH_SIZE) client = reth_buffer.Client(addr) client.append(data, np.ones(CAPACITY)) # loader = reth_buffer.TorchCudaLoader(addr) loader = reth_buffer.NumpyLoader(addr) # init loader.sample() print("ready") t0 = time.perf_counter() for _ in range(TEST_CNT): loader.sample() t1 = time.perf_counter() print(TEST_CNT, t1 - t0) buffer_process.terminate() buffer_process.join()
def trainer_main(config, perwez_url, rb_addr): weight_send = perwez.SendSocket(perwez_url, "weight", broadcast=True) rb_client = reth_buffer.Client(rb_addr) loader = TorchCudaLoader(rb_addr) trainer = get_trainer(config) send_weights_interval = config["common"]["send_weights_interval"] try: ts = 0 for data, indices, weights in loader: ts += 1 if trainer.cur_time > 3600 * 10: return loss = trainer.step(data, weights=weights) rb_client.update_priorities(np.asarray(indices), np.asarray(loss)) if ts % send_weights_interval == 0: weight_send.send(trainer.save_weights().getbuffer()) except KeyboardInterrupt: # silent exit pass
def trainer_main(config, perwez_url, rb_addrs): weights_send = perwez.SendSocket(perwez_url, "weights", broadcast=True) rb_clients = [reth_buffer.Client(addr) for addr in rb_addrs] rb_loaders = [ reth_buffer.TorchCudaLoader(addr, buffer_size=4, num_procs=2) for addr in rb_addrs ] trainer = get_trainer(config) send_weights_interval = config["common"]["send_weights_interval"] ts = 0 while True: ts += 1 if trainer.cur_time > 3600 * 40: return idx = ts % len(rb_clients) data, indices, weights = rb_loaders[idx].sample() loss = trainer.step(data, weights=weights) rb_clients[idx].update_priorities(np.asarray(indices), np.asarray(loss)) if ts % send_weights_interval == 0: stream = io.BytesIO() trainer.save_weights(stream) weights_send.send(stream.getbuffer())