Example #1
0
class MARWILAgent(Agent):
    """MARWIL implementation in TensorFlow."""

    _agent_name = "MARWIL"
    _default_config = DEFAULT_CONFIG
    _policy_graph = MARWILPolicyGraph

    @override(Agent)
    def _init(self):
        self.local_evaluator = self.make_local_evaluator(
            self.env_creator, self._policy_graph)
        self.remote_evaluators = self.make_remote_evaluators(
            self.env_creator, self._policy_graph, self.config["num_workers"])
        self.optimizer = SyncBatchReplayOptimizer(
            self.local_evaluator, self.remote_evaluators, {
                "learning_starts": self.config["learning_starts"],
                "buffer_size": self.config["replay_buffer_size"],
                "train_batch_size": self.config["train_batch_size"],
            })

    @override(Agent)
    def _train(self):
        prev_steps = self.optimizer.num_steps_sampled
        fetches = self.optimizer.step()
        res = self.optimizer.collect_metrics(
            self.config["collect_metrics_timeout"])
        res.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
                   prev_steps,
                   info=dict(fetches, **res.get("info", {})))
        return res
Example #2
0
class MARWILTrainer(Trainer):
    """MARWIL implementation in TensorFlow."""

    _name = "MARWIL"
    _default_config = DEFAULT_CONFIG
    _policy = MARWILPolicy

    @override(Trainer)
    def _init(self, config, env_creator):
        self.workers = self._make_workers(env_creator, self._policy, config,
                                          config["num_workers"])
        self.optimizer = SyncBatchReplayOptimizer(
            self.workers,
            learning_starts=config["learning_starts"],
            buffer_size=config["replay_buffer_size"],
            train_batch_size=config["train_batch_size"],
        )

    @override(Trainer)
    def _train(self):
        prev_steps = self.optimizer.num_steps_sampled
        fetches = self.optimizer.step()
        res = self.collect_metrics()
        res.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
                   prev_steps,
                   info=dict(fetches, **res.get("info", {})))
        return res