def make_estimator(self): # pylint:disable=attribute-defined-outside-init ecls = DPG if self.run.config.estimator == "dpg" else MAAC self.estimator = ecls(self.policy, self.lqg.trans, self.lqg.reward, self.qvalue) self.n_step = self.run.config.K # noinspection PyAttributeOutsideInit self._golden_standard = AnalyticSVG(policy=self.policy, model=self.lqg)
def __init__(self, policy: LQGPolicy, env: TorchLQGMixin): super().__init__() self.actor = policy.module.actor self.model = policy.module.model self.mdp = env.module self.monte_carlo_svg = MonteCarloSVG(self.actor, self.model) self.analytic_svg = None self.gold_standard = AnalyticSVG(self.actor, self.mdp)()
class HparamSearch(tune.Trainable): # pylint:disable=abstract-method,too-many-instance-attributes rng: RNG generator: LQGGenerator lqg: LQGModule policy: TVLinearPolicy rollout: MonteCarloSVG qvalue: QuadQValue optimizer: Optimizer estimator: Union[DPG, MAAC] n_step: int def setup(self, config: dict): self._init_wandb(config) self.rng = np.random.default_rng(self.run.config.seed) self.make_generator() self.make_modules() self.make_optimizer() self.make_estimator() self._init_stats() def _init_wandb(self, config: dict): # pylint:disable=attribute-defined-outside-init os.environ["WANDB_SILENT"] = "true" cwd = config.pop("wandb_dir") tags = config.pop("wandb_tags", []) self.run = wandb.init( dir=cwd, name="SuboptimalityGapSearch", config=config, project="LQG-SVG", entity="angelovtt", tags=[calver()] + list(tags), reinit=True, mode="online", ) def make_generator(self): self.generator = LQGGenerator( n_state=self.run.config.env_dim, n_ctrl=self.run.config.env_dim, horizon=20, stationary=True, passive_eigval_range=(0.5, 1.5), controllable=True, rng=self.rng, ) def make_modules(self): with nt.suppress_named_tensor_warning(): dynamics, cost, init = self.generator() lqg = LQGModule.from_existing(dynamics, cost, init) policy = TVLinearPolicy(lqg.n_state, lqg.n_ctrl, lqg.horizon) policy.stabilize_(dynamics, rng=self.generator.rng) qvalue = QuadQValue(lqg.n_state + lqg.n_ctrl, lqg.horizon) self.lqg, self.policy, self.qvalue = lqg, policy, qvalue self.rollout = MonteCarloSVG(policy, lqg) def make_optimizer(self): name = self.run.config.optimizer if name == "Adam": cls = torch.optim.Adam elif name == "SGD": cls = torch.optim.SGD else: raise ValueError(f"Invalid optimizer type {name}") self.optimizer = cls(self.policy.parameters(), lr=self.run.config.learning_rate) def make_estimator(self): # pylint:disable=attribute-defined-outside-init ecls = DPG if self.run.config.estimator == "dpg" else MAAC self.estimator = ecls(self.policy, self.lqg.trans, self.lqg.reward, self.qvalue) self.n_step = self.run.config.K # noinspection PyAttributeOutsideInit self._golden_standard = AnalyticSVG(policy=self.policy, model=self.lqg) def _init_stats(self): # pylint:disable=attribute-defined-outside-init self._return_history = deque([0], maxlen=100) dynamics, cost, init = self.lqg.standard_form() solver = NamedLQGControl(self.lqg.n_state, self.lqg.n_ctrl, self.lqg.horizon) _, _, vstar = solver(dynamics, cost) vstar = tuple(v.select("H", 0) for v in vstar) self._optimal_value = -ExpectedValue()(init, vstar).item() def step(self): obs = self.lqg_rollout() self.update_qvalue() self.optimizer.zero_grad() loss = self.estimator.surrogate(obs, n_steps=self.n_step).neg() loss.backward() info = self.postprocess_svg() self.optimizer.step() return self.get_stats(info) def lqg_rollout(self) -> Tensor: n_trajs = int(np.ceil(self.config["B"] / self.lqg.horizon)) with torch.no_grad(): obs, _, rew, _, _ = self.rollout.rsample_trajectory( torch.Size([n_trajs])) obs = obs.flatten(["H", "B1"], "B") rets = rew.sum("H").tolist() self._return_history.extend(rets) return obs def update_qvalue(self): self.qvalue.match_policy_( self.policy.standard_form(), self.lqg.trans.standard_form(), self.lqg.reward.standard_form(), ) def postprocess_svg(self): max_norm = self.run.config.clip_grad_norm grad_norm = nn.utils.clip_grad_norm_(self.policy.parameters(), max_norm=max_norm) return dict(grad_norm=grad_norm.item()) def get_stats(self, info) -> dict: stats = { "true_value": self._golden_standard.value().item(), "optimal_value": self._optimal_value, "episode_reward_mean": np.mean(self._return_history), "episode_reward_max": np.max(self._return_history), "episode_reward_min": np.min(self._return_history), **info, } self.run.log(stats) return stats def cleanup(self): self.run.finish()