def get_apex_dqn_car_trainer(): ModelCatalog.register_custom_model("my_model", TorchCustomModel) config = { "env": StoppingCar, # "model": { "custom_model": "my_model", "fcnet_hiddens": [64, 64], "fcnet_activation": "relu" }, # model config, "n_step": 1, "lr": 0.0005, "grad_clip": 2500, "batch_mode": "complete_episodes", "num_workers": 7, # parallelism "num_envs_per_worker": 10, "train_batch_size": 512, "hiddens": [32], "framework": "torch", "optimizer": { "num_replay_buffer_shards": 1 }, "horizon": 1000 } trainer = dqn.ApexTrainer(config=config) return trainer, config
def get_rl_agent(agent_name, config, env_to_agent): if agent_name == A2C: import ray.rllib.agents.a3c as a2c agent = a2c.A2CTrainer(config=config, env=env_to_agent) elif agent_name == A3C: import ray.rllib.agents.a3c as a3c agent = a3c.A3CTrainer(config=config, env=env_to_agent) elif agent_name == BC: import ray.rllib.agents.marwil as bc agent = bc.BCTrainer(config=config, env=env_to_agent) elif agent_name == DQN: import ray.rllib.agents.dqn as dqn agent = dqn.DQNTrainer(config=config, env=env_to_agent) elif agent_name == APEX_DQN: import ray.rllib.agents.dqn as dqn agent = dqn.ApexTrainer(config=config, env=env_to_agent) elif agent_name == IMPALA: import ray.rllib.agents.impala as impala agent = impala.ImpalaTrainer(config=config, env=env_to_agent) elif agent_name == MARWIL: import ray.rllib.agents.marwil as marwil agent = marwil.MARWILTrainer(config=config, env=env_to_agent) elif agent_name == PG: import ray.rllib.agents.pg as pg agent = pg.PGTrainer(config=config, env=env_to_agent) elif agent_name == PPO: import ray.rllib.agents.ppo as ppo agent = ppo.PPOTrainer(config=config, env=env_to_agent) elif agent_name == APPO: import ray.rllib.agents.ppo as ppo agent = ppo.APPOTrainer(config=config, env=env_to_agent) elif agent_name == SAC: import ray.rllib.agents.sac as sac agent = sac.SACTrainer(config=config, env=env_to_agent) elif agent_name == LIN_UCB: import ray.rllib.contrib.bandits.agents.lin_ucb as lin_ucb agent = lin_ucb.LinUCBTrainer(config=config, env=env_to_agent) elif agent_name == LIN_TS: import ray.rllib.contrib.bandits.agents.lin_ts as lin_ts agent = lin_ts.LinTSTrainer(config=config, env=env_to_agent) else: raise Exception("Not valid agent name") return agent
def get_rllib_agent(agent_name, env_name, env, env_to_agent): config = get_config(env_name, env, 1) if is_rllib_agent(agent_name) else {} if agent_name == RLLIB_A2C: import ray.rllib.agents.a3c as a2c agent = a2c.A2CTrainer(config=config, env=env_to_agent) elif agent_name == RLLIB_A3C: import ray.rllib.agents.a3c as a3c agent = a3c.A3CTrainer(config=config, env=env_to_agent) elif agent_name == RLLIB_BC: import ray.rllib.agents.marwil as bc agent = bc.BCTrainer(config=config, env=env_to_agent) elif agent_name == RLLIB_DQN: import ray.rllib.agents.dqn as dqn agent = dqn.DQNTrainer(config=config, env=env_to_agent) elif agent_name == RLLIB_APEX_DQN: import ray.rllib.agents.dqn as dqn agent = dqn.ApexTrainer(config=config, env=env_to_agent) elif agent_name == RLLIB_IMPALA: import ray.rllib.agents.impala as impala agent = impala.ImpalaTrainer(config=config, env=env_to_agent) elif agent_name == RLLIB_MARWIL: import ray.rllib.agents.marwil as marwil agent = marwil.MARWILTrainer(config=config, env=env_to_agent) elif agent_name == RLLIB_PG: import ray.rllib.agents.pg as pg agent = pg.PGTrainer(config=config, env=env_to_agent) elif agent_name == RLLIB_PPO: import ray.rllib.agents.ppo as ppo agent = ppo.PPOTrainer(config=config, env=env_to_agent) elif agent_name == RLLIB_APPO: import ray.rllib.agents.ppo as ppo agent = ppo.APPOTrainer(config=config, env=env_to_agent) elif agent_name == RLLIB_SAC: import ray.rllib.agents.sac as sac agent = sac.SACTrainer(config=config, env=env_to_agent) elif agent_name == RLLIB_LIN_UCB: import ray.rllib.contrib.bandits.agents.lin_ucb as lin_ucb agent = lin_ucb.LinUCBTrainer(config=config, env=env_to_agent) elif agent_name == RLLIB_LIN_TS: import ray.rllib.contrib.bandits.agents.lin_ts as lin_ts agent = lin_ts.LinTSTrainer(config=config, env=env_to_agent) return agent
print(m[0]) checkpoint_numbers.append(int(m[0])) mc = max(checkpoint_numbers) checkpoint_path = path_to_results+"/"+"checkpoint_{}/checkpoint-{}".format(mc,mc) print("found {} checkpoints".format(len(checkpoint_numbers))) print("restoring "+checkpoint_path) # ============================================================== # # evaluation {{{ # ============================================================== # #ray.init() ray.init(temp_dir=tmpdir+"/ray") # you may need to change the temp directory in case it runs on a cluster or shared machine if config["optimizer_class"] == "AsyncReplayOptimizer": trainer = dqn.ApexTrainer(config=config, env=CodeEnv) else: trainer = dqn.DQNTrainer(config=config, env=CodeEnv) trainer.restore(checkpoint_path) env = CodeEnv(env_config) n = env.n dB_len = len(dB_range) BitErr = np.zeros([dB_len], dtype=int) CwErr = np.zeros([dB_len], dtype=int) totCw = np.zeros([dB_len], dtype=int) totBit = np.zeros([dB_len], dtype=int) for i in range(dB_len): print("\n--------\nSimulating EbNo = {} dB".format(dB_range[i])) env.set_EbNo_dB(dB_range[i])
def restart_ray(): global _ray_error global save_interval global save_counter _ray_error = False largest = level_state.find_latest_checkpoint(checkpoint_all) latest_checkpoint = None if largest > -1: latest_checkpoint = os.path.join(checkpoint_all, f"checkpoint_{str(largest)}", f"checkpoint-{str(largest)}") register_env(c.ENV_NAME, lambda config: MarioEnv(config)) ray_init(ignore_reinit_error=True) trainer = dqn.ApexTrainer( env=c.ENV_NAME, config={ "num_gpus": 1, "num_workers": 1, "eager": False, "model": { "conv_filters": [[c.OBS_FRAMES, c.OBS_SIZE, c.OBS_SIZE]] }, "env_config": dict(actions=COMPLEX_MOVEMENT, window=False, fps=60_000) # "train_batch_size": 2048 }) if latest_checkpoint and c.LOAD_CHECKPOINT: trainer.restore(latest_checkpoint) print("Loaded Mario checkpoint:", latest_checkpoint) if c.EVALUATE: eval_thread = Thread(target=test, args=(trainer, )) eval_thread.daemon = True eval_thread.start() try: while True: trainer.train() if save_counter % save_interval == 1: checkpoint = trainer.save(checkpoint_all) print("Saved Mario checkpoint:", checkpoint) save_counter += 1 except RayOutOfMemoryError: print("Ray out of memory!") # print("********************* objgraph.show_most_common_types() ************************") # # Display most common types in console. # objgraph.show_most_common_types() # print("********************* objgraph.show_growth(limit=10) ************************") # # Display common type growth in console. # objgraph.show_growth(limit=10) finally: # print("Restarting ray...") _ray_error = True if c.EVALUATE: eval_thread.join() ray_shutdown() restart_ray()