def policy_svg(policy: TVLinearPolicy, value: Tensor) -> lqr.Linear: """Computes the policy SVG from the estimated return.""" # pylint:disable=invalid-name policy.zero_grad(set_to_none=True) value.backward() K, k = policy.standard_form() return K.grad.clone(), k.grad.clone()
def test_stabilize_(self, module: TVLinearPolicy, dynamics: lqr.LinSDynamics, seed: int): module.stabilize_(dynamics, rng=seed) K, k = module.K.clone(), module.k.clone() module.stabilize_(dynamics, rng=seed) assert torch.allclose(module.K, K) assert torch.allclose(module.k, k)
def make_modules(self): with nt.suppress_named_tensor_warning(): dynamics, cost, init = self.generator() lqg = LQGModule.from_existing(dynamics, cost, init) policy = TVLinearPolicy(lqg.n_state, lqg.n_ctrl, lqg.horizon) policy.stabilize_(dynamics, rng=self.generator.rng) qvalue = QuadQValue(lqg.n_state + lqg.n_ctrl, lqg.horizon) self.lqg, self.policy, self.qvalue = lqg, policy, qvalue self.rollout = MonteCarloSVG(policy, lqg)
def test_frozen(self, module: TVLinearPolicy, obs: Tensor, n_ctrl: int): act = module.frozen(obs) assert torch.is_tensor(act) assert torch.isfinite(act).all() assert act.names == obs.names # noinspection PyArgumentList assert act.size("R") == n_ctrl module.zero_grad(set_to_none=True) act.sum().backward() assert obs.grad is not None assert not torch.allclose(obs.grad, torch.zeros(())) assert torch.isfinite(obs.grad).all() grads = [p.grad for p in module.parameters()] assert all(list(g is None for g in grads))
def test_mixed_call(self, module: TVLinearPolicy, mix_obs: Tensor, n_ctrl: int): act = module(mix_obs) assert torch.is_tensor(act) assert torch.isfinite(act).all() assert act.names == mix_obs.names assert act.size("R") == n_ctrl act.sum().backward() assert mix_obs.grad is not None assert torch.isfinite(mix_obs.grad).all() grads = [p.grad for p in module.parameters()] assert all(list(g is not None for g in grads)) assert all(list(torch.isfinite(g).all() for g in grads))
def test_terminal_call(self, module: TVLinearPolicy, last_obs: Tensor, n_ctrl: int): act = module(last_obs) assert nt.allclose(act, torch.zeros(())) assert torch.is_tensor(act) assert torch.isfinite(act).all() assert act.names == last_obs.names assert act.size("R") == n_ctrl act.sum().backward() assert last_obs.grad is not None assert torch.allclose(last_obs.grad, torch.zeros(())) grads = [p.grad for p in module.parameters()] assert all(list(g is not None for g in grads)) assert all(list(torch.allclose(g, torch.zeros(())) for g in grads))
def module(self, n_state: int, n_ctrl: int, horizon: int) -> TVLinearPolicy: return TVLinearPolicy(n_state, n_ctrl, horizon)
def test_standard_form(self, module: TVLinearPolicy): K, k = module.standard_form() (K.sum() + k.sum()).backward() assert torch.allclose(module.K.grad, torch.ones_like(module.K.grad)) assert torch.allclose(module.k.grad, torch.ones_like(module.k.grad))
def policy(n_state: int, n_ctrl: int, horizon: int) -> TVLinearPolicy: return TVLinearPolicy(n_state, n_ctrl, horizon)