コード例 #1
0
ファイル: ars.py プロジェクト: RuofanKong/ray
    def _init(self, config, env_creator):
        validate_config(config)
        env_context = EnvContext(config["env_config"] or {}, worker_index=0)
        env = env_creator(env_context)

        self._policy_class = get_policy_class(config)
        self.policy = self._policy_class(env.observation_space,
                                         env.action_space, config)
        self.optimizer = optimizers.SGD(self.policy, config["sgd_stepsize"])

        self.rollouts_used = config["rollouts_used"]
        self.num_rollouts = config["num_rollouts"]
        self.report_length = config["report_length"]

        # Create the shared noise table.
        logger.info("Creating shared noise table.")
        noise_id = create_shared_noise.remote(config["noise_size"])
        self.noise = SharedNoiseTable(ray.get(noise_id))

        # Create the actors.
        logger.info("Creating actors.")
        self.workers = [
            Worker.remote(config, env_creator, noise_id, idx + 1)
            for idx in range(config["num_workers"])
        ]

        self.episodes_so_far = 0
        self.reward_list = []
        self.tstart = time.time()
コード例 #2
0
    def _init(self, config, env_creator):
        super()._init(config, env_creator)
        validate_config(config)
        env_context = EnvContext(config["env_config"] or {}, worker_index=0)
        env = env_creator(env_context)

        policy_cls = get_ars_frac_policy_class(config)
        self.policy = policy_cls(env.observation_space, env.action_space, config)