def calculate_quantile_loss(self, state_embeddings, tau_hats, current_sa_quantile_hats, actions, rewards, next_states, dones, weights): assert not tau_hats.requires_grad with torch.no_grad(): # NOTE: Current and target quantiles share the same proposed # fractions to reduce computations. (i.e. next_tau_hats = tau_hats) # Calculate Q values of next states. if self.double_q_learning: # Sample the noise of online network to decorrelate between # the action selection and the quantile calculation. self.online_net.sample_noise() next_q = self.online_net.calculate_q(states=next_states) else: next_state_embeddings =\ self.target_net.calculate_state_embeddings(next_states) next_q = self.target_net.calculate_q( state_embeddings=next_state_embeddings, fraction_net=self.online_net.fraction_net) # Calculate greedy actions. next_actions = torch.argmax(next_q, dim=1, keepdim=True) assert next_actions.shape == (self.batch_size, 1) # Calculate features of next states. if self.double_q_learning: next_state_embeddings =\ self.target_net.calculate_state_embeddings(next_states) # Calculate quantile values of next states and actions at tau_hats. next_sa_quantile_hats = evaluate_quantile_at_action( self.target_net.calculate_quantiles( taus=tau_hats, state_embeddings=next_state_embeddings), next_actions).transpose(1, 2) assert next_sa_quantile_hats.shape == (self.batch_size, 1, self.N) # Calculate target quantile values. target_sa_quantile_hats = rewards[..., None] + ( 1.0 - dones[..., None]) * self.gamma_n * next_sa_quantile_hats assert target_sa_quantile_hats.shape == (self.batch_size, 1, self.N) c_shape = current_sa_quantile_hats.shape # print("!!!", c_shape) # current_sa_quantile_hats = current_sa_quantile_hats[:, :, 0]#current_sa_quantile_hats.reshape(c_shape[0], c_shape[2], c_shape[1]) # print("calc_quan_loss", target_sa_quantile_hats.shape, current_sa_quantile_hats.shape) td_errors = target_sa_quantile_hats - current_sa_quantile_hats assert td_errors.shape == (self.batch_size, self.N, self.N) quantile_huber_loss = calculate_quantile_huber_loss( td_errors, tau_hats, weights, self.kappa) return quantile_huber_loss, next_q.detach().mean().item(), \ td_errors.detach().abs()
def calculate_loss(self, states, actions, rewards, next_states, dones, weights): # Calculate quantile values of current states and actions at taus. current_sa_quantiles = evaluate_quantile_at_action( self.online_net(states=states), actions) assert current_sa_quantiles.shape == (self.batch_size, self.N, 1) with torch.no_grad(): # Calculate Q values of next states. if self.double_q_learning: # Sample the noise of online network to decorrelate between # the action selection and the quantile calculation. self.online_net.sample_noise() next_q = self.online_net.calculate_q(states=next_states) else: next_q = self.target_net.calculate_q(states=next_states) # Calculate greedy actions. next_actions = torch.argmax(next_q, dim=1, keepdim=True) assert next_actions.shape == (self.batch_size, 1) # Calculate quantile values of next states and actions at tau_hats. next_sa_quantiles = evaluate_quantile_at_action( self.target_net(states=next_states), next_actions).transpose(1, 2) assert next_sa_quantiles.shape == (self.batch_size, 1, self.N) # Calculate target quantile values. target_sa_quantiles = rewards[..., None] + ( 1.0 - dones[..., None]) * self.gamma_n * next_sa_quantiles assert target_sa_quantiles.shape == (self.batch_size, 1, self.N) td_errors = target_sa_quantiles - current_sa_quantiles assert td_errors.shape == (self.batch_size, self.N, self.N) quantile_huber_loss = calculate_quantile_huber_loss( td_errors, self.tau_hats, weights, self.kappa) return quantile_huber_loss, next_q.detach().mean().item(), \ td_errors.detach().abs()
def calculate_fraction_loss(self, state_embeddings, sa_quantile_hats, taus, actions, weights): assert not state_embeddings.requires_grad assert not sa_quantile_hats.requires_grad batch_size = state_embeddings.shape[0] with torch.no_grad(): sa_quantiles = evaluate_quantile_at_action( self.online_net.calculate_quantiles( taus=taus[:, 1:-1], state_embeddings=state_embeddings), actions) assert sa_quantiles.shape == (batch_size, self.N - 1, 1) # NOTE: Proposition 1 in the paper requires F^{-1} is non-decreasing. # I relax this requirements and calculate gradients of taus even when # F^{-1} is not non-decreasing. values_1 = sa_quantiles - sa_quantile_hats[:, :-1] signs_1 = sa_quantiles > torch.cat([ sa_quantile_hats[:, :1], sa_quantiles[:, :-1]], dim=1) assert values_1.shape == signs_1.shape values_2 = sa_quantiles - sa_quantile_hats[:, 1:] signs_2 = sa_quantiles < torch.cat([ sa_quantiles[:, 1:], sa_quantile_hats[:, -1:]], dim=1) assert values_2.shape == signs_2.shape gradient_of_taus = ( torch.where(signs_1, values_1, -values_1) + torch.where(signs_2, values_2, -values_2) ).view(batch_size, self.N - 1) assert not gradient_of_taus.requires_grad assert gradient_of_taus.shape == taus[:, 1:-1].shape # Gradients of the network parameters and corresponding loss # are calculated using chain rule. if weights is not None: fraction_loss = (( (gradient_of_taus * taus[:, 1:-1]).sum(dim=1, keepdim=True) ) * weights).mean() else: fraction_loss = \ (gradient_of_taus * taus[:, 1:-1]).sum(dim=1).mean() return fraction_loss
def learn(self): self.learning_steps += 1 self.online_net.sample_noise() self.target_net.sample_noise() if self.use_per: (states, actions, rewards, next_states, dones), weights =\ self.memory.sample(self.batch_size) else: states, actions, rewards, next_states, dones =\ self.memory.sample(self.batch_size) weights = None # Calculate embeddings of current states. state_embeddings = self.online_net.calculate_state_embeddings(states) # Calculate fractions of current states and entropies. taus, tau_hats, entropies =\ self.online_net.calculate_fractions( state_embeddings=state_embeddings) # Calculate quantile values of current states and actions at tau_hats. current_sa_quantile_hats = evaluate_quantile_at_action( self.online_net.calculate_quantiles( tau_hats, state_embeddings=state_embeddings), actions) assert current_sa_quantile_hats.shape == (self.batch_size, self.N, 1) # NOTE: Detach state_embeddings not to update convolution layers. Also, # detach current_sa_quantile_hats because I calculate gradients of taus # explicitly, not by backpropagation. fraction_loss = self.calculate_fraction_loss( state_embeddings.detach(), current_sa_quantile_hats.detach(), taus, actions, weights) quantile_loss, mean_q, errors = self.calculate_quantile_loss( state_embeddings, tau_hats, current_sa_quantile_hats, actions, rewards, next_states, dones, weights) entropy_loss = -self.ent_coef * entropies.mean() update_params(self.fraction_optim, fraction_loss + entropy_loss, networks=[self.online_net.fraction_net], retain_graph=True, grad_cliping=self.grad_cliping) update_params(self.quantile_optim, quantile_loss + entropy_loss, networks=[ self.online_net.dqn_net, self.online_net.cosine_net, self.online_net.quantile_net ], retain_graph=False, grad_cliping=self.grad_cliping) if self.use_per: self.memory.update_priority(errors) if self.learning_steps % self.log_interval == 0: self.writer.add_scalar('loss/fraction_loss', fraction_loss.detach().item(), 4 * self.steps) self.writer.add_scalar('loss/quantile_loss', quantile_loss.detach().item(), 4 * self.steps) if self.ent_coef > 0.0: self.writer.add_scalar('loss/entropy_loss', entropy_loss.detach().item(), 4 * self.steps) self.writer.add_scalar('stats/mean_Q', mean_q, 4 * self.steps) self.writer.add_scalar('stats/mean_entropy_of_value_distribution', entropies.mean().detach().item(), 4 * self.steps)