def update_q_functions(self, batch, writer, imp_ws1=None, imp_ws2=None, fast_batch=None): states, actions, rewards, next_states, dones, *_ = batch # Calculate current and target Q values. curr_qs1, curr_qs2 = self.calc_current_qs(states) target_qs = self.calc_target_qs(rewards, next_states, dones) # Update Q functions. q_loss, mean_q1, mean_q2, unweighted_q_loss = \ self.calc_q_loss(curr_qs1, curr_qs2, target_qs, imp_ws1, imp_ws2) update_params(self._q_optim, q_loss) #TODO: compute Q loss for online batch if self._learning_steps % self._log_interval == 0: writer.add_scalar('loss/Q', unweighted_q_loss.detach().item(), self._learning_steps) writer.add_scalar('stats/mean_Q1', mean_q1, self._learning_steps) writer.add_scalar('stats/mean_Q2', mean_q2, self._learning_steps) # Return there values for DisCor algorithm. return curr_qs1.detach(), curr_qs2.detach(), target_qs
def calc_update_d_pi_iw(self, slow_obs, slow_act, fast_obs, fast_act, target_obs=None, target_act=None): slow_samples = torch.cat((slow_obs, slow_act), dim=1) fast_samples = torch.cat((fast_obs, fast_act), dim=1) zeros = torch.zeros(slow_samples.size(0)).to(device=self._device) ones = torch.ones(fast_samples.size(0)).to(device=self._device) slow_preds = self._prob_classifier(slow_samples) fast_preds = self._prob_classifier(fast_samples) loss = F.binary_cross_entropy(F.sigmoid(slow_preds), zeros) + \ F.binary_cross_entropy(F.sigmoid(fast_preds), ones) update_params(self._prob_optim, loss) # In case we want to compute ratio on data different from what we train the network if target_obs is None: target_obs = slow_obs if target_act is None: target_act = slow_act target_samples = torch.cat((target_obs, target_act), dim=1) slow_preds = self._prob_classifier(target_samples) importance_weights = F.sigmoid(slow_preds / self.prob_temperature).detach() importance_weights = importance_weights / torch.sum(importance_weights) return importance_weights, loss
def update_q_functions_and_error_models(self, batch, writer): states, actions, rewards, next_states, dones = batch # Calculate importance weights. imp_ws1, imp_ws2 = self.calc_importance_weights(next_states, dones) # Update Q functions. curr_qs1, curr_qs2, target_qs = \ self.update_q_functions(batch, writer, imp_ws1, imp_ws2) # Calculate current and target errors, as well as importance weights. curr_errs1, curr_errs2 = self.calc_current_errors(states, actions) target_errs1, target_errs2 = self.calc_target_errors( next_states, dones, curr_qs1, curr_qs2, target_qs) # Update error models. err_loss = self.calc_error_loss(curr_errs1, curr_errs2, target_errs1, target_errs2) update_params(self._error_optim, err_loss) if self._learning_steps % self._log_interval == 0: writer.add_scalar('loss/error', err_loss.detach().item(), self._learning_steps) writer.add_scalar('stats/tau1', self._tau1.item(), self._learning_steps) writer.add_scalar('stats/tau2', self._tau2.item(), self._learning_steps)
def update_q_functions(self, batch, writer, imp_ws1=None, imp_ws2=None, fast_batch=None, err_preds=None): states, actions, rewards, next_states, dones = \ batch["states"], batch["actions"], batch["rewards"], batch["next_states"], batch["dones"] # Calculate current and target Q values. curr_qs1, curr_qs2 = self.calc_current_qs(states, actions) target_qs = self.calc_target_qs(rewards, next_states, dones) # Update Q functions. q_loss, mean_q1, mean_q2, unweighted_q_loss = \ self.calc_q_loss(curr_qs1, curr_qs2, target_qs, imp_ws1, imp_ws2) update_params(self._q_optim, q_loss) #TODO: compute Q loss for online batch if self._learning_steps % self._log_interval == 0: writer.add_scalar('loss/Q', unweighted_q_loss.detach().item(), self._learning_steps) writer.add_scalar('stats/mean_Q1', mean_q1, self._learning_steps) writer.add_scalar('stats/mean_Q2', mean_q2, self._learning_steps) if self._eval_tper and self._learning_steps % self._eval_tper_interval == 0: steps = batch["steps"] sim_states = batch["sim_states"] done_cnts = batch["done_cnts"] self.eval_Q(states[:128], actions[:128], steps[:128], sim_states[:128], curr_qs1[:128], done_cnts[:128], err_preds[:128] if err_preds is not None else None) # Return their values for DisCor algorithm. return curr_qs1.detach(), curr_qs2.detach(), target_qs
def update_policy_and_entropy(self, batch, writer): states = batch["states"] # Update policy. policy_loss, entropies = self.calc_policy_loss(states) update_params(self._policy_optim, policy_loss) # Update the entropy coefficient. entropy_loss = self.calc_entropy_loss(entropies) update_params(self._alpha_optim, entropy_loss) self._alpha = self._log_alpha.detach().exp() if self._learning_steps % self._log_interval == 0: writer.add_scalar('loss/policy', policy_loss.detach().item(), self._learning_steps) writer.add_scalar('loss/entropy', entropy_loss.detach().item(), self._learning_steps) writer.add_scalar('stats/alpha', self._alpha.item(), self._learning_steps) writer.add_scalar('stats/entropy', entropies.detach().mean().item(), self._learning_steps)
def update_q_functions_and_error_models(self, batch, writer): uniform_batch = batch["uniform"] if self.lfiw: fast_batch = batch['fast'] fast_states, fast_actions = fast_batch['states'], fast_batch[ 'actions'] else: fast_batch = None # train_batch = batch["prior"] if self.tper else batch["uniform"] train_batch = batch["uniform"] # transition to update Q net states, actions, next_states, dones = \ train_batch["states"], train_batch["actions"], train_batch["next_states"], train_batch["dones"] # s,a to update the weight of lfiw network slow_states, slow_actions = uniform_batch["states"], uniform_batch[ "actions"] # Calculate importance weights. batch_size = states.shape[0] weights1 = torch.ones((batch_size, 1)).to(device=self._device) weights2 = torch.ones((batch_size, 1)).to(device=self._device) if self.discor: discor_weights = self.calc_importance_weights(next_states, dones) # print(weights[0].shape, discor_weights[0].shape) weights1 *= discor_weights[0] weights2 *= discor_weights[1] # Calculate and update prob_classifier if self.lfiw: lfiw_weights, prob_loss = self.calc_update_d_pi_iw( slow_states, slow_actions, fast_states, fast_actions, states, actions) weights1 *= lfiw_weights weights2 *= lfiw_weights # Calculate weights for temporal priority if self.tper: steps = train_batch["steps"] done_cnts = train_batch["done_cnts"] tper_weights = self.calc_tper_weights(steps, done_cnts) weights1 *= tper_weights weights2 *= tper_weights # Update Q functions. curr_errs1, curr_errs2 = None, None if self.discor: curr_errs1, curr_errs2 = self.calc_current_errors(states, actions) # pass in curr_errs1 for evaluating discor curr_qs1, curr_qs2, target_qs = \ self.update_q_functions(train_batch, writer, weights1, weights2, fast_batch, curr_errs1) # Calculate current and target errors. if self.discor: target_errs1, target_errs2 = self.calc_target_errors( next_states, dones, curr_qs1, curr_qs2, target_qs) # Update error models. err_loss = self.calc_error_loss(curr_errs1, curr_errs2, target_errs1, target_errs2) update_params(self._error_optim, err_loss) if self._learning_steps % self._log_interval == 0: if self.discor: writer.add_scalar('loss/error', err_loss.detach().item(), self._learning_steps) writer.add_scalar('stats/tau1', self._tau1.item(), self._learning_steps) writer.add_scalar('stats/tau2', self._tau2.item(), self._learning_steps) if self.lfiw: writer.add_scalar('loss/prob_loss', prob_loss.detach().item(), self._learning_steps)