def trainer_main_tf_dataset(perwez_url, config): weight_send = perwez.SendSocket(perwez_url, "weight", broadcast=True) # init reverb reverb_client = reverb.Client(f"localhost:{PORT}") # reverb dataset def _make_dataset(_): dataset = reverb.ReplayDataset( f"localhost:{PORT}", TABLE_NAME, max_in_flight_samples_per_worker=config["common"]["batch_size"], dtypes=(tf.float32, tf.int64, tf.float32, tf.float32, tf.float32), shapes=( tf.TensorShape((4, 84, 84)), tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape((4, 84, 84)), tf.TensorShape([]), ), ) dataset = dataset.batch(config["common"]["batch_size"], drop_remainder=True) return dataset num_parallel_calls = 16 prefetch_size = 4 dataset = tf.data.Dataset.range(num_parallel_calls) dataset = dataset.interleave( map_func=_make_dataset, cycle_length=num_parallel_calls, num_parallel_calls=num_parallel_calls, deterministic=False, ) dataset = dataset.prefetch(prefetch_size) numpy_iter = dataset.as_numpy_iterator() trainer = get_trainer(config) sync_weights_interval = config["common"]["sync_weights_interval"] ts = 0 while True: ts += 1 info, data = next(numpy_iter) indices = info.key weights = info.probability weights = (weights / weights.min()) ** (-0.4) loss = trainer.step(data, weights=weights) reverb_client.mutate_priorities( TABLE_NAME, updates=dict(zip(np.asarray(indices), np.asarray(loss))) ) if ts % sync_weights_interval == 0: weight_send.send(trainer.save_weights().getbuffer())
def trainer_main(config, perwez_url, rb_addr): weight_send = perwez.SendSocket(perwez_url, "weight", broadcast=True) rb_client = reth_buffer.Client(rb_addr) loader = TorchCudaLoader(rb_addr) trainer = get_trainer(config) send_weights_interval = config["common"]["send_weights_interval"] try: ts = 0 for data, indices, weights in loader: ts += 1 if trainer.cur_time > 3600 * 10: return loss = trainer.step(data, weights=weights) rb_client.update_priorities(np.asarray(indices), np.asarray(loss)) if ts % send_weights_interval == 0: weight_send.send(trainer.save_weights().getbuffer()) except KeyboardInterrupt: # silent exit pass
def trainer_main_np_client(perwez_url, config): weight_send = perwez.SendSocket(perwez_url, "weight", broadcast=True) # init reverb reverb_client = reverb.Client(f"localhost:{PORT}") trainer = get_trainer(config) sync_weights_interval = config["common"]["sync_weights_interval"] ts = 0 while True: ts += 1 samples = reverb_client.sample(TABLE_NAME, config["common"]["batch_size"]) samples = list(samples) data, indices, weights = _reverb_samples_to_ndarray(samples) weights = (weights / weights.min()) ** (-0.4) loss = trainer.step(data, weights=weights) reverb_client.mutate_priorities( TABLE_NAME, updates=dict(zip(np.asarray(indices), np.asarray(loss))) ) if ts % sync_weights_interval == 0: weight_send.send(trainer.save_weights().getbuffer())
def trainer_main(config, perwez_url, rb_addrs): weights_send = perwez.SendSocket(perwez_url, "weights", broadcast=True) rb_clients = [reth_buffer.Client(addr) for addr in rb_addrs] rb_loaders = [ reth_buffer.TorchCudaLoader(addr, buffer_size=4, num_procs=2) for addr in rb_addrs ] trainer = get_trainer(config) send_weights_interval = config["common"]["send_weights_interval"] ts = 0 while True: ts += 1 if trainer.cur_time > 3600 * 40: return idx = ts % len(rb_clients) data, indices, weights = rb_loaders[idx].sample() loss = trainer.step(data, weights=weights) rb_clients[idx].update_priorities(np.asarray(indices), np.asarray(loss)) if ts % send_weights_interval == 0: stream = io.BytesIO() trainer.save_weights(stream) weights_send.send(stream.getbuffer())
from reth.algorithm.util import calculate_discount_rewards_with_dones from reth.buffer import DynamicSizeBuffer from reth.presets.config import get_solver, get_worker, get_trainer MAX_TS = 100000 GAMMA = 0.99 UPDATE_INTERVAL = 2000 if __name__ == "__main__": config_path = path.join(path.dirname(__file__), "config.yaml") solver = get_solver(config_path) act_solver = get_solver(config_path) # shared solver worker = get_worker(config_path, solver=act_solver) trainer = get_trainer(config_path, solver=solver) episode_buffer = DynamicSizeBuffer(64) data_buffer = DynamicSizeBuffer(64) act_solver.sync_weights(solver) for _ in range(MAX_TS): for _ in range(UPDATE_INTERVAL): a, logprob = act_solver.act(worker.s0) s0, a, r, s1, done = worker.step(a) episode_buffer.append((s0, a, r, logprob, done)) if done: s0, a, r, logprob, done = episode_buffer.data r = calculate_discount_rewards_with_dones(r, done, GAMMA) data_buffer.append_batch((s0, a, r, logprob)) episode_buffer.clear()