Пример #1
0
 def _absorving_logp(
     cur_state: Tensor, cur_time: IntTensor, state: Tensor, time: IntTensor
 ) -> Tensor:
     # We assume time is a named scalar tensor
     cur_obs = pack_obs(cur_state, nt.scalar_to_vector(cur_time))
     obs = pack_obs(state, nt.scalar_to_vector(time))
     return nt.where(
         # Point mass only at the same state
         nt.reduce_all(nt.isclose(cur_obs, obs), dim="R"),
         torch.zeros_like(time, dtype=torch.float),
         torch.full_like(time, fill_value=float("nan"), dtype=torch.float),
     )
Пример #2
0
    def _gen_sample(self, loc: Tensor, scale_tril: Tensor, time: IntTensor) -> Tensor:
        next_obs = self._transition(loc, scale_tril, time)
        if not self.horizon:
            return next_obs

        # Filter results
        # We're in an absorving state if the current timestep is the horizon
        return nt.where(time.eq(self.horizon), pack_obs(loc, time), next_obs)
Пример #3
0
def last_obs(
    n_state: int,
    horizon: int,
    batch_shape: tuple[int, ...],
    batch_names: tuple[str, ...],
) -> Tensor:
    state = nt.vector(torch.randn(batch_shape + (n_state, ))).refine_names(
        *batch_names, ...)
    dummy, _ = nt.split(state, [1, n_state - 1], dim="R")
    time = torch.full_like(dummy, fill_value=horizon).int()
    return pack_obs(state, time).requires_grad_()
Пример #4
0
def obs(
    n_state: int,
    horizon: int,
    batch_shape: tuple[int, ...],
    batch_names: tuple[str, ...],
) -> Tensor:
    state = nt.vector(torch.randn(batch_shape + (n_state, ))).refine_names(
        *batch_names, ...)
    dummy, _ = nt.split(state, [1, n_state - 1], dim="R")
    time = torch.randint_like(nt.unnamed(dummy), low=0, high=horizon)
    time = time.refine_names(*dummy.names).int()
    return pack_obs(state, time).requires_grad_()
Пример #5
0
 def last_obs(
     self, state: Tensor, horizon: int, batch_shape: tuple[int, ...]
 ) -> Tensor:
     time = torch.full(batch_shape + (1,), fill_value=horizon, dtype=torch.int)
     return pack_obs(state, nt.vector(time)).requires_grad_(True)
Пример #6
0
 def obs(self, state: Tensor, horizon: int, batch_shape: tuple[int, ...]) -> Tensor:
     time = torch.randint(low=0, high=horizon, size=batch_shape + (1,))
     return pack_obs(state, nt.vector(time)).requires_grad_(True)
Пример #7
0
 def rsample(
     self, params: TensorDict, sample_shape: list[int] = ()) -> SampleLogp:
     residual, log_prob = self.dist.rsample(params, sample_shape)
     delta, time = unpack_obs(residual)
     next_obs = pack_obs(params["state"] + delta, time)
     return next_obs, log_prob
Пример #8
0
 def forward(self, obs: Tensor, action: Tensor) -> TensorDict:
     state, time = unpack_obs(obs)
     state = (self.normalizer(nt.unnamed(state).reshape(
         -1, self.n_state)).reshape_as(state).refine_names(*state.names))
     obs = pack_obs(state, time)
     return self._model(obs, action)
Пример #9
0
 def deterministic(self, params: TensorDict) -> SampleLogp:
     residual, log_prob = self.dist.deterministic(params)
     delta, time = unpack_obs(residual)
     return pack_obs(params["state"] + delta, time), log_prob
Пример #10
0
 def pre_terminal(self, state: Tensor, horizon: int,
                  n_batch: int) -> Tensor:
     time = torch.full((n_batch, 1),
                       fill_value=horizon - 1,
                       dtype=torch.int)
     return pack_obs(state, nt.vector(time).refine_names("B", ...))
Пример #11
0
 def icdf(self, prob, params: TensorDict) -> Tensor:
     residual = self.dist.icdf(prob, params)
     delta, time = unpack_obs(residual)
     return pack_obs(params["state"] + delta, time)
Пример #12
0
 def cdf(self, next_obs: Tensor, params: TensorDict) -> Tensor:
     next_state, time = unpack_obs(next_obs)
     residual = pack_obs(next_state - params["state"], time)
     return self.dist.cdf(residual, params)
Пример #13
0
def new_obs(obs: Tensor) -> Tensor:
    state, time = unpack_obs(obs)
    state_ = torch.randn_like(state)
    time_ = time + 1
    return pack_obs(state_, time_).requires_grad_()
Пример #14
0
def last_obs(n_state: int, horizon: int, batch_shape: tuple[int,
                                                            ...]) -> Tensor:
    state = nt.vector(torch.randn(batch_shape + (n_state, )))
    time = nt.vector(torch.full_like(state[..., :1], fill_value=horizon))
    # noinspection PyTypeChecker
    return pack_obs(state, time).requires_grad_(True)
Пример #15
0
def obs(n_state: int, horizon: int, batch_shape: tuple[int, ...]) -> Tensor:
    state = nt.vector(torch.randn(batch_shape + (n_state, )))
    time = nt.vector(
        torch.randint_like(nt.unnamed(state[..., :1]), low=0, high=horizon))
    # noinspection PyTypeChecker
    return pack_obs(state, time).requires_grad_(True)
Пример #16
0
 def reproduce(self, next_obs, params: TensorDict) -> SampleLogp:
     next_state, time = unpack_obs(next_obs)
     residual = pack_obs(next_state - params["state"], time)
     residual_, log_prob_ = self.dist.reproduce(residual, params)
     delta_, time_ = unpack_obs(residual_)
     return pack_obs(params["state"] + delta_, time_), log_prob_
Пример #17
0
 def _transition(loc: Tensor, scale_tril: Tensor, time: IntTensor) -> Tensor:
     loc, scale_tril = nt.unnamed(loc, scale_tril)
     dist = torch.distributions.MultivariateNormal(loc=loc, scale_tril=scale_tril)
     state = dist.rsample().refine_names(*time.names)
     return pack_obs(state, time + 1)
Пример #18
0
 def obs(self, state: Tensor, horizon: int, n_batch: int) -> Tensor:
     time = torch.randint(low=0, high=horizon, size=(n_batch, 1))
     return pack_obs(state, nt.vector(time).refine_names("B", ...))