Exemple #1
0
def test_stationary_pred_equality(
    lqr_control: NamedLQRControl,
    lqg_prediction: NamedLQGPrediction,
    stationary_deterministic_dynamics: LinDynamics,
    stationary_cost: QuadCost,
    horizon: int,
    n_state: int,
):
    lqr_pi, _, lqr_val = lqr_control(stationary_deterministic_dynamics,
                                     stationary_cost)
    lqr_V, lqr_v, lqr_vc = nt.unnamed(*lqr_val)

    _, lqg_val = lqg_prediction(
        # Insert batch dimension and use column vector
        lqr_pi,
        LinSDynamics(
            F=stationary_deterministic_dynamics.F,
            f=stationary_deterministic_dynamics.f,
            W=torch.zeros(horizon, n_state, n_state),
        ),
        stationary_cost,
    )
    lqg_V, lqg_v, lqg_vc = nt.unnamed(*lqg_val)

    assert torch.allclose(lqr_V, lqg_V)
    assert torch.allclose(lqr_v, lqg_v)
    assert torch.allclose(lqr_vc, lqg_vc)
Exemple #2
0
def test_stationary_ctrl_equality(
    lqr_control: NamedLQRControl,
    lqg_control: NamedLQGControl,
    stationary_deterministic_dynamics: LinSDynamics,
    stationary_cost: QuadCost,
    horizon: int,
    n_state: int,
):
    lqr_pi, _, _ = lqr_control(stationary_deterministic_dynamics,
                               stationary_cost)
    lqg_pi, _, _ = lqg_control(
        LinSDynamics(
            F=stationary_deterministic_dynamics.F,
            f=stationary_deterministic_dynamics.f,
            W=torch.zeros(horizon, n_state, n_state),
        ),
        stationary_cost,
    )

    lqr_K, lqr_k = nt.unnamed(*lqr_pi)
    lqg_K, lqg_k = nt.unnamed(*lqg_pi)

    assert lqr_K.shape == lqg_K.shape
    assert lqr_k.shape == lqg_k.shape
    assert torch.allclose(lqr_k, lqg_k)
    assert torch.allclose(lqr_K, lqg_K)
Exemple #3
0
def check_cost(
    cost: QuadCost,
    n_state: int,
    n_ctrl: int,
    horizon: int,
    stationary: bool,
    linear: bool,
):
    assert_all_tensor(*cost)
    n_tau = n_state + n_ctrl

    C, c = cost
    assert_horizon_len(C, horizon)
    assert_horizon_len(c, horizon)
    assert_row_size(C, n_tau)
    assert_row_size(c, n_tau)
    assert_col_size(C, n_tau)

    eigval, _ = torch.linalg.eigh(nt.unnamed(C))
    assert eigval.ge(0).all()
    assert linear or nt.allclose(c, torch.zeros_like(c))

    if horizon > 1:
        assert stationary == nt.allclose(C, C.select("H", 0))
        assert not linear or stationary == nt.allclose(c, c.select("H", 0))
Exemple #4
0
 def factorize_(self, matrix: Tensor) -> CholeskyFactor:
     """Set parameters to factorize a symmetric positive definite matrix."""
     ltril, pre_diag = nt.unnamed(
         *disassemble_cholesky(matrix, beta=self.beta))
     self.ltril.data.copy_(ltril)
     self.pre_diag.data.copy_(pre_diag)
     return self
Exemple #5
0
 def _trans_logp(
     loc: Tensor,
     scale_tril: Tensor,
     cur_time: IntTensor,
     state: Tensor,
     time: IntTensor,
 ) -> Tensor:
     loc, scale_tril = nt.unnamed(loc, scale_tril)
     dist = torch.distributions.MultivariateNormal(loc=loc, scale_tril=scale_tril)
     trans_logp: Tensor = dist.log_prob(nt.unnamed(state))
     trans_logp = nt.where(
         # Logp only defined at next timestep
         time.eq(cur_time + 1),
         trans_logp,
         torch.full(time.shape, fill_value=float("nan")),
     )
     # We assume time is a named scalar tensor
     return trans_logp.refine_names(*time.names)
Exemple #6
0
def obs(
    n_state: int,
    horizon: int,
    batch_shape: tuple[int, ...],
    batch_names: tuple[str, ...],
) -> Tensor:
    state = nt.vector(torch.randn(batch_shape + (n_state, ))).refine_names(
        *batch_names, ...)
    dummy, _ = nt.split(state, [1, n_state - 1], dim="R")
    time = torch.randint_like(nt.unnamed(dummy), low=0, high=horizon)
    time = time.refine_names(*dummy.names).int()
    return pack_obs(state, time).requires_grad_()
Exemple #7
0
def vector_to_tensors(vector: Tensor,
                      tensors: Iterable[Tensor]) -> Iterable[Tensor]:
    """Split and reshape vector into tensors matching others' shapes."""
    split = []
    offset = 0
    vector = nt.unnamed(vector)
    for t in tensors:
        split += [
            vector[offset:offset + t.numel()].view_as(t).refine_names(*t.names)
        ]
        offset += t.numel()
    return split
Exemple #8
0
def linear_feedback_norm(linear: lqr.Linear) -> Tensor:
    """Norm of the parameters of a linear (affine) function.

    Uses the default norms for vectors and matrices chosen by PyTorch:
    frobenius for matrices and L2 for vectors.

    Equivalent to the norm of the flattened parameter vector.

    Args:
        linear: tuple of affine function parameters (weight matrix and bias
        column vector)

    Returns:
        Norm of the affine function's parameters
    """
    # pylint:disable=invalid-name
    K, k = linear
    K_norm = torch.linalg.norm(nt.unnamed(K), dim=(-2, -1))
    k_norm = torch.linalg.norm(nt.unnamed(k), dim=-1)
    # Following PyTorch's clip_grad_norm_ implementation
    # Reduce by horizon
    total_norm = torch.linalg.norm(torch.cat((K_norm, k_norm), dim=0), dim=0)
    return total_norm
Exemple #9
0
def check_dynamics_covariance(W: Tensor, n_state: int, horizon: int,
                              stationary: int, sample_covariance: bool):
    assert_horizon_len(W, horizon)
    assert_row_size(W, n_state)
    assert_col_size(W, n_state)

    assert nt.allclose(W, nt.transpose(W))
    eigval, _ = torch.linalg.eigh(nt.unnamed(W))
    assert eigval.gt(0).all()

    assert sample_covariance != nt.allclose(W, nt.matrix(torch.eye(n_state)))

    # noinspection PyTypeChecker
    assert (horizon == 1 or not sample_covariance
            or stationary == nt.allclose(W, W.select("H", 0)))
Exemple #10
0
 def test_factorize_(
     self,
     module: SPDMatrix,
     size: int,
     sample_shape: tuple[int, ...],
     use_sample_shape: bool,
     seed: int,
 ):
     # pylint:disable=too-many-arguments
     sample_shape = sample_shape if use_sample_shape else ()
     A = make_spd_matrix(size, sample_shape=sample_shape, rng=seed)
     module.factorize_(nt.matrix(A))
     B = nt.unnamed(module())
     A, B = torch.broadcast_tensors(A, B)
     isclose = torch.isclose(A, B, atol=1e-6)
     assert isclose.all(), (A[~isclose].tolist(), B[~isclose].tolist())
Exemple #11
0
 def test_factorize_(
     self,
     module: CholeskyFactor,
     size: int,
     sample_shape: tuple[int, ...],
     use_sample_shape: bool,
     seed: int,
 ):
     # pylint:disable=too-many-arguments
     sample_shape = sample_shape if use_sample_shape else ()
     A = make_spd_matrix(size, sample_shape=sample_shape, rng=seed)
     module.factorize_(nt.matrix(A))
     L = nt.unnamed(module())
     C = nt.cholesky(A)
     C, L = torch.broadcast_tensors(C, L)
     isclose = torch.isclose(C, L, atol=1e-6)
     assert isclose.all(), (C[~isclose].tolist(), L[~isclose].tolist())
Exemple #12
0
 def reset_at(self, index: int) -> EnvObsType:
     init_state, _ = self.module.init.sample()
     self._curr_states[index] = nt.unnamed(init_state)
     return init_state.numpy().astype(self.observation_space.dtype)
Exemple #13
0
 def _transition(loc: Tensor, scale_tril: Tensor, time: IntTensor) -> Tensor:
     loc, scale_tril = nt.unnamed(loc, scale_tril)
     dist = torch.distributions.MultivariateNormal(loc=loc, scale_tril=scale_tril)
     state = dist.rsample().refine_names(*time.names)
     return pack_obs(state, time + 1)
Exemple #14
0
 def forward(self, obs: Tensor, action: Tensor) -> TensorDict:
     state, time = unpack_obs(obs)
     state = (self.normalizer(nt.unnamed(state).reshape(
         -1, self.n_state)).reshape_as(state).refine_names(*state.names))
     obs = pack_obs(state, time)
     return self._model(obs, action)
Exemple #15
0
def tensors_to_vector(tensors: Iterable[Tensor]) -> Tensor:
    """Reshape and combine tensors into a vector representation."""
    vector = []
    for t in tensors:
        vector += [nt.unnamed(t).reshape(-1)]
    return nt.vector(torch.cat(vector))
Exemple #16
0
def obs(n_state: int, horizon: int, batch_shape: tuple[int, ...]) -> Tensor:
    state = nt.vector(torch.randn(batch_shape + (n_state, )))
    time = nt.vector(
        torch.randint_like(nt.unnamed(state[..., :1]), low=0, high=horizon))
    # noinspection PyTypeChecker
    return pack_obs(state, time).requires_grad_(True)