예제 #1
0
def _worker(idx, server_url):
    weight_recv = perwez.RecvSocket(server_url, "weight", broadcast=True)
    data_send = perwez.SendSocket(server_url, "data", broadcast=False)

    env = reth.env.make(ENV_NAME)
    buffer = reth.buffer.NumpyBuffer(WORKER_BATCH_SIZE, circular=False)

    w_cnt = 0
    d_cnt = 0
    s0 = env.reset().astype("f4")
    while w_cnt < 10:
        if not weight_recv.empty():
            weight_recv.recv()
            w_cnt += 1
        s1, r, done, _ = env.step(env.action_space.sample())
        s1 = s1.astype("f4")
        buffer.append((s0, r, done, s1))

        if buffer.size == buffer.capacity:
            data_send.send(buffer.data)
            d_cnt += 1
            buffer.clear()
            print(f"worker{idx}: recv weights {w_cnt}, send data {d_cnt}")

        if done:
            s0 = env.reset().astype("f4")
        else:
            s0 = s1
예제 #2
0
파일: worker.py 프로젝트: sosp2021/reth
def weights_proxy(perwez_url):
    weight_recv = perwez.RecvSocket(perwez_url, "weights", broadcast=True)
    weight_send = perwez.SendSocket(perwez_url,
                                    "local-weights",
                                    broadcast=True)
    while True:
        res = weight_recv.recv()
        weight_send.send(res)
예제 #3
0
def _trainer(server_url):
    data_recv = perwez.RecvSocket(server_url, "data", broadcast=False)
    weight_send = perwez.SendSocket(server_url, "weight", broadcast=True)

    cnt = 0
    while True:
        latest_res = data_recv.recv()
        shapes = [x.shape for x in latest_res]
        cnt += 1
        if cnt % 5 == 0:
            print(f"recv data: {cnt}, shape: {shapes}")
            weight = secrets.token_bytes(50 * 1024 * 1024)
            weight_send.send(weight)
            time.sleep(0.333)
예제 #4
0
파일: run.py 프로젝트: sosp2021/reth
def worker_main(config, perwez_url, rb_addr, idx):
    rb_client = reth_buffer.Client(rb_addr)
    weight_recv = perwez.RecvSocket(perwez_url, "weight", broadcast=True)

    batch_size = config["common"]["batch_size"]
    num_workers = config["common"]["num_workers"]
    eps = 0.4**(1 + (idx / (num_workers - 1)) * 7)
    solver = get_solver(config, device="cpu")
    log_flag = idx >= num_workers + (-num_workers // 3)  # aligned with ray
    worker = get_worker(
        config,
        exploration=eps,
        solver=solver,
        logger=getLogger(f"worker{idx}") if log_flag else None,
    )
    recv_weights_interval = config["common"]["recv_weights_interval"]

    step = 0
    prev_recv = 0
    try:
        while True:
            step += 1

            # load weights
            interval_flag = (step -
                             prev_recv) * batch_size >= recv_weights_interval
            if interval_flag and not weight_recv.empty():
                worker.load_weights(io.BytesIO(weight_recv.recv()))

            # step
            data = worker.step_batch(batch_size)
            loss = worker.solver.calc_loss(data)
            # format
            s0, a, r, s1, done = data
            s0 = np.asarray(s0, dtype="f4")
            a = np.asarray(a, dtype="i8")
            r = np.asarray(r, dtype="f4")
            s1 = np.asarray(s1, dtype="f4")
            done = np.asarray(done, dtype="f4")
            loss = np.asarray(loss, dtype="f4")
            # upload
            rb_client.append([s0, a, r, s1, done], loss)
    except KeyboardInterrupt:
        # silent exit
        pass
예제 #5
0
파일: worker.py 프로젝트: sosp2021/reth
def worker_main(config, idx, size, perwez_url, rb_addr):
    rb_client = reth_buffer.Client(rb_addr)
    weight_recv = perwez.RecvSocket(perwez_url,
                                    "local-weights",
                                    broadcast=True)

    batch_size = config["common"]["rollout_batch_size"]
    eps = 0.4**(1 + (idx / (size - 1)) * 7)
    solver = get_solver(config, device="cpu")
    worker = get_worker(config,
                        exploration=eps,
                        solver=solver,
                        logger=getLogger(f"worker{idx}"))

    recv_weights_interval = config["common"]["recv_weights_interval"]
    prev_load = 0

    adder = NStepAdder(config["solver"]["gamma"], config["solver"]["n_step"])
    buffer = NumpyBuffer(batch_size, circular=False)
    while True:
        # load weights
        if (worker.cur_step -
                prev_load) > recv_weights_interval and not weight_recv.empty():
            worker.load_weights(io.BytesIO(weight_recv.recv()))
            prev_load = worker.cur_step

        # step
        s0, a, r, s1, done = worker.step()
        s0 = np.asarray(s0, dtype="f4")
        a = np.asarray(a, dtype="i8")
        r = np.asarray(r, dtype="f4")
        s1 = np.asarray(s1, dtype="f4")
        done = np.asarray(done, dtype="f4")
        # adder
        row = adder.push(s0, a, r, s1, done)
        if row is None:
            continue
        buffer.append(row)
        if buffer.size == buffer.capacity:
            loss = worker.solver.calc_loss(buffer.data)
            loss = np.asarray(loss, dtype="f4")
            rb_client.append(buffer.data, loss, compress=True)
            buffer.clear()
예제 #6
0
def worker_main(perwez_url, config, idx):
    reverb_client = reverb.Client(f"localhost:{PORT}")
    reverb_writer = reverb_client.writer(1)
    weight_recv = perwez.RecvSocket(perwez_url, "weight", broadcast=True)

    batch_size = config["common"]["batch_size"]
    num_workers = config["common"]["num_workers"]
    eps = 0.4 ** (1 + (idx / (num_workers - 1)) * 7)
    solver = get_solver(config, device="cpu")
    log_flag = idx >= num_workers + (-num_workers // 3)  # aligned with ray
    worker = get_worker(
        config,
        exploration=eps,
        solver=solver,
        logger=getLogger(f"worker{idx}") if log_flag else None,
    )

    while True:
        # load weights
        if not weight_recv.empty():
            worker.load_weights(io.BytesIO(weight_recv.recv()))

        # step
        data = worker.step_batch(batch_size)
        loss = worker.solver.calc_loss(data)
        # format
        s0, a, r, s1, done = data
        s0 = np.asarray(s0, dtype="f4")
        a = np.asarray(a, dtype="i8")
        r = np.asarray(r, dtype="f4")
        s1 = np.asarray(s1, dtype="f4")
        done = np.asarray(done, dtype="f4")
        loss = np.asarray(loss, dtype="f4")
        # upload
        for i, _ in enumerate(s0):
            reverb_writer.append([s0[i], a[i], r[i], s1[i], done[i]])
            reverb_writer.create_item(
                table=TABLE_NAME, num_timesteps=1, priority=loss[i]
            )