def forward(self, *inputs: Tensorable): device = next(self.parameters()) return super().forward( torch.cat([ check_tensor(inputs[0], device), check_tensor(inputs[1], device) ], dim=-1))
def critic(self, obs: Arrayable, action: Tensorable, target: bool = False) -> Tensor: q_function = self._q_function if not target else self._target_q_function if len(q_function.input_shape()) == 2: q = q_function(obs, action).squeeze() else: q_all = q_function(obs) action = check_tensor(action, self._device).long() q = q_all.gather(1, action.view(-1, 1)).squeeze() return q
def critic(self, obs: Arrayable, action: Tensorable, target: bool = False) -> Tensor: func = self._adv_function if not target else self._target_adv_function if len(func.input_shape()) == 2: adv = func(obs, action).squeeze() else: adv_all = func(obs) action = check_tensor(action, self._device).long() adv = adv_all.gather(1, action.view(-1, 1)).squeeze() return adv
def __init__(self, obs: Tensorable, actions: Tensorable, rewards: Tensorable, done: Tensorable, time_limit: Tensorable) -> None: self.obs = check_tensor(obs) self.actions = check_tensor(actions) self.rewards = check_tensor(rewards) self.done = check_tensor(done) self.time_limit = check_tensor(time_limit) self.batch_size = self.obs.shape[0] self.length = self.obs.shape[1] assert self.actions.shape[0] == self.batch_size \ and self.rewards.shape[0] == self.batch_size \ and self.done.shape[0] == self.batch_size \ and self.time_limit.shape[0] == self.batch_size assert self.actions.shape[1] == self.length \ and self.rewards.shape[1] == self.length \ and self.done.shape[1] == self.length \ and self.time_limit.shape[1] == self.length assert len(self.done.shape) == 2 \ and len(self.rewards.shape) == 2 \ and len(self.time_limit.shape) == 2
def forward(self, *inputs: Tensorable) -> torch.Tensor: device = self._mean.device # type: ignore t_input = check_tensor(inputs[0], device) batch_size = t_input.size(0) mean_shape = tuple(self._mean.size()) std = torch.sqrt( torch.clamp(self._squared_mean - self._mean**2, min=1e-2)) # type: ignore output = (t_input - self._mean) / std # type: ignore with torch.no_grad(): self._mean = ( self._mean * self._count + batch_size * t_input.view(-1, *mean_shape).mean(dim=0)) \ / (self._count + batch_size) # type: ignore self._squared_mean = ( self._squared_mean * self._count + batch_size * (t_input.view(-1, *mean_shape) ** 2).mean(dim=0)) \ / (self._count + batch_size) # type: ignore self._count += batch_size return output
def _init_noise(self, template: Tensorable): action_shape = check_tensor(template).size() self.noise = self._sigma / np.sqrt(2 * self._theta) * \ torch.randn(action_shape, requires_grad=False).to(self._device)
def forward(self, *inputs: Tensorable): device = next(self.parameters()) return self._core(check_tensor(inputs[0], device))
def forward(self, *inputs: Tensorable): device = next(self.parameters()) tens_inputs = [check_tensor(inp, device) for inp in inputs] tens_inputs = [self._bn(tens_inputs[0])] + tens_inputs[1:] return self._model(*tens_inputs)