示例#1
0
def trainer_main_tf_dataset(perwez_url, config):
    weight_send = perwez.SendSocket(perwez_url, "weight", broadcast=True)
    # init reverb
    reverb_client = reverb.Client(f"localhost:{PORT}")
    # reverb dataset
    def _make_dataset(_):
        dataset = reverb.ReplayDataset(
            f"localhost:{PORT}",
            TABLE_NAME,
            max_in_flight_samples_per_worker=config["common"]["batch_size"],
            dtypes=(tf.float32, tf.int64, tf.float32, tf.float32, tf.float32),
            shapes=(
                tf.TensorShape((4, 84, 84)),
                tf.TensorShape([]),
                tf.TensorShape([]),
                tf.TensorShape((4, 84, 84)),
                tf.TensorShape([]),
            ),
        )
        dataset = dataset.batch(config["common"]["batch_size"], drop_remainder=True)
        return dataset

    num_parallel_calls = 16
    prefetch_size = 4
    dataset = tf.data.Dataset.range(num_parallel_calls)
    dataset = dataset.interleave(
        map_func=_make_dataset,
        cycle_length=num_parallel_calls,
        num_parallel_calls=num_parallel_calls,
        deterministic=False,
    )
    dataset = dataset.prefetch(prefetch_size)
    numpy_iter = dataset.as_numpy_iterator()

    trainer = get_trainer(config)
    sync_weights_interval = config["common"]["sync_weights_interval"]
    ts = 0
    while True:
        ts += 1

        info, data = next(numpy_iter)
        indices = info.key
        weights = info.probability
        weights = (weights / weights.min()) ** (-0.4)

        loss = trainer.step(data, weights=weights)
        reverb_client.mutate_priorities(
            TABLE_NAME, updates=dict(zip(np.asarray(indices), np.asarray(loss)))
        )

        if ts % sync_weights_interval == 0:
            weight_send.send(trainer.save_weights().getbuffer())
示例#2
0
文件: run.py 项目: sosp2021/reth
def trainer_main(config, perwez_url, rb_addr):
    weight_send = perwez.SendSocket(perwez_url, "weight", broadcast=True)
    rb_client = reth_buffer.Client(rb_addr)
    loader = TorchCudaLoader(rb_addr)
    trainer = get_trainer(config)
    send_weights_interval = config["common"]["send_weights_interval"]

    try:
        ts = 0
        for data, indices, weights in loader:
            ts += 1
            if trainer.cur_time > 3600 * 10:
                return
            loss = trainer.step(data, weights=weights)
            rb_client.update_priorities(np.asarray(indices), np.asarray(loss))

            if ts % send_weights_interval == 0:
                weight_send.send(trainer.save_weights().getbuffer())
    except KeyboardInterrupt:
        # silent exit
        pass
示例#3
0
def trainer_main_np_client(perwez_url, config):
    weight_send = perwez.SendSocket(perwez_url, "weight", broadcast=True)
    # init reverb
    reverb_client = reverb.Client(f"localhost:{PORT}")

    trainer = get_trainer(config)
    sync_weights_interval = config["common"]["sync_weights_interval"]
    ts = 0
    while True:
        ts += 1

        samples = reverb_client.sample(TABLE_NAME, config["common"]["batch_size"])
        samples = list(samples)
        data, indices, weights = _reverb_samples_to_ndarray(samples)
        weights = (weights / weights.min()) ** (-0.4)

        loss = trainer.step(data, weights=weights)
        reverb_client.mutate_priorities(
            TABLE_NAME, updates=dict(zip(np.asarray(indices), np.asarray(loss)))
        )

        if ts % sync_weights_interval == 0:
            weight_send.send(trainer.save_weights().getbuffer())
示例#4
0
文件: trainer.py 项目: sosp2021/reth
def trainer_main(config, perwez_url, rb_addrs):
    weights_send = perwez.SendSocket(perwez_url, "weights", broadcast=True)
    rb_clients = [reth_buffer.Client(addr) for addr in rb_addrs]
    rb_loaders = [
        reth_buffer.TorchCudaLoader(addr, buffer_size=4, num_procs=2)
        for addr in rb_addrs
    ]
    trainer = get_trainer(config)
    send_weights_interval = config["common"]["send_weights_interval"]
    ts = 0
    while True:
        ts += 1
        if trainer.cur_time > 3600 * 40:
            return
        idx = ts % len(rb_clients)
        data, indices, weights = rb_loaders[idx].sample()
        loss = trainer.step(data, weights=weights)
        rb_clients[idx].update_priorities(np.asarray(indices),
                                          np.asarray(loss))

        if ts % send_weights_interval == 0:
            stream = io.BytesIO()
            trainer.save_weights(stream)
            weights_send.send(stream.getbuffer())
示例#5
0
from reth.algorithm.util import calculate_discount_rewards_with_dones
from reth.buffer import DynamicSizeBuffer
from reth.presets.config import get_solver, get_worker, get_trainer

MAX_TS = 100000
GAMMA = 0.99
UPDATE_INTERVAL = 2000

if __name__ == "__main__":
    config_path = path.join(path.dirname(__file__), "config.yaml")
    solver = get_solver(config_path)
    act_solver = get_solver(config_path)
    # shared solver
    worker = get_worker(config_path, solver=act_solver)
    trainer = get_trainer(config_path, solver=solver)

    episode_buffer = DynamicSizeBuffer(64)
    data_buffer = DynamicSizeBuffer(64)
    act_solver.sync_weights(solver)
    for _ in range(MAX_TS):
        for _ in range(UPDATE_INTERVAL):
            a, logprob = act_solver.act(worker.s0)
            s0, a, r, s1, done = worker.step(a)
            episode_buffer.append((s0, a, r, logprob, done))
            if done:
                s0, a, r, logprob, done = episode_buffer.data
                r = calculate_discount_rewards_with_dones(r, done, GAMMA)
                data_buffer.append_batch((s0, a, r, logprob))
                episode_buffer.clear()