コード例 #1
0
def test_basic():
    rb_proc, addr = reth_buffer.start_server(CAPACITY, BATCH_SIZE)
    try:
        client = reth_buffer.Client(addr)
        loader = reth_buffer.NumpyLoader(addr)
        # 1. append
        data = [
            np.random.rand(1000, 4, 84),
            np.random.rand(1000),
            np.arange(1000)
        ]
        buf_weights = np.random.rand(1000) + 1
        for start in range(0, 1000, 100):
            batch = [x[start:start + 100] for x in data]
            client.append(batch, buf_weights[start:start + 100])
        # 2. sample

        for _ in range(10):
            batch, indices, weights = loader.sample()
            for i, idx in enumerate(indices):
                assert batch[2][i] == idx
                assert batch[1][i] == data[1][idx]
        # 3. update priorities
        buf_weights = np.random.rand(1000) + 10
        client.update_priorities(np.arange(1000), buf_weights)
    finally:
        rb_proc.terminate()
        rb_proc.join()
コード例 #2
0
def _trainer(addr):
    client = reth_buffer.Client(addr)
    loader = reth_buffer.NumpyLoader(addr)
    for _ in range(10):
        print(_, flush=True)
        data, indices, weights = loader.sample()
        assert len(data) == 5
        time.sleep(0.1)
        client.update_priorities(indices, np.random.rand(len(indices)))
コード例 #3
0
def _worker(addr):
    client = reth_buffer.Client(addr)
    for _ in range(10):
        print("worker", _, flush=True)
        s0 = np.random.rand(BATCH_SIZE, 4, 84, 84)
        s1 = np.random.rand(BATCH_SIZE, 4, 84, 84)
        a = np.random.randint(0, 8, BATCH_SIZE)
        r = np.random.rand(BATCH_SIZE)
        done = np.random.rand(BATCH_SIZE)
        weights = np.random.rand(BATCH_SIZE) + 1
        client.append([s0, a, r, s1, done], weights)
        time.sleep(0.1)
コード例 #4
0
ファイル: run.py プロジェクト: sosp2021/reth
def worker_main(config, perwez_url, rb_addr, idx):
    rb_client = reth_buffer.Client(rb_addr)
    weight_recv = perwez.RecvSocket(perwez_url, "weight", broadcast=True)

    batch_size = config["common"]["batch_size"]
    num_workers = config["common"]["num_workers"]
    eps = 0.4**(1 + (idx / (num_workers - 1)) * 7)
    solver = get_solver(config, device="cpu")
    log_flag = idx >= num_workers + (-num_workers // 3)  # aligned with ray
    worker = get_worker(
        config,
        exploration=eps,
        solver=solver,
        logger=getLogger(f"worker{idx}") if log_flag else None,
    )
    recv_weights_interval = config["common"]["recv_weights_interval"]

    step = 0
    prev_recv = 0
    try:
        while True:
            step += 1

            # load weights
            interval_flag = (step -
                             prev_recv) * batch_size >= recv_weights_interval
            if interval_flag and not weight_recv.empty():
                worker.load_weights(io.BytesIO(weight_recv.recv()))

            # step
            data = worker.step_batch(batch_size)
            loss = worker.solver.calc_loss(data)
            # format
            s0, a, r, s1, done = data
            s0 = np.asarray(s0, dtype="f4")
            a = np.asarray(a, dtype="i8")
            r = np.asarray(r, dtype="f4")
            s1 = np.asarray(s1, dtype="f4")
            done = np.asarray(done, dtype="f4")
            loss = np.asarray(loss, dtype="f4")
            # upload
            rb_client.append([s0, a, r, s1, done], loss)
    except KeyboardInterrupt:
        # silent exit
        pass
コード例 #5
0
ファイル: worker.py プロジェクト: sosp2021/reth
def worker_main(config, idx, size, perwez_url, rb_addr):
    rb_client = reth_buffer.Client(rb_addr)
    weight_recv = perwez.RecvSocket(perwez_url,
                                    "local-weights",
                                    broadcast=True)

    batch_size = config["common"]["rollout_batch_size"]
    eps = 0.4**(1 + (idx / (size - 1)) * 7)
    solver = get_solver(config, device="cpu")
    worker = get_worker(config,
                        exploration=eps,
                        solver=solver,
                        logger=getLogger(f"worker{idx}"))

    recv_weights_interval = config["common"]["recv_weights_interval"]
    prev_load = 0

    adder = NStepAdder(config["solver"]["gamma"], config["solver"]["n_step"])
    buffer = NumpyBuffer(batch_size, circular=False)
    while True:
        # load weights
        if (worker.cur_step -
                prev_load) > recv_weights_interval and not weight_recv.empty():
            worker.load_weights(io.BytesIO(weight_recv.recv()))
            prev_load = worker.cur_step

        # step
        s0, a, r, s1, done = worker.step()
        s0 = np.asarray(s0, dtype="f4")
        a = np.asarray(a, dtype="i8")
        r = np.asarray(r, dtype="f4")
        s1 = np.asarray(s1, dtype="f4")
        done = np.asarray(done, dtype="f4")
        # adder
        row = adder.push(s0, a, r, s1, done)
        if row is None:
            continue
        buffer.append(row)
        if buffer.size == buffer.capacity:
            loss = worker.solver.calc_loss(buffer.data)
            loss = np.asarray(loss, dtype="f4")
            rb_client.append(buffer.data, loss, compress=True)
            buffer.clear()
コード例 #6
0
def test_reth_buffer(data):
    import reth_buffer

    print("TEST RETH_BUFFER")
    print("initializing...")
    buffer_process, addr = reth_buffer.start_server(CAPACITY, BATCH_SIZE)
    client = reth_buffer.Client(addr)
    client.append(data, np.ones(CAPACITY))
    # loader = reth_buffer.TorchCudaLoader(addr)
    loader = reth_buffer.NumpyLoader(addr)
    # init
    loader.sample()
    print("ready")
    t0 = time.perf_counter()
    for _ in range(TEST_CNT):
        loader.sample()
    t1 = time.perf_counter()
    print(TEST_CNT, t1 - t0)
    buffer_process.terminate()
    buffer_process.join()
コード例 #7
0
ファイル: run.py プロジェクト: sosp2021/reth
def trainer_main(config, perwez_url, rb_addr):
    weight_send = perwez.SendSocket(perwez_url, "weight", broadcast=True)
    rb_client = reth_buffer.Client(rb_addr)
    loader = TorchCudaLoader(rb_addr)
    trainer = get_trainer(config)
    send_weights_interval = config["common"]["send_weights_interval"]

    try:
        ts = 0
        for data, indices, weights in loader:
            ts += 1
            if trainer.cur_time > 3600 * 10:
                return
            loss = trainer.step(data, weights=weights)
            rb_client.update_priorities(np.asarray(indices), np.asarray(loss))

            if ts % send_weights_interval == 0:
                weight_send.send(trainer.save_weights().getbuffer())
    except KeyboardInterrupt:
        # silent exit
        pass
コード例 #8
0
ファイル: trainer.py プロジェクト: sosp2021/reth
def trainer_main(config, perwez_url, rb_addrs):
    weights_send = perwez.SendSocket(perwez_url, "weights", broadcast=True)
    rb_clients = [reth_buffer.Client(addr) for addr in rb_addrs]
    rb_loaders = [
        reth_buffer.TorchCudaLoader(addr, buffer_size=4, num_procs=2)
        for addr in rb_addrs
    ]
    trainer = get_trainer(config)
    send_weights_interval = config["common"]["send_weights_interval"]
    ts = 0
    while True:
        ts += 1
        if trainer.cur_time > 3600 * 40:
            return
        idx = ts % len(rb_clients)
        data, indices, weights = rb_loaders[idx].sample()
        loss = trainer.step(data, weights=weights)
        rb_clients[idx].update_priorities(np.asarray(indices),
                                          np.asarray(loss))

        if ts % send_weights_interval == 0:
            stream = io.BytesIO()
            trainer.save_weights(stream)
            weights_send.send(stream.getbuffer())