def test_absorving(self, module: StochasticModel, last_obs: Tensor, act: Tensor): params = module(last_obs, act) sample, logp = module.rsample(params) assert sample.shape == last_obs.shape assert sample.names == last_obs.names state, time = unpack_obs(last_obs) state_, time_ = unpack_obs(sample) assert nt.allclose(state, state_) assert time.eq(time_).all() assert sample.grad_fn is not None sample.sum().backward(retain_graph=True) assert last_obs.grad is not None expected_grad = torch.cat( [torch.ones_like(state), torch.zeros_like(time)], dim="R") assert nt.allclose(last_obs.grad, expected_grad) assert nt.allclose(act.grad, torch.zeros(())) last_obs.grad, act.grad = None, None assert logp.shape == tuple( s for s, n in zip(last_obs.shape, last_obs.names) if n != "R") assert logp.names == tuple(n for n in last_obs.names if n != "R") assert nt.allclose(logp, torch.zeros(())) logp.sum().backward() assert nt.allclose(last_obs.grad, torch.zeros(())) assert nt.allclose(act.grad, torch.zeros(()))
def test_log_prob(self, module: StochasticModel, obs: Tensor, act: Tensor, new_obs: Tensor): params = module(obs, act) log_prob = module.log_prob(new_obs, params) _, time = unpack_obs(obs) _, time_ = unpack_obs(new_obs) time, time_ = nt.vector_to_scalar(time, time_) assert torch.is_tensor(log_prob) assert torch.isfinite(log_prob).all() assert log_prob.shape == time.shape == time_.shape assert log_prob.names == time.names == time_.names assert log_prob.grad_fn is not None log_prob.sum().backward() assert obs.grad is not None assert act.grad is not None assert not nt.allclose(obs.grad, torch.zeros(())) assert not nt.allclose(act.grad, torch.zeros(())) grads = list(p.grad for p in module.parameters()) assert all(list(g is not None for g in grads)) assert all(list(not torch.allclose(g, torch.zeros(())) for g in grads))
def forward(self, obs: Tensor, act: Tensor) -> Tensor: obs, act = (nt.vector(x) for x in (obs, act)) state, time = unpack_obs(obs) tau = nt.vector_to_matrix(torch.cat([state, act], dim="R")) time = nt.vector_to_scalar(time) C, c = self._index_parameters(time) c = nt.vector_to_matrix(c) cost = nt.transpose(tau) @ C @ tau / 2 + nt.transpose(c) @ tau reward = nt.matrix_to_scalar(cost.neg()) return nt.where(time.eq(self.horizon), torch.zeros_like(reward), reward)
def forward(self, obs: Tensor) -> Tensor: state, time = unpack_obs(obs) time = nt.vector_to_scalar(time) quad, linear, const = index_quadratic_parameters(self.quad, self.linear, self.const, time, max_idx=self.horizon) state = nt.vector_to_matrix(state) cost = nt.matrix_to_scalar( nt.transpose(state) @ quad @ state / 2 + nt.transpose(nt.vector_to_matrix(linear)) @ state + nt.scalar_to_matrix(const)) return cost.neg()
def forward(self, obs: Tensor, frozen: bool = False) -> Tensor: """Compute the action vector for the observed state.""" obs = nt.vector(obs) state, time = unpack_obs(obs) # noinspection PyTypeChecker K, k = self._gains_at(nt.vector_to_scalar(time)) if frozen: K, k = K.detach(), k.detach() ctrl = K @ nt.vector_to_matrix(state) + nt.vector_to_matrix(k) ctrl = nt.matrix_to_vector(ctrl) # Return zeroed actions if in terminal state terminal = time.eq(self.horizon) return nt.where(terminal, torch.zeros_like(ctrl), ctrl)
def test_call(self, module: QuadraticReward, obs: Tensor, act: Tensor): val = module(obs, act) assert torch.is_tensor(val) assert torch.isfinite(val).all() val.sum().backward() assert obs.grad is not None and act.grad is not None s_grad, t_grad = unpack_obs(nt.vector(obs.grad)) assert not nt.allclose(s_grad, torch.zeros_like(s_grad)) assert torch.isfinite(s_grad).all() assert nt.allclose(t_grad, torch.zeros_like(t_grad)) assert not nt.allclose(act.grad, torch.zeros_like(act)) assert torch.isfinite(act.grad).all()
def test_rsample(self, module: StochasticModel, obs: Tensor, act: Tensor): params = module(obs, act) sample, logp = module.rsample(params) assert sample.shape == obs.shape assert sample.names == obs.names _, time = unpack_obs(obs) _, time_ = unpack_obs(sample) assert time.eq(time_ - 1).all() assert sample.grad_fn is not None sample.sum().backward(retain_graph=True) assert obs.grad is not None assert act.grad is not None assert logp.shape == tuple(s for s, n in zip(obs.shape, obs.names) if n != "R") assert logp.names == tuple(n for n in obs.names if n != "R") obs.grad, act.grad = None, None assert logp.grad_fn is not None logp.sum().backward() assert obs.grad is not None assert act.grad is not None
def forward(self, obs: Tensor, action: Tensor): # pylint:disable=missing-function-docstring obs, action = nt.vector(obs), nt.vector(action) state, time = unpack_obs(obs) # Get parameters for each timestep index = nt.vector_to_scalar(time) F, f, scale_tril = self._transition_factors(index) # Compute the loc for normal transitions tau = nt.vector_to_matrix(torch.cat([state, action], dim="R")) trans_loc = nt.matrix_to_vector(F @ tau + nt.vector_to_matrix(f)) # Treat absorving states if necessary terminal = time.eq(self.horizon) loc = nt.where(terminal, state, trans_loc) return {"loc": loc, "scale_tril": scale_tril, "time": time}
def forward(self, obs: Tensor, action: Tensor) -> Tensor: state, time = unpack_obs(obs) time = nt.vector_to_scalar(time) # noinspection PyTypeChecker quad, linear, const = index_quadratic_parameters(self.quad, self.linear, self.const, time, max_idx=self.horizon - 1) vec = nt.vector_to_matrix(torch.cat([state, action], dim="R")) cost = nt.matrix_to_scalar( nt.transpose(vec) @ quad @ vec / 2 + nt.transpose(nt.vector_to_matrix(linear)) @ vec + nt.scalar_to_matrix(const)) val = cost.neg() return nt.where(time.eq(self.horizon), torch.zeros_like(val), val)
def _logp( self, loc: Tensor, scale_tril: Tensor, time: Tensor, value: Tensor ) -> Tensor: # Align input tensors state, time_ = unpack_obs(value) loc, state = torch.broadcast_tensors(loc, state) time, time_ = torch.broadcast_tensors(time, time_) # Consider normal state transition time, time_ = nt.vector_to_scalar(time, time_) trans_logp = self._trans_logp(loc, scale_tril, time, state, time_) if not self.horizon: return trans_logp # If horizon is set, treat absorving state transitions absorving_logp = self._absorving_logp(loc, time, state, time_) # Filter results # We're in an absorving state if the current timestep is the horizon return nt.where(time.eq(self.horizon), absorving_logp, trans_logp)
def cdf(self, next_obs: Tensor, params: TensorDict) -> Tensor: next_state, time = unpack_obs(next_obs) residual = pack_obs(next_state - params["state"], time) return self.dist.cdf(residual, params)
def rsample( self, params: TensorDict, sample_shape: list[int] = ()) -> SampleLogp: residual, log_prob = self.dist.rsample(params, sample_shape) delta, time = unpack_obs(residual) next_obs = pack_obs(params["state"] + delta, time) return next_obs, log_prob
def forward(self, obs: Tensor, action: Tensor) -> TensorDict: params = self.params(obs, action) state, _ = unpack_obs(obs) params["state"] = state return params
def forward(self, obs: Tensor, action: Tensor) -> TensorDict: state, time = unpack_obs(obs) state = (self.normalizer(nt.unnamed(state).reshape( -1, self.n_state)).reshape_as(state).refine_names(*state.names)) obs = pack_obs(state, time) return self._model(obs, action)
def deterministic(self, params: TensorDict) -> SampleLogp: residual, log_prob = self.dist.deterministic(params) delta, time = unpack_obs(residual) return pack_obs(params["state"] + delta, time), log_prob
def reproduce(self, next_obs, params: TensorDict) -> SampleLogp: next_state, time = unpack_obs(next_obs) residual = pack_obs(next_state - params["state"], time) residual_, log_prob_ = self.dist.reproduce(residual, params) delta_, time_ = unpack_obs(residual_) return pack_obs(params["state"] + delta_, time_), log_prob_
def new_obs(obs: Tensor) -> Tensor: state, time = unpack_obs(obs) state_ = torch.randn_like(state) time_ = time + 1 return pack_obs(state_, time_).requires_grad_()
def check_shapes(self, loc: Tensor, scale_tril: Tensor, time: Tensor, obs: Tensor): state, time_ = unpack_obs(obs) assert loc.shape == state.shape assert scale_tril.shape == state.shape + state.shape[-1:] assert time.shape == time_.shape
def icdf(self, prob, params: TensorDict) -> Tensor: residual = self.dist.icdf(prob, params) delta, time = unpack_obs(residual) return pack_obs(params["state"] + delta, time)
def mix_obs(obs: Tensor, last_obs: Tensor) -> Tensor: _, time = unpack_obs(obs) mix = nt.where(torch.rand_like(time.float()) < 0.5, obs, last_obs) mix.retain_grad() return mix