def _gen_sample(self, loc: Tensor, scale_tril: Tensor, time: IntTensor) -> Tensor: next_obs = self._transition(loc, scale_tril, time) if not self.horizon: return next_obs # Filter results # We're in an absorving state if the current timestep is the horizon return nt.where(time.eq(self.horizon), pack_obs(loc, time), next_obs)
def _absorving_logp( cur_state: Tensor, cur_time: IntTensor, state: Tensor, time: IntTensor ) -> Tensor: # We assume time is a named scalar tensor cur_obs = pack_obs(cur_state, nt.scalar_to_vector(cur_time)) obs = pack_obs(state, nt.scalar_to_vector(time)) return nt.where( # Point mass only at the same state nt.reduce_all(nt.isclose(cur_obs, obs), dim="R"), torch.zeros_like(time, dtype=torch.float), torch.full_like(time, fill_value=float("nan"), dtype=torch.float), )
def forward(self, obs: Tensor, act: Tensor) -> Tensor: obs, act = (nt.vector(x) for x in (obs, act)) state, time = unpack_obs(obs) tau = nt.vector_to_matrix(torch.cat([state, act], dim="R")) time = nt.vector_to_scalar(time) C, c = self._index_parameters(time) c = nt.vector_to_matrix(c) cost = nt.transpose(tau) @ C @ tau / 2 + nt.transpose(c) @ tau reward = nt.matrix_to_scalar(cost.neg()) return nt.where(time.eq(self.horizon), torch.zeros_like(reward), reward)
def forward(self, obs: Tensor, frozen: bool = False) -> Tensor: """Compute the action vector for the observed state.""" obs = nt.vector(obs) state, time = unpack_obs(obs) # noinspection PyTypeChecker K, k = self._gains_at(nt.vector_to_scalar(time)) if frozen: K, k = K.detach(), k.detach() ctrl = K @ nt.vector_to_matrix(state) + nt.vector_to_matrix(k) ctrl = nt.matrix_to_vector(ctrl) # Return zeroed actions if in terminal state terminal = time.eq(self.horizon) return nt.where(terminal, torch.zeros_like(ctrl), ctrl)
def forward(self, obs: Tensor, action: Tensor) -> Tensor: state, time = unpack_obs(obs) time = nt.vector_to_scalar(time) # noinspection PyTypeChecker quad, linear, const = index_quadratic_parameters(self.quad, self.linear, self.const, time, max_idx=self.horizon - 1) vec = nt.vector_to_matrix(torch.cat([state, action], dim="R")) cost = nt.matrix_to_scalar( nt.transpose(vec) @ quad @ vec / 2 + nt.transpose(nt.vector_to_matrix(linear)) @ vec + nt.scalar_to_matrix(const)) val = cost.neg() return nt.where(time.eq(self.horizon), torch.zeros_like(val), val)
def forward(self, obs: Tensor, action: Tensor): # pylint:disable=missing-function-docstring obs, action = nt.vector(obs), nt.vector(action) state, time = unpack_obs(obs) # Get parameters for each timestep index = nt.vector_to_scalar(time) F, f, scale_tril = self._transition_factors(index) # Compute the loc for normal transitions tau = nt.vector_to_matrix(torch.cat([state, action], dim="R")) trans_loc = nt.matrix_to_vector(F @ tau + nt.vector_to_matrix(f)) # Treat absorving states if necessary terminal = time.eq(self.horizon) loc = nt.where(terminal, state, trans_loc) return {"loc": loc, "scale_tril": scale_tril, "time": time}
def _trans_logp( loc: Tensor, scale_tril: Tensor, cur_time: IntTensor, state: Tensor, time: IntTensor, ) -> Tensor: loc, scale_tril = nt.unnamed(loc, scale_tril) dist = torch.distributions.MultivariateNormal(loc=loc, scale_tril=scale_tril) trans_logp: Tensor = dist.log_prob(nt.unnamed(state)) trans_logp = nt.where( # Logp only defined at next timestep time.eq(cur_time + 1), trans_logp, torch.full(time.shape, fill_value=float("nan")), ) # We assume time is a named scalar tensor return trans_logp.refine_names(*time.names)
def _logp( self, loc: Tensor, scale_tril: Tensor, time: Tensor, value: Tensor ) -> Tensor: # Align input tensors state, time_ = unpack_obs(value) loc, state = torch.broadcast_tensors(loc, state) time, time_ = torch.broadcast_tensors(time, time_) # Consider normal state transition time, time_ = nt.vector_to_scalar(time, time_) trans_logp = self._trans_logp(loc, scale_tril, time, state, time_) if not self.horizon: return trans_logp # If horizon is set, treat absorving state transitions absorving_logp = self._absorving_logp(loc, time, state, time_) # Filter results # We're in an absorving state if the current timestep is the horizon return nt.where(time.eq(self.horizon), absorving_logp, trans_logp)
def mix_obs(obs: Tensor, last_obs: Tensor) -> Tensor: _, time = unpack_obs(obs) mix = nt.where(torch.rand_like(time.float()) < 0.5, obs, last_obs) mix.retain_grad() return mix