예제 #1
0
def make_gaussinit(
    state_size: int,
    n_batch: Optional[int] = None,
    sample_covariance: bool = False,
    rng: RNG = None,
) -> GaussInit:
    """Generate parameters for Gaussian initial state distribution.

    Args:
        state_size: size of state vector
        n_batch: batch size, if any
        sample_covariance: whether to sample a random SPD matrix for the
            Gaussian covariance or use the identity matrix.
        rng: random number generator, seed, or None
    """
    # pylint:disable=invalid-name
    vec_shape = (state_size, )
    batch_shape = () if n_batch is None else (n_batch, )

    mu = torch.zeros(batch_shape + vec_shape)
    if sample_covariance:
        sig = as_float_tensor(
            make_spd_matrix(state_size, sample_shape=batch_shape, rng=rng))
    else:
        sig = torch.eye(state_size)

    return GaussInit(
        mu=utils.expand_and_refine(nt.vector(mu), 1, n_batch=n_batch),
        sig=utils.expand_and_refine(nt.matrix(sig), 2, n_batch=n_batch),
    )
예제 #2
0
def policy(n_state: int, n_ctrl: int, horizon: int) -> lqr.Linear:
    K = torch.Tensor(horizon, n_ctrl, n_state)
    k = torch.Tensor(horizon, n_ctrl)
    nn.init.xavier_uniform_(K)
    nn.init.constant_(k, 0)
    K, k = nt.horizon(nt.matrix(K), nt.vector(k))
    return K, k
예제 #3
0
    def forward(self, obs: Tensor, action: Tensor):
        # pylint:disable=missing-function-docstring
        obs, action = nt.vector(obs), nt.vector(action)
        state, time = unpack_obs(obs)

        # Get parameters for each timestep
        index = nt.vector_to_scalar(time)
        F, f, scale_tril = self._transition_factors(index)

        # Compute the loc for normal transitions
        tau = nt.vector_to_matrix(torch.cat([state, action], dim="R"))
        trans_loc = nt.matrix_to_vector(F @ tau + nt.vector_to_matrix(f))

        # Treat absorving states if necessary
        terminal = time.eq(self.horizon)
        loc = nt.where(terminal, state, trans_loc)
        return {"loc": loc, "scale_tril": scale_tril, "time": time}
예제 #4
0
    def __init__(self, n_state: int):
        super().__init__()
        self.n_state = n_state

        self.loc = nn.Parameter(Tensor(n_state))
        self.scale_tril = CholeskyFactor((n_state, n_state))
        self.register_buffer("time",
                             -nt.vector(torch.ones(1, dtype=torch.int)))
        self.reset_parameters()

        self.dist = TVMultivariateNormal()
예제 #5
0
    def step(self, action: Act) -> Tuple[Obs, Rew, Done, Info]:
        state = self._curr_state
        action = torch.as_tensor(action, dtype=torch.float32)
        action = nt.vector(action)

        reward = self.module.reward(state, action)
        next_state, _ = self.module.trans.sample(self.module.trans(state, action))

        self._curr_state = next_state
        done = next_state[-1].long() == self.horizon
        return self._get_obs(), reward.item(), done.item(), {}
예제 #6
0
 def _gains_at(
     self,
     index: Union[IntTensor, LongTensor,
                  None] = None) -> tuple[Tensor, Tensor]:
     K, k = nt.horizon(nt.matrix(self.K), nt.vector(self.k))
     if index is not None:
         index = torch.clamp(index, max=self.horizon - 1)
         # Assumes index is a named scalar tensor
         # noinspection PyTypeChecker
         K, k = (nt.index_by(x, dim="H", index=index) for x in (K, k))
     return K, k
예제 #7
0
def last_obs(
    n_state: int,
    horizon: int,
    batch_shape: tuple[int, ...],
    batch_names: tuple[str, ...],
) -> Tensor:
    state = nt.vector(torch.randn(batch_shape + (n_state, ))).refine_names(
        *batch_names, ...)
    dummy, _ = nt.split(state, [1, n_state - 1], dim="R")
    time = torch.full_like(dummy, fill_value=horizon).int()
    return pack_obs(state, time).requires_grad_()
예제 #8
0
def obs(
    n_state: int,
    horizon: int,
    batch_shape: tuple[int, ...],
    batch_names: tuple[str, ...],
) -> Tensor:
    state = nt.vector(torch.randn(batch_shape + (n_state, ))).refine_names(
        *batch_names, ...)
    dummy, _ = nt.split(state, [1, n_state - 1], dim="R")
    time = torch.randint_like(nt.unnamed(dummy), low=0, high=horizon)
    time = time.refine_names(*dummy.names).int()
    return pack_obs(state, time).requires_grad_()
예제 #9
0
def make_quadcost(
    state_size: int,
    ctrl_size: int,
    horizon: int,
    stationary: bool = True,
    n_batch: Optional[int] = None,
    linear: bool = False,
    cross_terms: bool = False,
    rng: RNG = None,
) -> QuadCost:
    """Generate quadratic cost parameters.

    Args:
        state_size: size of state vector
        ctrl_size: size of control vector
        horizon: length of the horizon
        stationary: whether dynamics vary with time
        n_batch: batch size, if any
        linear: whether to include a linear term in addition to the quadratic
        cross_terms: whether to include state-ctrl cross terms in the quadratic
            (C_sa and C_as)
        rng: random number generator, seed, or None
    """
    # pylint:disable=too-many-arguments,too-many-locals
    rng = np.random.default_rng(rng)
    n_tau = state_size + ctrl_size

    kwargs = dict(horizon=horizon,
                  stationary=stationary,
                  n_batch=n_batch,
                  rng=rng)

    C = utils.random_spd_matrix(n_tau, **kwargs)
    C_s, C_a = nt.split(C, [state_size, ctrl_size], dim="C")
    C_ss, C_sa = nt.split(C_s, [state_size, ctrl_size], dim="R")
    C_as, C_aa = nt.split(C_a, [state_size, ctrl_size], dim="R")

    if not cross_terms:
        C_sa, C_as = torch.zeros_like(C_sa), torch.zeros_like(C_as)

    C_s = torch.cat((C_ss, C_sa), dim="R")
    C_a = torch.cat((C_as, C_aa), dim="R")
    C = torch.cat((C_s, C_a), dim="C")

    if linear:
        c = utils.random_normal_vector(n_tau, **kwargs)
    else:
        c = utils.expand_and_refine(nt.vector(torch.zeros(n_tau)),
                                    1,
                                    horizon=horizon,
                                    n_batch=n_batch)
    return QuadCost(C, c)
예제 #10
0
 def _transition_factors(
         self,
         index: Optional[IntTensor] = None) -> (Tensor, Tensor, Tensor):
     F, f, L = nt.horizon(nt.matrix(self.F), nt.vector(self.f),
                          self.scale_tril())
     if index is not None:
         if self.stationary:
             idx = torch.zeros_like(index)
         else:
             # Timesteps after termination use last parameters
             idx = torch.clamp(index, max=self.horizon - 1).int()
         F, f, L = (nt.index_by(x, dim="H", index=idx) for x in (F, f, L))
     return F, f, L
예제 #11
0
    def forward(self, obs: Tensor, act: Tensor) -> Tensor:
        obs, act = (nt.vector(x) for x in (obs, act))
        state, time = unpack_obs(obs)
        tau = nt.vector_to_matrix(torch.cat([state, act], dim="R"))
        time = nt.vector_to_scalar(time)

        C, c = self._index_parameters(time)
        c = nt.vector_to_matrix(c)

        cost = nt.transpose(tau) @ C @ tau / 2 + nt.transpose(c) @ tau
        reward = nt.matrix_to_scalar(cost.neg())
        return nt.where(time.eq(self.horizon), torch.zeros_like(reward),
                        reward)
예제 #12
0
    def forward(self, obs: Tensor, frozen: bool = False) -> Tensor:
        """Compute the action vector for the observed state."""
        obs = nt.vector(obs)
        state, time = unpack_obs(obs)

        # noinspection PyTypeChecker
        K, k = self._gains_at(nt.vector_to_scalar(time))
        if frozen:
            K, k = K.detach(), k.detach()

        ctrl = K @ nt.vector_to_matrix(state) + nt.vector_to_matrix(k)
        ctrl = nt.matrix_to_vector(ctrl)
        # Return zeroed actions if in terminal state
        terminal = time.eq(self.horizon)
        return nt.where(terminal, torch.zeros_like(ctrl), ctrl)
예제 #13
0
def index_quadratic_parameters(
    quad: nn.Parameter,
    linear: nn.Parameter,
    const: nn.Parameter,
    index: IntTensor,
    max_idx: int,
) -> tuple[Tensor, Tensor, Tensor]:
    # pylint:disable=missing-function-docstring
    quad, linear, const = nt.horizon(nt.matrix(quad), nt.vector(linear),
                                     nt.scalar(const))

    index = torch.clamp(index, max=max_idx)
    quad, linear, const = map(lambda x: nt.index_by(x, dim="H", index=index),
                              (quad, linear, const))
    return quad, linear, const
예제 #14
0
    def test_call(self, module: QuadraticReward, obs: Tensor, act: Tensor):
        val = module(obs, act)
        assert torch.is_tensor(val)
        assert torch.isfinite(val).all()

        val.sum().backward()
        assert obs.grad is not None and act.grad is not None

        s_grad, t_grad = unpack_obs(nt.vector(obs.grad))
        assert not nt.allclose(s_grad, torch.zeros_like(s_grad))
        assert torch.isfinite(s_grad).all()
        assert nt.allclose(t_grad, torch.zeros_like(t_grad))

        assert not nt.allclose(act.grad, torch.zeros_like(act))
        assert torch.isfinite(act.grad).all()
예제 #15
0
def refine_lqr(dynamics: LinDynamics,
               cost: QuadCost) -> Tuple[LinDynamics, QuadCost]:
    """Add dimension names to LQR parameters.

    Args:
        dynamics: transition matrix and vector
        cost: quadratic cost matrix and vector

    Returns:
        A tuple with named dynamics and cost parameters
    """
    F, f = dynamics
    C, c = cost
    F, C = nt.matrix(F, C)
    f, c = nt.vector(f, c)
    F, f, C, c = nt.horizon(F, f, C, c)
    return LinDynamics(F, f), QuadCost(C, c)
예제 #16
0
    def vector_step(
        self, actions: List[EnvActionType]
    ) -> Tuple[List[EnvObsType], List[float], List[bool], List[EnvInfoDict]]:
        states = self._curr_states
        actions = np.vstack(actions).astype(self.action_space.dtype)
        actions = torch.from_numpy(actions)
        actions = nt.vector(actions)

        rewards = self.module.reward(states, actions)
        next_states, _ = self.module.trans.sample(self.module.trans(states, actions))
        dones = next_states[..., -1].long() == self.horizon
        self._curr_states = next_states

        obs = self._get_obs(self.curr_states)
        rewards = rewards.numpy().tolist()
        dones = dones.numpy().tolist()
        infos = [{} for _ in range(self.num_envs)]
        return obs, rewards, dones, infos
예제 #17
0
def random_normal_vector(
    size: int,
    horizon: int,
    stationary: bool = False,
    n_batch: Optional[int] = None,
    rng: RNG = None,
) -> Tensor:
    # pylint:disable=missing-function-docstring
    rng = np.random.default_rng(rng)

    vec_shape = (size,)
    shape = (
        minimal_sample_shape(horizon, stationary=stationary, n_batch=n_batch)
        + vec_shape
    )
    vec = nt.vector(as_float_tensor(rng.normal(size=shape)))
    vec = expand_and_refine(vec, 1, horizon=horizon, n_batch=n_batch)
    return vec
예제 #18
0
 def linear(self, n_state: int, n_ctrl: int, horizon: int) -> lqr.Linear:
     K, k = torch.randn(horizon, n_ctrl,
                        n_state), torch.randn(horizon, n_ctrl)
     K, k = nt.horizon(nt.matrix(K), nt.vector(k))
     return K, k
예제 #19
0
def policy(n_state: int, n_ctrl: int, horizon: int) -> Linear:
    K = torch.rand((horizon, n_ctrl, n_state))
    k = torch.rand((horizon, n_ctrl))
    K, k = nt.horizon(nt.matrix(K), nt.vector(k))
    return K, k
예제 #20
0
 def _refined_parameters(self) -> tuple[Tensor, Tensor]:
     C, c = nt.horizon(nt.matrix(self.C), nt.vector(self.c))
     return C, c
예제 #21
0
 def standard_form(self) -> lqr.GaussInit:
     # pylint:disable=missing-function-docstring
     loc = nt.vector(self.loc)
     scale_tril = self.scale_tril()
     sigma = scale_tril @ nt.transpose(scale_tril)
     return lqr.GaussInit(loc, sigma)
예제 #22
0
 def obs(self, state: Tensor, horizon: int, batch_shape: tuple[int, ...]) -> Tensor:
     time = torch.randint(low=0, high=horizon, size=batch_shape + (1,))
     return pack_obs(state, nt.vector(time)).requires_grad_(True)
예제 #23
0
def obs(state: Tensor, batch_shape: tuple[int, ...],
        batch_names: tuple[str, ...]) -> Tensor:
    time = nt.vector(torch.zeros(batch_shape + (1, )).int())
    return pack_obs(state, time).refine_names(*batch_names,
                                              ...).requires_grad_(True)
예제 #24
0
 def gains(self) -> lqr.Linear:
     """Return current parameters as linear parameters."""
     K, k = nt.horizon(nt.matrix(self.K), nt.vector(self.k))
     K.grad, k.grad = self.K.grad, self.k.grad
     return K, k
예제 #25
0
 def last_obs(
     self, state: Tensor, horizon: int, batch_shape: tuple[int, ...]
 ) -> Tensor:
     time = torch.full(batch_shape + (1,), fill_value=horizon, dtype=torch.int)
     return pack_obs(state, nt.vector(time)).requires_grad_(True)
예제 #26
0
 def act(self, n_ctrl: int, batch_shape) -> Tensor:
     return nt.vector(torch.randn(batch_shape + (n_ctrl,))).requires_grad_(True)
예제 #27
0
def state(dim: int, batch_shape: tuple[int, ...]) -> Tensor:
    return nt.vector(torch.randn(batch_shape + (dim, )))
예제 #28
0
 def forward(self) -> DistParams:
     # pylint:disable=missing-function-docstring
     loc = nt.vector(self.loc)
     return {"loc": loc, "scale_tril": self.scale_tril(), "time": self.time}
예제 #29
0
def init(dim: int) -> lqr.GaussInit:
    return lqr.GaussInit(nt.vector(torch.randn(dim)),
                         nt.matrix(torch.eye(dim)))
예제 #30
0
 def log_prob(self, value: Tensor) -> Tensor:
     value = nt.vector(value)
     params = self()
     return self.dist.log_prob(value, params)