Beispiel #1
0
 def __init__(self, env_name, policy_cls, actor_id, batch_size, logdir):
     env = create_env(env_name)
     self.id = actor_id
     # TODO(rliaw): should change this to be just env.observation_space
     self.policy = policy_cls(env.observation_space.shape, env.action_space)
     self.runner = RunnerThread(env, self.policy, batch_size)
     self.env = env
     self.logdir = logdir
     self.start()
Beispiel #2
0
 def __init__(self, env_name, config):
     Algorithm.__init__(self, env_name, config)
     self.env = create_env(env_name)
     self.policy = LSTMPolicy(self.env.observation_space.shape,
                              self.env.action_space.n, 0)
     self.agents = [
         Runner.remote(env_name, i) for i in range(config["num_workers"])
     ]
     self.parameters = self.policy.get_weights()
     self.iteration = 0
Beispiel #3
0
 def __init__(self, env_name, actor_id, logdir="/tmp/ray/a3c/", start=True):
     env = create_env(env_name)
     self.id = actor_id
     num_actions = env.action_space.n
     self.policy = LSTMPolicy(env.observation_space.shape, num_actions,
                              actor_id)
     self.runner = RunnerThread(env, self.policy, 20)
     self.env = env
     self.logdir = logdir
     if start:
         self.start()
Beispiel #4
0
Datei: a3c.py Projekt: xgong/ray
 def __init__(self, env_name, config, upload_dir=None):
     config.update({"alg": "A3C"})
     Algorithm.__init__(self, env_name, config, upload_dir=upload_dir)
     self.env = create_env(env_name)
     self.policy = LSTMPolicy(self.env.observation_space.shape,
                              self.env.action_space.n, 0)
     self.agents = [
         Runner.remote(env_name, i, self.logdir)
         for i in range(config["num_workers"])
     ]
     self.parameters = self.policy.get_weights()
     self.iteration = 0
Beispiel #5
0
 def __init__(self,
              env_name,
              config,
              policy_cls=SharedModelLSTM,
              upload_dir=None):
     config.update({"alg": "A3C"})
     Agent.__init__(self, env_name, config, upload_dir=upload_dir)
     self.env = create_env(env_name)
     self.policy = policy_cls(self.env.observation_space.shape,
                              self.env.action_space)
     self.agents = [
         Runner.remote(env_name, policy_cls, i, config["batch_size"],
                       self.logdir) for i in range(config["num_workers"])
     ]
     self.parameters = self.policy.get_weights()