Esempio n. 1
0
 def test_gym_env(self):
     """
     Tests working of Gym Wrapper and the GymEnv function
     """
     env = VectorEnv("Pendulum-v0", env_type="gym")
     env.reset()
     env.step(env.sample())
     env.close()
Esempio n. 2
0
def main(args):
    ALGOS = {
        "sac": SAC,
        "a2c": A2C,
        "ppo": PPO1,
        "ddpg": DDPG,
        "td3": TD3,
        "vpg": VPG,
        "dqn": DQN,
    }

    algo = ALGOS[args.algo.lower()]
    env = VectorEnv(args.env,
                    n_envs=args.n_envs,
                    parallel=not args.serial,
                    env_type=args.env_type)

    logger = get_logger(args.log)
    trainer = None

    if args.algo in ["ppo", "vpg", "a2c"]:
        agent = algo(
            args.arch, env,
            rollout_size=args.rollout_size)  # , batch_size=args.batch_size)
        trainer = OnPolicyTrainer(
            agent,
            env,
            logger,
            epochs=args.epochs,
            render=args.render,
            log_interval=args.log_interval,
        )

    else:
        agent = algo(args.arch,
                     env,
                     replay_size=args.replay_size,
                     batch_size=args.batch_size)
        trainer = OffPolicyTrainer(
            agent,
            env,
            logger,
            epochs=args.epochs,
            render=args.render,
            warmup_steps=args.warmup_steps,
            log_interval=args.log_interval,
        )

    trainer.train()
    trainer.evaluate()
    env.render()
Esempio n. 3
0
def main(args):
    env = VectorEnv(
        args.env, n_envs=args.n_envs, parallel=not args.serial, env_type=args.env_type
    )

    input_dim, action_dim, discrete, action_lim = get_env_properties(env, "mlp")

    network = MlpActorCritic(
        input_dim,
        action_dim,
        (1, 1),  # layers
        (1, 1),
        "V",  # type of value function
        discrete,
        action_lim=action_lim,
        activation="relu",
    )
    
    generic_agent = A2C(network, env, rollout_size=args.rollout_size)

    agent_parameter_choices = {
        "gamma": [12, 121],
        # 'clip_param': [0.2, 0.3],
        # 'lr_policy': [0.001, 0.002],
        # 'lr_value': [0.001, 0.002]
    }

    generate(
        args.generations,
        args.population,
        agent_parameter_choices,
        env,
        generic_agent,
        args,
    )
Esempio n. 4
0
def test_ddpg():
    env = VectorEnv("Pendulum-v0", 2)
    algo = DDPG("mlp", env, noise=NormalActionNoise, layers=[1, 1])

    trainer = OffPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1)
    trainer.train()
    shutil.rmtree("./logs")
Esempio n. 5
0
def test_off_policy_trainer():
    env = VectorEnv("Pendulum-v0", 2)
    algo = DDPG("mlp", env, replay_size=100)
    trainer = OffPolicyTrainer(algo, env, ["stdout"], epochs=1, evaluate_episodes=2)
    assert trainer.off_policy
    trainer.train()
    trainer.evaluate()
Esempio n. 6
0
def test_sac_shared():
    env = VectorEnv("Pendulum-v0", 2)
    algo = SAC(
        "mlp",
        env,
        batch_size=5,
        shared_layers=[1, 1],
        policy_layers=[1, 1],
        value_layers=[1, 1],
        max_timesteps=100,
    )

    trainer = OffPolicyTrainer(
        algo,
        env,
        log_mode=["csv"],
        logdir="./logs",
        epochs=5,
        max_ep_len=500,
        warmup_steps=10,
        start_update=10,
        max_timesteps=100,
    )
    trainer.train()
    shutil.rmtree("./logs")
Esempio n. 7
0
def test_on_policy_trainer():
    env = VectorEnv("CartPole-v1", 2)
    algo = PPO1("mlp", env)
    trainer = OnPolicyTrainer(algo, env, ["stdout"], epochs=1, evaluate_episodes=2)
    assert not trainer.off_policy
    trainer.train()
    trainer.evaluate()
Esempio n. 8
0
def test_sac():
    env = VectorEnv("Pendulum-v0", 2)
    algo = SAC("mlp", env, layers=[1, 1])

    trainer = OffPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1)
    trainer.train()
    shutil.rmtree("./logs")
Esempio n. 9
0
 def test_ppo1(self):
     env = VectorEnv("CartPole-v0")
     algo = PPO1("mlp", env, rollout_size=128)
     trainer = OnPolicyTrainer(
         algo, env, log_mode=["csv"], logdir="./logs", epochs=1
     )
     trainer.train()
     shutil.rmtree("./logs")
Esempio n. 10
0
 def test_ppo1_cnn(self):
     env = VectorEnv("Pong-v0", env_type="atari")
     algo = PPO1("cnn", env, rollout_size=128)
     trainer = OnPolicyTrainer(
         algo, env, log_mode=["csv"], logdir="./logs", epochs=1
     )
     trainer.train()
     shutil.rmtree("./logs")
Esempio n. 11
0
    def test_get_env_properties(self):
        """
        test getting environment properties
        """
        env = VectorEnv("CartPole-v0", 1)

        state_dim, action_dim, discrete, _ = get_env_properties(env)
        assert state_dim == 4
        assert action_dim == 2
        assert discrete is True

        env = VectorEnv("Pendulum-v0", 1)

        state_dim, action_dim, discrete, action_lim = get_env_properties(env)
        assert state_dim == 3
        assert action_dim == 1
        assert discrete is False
        assert action_lim == 2.0
Esempio n. 12
0
    def test_vpg(self):
        env = VectorEnv("CartPole-v0", 2)
        algo = VPG("mlp", env, layers=[1, 1])

        trainer = OnPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, evaluate_episodes=2
        )
        trainer.train()
        trainer.evaluate()
        shutil.rmtree("./logs")
Esempio n. 13
0
    def test_atari_env(self):
        """
        Tests working of Atari Wrappers and the AtariEnv function
        """
        env = VectorEnv("Pong-v0", env_type="atari")
        algo = DQN("cnn", env)

        trainer = OffPolicyTrainer(algo, env, epochs=1, steps_per_epoch=200)
        trainer.train()
        shutil.rmtree("./logs")
Esempio n. 14
0
def test_vpg_cnn():
    env = VectorEnv("Pong-v0", 1, env_type="atari")
    algo = VPG("cnn", env)
    trainer = OnPolicyTrainer(algo,
                              env,
                              log_mode=["csv"],
                              logdir="./logs",
                              epochs=1)
    trainer.train()
    shutil.rmtree("./logs")
Esempio n. 15
0
def test_vpg():
    env = VectorEnv("CartPole-v0", 1)
    algo = VPG("mlp", env)
    trainer = OnPolicyTrainer(algo,
                              env,
                              log_mode=["csv"],
                              logdir="./logs",
                              epochs=1)
    trainer.train()
    shutil.rmtree("./logs")
Esempio n. 16
0
    def test_td3(self):
        env = VectorEnv("Pendulum-v0", 2)
        algo = TD3("mlp", env, noise=OrnsteinUhlenbeckActionNoise, layers=[1, 1])

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, evaluate_episodes=2
        )
        trainer.train()
        trainer.evaluate()
        shutil.rmtree("./logs")
Esempio n. 17
0
def test_a2c_shared():
    env = VectorEnv("CartPole-v0", 1)
    algo = A2C("mlp", env, shared_layers=(32, 32), rollout_size=128)
    trainer = OnPolicyTrainer(algo,
                              env,
                              log_mode=["csv"],
                              logdir="./logs",
                              epochs=1)
    trainer.train()
    shutil.rmtree("./logs")
Esempio n. 18
0
 def test_a2c_discrete(self):
     env = VectorEnv("CartPole-v0", 1)
     algo = A2C("mlp", env, rollout_size=128)
     trainer = OnPolicyTrainer(algo,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1)
     trainer.train()
     trainer.evaluate()
     shutil.rmtree("./logs")
Esempio n. 19
0
def test_on_policy_trainer():
    env = VectorEnv("CartPole-v1", 2)
    algo = PPO1("mlp", env, rollout_size=128)
    trainer = OnPolicyTrainer(algo,
                              env, ["stdout"],
                              epochs=2,
                              evaluate_episodes=2,
                              max_timesteps=300)
    assert not trainer.off_policy
    trainer.train()
    trainer.evaluate()
Esempio n. 20
0
def test_dqn():
    env = VectorEnv("CartPole-v0", 2)
    # DQN
    algo = DQN("mlp", env)

    trainer = OffPolicyTrainer(algo,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1)
    trainer.train()
    shutil.rmtree("./logs")
Esempio n. 21
0
def test_dqn():
    env = VectorEnv("CartPole-v0", 2)
    # DQN
    algo = DQN("mlp", env)

    trainer = OffPolicyTrainer(algo,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1)
    trainer.train()
    shutil.rmtree("./logs")

    # Double DQN with prioritized replay buffer
    # algo1 = DQN("mlp", env, double_dqn=True, prioritized_replay=True)

    # trainer = OffPolicyTrainer(algo1, env, log_mode=["csv"], logdir="./logs", epochs=1)
    # trainer.train()
    # shutil.rmtree("./logs")

    # Noisy DQN
    algo2 = DQN("mlp", env, noisy_dqn=True)

    trainer = OffPolicyTrainer(algo2,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1)
    trainer.train()
    shutil.rmtree("./logs")

    # Dueling DDQN
    algo3 = DQN("mlp", env, dueling_dqn=True, double_dqn=True)

    trainer = OffPolicyTrainer(algo3,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1)
    trainer.train()
    shutil.rmtree("./logs")

    # Categorical DQN
    algo4 = DQN("mlp", env, categorical_dqn=True)

    trainer = OffPolicyTrainer(algo4,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1)
    trainer.train()
    shutil.rmtree("./logs")
Esempio n. 22
0
    def test_load_params(self):
        """
        test loading algorithm parameters
        """
        env = VectorEnv("CartPole-v0", 1)
        algo = PPO1(
            "mlp",
            env,
            epochs=1,
            load_model="test_ckpt/PPO1_CartPole-v0/0-log-0.pt",
        )

        rmtree("logs")
Esempio n. 23
0
def test_double_dqn_cnn():
    env = VectorEnv("Pong-v0", n_envs=2, env_type="atari")

    # Double DQN with prioritized replay buffer
    algo = DQN("cnn", env, double_dqn=True, prioritized_replay=True)

    trainer = OffPolicyTrainer(algo,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1,
                               steps_per_epoch=200)
    trainer.train()
    shutil.rmtree("./logs")
Esempio n. 24
0
def test_categorical_dqn_cnn():
    env = VectorEnv("Pong-v0", n_envs=2, env_type="atari")

    # Categorical DQN
    algo = DQN("cnn", env, categorical_dqn=True)

    trainer = OffPolicyTrainer(algo,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1,
                               steps_per_epoch=200)
    trainer.train()
    shutil.rmtree("./logs")
Esempio n. 25
0
    def test_save_params(self):
        """
        test saving algorithm state dict
        """
        env = VectorEnv("CartPole-v0", 1)
        algo = PPO1("mlp", env)
        trainer = OnPolicyTrainer(algo,
                                  env, ["stdout"],
                                  save_model="test_ckpt",
                                  save_interval=1,
                                  epochs=1)
        trainer.train()

        assert len(os.listdir("test_ckpt/PPO1_CartPole-v0")) != 0
Esempio n. 26
0
    def test_load_params(self):
        """
        test loading algorithm parameters
        """
        env = VectorEnv("CartPole-v0", 1)
        algo = PPO1("mlp", env)
        trainer = OnPolicyTrainer(
            algo,
            env,
            epochs=0,
            load_model="test_ckpt/PPO1_CartPole-v0/0-log-0.pt")
        trainer.train()

        rmtree("logs")
Esempio n. 27
0
    def test_custom_ppo1(self):
        env = VectorEnv("CartPole-v0", 1)
        state_dim = env.observation_space.shape[0]
        action_dim = env.action_space.n
        actorcritic = custom_actorcritic(state_dim, action_dim)

        algo = PPO1(actorcritic, env)

        trainer = OnPolicyTrainer(algo,
                                  env,
                                  log_mode=["csv"],
                                  logdir="./logs",
                                  epochs=1)
        trainer.train()
        shutil.rmtree("./logs")
Esempio n. 28
0
    def test_custom_vpg(self):
        env = VectorEnv("CartPole-v0", 1)
        state_dim = env.observation_space.shape[0]
        action_dim = env.action_space.n
        policy = custom_policy(state_dim, action_dim)

        algo = VPG(policy, env)

        trainer = OnPolicyTrainer(algo,
                                  env,
                                  log_mode=["csv"],
                                  logdir="./logs",
                                  epochs=1)
        trainer.train()
        shutil.rmtree("./logs")
Esempio n. 29
0
 def test_off_policy_trainer(self):
     env = VectorEnv("Pendulum-v0", 2)
     algo = DDPG("mlp", env, replay_size=100)
     trainer = OffPolicyTrainer(
         algo,
         env,
         ["stdout"],
         epochs=2,
         evaluate_episodes=2,
         max_ep_len=300,
         max_timesteps=300,
     )
     assert trainer.off_policy
     trainer.train()
     trainer.evaluate()
Esempio n. 30
0
 def test_double_dqn(self):
     env = VectorEnv("CartPole-v0")
     algo = DoubleDQN("mlp", env, batch_size=5, replay_size=100)
     assert isinstance(algo.model, MlpValue)
     trainer = OffPolicyTrainer(
         algo,
         env,
         log_mode=["csv"],
         logdir="./logs",
         max_ep_len=200,
         epochs=4,
         warmup_steps=10,
         start_update=10,
     )
     trainer.train()
     shutil.rmtree("./logs")