예제 #1
0
 def __init__(self, name, max_step=None):
     if name == "simple_spread":
         self.env = simple_spread_v2.env()
     elif name == "waterworld":
         self.env = waterworld_v3.env()
     elif name == "multiwalker":
         self.env = multiwalker_v6.env()
     else:
         assert AssertionError, "wrong env name."
     self.max_step = max_step
     self.curr_step = 0
     self.name = name
     self.agents = self.env.possible_agents
     self.env.reset()
from ray import tune
from ray.tune.registry import register_env
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
from pettingzoo.sisl import waterworld_v3

# Based on code from github.com/parametersharingmadrl/parametersharingmadrl

if __name__ == "__main__":
    # RDQN - Rainbow DQN
    # ADQN - Apex DQN

    register_env("waterworld", lambda _: PettingZooEnv(waterworld_v3.env()))

    tune.run(
        "APEX_DDPG",
        stop={"episodes_total": 60000},
        checkpoint_freq=10,
        config={
            # Enviroment specific.
            "env": "waterworld",
            # General
            "num_gpus": 1,
            "num_workers": 2,
            "num_envs_per_worker": 8,
            "learning_starts": 1000,
            "buffer_size": int(1e5),
            "compress_observations": True,
            "rollout_fragment_length": 20,
            "train_batch_size": 512,
            "gamma": 0.99,
            "n_step": 3,
예제 #3
0
 def env_creator(args):
     return PettingZooEnv(waterworld_v3.env())