예제 #1
0
파일: run.py 프로젝트: sosp2021/reth
def worker_main(config, perwez_url, rb_addr, idx):
    rb_client = reth_buffer.Client(rb_addr)
    weight_recv = perwez.RecvSocket(perwez_url, "weight", broadcast=True)

    batch_size = config["common"]["batch_size"]
    num_workers = config["common"]["num_workers"]
    eps = 0.4**(1 + (idx / (num_workers - 1)) * 7)
    solver = get_solver(config, device="cpu")
    log_flag = idx >= num_workers + (-num_workers // 3)  # aligned with ray
    worker = get_worker(
        config,
        exploration=eps,
        solver=solver,
        logger=getLogger(f"worker{idx}") if log_flag else None,
    )
    recv_weights_interval = config["common"]["recv_weights_interval"]

    step = 0
    prev_recv = 0
    try:
        while True:
            step += 1

            # load weights
            interval_flag = (step -
                             prev_recv) * batch_size >= recv_weights_interval
            if interval_flag and not weight_recv.empty():
                worker.load_weights(io.BytesIO(weight_recv.recv()))

            # step
            data = worker.step_batch(batch_size)
            loss = worker.solver.calc_loss(data)
            # format
            s0, a, r, s1, done = data
            s0 = np.asarray(s0, dtype="f4")
            a = np.asarray(a, dtype="i8")
            r = np.asarray(r, dtype="f4")
            s1 = np.asarray(s1, dtype="f4")
            done = np.asarray(done, dtype="f4")
            loss = np.asarray(loss, dtype="f4")
            # upload
            rb_client.append([s0, a, r, s1, done], loss)
    except KeyboardInterrupt:
        # silent exit
        pass
예제 #2
0
파일: worker.py 프로젝트: sosp2021/reth
def worker_main(config, idx, size, perwez_url, rb_addr):
    rb_client = reth_buffer.Client(rb_addr)
    weight_recv = perwez.RecvSocket(perwez_url,
                                    "local-weights",
                                    broadcast=True)

    batch_size = config["common"]["rollout_batch_size"]
    eps = 0.4**(1 + (idx / (size - 1)) * 7)
    solver = get_solver(config, device="cpu")
    worker = get_worker(config,
                        exploration=eps,
                        solver=solver,
                        logger=getLogger(f"worker{idx}"))

    recv_weights_interval = config["common"]["recv_weights_interval"]
    prev_load = 0

    adder = NStepAdder(config["solver"]["gamma"], config["solver"]["n_step"])
    buffer = NumpyBuffer(batch_size, circular=False)
    while True:
        # load weights
        if (worker.cur_step -
                prev_load) > recv_weights_interval and not weight_recv.empty():
            worker.load_weights(io.BytesIO(weight_recv.recv()))
            prev_load = worker.cur_step

        # step
        s0, a, r, s1, done = worker.step()
        s0 = np.asarray(s0, dtype="f4")
        a = np.asarray(a, dtype="i8")
        r = np.asarray(r, dtype="f4")
        s1 = np.asarray(s1, dtype="f4")
        done = np.asarray(done, dtype="f4")
        # adder
        row = adder.push(s0, a, r, s1, done)
        if row is None:
            continue
        buffer.append(row)
        if buffer.size == buffer.capacity:
            loss = worker.solver.calc_loss(buffer.data)
            loss = np.asarray(loss, dtype="f4")
            rb_client.append(buffer.data, loss, compress=True)
            buffer.clear()
예제 #3
0
def worker_main(perwez_url, config, idx):
    reverb_client = reverb.Client(f"localhost:{PORT}")
    reverb_writer = reverb_client.writer(1)
    weight_recv = perwez.RecvSocket(perwez_url, "weight", broadcast=True)

    batch_size = config["common"]["batch_size"]
    num_workers = config["common"]["num_workers"]
    eps = 0.4 ** (1 + (idx / (num_workers - 1)) * 7)
    solver = get_solver(config, device="cpu")
    log_flag = idx >= num_workers + (-num_workers // 3)  # aligned with ray
    worker = get_worker(
        config,
        exploration=eps,
        solver=solver,
        logger=getLogger(f"worker{idx}") if log_flag else None,
    )

    while True:
        # load weights
        if not weight_recv.empty():
            worker.load_weights(io.BytesIO(weight_recv.recv()))

        # step
        data = worker.step_batch(batch_size)
        loss = worker.solver.calc_loss(data)
        # format
        s0, a, r, s1, done = data
        s0 = np.asarray(s0, dtype="f4")
        a = np.asarray(a, dtype="i8")
        r = np.asarray(r, dtype="f4")
        s1 = np.asarray(s1, dtype="f4")
        done = np.asarray(done, dtype="f4")
        loss = np.asarray(loss, dtype="f4")
        # upload
        for i, _ in enumerate(s0):
            reverb_writer.append([s0[i], a[i], r[i], s1[i], done[i]])
            reverb_writer.create_item(
                table=TABLE_NAME, num_timesteps=1, priority=loss[i]
            )
예제 #4
0
import os.path as path

from reth.algorithm.util import calculate_discount_rewards_with_dones
from reth.buffer import DynamicSizeBuffer
from reth.presets.config import get_solver, get_worker, get_trainer

MAX_TS = 100000
GAMMA = 0.99
UPDATE_INTERVAL = 2000

if __name__ == "__main__":
    config_path = path.join(path.dirname(__file__), "config.yaml")
    solver = get_solver(config_path)
    act_solver = get_solver(config_path)
    # shared solver
    worker = get_worker(config_path, solver=act_solver)
    trainer = get_trainer(config_path, solver=solver)

    episode_buffer = DynamicSizeBuffer(64)
    data_buffer = DynamicSizeBuffer(64)
    act_solver.sync_weights(solver)
    for _ in range(MAX_TS):
        for _ in range(UPDATE_INTERVAL):
            a, logprob = act_solver.act(worker.s0)
            s0, a, r, s1, done = worker.step(a)
            episode_buffer.append((s0, a, r, logprob, done))
            if done:
                s0, a, r, logprob, done = episode_buffer.data
                r = calculate_discount_rewards_with_dones(r, done, GAMMA)
                data_buffer.append_batch((s0, a, r, logprob))
                episode_buffer.clear()