Пример #1
0
    def _gen_sample(self, loc: Tensor, scale_tril: Tensor, time: IntTensor) -> Tensor:
        next_obs = self._transition(loc, scale_tril, time)
        if not self.horizon:
            return next_obs

        # Filter results
        # We're in an absorving state if the current timestep is the horizon
        return nt.where(time.eq(self.horizon), pack_obs(loc, time), next_obs)
Пример #2
0
 def _absorving_logp(
     cur_state: Tensor, cur_time: IntTensor, state: Tensor, time: IntTensor
 ) -> Tensor:
     # We assume time is a named scalar tensor
     cur_obs = pack_obs(cur_state, nt.scalar_to_vector(cur_time))
     obs = pack_obs(state, nt.scalar_to_vector(time))
     return nt.where(
         # Point mass only at the same state
         nt.reduce_all(nt.isclose(cur_obs, obs), dim="R"),
         torch.zeros_like(time, dtype=torch.float),
         torch.full_like(time, fill_value=float("nan"), dtype=torch.float),
     )
Пример #3
0
    def forward(self, obs: Tensor, act: Tensor) -> Tensor:
        obs, act = (nt.vector(x) for x in (obs, act))
        state, time = unpack_obs(obs)
        tau = nt.vector_to_matrix(torch.cat([state, act], dim="R"))
        time = nt.vector_to_scalar(time)

        C, c = self._index_parameters(time)
        c = nt.vector_to_matrix(c)

        cost = nt.transpose(tau) @ C @ tau / 2 + nt.transpose(c) @ tau
        reward = nt.matrix_to_scalar(cost.neg())
        return nt.where(time.eq(self.horizon), torch.zeros_like(reward),
                        reward)
Пример #4
0
    def forward(self, obs: Tensor, frozen: bool = False) -> Tensor:
        """Compute the action vector for the observed state."""
        obs = nt.vector(obs)
        state, time = unpack_obs(obs)

        # noinspection PyTypeChecker
        K, k = self._gains_at(nt.vector_to_scalar(time))
        if frozen:
            K, k = K.detach(), k.detach()

        ctrl = K @ nt.vector_to_matrix(state) + nt.vector_to_matrix(k)
        ctrl = nt.matrix_to_vector(ctrl)
        # Return zeroed actions if in terminal state
        terminal = time.eq(self.horizon)
        return nt.where(terminal, torch.zeros_like(ctrl), ctrl)
Пример #5
0
 def forward(self, obs: Tensor, action: Tensor) -> Tensor:
     state, time = unpack_obs(obs)
     time = nt.vector_to_scalar(time)
     # noinspection PyTypeChecker
     quad, linear, const = index_quadratic_parameters(self.quad,
                                                      self.linear,
                                                      self.const,
                                                      time,
                                                      max_idx=self.horizon -
                                                      1)
     vec = nt.vector_to_matrix(torch.cat([state, action], dim="R"))
     cost = nt.matrix_to_scalar(
         nt.transpose(vec) @ quad @ vec / 2 +
         nt.transpose(nt.vector_to_matrix(linear)) @ vec +
         nt.scalar_to_matrix(const))
     val = cost.neg()
     return nt.where(time.eq(self.horizon), torch.zeros_like(val), val)
Пример #6
0
    def forward(self, obs: Tensor, action: Tensor):
        # pylint:disable=missing-function-docstring
        obs, action = nt.vector(obs), nt.vector(action)
        state, time = unpack_obs(obs)

        # Get parameters for each timestep
        index = nt.vector_to_scalar(time)
        F, f, scale_tril = self._transition_factors(index)

        # Compute the loc for normal transitions
        tau = nt.vector_to_matrix(torch.cat([state, action], dim="R"))
        trans_loc = nt.matrix_to_vector(F @ tau + nt.vector_to_matrix(f))

        # Treat absorving states if necessary
        terminal = time.eq(self.horizon)
        loc = nt.where(terminal, state, trans_loc)
        return {"loc": loc, "scale_tril": scale_tril, "time": time}
Пример #7
0
 def _trans_logp(
     loc: Tensor,
     scale_tril: Tensor,
     cur_time: IntTensor,
     state: Tensor,
     time: IntTensor,
 ) -> Tensor:
     loc, scale_tril = nt.unnamed(loc, scale_tril)
     dist = torch.distributions.MultivariateNormal(loc=loc, scale_tril=scale_tril)
     trans_logp: Tensor = dist.log_prob(nt.unnamed(state))
     trans_logp = nt.where(
         # Logp only defined at next timestep
         time.eq(cur_time + 1),
         trans_logp,
         torch.full(time.shape, fill_value=float("nan")),
     )
     # We assume time is a named scalar tensor
     return trans_logp.refine_names(*time.names)
Пример #8
0
    def _logp(
        self, loc: Tensor, scale_tril: Tensor, time: Tensor, value: Tensor
    ) -> Tensor:
        # Align input tensors
        state, time_ = unpack_obs(value)
        loc, state = torch.broadcast_tensors(loc, state)
        time, time_ = torch.broadcast_tensors(time, time_)

        # Consider normal state transition
        time, time_ = nt.vector_to_scalar(time, time_)
        trans_logp = self._trans_logp(loc, scale_tril, time, state, time_)
        if not self.horizon:
            return trans_logp

        # If horizon is set, treat absorving state transitions
        absorving_logp = self._absorving_logp(loc, time, state, time_)

        # Filter results
        # We're in an absorving state if the current timestep is the horizon
        return nt.where(time.eq(self.horizon), absorving_logp, trans_logp)
Пример #9
0
def mix_obs(obs: Tensor, last_obs: Tensor) -> Tensor:
    _, time = unpack_obs(obs)
    mix = nt.where(torch.rand_like(time.float()) < 0.5, obs, last_obs)
    mix.retain_grad()
    return mix