def main(args): env = VectorEnv( args.env, n_envs=args.n_envs, parallel=not args.serial, env_type=args.env_type ) input_dim, action_dim, discrete, action_lim = get_env_properties(env, "mlp") network = MlpActorCritic( input_dim, action_dim, (1, 1), # layers (1, 1), "V", # type of value function discrete, action_lim=action_lim, activation="relu", ) generic_agent = A2C(network, env, rollout_size=args.rollout_size) agent_parameter_choices = { "gamma": [12, 121], # 'clip_param': [0.2, 0.3], # 'lr_policy': [0.001, 0.002], # 'lr_value': [0.001, 0.002] } generate( args.generations, args.population, agent_parameter_choices, env, generic_agent, args, )
def test_a2c(): env = VectorEnv("CartPole-v0", 1) algo = A2C("mlp", env, rollout_size=128) trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs")
def test_a2c_shared_discrete(self): env = VectorEnv("CartPole-v0", 1) algo = A2C("mlp", env, shared_layers=(32, 32), rollout_size=128) trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs")
def test_a2c_cnn(): env = VectorEnv("Pong-v0", 1, env_type="atari") algo = A2C("cnn", env, rollout_size=128) trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs")
def test_a2c_continuous(self): env = VectorEnv("Pendulum-v0", 1) algo = A2C("mlp", env, rollout_size=128) trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() trainer.evaluate() shutil.rmtree("./logs")