Example #1
0
def get_apex_dqn_car_trainer():
    ModelCatalog.register_custom_model("my_model", TorchCustomModel)
    config = {
        "env": StoppingCar,  #
        "model": {
            "custom_model": "my_model",
            "fcnet_hiddens": [64, 64],
            "fcnet_activation": "relu"
        },  # model config,
        "n_step": 1,
        "lr": 0.0005,
        "grad_clip": 2500,
        "batch_mode": "complete_episodes",
        "num_workers": 7,  # parallelism
        "num_envs_per_worker": 10,
        "train_batch_size": 512,
        "hiddens": [32],
        "framework": "torch",
        "optimizer": {
            "num_replay_buffer_shards": 1
        },
        "horizon": 1000
    }
    trainer = dqn.ApexTrainer(config=config)
    return trainer, config
Example #2
0
def get_rl_agent(agent_name, config, env_to_agent):
    if agent_name == A2C:
        import ray.rllib.agents.a3c as a2c
        agent = a2c.A2CTrainer(config=config, env=env_to_agent)
    elif agent_name == A3C:
        import ray.rllib.agents.a3c as a3c
        agent = a3c.A3CTrainer(config=config, env=env_to_agent)
    elif agent_name == BC:
        import ray.rllib.agents.marwil as bc
        agent = bc.BCTrainer(config=config, env=env_to_agent)
    elif agent_name == DQN:
        import ray.rllib.agents.dqn as dqn
        agent = dqn.DQNTrainer(config=config, env=env_to_agent)
    elif agent_name == APEX_DQN:
        import ray.rllib.agents.dqn as dqn
        agent = dqn.ApexTrainer(config=config, env=env_to_agent)
    elif agent_name == IMPALA:
        import ray.rllib.agents.impala as impala
        agent = impala.ImpalaTrainer(config=config, env=env_to_agent)
    elif agent_name == MARWIL:
        import ray.rllib.agents.marwil as marwil
        agent = marwil.MARWILTrainer(config=config, env=env_to_agent)
    elif agent_name == PG:
        import ray.rllib.agents.pg as pg
        agent = pg.PGTrainer(config=config, env=env_to_agent)
    elif agent_name == PPO:
        import ray.rllib.agents.ppo as ppo
        agent = ppo.PPOTrainer(config=config, env=env_to_agent)
    elif agent_name == APPO:
        import ray.rllib.agents.ppo as ppo
        agent = ppo.APPOTrainer(config=config, env=env_to_agent)
    elif agent_name == SAC:
        import ray.rllib.agents.sac as sac
        agent = sac.SACTrainer(config=config, env=env_to_agent)
    elif agent_name == LIN_UCB:
        import ray.rllib.contrib.bandits.agents.lin_ucb as lin_ucb
        agent = lin_ucb.LinUCBTrainer(config=config, env=env_to_agent)
    elif agent_name == LIN_TS:
        import ray.rllib.contrib.bandits.agents.lin_ts as lin_ts
        agent = lin_ts.LinTSTrainer(config=config, env=env_to_agent)
    else:
        raise Exception("Not valid agent name")
    return agent
def get_rllib_agent(agent_name, env_name, env, env_to_agent):
    config = get_config(env_name, env, 1) if is_rllib_agent(agent_name) else {}
    if agent_name == RLLIB_A2C:
        import ray.rllib.agents.a3c as a2c
        agent = a2c.A2CTrainer(config=config, env=env_to_agent)
    elif agent_name == RLLIB_A3C:
        import ray.rllib.agents.a3c as a3c
        agent = a3c.A3CTrainer(config=config, env=env_to_agent)
    elif agent_name == RLLIB_BC:
        import ray.rllib.agents.marwil as bc
        agent = bc.BCTrainer(config=config, env=env_to_agent)
    elif agent_name == RLLIB_DQN:
        import ray.rllib.agents.dqn as dqn
        agent = dqn.DQNTrainer(config=config, env=env_to_agent)
    elif agent_name == RLLIB_APEX_DQN:
        import ray.rllib.agents.dqn as dqn
        agent = dqn.ApexTrainer(config=config, env=env_to_agent)
    elif agent_name == RLLIB_IMPALA:
        import ray.rllib.agents.impala as impala
        agent = impala.ImpalaTrainer(config=config, env=env_to_agent)
    elif agent_name == RLLIB_MARWIL:
        import ray.rllib.agents.marwil as marwil
        agent = marwil.MARWILTrainer(config=config, env=env_to_agent)
    elif agent_name == RLLIB_PG:
        import ray.rllib.agents.pg as pg
        agent = pg.PGTrainer(config=config, env=env_to_agent)
    elif agent_name == RLLIB_PPO:
        import ray.rllib.agents.ppo as ppo
        agent = ppo.PPOTrainer(config=config, env=env_to_agent)
    elif agent_name == RLLIB_APPO:
        import ray.rllib.agents.ppo as ppo
        agent = ppo.APPOTrainer(config=config, env=env_to_agent)
    elif agent_name == RLLIB_SAC:
        import ray.rllib.agents.sac as sac
        agent = sac.SACTrainer(config=config, env=env_to_agent)
    elif agent_name == RLLIB_LIN_UCB:
        import ray.rllib.contrib.bandits.agents.lin_ucb as lin_ucb
        agent = lin_ucb.LinUCBTrainer(config=config, env=env_to_agent)
    elif agent_name == RLLIB_LIN_TS:
        import ray.rllib.contrib.bandits.agents.lin_ts as lin_ts
        agent = lin_ts.LinTSTrainer(config=config, env=env_to_agent)
    return agent
Example #4
0
    print(m[0])
    checkpoint_numbers.append(int(m[0]))

mc = max(checkpoint_numbers)
checkpoint_path = path_to_results+"/"+"checkpoint_{}/checkpoint-{}".format(mc,mc)
print("found {} checkpoints".format(len(checkpoint_numbers)))
print("restoring "+checkpoint_path)

# ============================================================== #
# evaluation {{{
# ============================================================== #
#ray.init()
ray.init(temp_dir=tmpdir+"/ray")  # you may need to change the temp directory in case it runs on a cluster or shared machine

if config["optimizer_class"] == "AsyncReplayOptimizer":
    trainer = dqn.ApexTrainer(config=config, env=CodeEnv)
else:
    trainer = dqn.DQNTrainer(config=config, env=CodeEnv)
trainer.restore(checkpoint_path)
env = CodeEnv(env_config)
n = env.n

dB_len = len(dB_range)
BitErr = np.zeros([dB_len], dtype=int)
CwErr = np.zeros([dB_len], dtype=int)
totCw = np.zeros([dB_len], dtype=int)
totBit = np.zeros([dB_len], dtype=int)

for i in range(dB_len):
    print("\n--------\nSimulating EbNo = {} dB".format(dB_range[i]))
    env.set_EbNo_dB(dB_range[i])
Example #5
0
    def restart_ray():
        global _ray_error
        global save_interval
        global save_counter
        _ray_error = False

        largest = level_state.find_latest_checkpoint(checkpoint_all)
        latest_checkpoint = None
        if largest > -1:
            latest_checkpoint = os.path.join(checkpoint_all,
                                             f"checkpoint_{str(largest)}",
                                             f"checkpoint-{str(largest)}")

        register_env(c.ENV_NAME, lambda config: MarioEnv(config))

        ray_init(ignore_reinit_error=True)

        trainer = dqn.ApexTrainer(
            env=c.ENV_NAME,
            config={
                "num_gpus":
                1,
                "num_workers":
                1,
                "eager":
                False,
                "model": {
                    "conv_filters": [[c.OBS_FRAMES, c.OBS_SIZE, c.OBS_SIZE]]
                },
                "env_config":
                dict(actions=COMPLEX_MOVEMENT, window=False, fps=60_000)
                # "train_batch_size": 2048
            })
        if latest_checkpoint and c.LOAD_CHECKPOINT:
            trainer.restore(latest_checkpoint)
            print("Loaded Mario checkpoint:", latest_checkpoint)

        if c.EVALUATE:
            eval_thread = Thread(target=test, args=(trainer, ))
            eval_thread.daemon = True
            eval_thread.start()
        try:
            while True:
                trainer.train()
                if save_counter % save_interval == 1:
                    checkpoint = trainer.save(checkpoint_all)
                    print("Saved Mario checkpoint:", checkpoint)

                save_counter += 1

        except RayOutOfMemoryError:
            print("Ray out of memory!")
            # print("********************* objgraph.show_most_common_types() ************************")
            # # Display most common types in console.
            # objgraph.show_most_common_types()
            # print("********************* objgraph.show_growth(limit=10) ************************")
            # # Display common type growth in console.
            # objgraph.show_growth(limit=10)
        finally:
            # print("Restarting ray...")
            _ray_error = True
            if c.EVALUATE:
                eval_thread.join()
            ray_shutdown()
            restart_ray()