Beispiel #1
0
    def test_absorving(self, module: StochasticModel, last_obs: Tensor,
                       act: Tensor):
        params = module(last_obs, act)
        sample, logp = module.rsample(params)

        assert sample.shape == last_obs.shape
        assert sample.names == last_obs.names
        state, time = unpack_obs(last_obs)
        state_, time_ = unpack_obs(sample)
        assert nt.allclose(state, state_)
        assert time.eq(time_).all()

        assert sample.grad_fn is not None
        sample.sum().backward(retain_graph=True)
        assert last_obs.grad is not None
        expected_grad = torch.cat(
            [torch.ones_like(state),
             torch.zeros_like(time)], dim="R")
        assert nt.allclose(last_obs.grad, expected_grad)
        assert nt.allclose(act.grad, torch.zeros(()))

        last_obs.grad, act.grad = None, None
        assert logp.shape == tuple(
            s for s, n in zip(last_obs.shape, last_obs.names) if n != "R")
        assert logp.names == tuple(n for n in last_obs.names if n != "R")
        assert nt.allclose(logp, torch.zeros(()))
        logp.sum().backward()
        assert nt.allclose(last_obs.grad, torch.zeros(()))
        assert nt.allclose(act.grad, torch.zeros(()))
Beispiel #2
0
    def test_log_prob(self, module: StochasticModel, obs: Tensor, act: Tensor,
                      new_obs: Tensor):
        params = module(obs, act)
        log_prob = module.log_prob(new_obs, params)
        _, time = unpack_obs(obs)
        _, time_ = unpack_obs(new_obs)
        time, time_ = nt.vector_to_scalar(time, time_)

        assert torch.is_tensor(log_prob)
        assert torch.isfinite(log_prob).all()
        assert log_prob.shape == time.shape == time_.shape
        assert log_prob.names == time.names == time_.names

        assert log_prob.grad_fn is not None
        log_prob.sum().backward()
        assert obs.grad is not None
        assert act.grad is not None
        assert not nt.allclose(obs.grad, torch.zeros(()))
        assert not nt.allclose(act.grad, torch.zeros(()))
        grads = list(p.grad for p in module.parameters())
        assert all(list(g is not None for g in grads))
        assert all(list(not torch.allclose(g, torch.zeros(())) for g in grads))
Beispiel #3
0
    def forward(self, obs: Tensor, act: Tensor) -> Tensor:
        obs, act = (nt.vector(x) for x in (obs, act))
        state, time = unpack_obs(obs)
        tau = nt.vector_to_matrix(torch.cat([state, act], dim="R"))
        time = nt.vector_to_scalar(time)

        C, c = self._index_parameters(time)
        c = nt.vector_to_matrix(c)

        cost = nt.transpose(tau) @ C @ tau / 2 + nt.transpose(c) @ tau
        reward = nt.matrix_to_scalar(cost.neg())
        return nt.where(time.eq(self.horizon), torch.zeros_like(reward),
                        reward)
Beispiel #4
0
 def forward(self, obs: Tensor) -> Tensor:
     state, time = unpack_obs(obs)
     time = nt.vector_to_scalar(time)
     quad, linear, const = index_quadratic_parameters(self.quad,
                                                      self.linear,
                                                      self.const,
                                                      time,
                                                      max_idx=self.horizon)
     state = nt.vector_to_matrix(state)
     cost = nt.matrix_to_scalar(
         nt.transpose(state) @ quad @ state / 2 +
         nt.transpose(nt.vector_to_matrix(linear)) @ state +
         nt.scalar_to_matrix(const))
     return cost.neg()
Beispiel #5
0
    def forward(self, obs: Tensor, frozen: bool = False) -> Tensor:
        """Compute the action vector for the observed state."""
        obs = nt.vector(obs)
        state, time = unpack_obs(obs)

        # noinspection PyTypeChecker
        K, k = self._gains_at(nt.vector_to_scalar(time))
        if frozen:
            K, k = K.detach(), k.detach()

        ctrl = K @ nt.vector_to_matrix(state) + nt.vector_to_matrix(k)
        ctrl = nt.matrix_to_vector(ctrl)
        # Return zeroed actions if in terminal state
        terminal = time.eq(self.horizon)
        return nt.where(terminal, torch.zeros_like(ctrl), ctrl)
Beispiel #6
0
    def test_call(self, module: QuadraticReward, obs: Tensor, act: Tensor):
        val = module(obs, act)
        assert torch.is_tensor(val)
        assert torch.isfinite(val).all()

        val.sum().backward()
        assert obs.grad is not None and act.grad is not None

        s_grad, t_grad = unpack_obs(nt.vector(obs.grad))
        assert not nt.allclose(s_grad, torch.zeros_like(s_grad))
        assert torch.isfinite(s_grad).all()
        assert nt.allclose(t_grad, torch.zeros_like(t_grad))

        assert not nt.allclose(act.grad, torch.zeros_like(act))
        assert torch.isfinite(act.grad).all()
Beispiel #7
0
    def test_rsample(self, module: StochasticModel, obs: Tensor, act: Tensor):
        params = module(obs, act)
        sample, logp = module.rsample(params)

        assert sample.shape == obs.shape
        assert sample.names == obs.names
        _, time = unpack_obs(obs)
        _, time_ = unpack_obs(sample)
        assert time.eq(time_ - 1).all()

        assert sample.grad_fn is not None
        sample.sum().backward(retain_graph=True)
        assert obs.grad is not None
        assert act.grad is not None

        assert logp.shape == tuple(s for s, n in zip(obs.shape, obs.names)
                                   if n != "R")
        assert logp.names == tuple(n for n in obs.names if n != "R")

        obs.grad, act.grad = None, None
        assert logp.grad_fn is not None
        logp.sum().backward()
        assert obs.grad is not None
        assert act.grad is not None
Beispiel #8
0
    def forward(self, obs: Tensor, action: Tensor):
        # pylint:disable=missing-function-docstring
        obs, action = nt.vector(obs), nt.vector(action)
        state, time = unpack_obs(obs)

        # Get parameters for each timestep
        index = nt.vector_to_scalar(time)
        F, f, scale_tril = self._transition_factors(index)

        # Compute the loc for normal transitions
        tau = nt.vector_to_matrix(torch.cat([state, action], dim="R"))
        trans_loc = nt.matrix_to_vector(F @ tau + nt.vector_to_matrix(f))

        # Treat absorving states if necessary
        terminal = time.eq(self.horizon)
        loc = nt.where(terminal, state, trans_loc)
        return {"loc": loc, "scale_tril": scale_tril, "time": time}
Beispiel #9
0
 def forward(self, obs: Tensor, action: Tensor) -> Tensor:
     state, time = unpack_obs(obs)
     time = nt.vector_to_scalar(time)
     # noinspection PyTypeChecker
     quad, linear, const = index_quadratic_parameters(self.quad,
                                                      self.linear,
                                                      self.const,
                                                      time,
                                                      max_idx=self.horizon -
                                                      1)
     vec = nt.vector_to_matrix(torch.cat([state, action], dim="R"))
     cost = nt.matrix_to_scalar(
         nt.transpose(vec) @ quad @ vec / 2 +
         nt.transpose(nt.vector_to_matrix(linear)) @ vec +
         nt.scalar_to_matrix(const))
     val = cost.neg()
     return nt.where(time.eq(self.horizon), torch.zeros_like(val), val)
Beispiel #10
0
    def _logp(
        self, loc: Tensor, scale_tril: Tensor, time: Tensor, value: Tensor
    ) -> Tensor:
        # Align input tensors
        state, time_ = unpack_obs(value)
        loc, state = torch.broadcast_tensors(loc, state)
        time, time_ = torch.broadcast_tensors(time, time_)

        # Consider normal state transition
        time, time_ = nt.vector_to_scalar(time, time_)
        trans_logp = self._trans_logp(loc, scale_tril, time, state, time_)
        if not self.horizon:
            return trans_logp

        # If horizon is set, treat absorving state transitions
        absorving_logp = self._absorving_logp(loc, time, state, time_)

        # Filter results
        # We're in an absorving state if the current timestep is the horizon
        return nt.where(time.eq(self.horizon), absorving_logp, trans_logp)
Beispiel #11
0
 def cdf(self, next_obs: Tensor, params: TensorDict) -> Tensor:
     next_state, time = unpack_obs(next_obs)
     residual = pack_obs(next_state - params["state"], time)
     return self.dist.cdf(residual, params)
Beispiel #12
0
 def rsample(
     self, params: TensorDict, sample_shape: list[int] = ()) -> SampleLogp:
     residual, log_prob = self.dist.rsample(params, sample_shape)
     delta, time = unpack_obs(residual)
     next_obs = pack_obs(params["state"] + delta, time)
     return next_obs, log_prob
Beispiel #13
0
 def forward(self, obs: Tensor, action: Tensor) -> TensorDict:
     params = self.params(obs, action)
     state, _ = unpack_obs(obs)
     params["state"] = state
     return params
Beispiel #14
0
 def forward(self, obs: Tensor, action: Tensor) -> TensorDict:
     state, time = unpack_obs(obs)
     state = (self.normalizer(nt.unnamed(state).reshape(
         -1, self.n_state)).reshape_as(state).refine_names(*state.names))
     obs = pack_obs(state, time)
     return self._model(obs, action)
Beispiel #15
0
 def deterministic(self, params: TensorDict) -> SampleLogp:
     residual, log_prob = self.dist.deterministic(params)
     delta, time = unpack_obs(residual)
     return pack_obs(params["state"] + delta, time), log_prob
Beispiel #16
0
 def reproduce(self, next_obs, params: TensorDict) -> SampleLogp:
     next_state, time = unpack_obs(next_obs)
     residual = pack_obs(next_state - params["state"], time)
     residual_, log_prob_ = self.dist.reproduce(residual, params)
     delta_, time_ = unpack_obs(residual_)
     return pack_obs(params["state"] + delta_, time_), log_prob_
Beispiel #17
0
def new_obs(obs: Tensor) -> Tensor:
    state, time = unpack_obs(obs)
    state_ = torch.randn_like(state)
    time_ = time + 1
    return pack_obs(state_, time_).requires_grad_()
Beispiel #18
0
 def check_shapes(self, loc: Tensor, scale_tril: Tensor, time: Tensor,
                  obs: Tensor):
     state, time_ = unpack_obs(obs)
     assert loc.shape == state.shape
     assert scale_tril.shape == state.shape + state.shape[-1:]
     assert time.shape == time_.shape
Beispiel #19
0
 def icdf(self, prob, params: TensorDict) -> Tensor:
     residual = self.dist.icdf(prob, params)
     delta, time = unpack_obs(residual)
     return pack_obs(params["state"] + delta, time)
Beispiel #20
0
def mix_obs(obs: Tensor, last_obs: Tensor) -> Tensor:
    _, time = unpack_obs(obs)
    mix = nt.where(torch.rand_like(time.float()) < 0.5, obs, last_obs)
    mix.retain_grad()
    return mix