示例#1
0
def cli_td3_train(environment, num_epochs, num_episodes, num_envs, num_evals,
                  num_cpus, gamma, tau, batch_size, replay_buffer,
                  reward_scale, policy_delay, random_steps, action_noise,
                  obs_normalizer, obs_clip, render, load, save, seed):
    """Trains a TD3 agent on an OpenAI's gym environment."""
    trainer = pyrl.trainer.AgentTrainer(agent_cls=pyrl.agents.TD3,
                                        env_name=environment,
                                        seed=seed,
                                        num_envs=num_envs,
                                        num_cpus=num_cpus,
                                        root_log_dir=os.path.join(save, "log"))
    pyrl.cli.util.initialize_seed(seed)
    trainer.env.seed(seed)

    if load:
        _LOG.info("Loading agent")
        trainer.initialize_agent(agent_path=load)
    else:
        _LOG.info("Initializing new agent")
        trainer.initialize_agent(
            agent_kwargs=dict(gamma=gamma,
                              tau=tau,
                              batch_size=batch_size,
                              reward_scale=reward_scale,
                              replay_buffer_size=replay_buffer,
                              policy_delay=policy_delay,
                              random_steps=random_steps,
                              actor_lr=1e-3,
                              critic_lr=1e-3,
                              observation_normalizer=obs_normalizer,
                              observation_clip=obs_clip,
                              action_noise=action_noise))

    _LOG.info("Agent Data")
    _LOG.info("  = Train steps: %d", trainer.agent.num_train_steps)
    _LOG.info("  = Replay buffer: %d", len(trainer.agent.replay_buffer))
    _LOG.info("    = Max. Size: %d", trainer.agent.replay_buffer.max_size)

    _LOG.debug("Actor network\n%s", str(trainer.agent.actor))
    _LOG.debug("Critic 1 network\n%s", str(trainer.agent.critic_1))
    _LOG.debug("Critic 2 network\n%s", str(trainer.agent.critic_2))

    _LOG.info("Action space: %s", str(trainer.env.action_space))
    _LOG.info("Observation space: %s", str(trainer.env.observation_space))

    if render:  # Some environments must be rendered
        trainer.env.render()  # before running

    with trainer:
        _run_train(trainer, num_epochs, num_episodes, num_evals, save)

    sys.exit(0)
示例#2
0
文件: her_sac.py 项目: jponf/pyrl
def cli_her_sac_train(environment, num_epochs, num_cycles, num_episodes,
                      num_envs, num_evals, num_cpus, demo_path, reward_scale,
                      replay_buffer, random_steps, replay_k, q_filter,
                      obs_normalizer, obs_clip, render, load, save, seed):
    """Trains a HER + SAC agent on an OpenAI's gym environment."""
    trainer = pyrl.trainer.AgentTrainer(agent_cls=pyrl.agents.HerSAC,
                                        env_name=environment,
                                        seed=seed,
                                        num_envs=num_envs,
                                        num_cpus=num_cpus,
                                        root_log_dir=os.path.join(save, "log"))

    pyrl.cli.util.initialize_seed(seed)
    trainer.env.seed(seed)

    if load:
        _LOG.info("Save path already exists, loading previously trained agent")
        trainer.initialize_agent(agent_path=load, demo_path=demo_path)
    else:
        _LOG.info("Initializing new agent")
        env = trainer.env
        agent_kwargs = dict(gamma=1.0 - 1.0 / env.spec.max_episode_steps,
                            tau=0.005,
                            batch_size=128,
                            reward_scale=reward_scale,
                            replay_buffer_episodes=int(
                                math.ceil(replay_buffer /
                                          env.spec.max_episode_steps)),
                            replay_buffer_steps=env.spec.max_episode_steps,
                            random_steps=random_steps,
                            replay_k=replay_k,
                            demo_batch_size=128,
                            q_filter=q_filter,
                            actor_lr=1e-3,
                            critic_lr=1e-3,
                            observation_normalizer=obs_normalizer,
                            observation_clip=obs_clip)
        trainer.initialize_agent(agent_kwargs=agent_kwargs,
                                 demo_path=demo_path)

    agent = trainer.agent
    _LOG.info("Agent Data")
    _LOG.info("  = Train steps: %d", trainer.agent.num_train_steps)
    _LOG.info("  = Replay buffer")
    _LOG.info("    = Episodes: %d", agent.replay_buffer.num_episodes)
    _LOG.info("        = Max: %d", agent.replay_buffer.max_episodes)
    _LOG.info("    = Steps: %d", agent.replay_buffer.count_steps())
    _LOG.info("        = Max: %d", agent.replay_buffer.max_steps)

    _LOG.debug("Actor network\n%s", str(agent.actor))
    _LOG.debug("Critic 1 network\n%s", str(agent.critic_1))
    _LOG.debug("Critic 2 network\n%s", str(agent.critic_2))

    _LOG.info("Action space: %s", str(trainer.env.action_space))
    _LOG.info("Observation space: %s", str(trainer.env.observation_space))

    if render:  # Some environments must be rendered
        trainer.env.render()  # before running

    with trainer:
        _run_train(trainer, num_epochs, num_cycles, num_episodes, num_evals,
                   save)

    sys.exit(0)
示例#3
0
def cli_her_ddpg_train(environment, num_epochs, num_cycles, num_episodes,
                       num_envs, num_evals, num_cpus, demo_path, eps_greedy,
                       reward_scale, replay_buffer, replay_k, q_filter,
                       action_noise, obs_normalizer, obs_clip, render, load,
                       save, seed):
    """Trains a HER + DDPG agent on an OpenAI's gym environment."""
    trainer = pyrl.trainer.AgentTrainer(agent_cls=pyrl.agents.HerDDPG,
                                        env_name=environment,
                                        seed=seed,
                                        num_envs=num_envs,
                                        num_cpus=num_cpus,
                                        root_log_dir=os.path.join(save, "log"))

    pyrl.cli.util.initialize_seed(seed)
    trainer.env.seed(seed)

    if load:
        _LOG.info("Loading previously trained agent")
        trainer.initialize_agent(agent_path=load, demo_path=demo_path)
    else:
        _LOG.info("Initializing new agent")
        env = trainer.env
        agent_kwargs = dict(eps_greedy=eps_greedy,
                            gamma=1.0 - 1.0 / env.spec.max_episode_steps,
                            tau=0.005,
                            batch_size=128,
                            reward_scale=reward_scale,
                            replay_buffer_episodes=int(
                                math.ceil(replay_buffer /
                                          env.spec.max_episode_steps)),
                            replay_buffer_steps=env.spec.max_episode_steps,
                            replay_k=replay_k,
                            demo_batch_size=128,
                            q_filter=q_filter,
                            actor_lr=3e-4,
                            critic_lr=3e-4,
                            observation_normalizer=obs_normalizer,
                            observation_clip=obs_clip,
                            action_noise=action_noise)
        trainer.initialize_agent(agent_kwargs=agent_kwargs,
                                 demo_path=demo_path)

    agent = trainer.agent
    _LOG.info("Agent Data")
    _LOG.info("  = Train steps: %d", trainer.agent.num_train_steps)
    _LOG.info("  = Replay buffer")
    _LOG.info("    = Episodes: %d", agent.replay_buffer.num_episodes)
    _LOG.info("        = Max: %d", agent.replay_buffer.max_episodes)
    _LOG.info("    = Steps: %d", agent.replay_buffer.count_steps())
    _LOG.info("        = Max: %d", agent.replay_buffer.max_steps)

    _LOG.debug("Actor network\n%s", str(agent.actor))
    _LOG.debug("Critic network\n%s", str(agent.critic))

    _LOG.info("Action space: %s", str(trainer.env.action_space))
    _LOG.info("Observation space: %s", str(trainer.env.observation_space))

    with trainer:
        _run_train(trainer, num_epochs, num_cycles, num_episodes, num_evals,
                   save)

    sys.exit(0)