示例#1
0
def main(args):
    env = VectorEnv(
        args.env, n_envs=args.n_envs, parallel=not args.serial, env_type=args.env_type
    )

    input_dim, action_dim, discrete, action_lim = get_env_properties(env, "mlp")

    network = MlpActorCritic(
        input_dim,
        action_dim,
        (1, 1),  # layers
        (1, 1),
        "V",  # type of value function
        discrete,
        action_lim=action_lim,
        activation="relu",
    )
    
    generic_agent = A2C(network, env, rollout_size=args.rollout_size)

    agent_parameter_choices = {
        "gamma": [12, 121],
        # 'clip_param': [0.2, 0.3],
        # 'lr_policy': [0.001, 0.002],
        # 'lr_value': [0.001, 0.002]
    }

    generate(
        args.generations,
        args.population,
        agent_parameter_choices,
        env,
        generic_agent,
        args,
    )
示例#2
0
def test_a2c():
    env = VectorEnv("CartPole-v0", 1)
    algo = A2C("mlp", env, rollout_size=128)
    trainer = OnPolicyTrainer(algo,
                              env,
                              log_mode=["csv"],
                              logdir="./logs",
                              epochs=1)
    trainer.train()
    shutil.rmtree("./logs")
示例#3
0
 def test_a2c_shared_discrete(self):
     env = VectorEnv("CartPole-v0", 1)
     algo = A2C("mlp", env, shared_layers=(32, 32), rollout_size=128)
     trainer = OnPolicyTrainer(algo,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1)
     trainer.train()
     shutil.rmtree("./logs")
示例#4
0
def test_a2c_cnn():
    env = VectorEnv("Pong-v0", 1, env_type="atari")
    algo = A2C("cnn", env, rollout_size=128)
    trainer = OnPolicyTrainer(algo,
                              env,
                              log_mode=["csv"],
                              logdir="./logs",
                              epochs=1)
    trainer.train()
    shutil.rmtree("./logs")
示例#5
0
 def test_a2c_continuous(self):
     env = VectorEnv("Pendulum-v0", 1)
     algo = A2C("mlp", env, rollout_size=128)
     trainer = OnPolicyTrainer(algo,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1)
     trainer.train()
     trainer.evaluate()
     shutil.rmtree("./logs")