Exemple #1
0
 def __init__(self, env_creator, policy_cls, actor_id, batch_size,
              preprocess_config, logdir):
     env = create_and_wrap(env_creator, preprocess_config)
     self.id = actor_id
     # TODO(rliaw): should change this to be just env.observation_space
     self.policy = policy_cls(env.observation_space.shape, env.action_space)
     self.runner = RunnerThread(env, self.policy, batch_size)
     self.env = env
     self.logdir = logdir
     self.start()
Exemple #2
0
    def __init__(self, env_creator, config, logdir):
        self.env = env = create_and_wrap(env_creator, config["model"])
        policy_cls = get_policy_cls(config)
        # TODO(rliaw): should change this to be just env.observation_space
        self.policy = policy_cls(env.observation_space.shape, env.action_space)
        obs_filter = get_filter(config["observation_filter"],
                                env.observation_space.shape)
        self.rew_filter = get_filter(config["reward_filter"], ())

        self.sampler = AsyncSampler(env, self.policy, config["batch_size"],
                                    obs_filter)
        self.logdir = logdir
Exemple #3
0
 def _init(self):
     self.env = create_and_wrap(self.env_creator, self.config["model"])
     policy_cls = get_policy_cls(self.config)
     self.policy = policy_cls(self.env.observation_space.shape,
                              self.env.action_space)
     self.obs_filter = get_filter(self.config["observation_filter"],
                                  self.env.observation_space.shape)
     self.rew_filter = get_filter(self.config["reward_filter"], ())
     self.agents = [
         RemoteRunner.remote(self.env_creator, self.config, self.logdir)
         for i in range(self.config["num_workers"])
     ]
     self.parameters = self.policy.get_weights()
Exemple #4
0
 def _init(self):
     self.env = create_and_wrap(self.env_creator, self.config["model"])
     if self.config["use_lstm"]:
         policy_cls = SharedModelLSTM
     else:
         policy_cls = SharedModel
     self.policy = policy_cls(
         self.env.observation_space.shape, self.env.action_space)
     self.agents = [
         RemoteRunner.remote(self.env_creator, policy_cls, i,
                             self.config["batch_size"],
                             self.config["model"], self.logdir)
         for i in range(self.config["num_workers"])]
     self.parameters = self.policy.get_weights()