def _value(self, t: TensorType) -> TensorType: """Returns the result of: initial_p * decay_rate ** (`t`/t_max). """ if self.framework == "torch" and torch and isinstance(t, torch.Tensor): t = t.float() return self.initial_p * \ self.decay_rate ** (t / self.schedule_timesteps)
def _value(self, t: TensorType) -> TensorType: """Returns the result of: final_p + (initial_p - final_p) * (1 - `t`/t_max) ** power """ if self.framework == "torch" and torch and isinstance(t, torch.Tensor): t = t.float() t = min(t, self.schedule_timesteps) return self.final_p + (self.initial_p - self.final_p) * ( 1.0 - (t / self.schedule_timesteps))**self.power
def representation_function(self, obs: TensorType) -> TensorType: obs = obs.float().permute(0, 3, 1, 2) output = self.representation(obs) self.hidden = output if not self.cache: self.cache = [self.hidden] * self.order else: self.cache.append(self.hidden) self.cache.pop(0) return output
def __init__( self, q_t_selected: TensorType, q_logits_t_selected: TensorType, q_tp1_best: TensorType, q_probs_tp1_best: TensorType, importance_weights: TensorType, rewards: TensorType, done_mask: TensorType, gamma=0.99, n_step=1, num_atoms=1, v_min=-10.0, v_max=10.0, ): if num_atoms > 1: # Distributional Q-learning which corresponds to an entropy loss z = torch.range(0.0, num_atoms - 1, dtype=torch.float32).to(rewards.device) z = v_min + z * (v_max - v_min) / float(num_atoms - 1) # (batch_size, 1) * (1, num_atoms) = (batch_size, num_atoms) r_tau = torch.unsqueeze( rewards, -1) + gamma**n_step * torch.unsqueeze( 1.0 - done_mask, -1) * torch.unsqueeze(z, 0) r_tau = torch.clamp(r_tau, v_min, v_max) b = (r_tau - v_min) / ((v_max - v_min) / float(num_atoms - 1)) lb = torch.floor(b) ub = torch.ceil(b) # Indispensable judgement which is missed in most implementations # when b happens to be an integer, lb == ub, so pr_j(s', a*) will # be discarded because (ub-b) == (b-lb) == 0. floor_equal_ceil = ((ub - lb) < 0.5).float() # (batch_size, num_atoms, num_atoms) l_project = F.one_hot(lb.long(), num_atoms) # (batch_size, num_atoms, num_atoms) u_project = F.one_hot(ub.long(), num_atoms) ml_delta = q_probs_tp1_best * (ub - b + floor_equal_ceil) mu_delta = q_probs_tp1_best * (b - lb) ml_delta = torch.sum(l_project * torch.unsqueeze(ml_delta, -1), dim=1) mu_delta = torch.sum(u_project * torch.unsqueeze(mu_delta, -1), dim=1) m = ml_delta + mu_delta # Rainbow paper claims that using this cross entropy loss for # priority is robust and insensitive to `prioritized_replay_alpha` self.td_error = softmax_cross_entropy_with_logits( logits=q_logits_t_selected, labels=m.detach()) self.loss = torch.mean(self.td_error * importance_weights) self.stats = { # TODO: better Q stats for dist dqn } else: q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best # compute RHS of bellman equation q_t_selected_target = rewards + gamma**n_step * q_tp1_best_masked # compute the error (potentially clipped) self.td_error = q_t_selected - q_t_selected_target.detach() self.loss = torch.mean(importance_weights.float() * huber_loss(self.td_error)) self.stats = { "mean_q": torch.mean(q_t_selected), "min_q": torch.min(q_t_selected), "max_q": torch.max(q_t_selected), }
def representation_function(self, obs: TensorType) -> TensorType: obs = obs.float().permute(0, 3, 1, 2) output = self.representation(obs) self.hidden = output return output