示例#1
0
class MARWILAgent(Agent):
    """MARWIL implementation in TensorFlow."""

    _agent_name = "MARWIL"
    _default_config = DEFAULT_CONFIG
    _policy_graph = MARWILPolicyGraph

    @override(Agent)
    def _init(self):
        self.local_evaluator = self.make_local_evaluator(
            self.env_creator, self._policy_graph)
        self.remote_evaluators = self.make_remote_evaluators(
            self.env_creator, self._policy_graph, self.config["num_workers"])
        self.optimizer = SyncBatchReplayOptimizer(
            self.local_evaluator, self.remote_evaluators, {
                "learning_starts": self.config["learning_starts"],
                "buffer_size": self.config["replay_buffer_size"],
                "train_batch_size": self.config["train_batch_size"],
            })

    @override(Agent)
    def _train(self):
        prev_steps = self.optimizer.num_steps_sampled
        fetches = self.optimizer.step()
        res = self.optimizer.collect_metrics(
            self.config["collect_metrics_timeout"])
        res.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
                   prev_steps,
                   info=dict(fetches, **res.get("info", {})))
        return res
示例#2
0
文件: marwil.py 项目: songhappy/ray
 def _init(self, config, env_creator):
     self.workers = self._make_workers(env_creator, self._policy, config,
                                       config["num_workers"])
     self.optimizer = SyncBatchReplayOptimizer(
         self.workers,
         learning_starts=config["learning_starts"],
         buffer_size=config["replay_buffer_size"],
         train_batch_size=config["train_batch_size"],
     )
示例#3
0
 def _init(self):
     self.local_evaluator = self.make_local_evaluator(
         self.env_creator, self._policy_graph)
     self.remote_evaluators = self.make_remote_evaluators(
         self.env_creator, self._policy_graph, self.config["num_workers"])
     self.optimizer = SyncBatchReplayOptimizer(
         self.local_evaluator, self.remote_evaluators, {
             "learning_starts": self.config["learning_starts"],
             "buffer_size": self.config["replay_buffer_size"],
             "train_batch_size": self.config["train_batch_size"],
         })
示例#4
0
def make_optimizer(workers, config):
    return SyncBatchReplayOptimizer(
        workers,
        learning_starts=config["learning_starts"],
        buffer_size=config["replay_buffer_size"],
        train_batch_size=config["train_batch_size"],
    )
示例#5
0
文件: marwil.py 项目: songhappy/ray
class MARWILTrainer(Trainer):
    """MARWIL implementation in TensorFlow."""

    _name = "MARWIL"
    _default_config = DEFAULT_CONFIG
    _policy = MARWILPolicy

    @override(Trainer)
    def _init(self, config, env_creator):
        self.workers = self._make_workers(env_creator, self._policy, config,
                                          config["num_workers"])
        self.optimizer = SyncBatchReplayOptimizer(
            self.workers,
            learning_starts=config["learning_starts"],
            buffer_size=config["replay_buffer_size"],
            train_batch_size=config["train_batch_size"],
        )

    @override(Trainer)
    def _train(self):
        prev_steps = self.optimizer.num_steps_sampled
        fetches = self.optimizer.step()
        res = self.collect_metrics()
        res.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
                   prev_steps,
                   info=dict(fetches, **res.get("info", {})))
        return res