def forward(self, obs: Tensor, frozen: bool = False) -> Tensor: """Compute the action vector for the observed state.""" obs = nt.vector(obs) state, time = unpack_obs(obs) # noinspection PyTypeChecker K, k = self._gains_at(nt.vector_to_scalar(time)) if frozen: K, k = K.detach(), k.detach() ctrl = K @ nt.vector_to_matrix(state) + nt.vector_to_matrix(k) ctrl = nt.matrix_to_vector(ctrl) # Return zeroed actions if in terminal state terminal = time.eq(self.horizon) return nt.where(terminal, torch.zeros_like(ctrl), ctrl)
def forward(self, obs: Tensor, action: Tensor): # pylint:disable=missing-function-docstring obs, action = nt.vector(obs), nt.vector(action) state, time = unpack_obs(obs) # Get parameters for each timestep index = nt.vector_to_scalar(time) F, f, scale_tril = self._transition_factors(index) # Compute the loc for normal transitions tau = nt.vector_to_matrix(torch.cat([state, action], dim="R")) trans_loc = nt.matrix_to_vector(F @ tau + nt.vector_to_matrix(f)) # Treat absorving states if necessary terminal = time.eq(self.horizon) loc = nt.where(terminal, state, trans_loc) return {"loc": loc, "scale_tril": scale_tril, "time": time}
def refine_cost_ouput(cost: QuadCost) -> QuadCost: C, c = cost C, c = nt.horizon(nt.matrix(C), nt.matrix_to_vector(c)) return QuadCost(C, c)
def refine_quadratic_output(quadratic: Quadratic): A, b, c = quadratic A, b, c = nt.horizon(nt.matrix(A), nt.matrix_to_vector(b), nt.matrix_to_scalar(c)) return A, b, c
def refine_linear_output(linear: Linear): K, k = linear K, k = nt.horizon(nt.matrix(K), nt.matrix_to_vector(k)) return K, k