Exemplo n.º 1
0
def policy_svg(policy: TVLinearPolicy, value: Tensor) -> lqr.Linear:
    """Computes the policy SVG from the estimated return."""
    # pylint:disable=invalid-name
    policy.zero_grad(set_to_none=True)
    value.backward()
    K, k = policy.standard_form()
    return K.grad.clone(), k.grad.clone()
Exemplo n.º 2
0
    def test_stabilize_(self, module: TVLinearPolicy,
                        dynamics: lqr.LinSDynamics, seed: int):
        module.stabilize_(dynamics, rng=seed)
        K, k = module.K.clone(), module.k.clone()

        module.stabilize_(dynamics, rng=seed)
        assert torch.allclose(module.K, K)
        assert torch.allclose(module.k, k)
Exemplo n.º 3
0
 def make_modules(self):
     with nt.suppress_named_tensor_warning():
         dynamics, cost, init = self.generator()
     lqg = LQGModule.from_existing(dynamics, cost, init)
     policy = TVLinearPolicy(lqg.n_state, lqg.n_ctrl, lqg.horizon)
     policy.stabilize_(dynamics, rng=self.generator.rng)
     qvalue = QuadQValue(lqg.n_state + lqg.n_ctrl, lqg.horizon)
     self.lqg, self.policy, self.qvalue = lqg, policy, qvalue
     self.rollout = MonteCarloSVG(policy, lqg)
Exemplo n.º 4
0
    def test_frozen(self, module: TVLinearPolicy, obs: Tensor, n_ctrl: int):
        act = module.frozen(obs)

        assert torch.is_tensor(act)
        assert torch.isfinite(act).all()
        assert act.names == obs.names
        # noinspection PyArgumentList
        assert act.size("R") == n_ctrl

        module.zero_grad(set_to_none=True)
        act.sum().backward()
        assert obs.grad is not None
        assert not torch.allclose(obs.grad, torch.zeros(()))
        assert torch.isfinite(obs.grad).all()
        grads = [p.grad for p in module.parameters()]
        assert all(list(g is None for g in grads))
Exemplo n.º 5
0
    def test_mixed_call(self, module: TVLinearPolicy, mix_obs: Tensor,
                        n_ctrl: int):
        act = module(mix_obs)

        assert torch.is_tensor(act)
        assert torch.isfinite(act).all()
        assert act.names == mix_obs.names
        assert act.size("R") == n_ctrl

        act.sum().backward()
        assert mix_obs.grad is not None
        assert torch.isfinite(mix_obs.grad).all()
        grads = [p.grad for p in module.parameters()]
        assert all(list(g is not None for g in grads))
        assert all(list(torch.isfinite(g).all() for g in grads))
Exemplo n.º 6
0
    def test_terminal_call(self, module: TVLinearPolicy, last_obs: Tensor,
                           n_ctrl: int):
        act = module(last_obs)

        assert nt.allclose(act, torch.zeros(()))
        assert torch.is_tensor(act)
        assert torch.isfinite(act).all()
        assert act.names == last_obs.names
        assert act.size("R") == n_ctrl

        act.sum().backward()
        assert last_obs.grad is not None
        assert torch.allclose(last_obs.grad, torch.zeros(()))
        grads = [p.grad for p in module.parameters()]
        assert all(list(g is not None for g in grads))
        assert all(list(torch.allclose(g, torch.zeros(())) for g in grads))
Exemplo n.º 7
0
 def module(self, n_state: int, n_ctrl: int,
            horizon: int) -> TVLinearPolicy:
     return TVLinearPolicy(n_state, n_ctrl, horizon)
Exemplo n.º 8
0
    def test_standard_form(self, module: TVLinearPolicy):
        K, k = module.standard_form()

        (K.sum() + k.sum()).backward()
        assert torch.allclose(module.K.grad, torch.ones_like(module.K.grad))
        assert torch.allclose(module.k.grad, torch.ones_like(module.k.grad))
Exemplo n.º 9
0
def policy(n_state: int, n_ctrl: int, horizon: int) -> TVLinearPolicy:
    return TVLinearPolicy(n_state, n_ctrl, horizon)