class MARWILAgent(Agent): """MARWIL implementation in TensorFlow.""" _agent_name = "MARWIL" _default_config = DEFAULT_CONFIG _policy_graph = MARWILPolicyGraph @override(Agent) def _init(self): self.local_evaluator = self.make_local_evaluator( self.env_creator, self._policy_graph) self.remote_evaluators = self.make_remote_evaluators( self.env_creator, self._policy_graph, self.config["num_workers"]) self.optimizer = SyncBatchReplayOptimizer( self.local_evaluator, self.remote_evaluators, { "learning_starts": self.config["learning_starts"], "buffer_size": self.config["replay_buffer_size"], "train_batch_size": self.config["train_batch_size"], }) @override(Agent) def _train(self): prev_steps = self.optimizer.num_steps_sampled fetches = self.optimizer.step() res = self.optimizer.collect_metrics( self.config["collect_metrics_timeout"]) res.update(timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps, info=dict(fetches, **res.get("info", {}))) return res
def _init(self, config, env_creator): self.workers = self._make_workers(env_creator, self._policy, config, config["num_workers"]) self.optimizer = SyncBatchReplayOptimizer( self.workers, learning_starts=config["learning_starts"], buffer_size=config["replay_buffer_size"], train_batch_size=config["train_batch_size"], )
def _init(self): self.local_evaluator = self.make_local_evaluator( self.env_creator, self._policy_graph) self.remote_evaluators = self.make_remote_evaluators( self.env_creator, self._policy_graph, self.config["num_workers"]) self.optimizer = SyncBatchReplayOptimizer( self.local_evaluator, self.remote_evaluators, { "learning_starts": self.config["learning_starts"], "buffer_size": self.config["replay_buffer_size"], "train_batch_size": self.config["train_batch_size"], })
def make_optimizer(workers, config): return SyncBatchReplayOptimizer( workers, learning_starts=config["learning_starts"], buffer_size=config["replay_buffer_size"], train_batch_size=config["train_batch_size"], )
class MARWILTrainer(Trainer): """MARWIL implementation in TensorFlow.""" _name = "MARWIL" _default_config = DEFAULT_CONFIG _policy = MARWILPolicy @override(Trainer) def _init(self, config, env_creator): self.workers = self._make_workers(env_creator, self._policy, config, config["num_workers"]) self.optimizer = SyncBatchReplayOptimizer( self.workers, learning_starts=config["learning_starts"], buffer_size=config["replay_buffer_size"], train_batch_size=config["train_batch_size"], ) @override(Trainer) def _train(self): prev_steps = self.optimizer.num_steps_sampled fetches = self.optimizer.step() res = self.collect_metrics() res.update(timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps, info=dict(fetches, **res.get("info", {}))) return res