示例#1
0
    def forward(self, obs: Tensor, frozen: bool = False) -> Tensor:
        """Compute the action vector for the observed state."""
        obs = nt.vector(obs)
        state, time = unpack_obs(obs)

        # noinspection PyTypeChecker
        K, k = self._gains_at(nt.vector_to_scalar(time))
        if frozen:
            K, k = K.detach(), k.detach()

        ctrl = K @ nt.vector_to_matrix(state) + nt.vector_to_matrix(k)
        ctrl = nt.matrix_to_vector(ctrl)
        # Return zeroed actions if in terminal state
        terminal = time.eq(self.horizon)
        return nt.where(terminal, torch.zeros_like(ctrl), ctrl)
示例#2
0
    def forward(self, obs: Tensor, action: Tensor):
        # pylint:disable=missing-function-docstring
        obs, action = nt.vector(obs), nt.vector(action)
        state, time = unpack_obs(obs)

        # Get parameters for each timestep
        index = nt.vector_to_scalar(time)
        F, f, scale_tril = self._transition_factors(index)

        # Compute the loc for normal transitions
        tau = nt.vector_to_matrix(torch.cat([state, action], dim="R"))
        trans_loc = nt.matrix_to_vector(F @ tau + nt.vector_to_matrix(f))

        # Treat absorving states if necessary
        terminal = time.eq(self.horizon)
        loc = nt.where(terminal, state, trans_loc)
        return {"loc": loc, "scale_tril": scale_tril, "time": time}
示例#3
0
def refine_cost_ouput(cost: QuadCost) -> QuadCost:
    C, c = cost
    C, c = nt.horizon(nt.matrix(C), nt.matrix_to_vector(c))
    return QuadCost(C, c)
示例#4
0
def refine_quadratic_output(quadratic: Quadratic):
    A, b, c = quadratic
    A, b, c = nt.horizon(nt.matrix(A), nt.matrix_to_vector(b),
                         nt.matrix_to_scalar(c))
    return A, b, c
示例#5
0
def refine_linear_output(linear: Linear):
    K, k = linear
    K, k = nt.horizon(nt.matrix(K), nt.matrix_to_vector(k))
    return K, k