def _train_policy(self, obs, actions, rewards, advantages): r"""Train the policy. Args: obs (torch.Tensor): Observation from the environment with shape :math:`(N, O*)`. actions (torch.Tensor): Actions fed to the environment with shape :math:`(N, A*)`. rewards (torch.Tensor): Acquired rewards with shape :math:`(N, )`. advantages (torch.Tensor): Advantage value at each step with shape :math:`(N, )`. Returns: torch.Tensor: Calculated mean scalar value of policy loss (float). """ # pylint: disable=protected-access zero_optim_grads(self._policy_optimizer._optimizer) loss = self._compute_loss_with_adv(obs, actions, rewards, advantages) loss.backward() self._policy_optimizer.step( f_loss=lambda: self._compute_loss_with_adv(obs, actions, rewards, advantages), f_constraint=lambda: self._compute_kl_constraint(obs)) return loss
def _train_value_function(self, paths): """Train the value function. Args: paths (list[dict]): A list of collected paths. Returns: torch.Tensor: Calculated mean scalar value of value function loss (float). """ # MAML resets a value function to its initial state before training. self._value_function.load_state_dict(self._initial_vf_state) obs = np.concatenate([path['observations'] for path in paths], axis=0) returns = np.concatenate([path['returns'] for path in paths]) obs = np_to_torch(obs) returns = np_to_torch(returns.astype(np.float32)) vf_loss = self._value_function.compute_loss(obs, returns) # pylint: disable=protected-access zero_optim_grads(self._inner_algo._vf_optimizer._optimizer) vf_loss.backward() # pylint: disable=protected-access self._inner_algo._vf_optimizer.step() return vf_loss
def _optimize_qf(self, timesteps): """Perform algorithm optimizing. Args: timesteps (TimeStepBatch): Processed batch data. Returns: qval_loss: Loss of Q-value predicted by the Q-network. ys: y_s. qval: Q-value predicted by the Q-network. """ observations = np_to_torch(timesteps.observations) rewards = np_to_torch(timesteps.rewards).reshape(-1, 1) rewards *= self._reward_scale actions = np_to_torch(timesteps.actions) next_observations = np_to_torch(timesteps.next_observations) terminals = np_to_torch(timesteps.terminals).reshape(-1, 1) next_inputs = next_observations inputs = observations with torch.no_grad(): if self._double_q: # Use online qf to get optimal actions selected_actions = torch.argmax(self._qf(next_inputs), axis=1) # use target qf to get Q values for those actions selected_actions = selected_actions.long().unsqueeze(1) best_qvals = torch.gather(self._target_qf(next_inputs), dim=1, index=selected_actions) else: target_qvals = self._target_qf(next_inputs) best_qvals, _ = torch.max(target_qvals, 1) best_qvals = best_qvals.unsqueeze(1) rewards_clipped = rewards if self._clip_reward is not None: rewards_clipped = torch.clamp(rewards, -1 * self._clip_reward, self._clip_reward) y_target = (rewards_clipped + (1.0 - terminals) * self._discount * best_qvals) y_target = y_target.squeeze(1) # optimize qf qvals = self._qf(inputs) selected_qs = torch.sum(qvals * actions, axis=1) qval_loss = F.smooth_l1_loss(selected_qs, y_target) zero_optim_grads(self._qf_optimizer) qval_loss.backward() # optionally clip the gradients if self._clip_grad is not None: torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self._clip_grad) self._qf_optimizer.step() return (qval_loss.detach(), y_target, selected_qs.detach())
def optimize_policy(self, samples_data): """Perform algorithm optimizing. Args: samples_data (dict): Processed batch data. Returns: action_loss: Loss of action predicted by the policy network. qval_loss: Loss of Q-value predicted by the Q-network. ys: y_s. qval: Q-value predicted by the Q-network. """ transitions = as_torch_dict(samples_data) observations = transitions['observations'] rewards = transitions['rewards'].reshape(-1, 1) actions = transitions['actions'] next_observations = transitions['next_observations'] terminals = transitions['terminals'].reshape(-1, 1) next_inputs = next_observations inputs = observations with torch.no_grad(): next_actions = self._target_policy(next_inputs) target_qvals = self._target_qf(next_inputs, next_actions) clip_range = (-self._clip_return, 0. if self._clip_pos_returns else self._clip_return) y_target = rewards + (1.0 - terminals) * self._discount * target_qvals y_target = torch.clamp(y_target, clip_range[0], clip_range[1]) # optimize critic qval = self._qf(inputs, actions) qf_loss = torch.nn.MSELoss() qval_loss = qf_loss(qval, y_target) zero_optim_grads(self._qf_optimizer) qval_loss.backward() self._qf_optimizer.step() # optimize actor actions = self.policy(inputs) action_loss = -1 * self._qf(inputs, actions).mean() zero_optim_grads(self._policy_optimizer) action_loss.backward() self._policy_optimizer.step() # update target networks self.update_target() return (qval_loss.detach(), y_target, qval.detach(), action_loss.detach())
def _train_once(self, trainer, all_samples, all_params): """Train the algorithm once. Args: trainer (Trainer): The experiment runner. all_samples (list[list[_MAMLEpisodeBatch]]): A two dimensional list of _MAMLEpisodeBatch of size [meta_batch_size * (num_grad_updates + 1)] all_params (list[dict]): A list of named parameter dictionaries. Each dictionary contains key value pair of names (str) and parameters (torch.Tensor). Returns: float: Average return. """ itr = trainer.step_itr old_theta = dict(self._policy.named_parameters()) kl_before = self._compute_kl_constraint(all_samples, all_params, set_grad=False) meta_objective = self._compute_meta_loss(all_samples, all_params) zero_optim_grads(self._meta_optimizer) meta_objective.backward() self._meta_optimize(all_samples, all_params) # Log loss_after = self._compute_meta_loss(all_samples, all_params, set_grad=False) kl_after = self._compute_kl_constraint(all_samples, all_params, set_grad=False) with torch.no_grad(): policy_entropy = self._compute_policy_entropy( [task_samples[0] for task_samples in all_samples]) average_return = self._log_performance( itr, all_samples, meta_objective.item(), loss_after.item(), kl_before.item(), kl_after.item(), policy_entropy.mean().item()) if self._meta_evaluator and itr % self._evaluate_every_n_epochs == 0: self._meta_evaluator.evaluate(self) update_module_params(self._old_policy, old_theta) return average_return
def optimize_policy(self, samples_data): """Optimize the policy q_functions, and temperature coefficient. Args: samples_data (dict): Transitions(S,A,R,S') that are sampled from the replay buffer. It should have the keys 'observation', 'action', 'reward', 'terminal', and 'next_observations'. Note: samples_data's entries should be torch.Tensor's with the following shapes: observation: :math:`(N, O^*)` action: :math:`(N, A^*)` reward: :math:`(N, 1)` terminal: :math:`(N, 1)` next_observation: :math:`(N, O^*)` Returns: torch.Tensor: loss from actor/policy network after optimization. torch.Tensor: loss from 1st q-function after optimization. torch.Tensor: loss from 2nd q-function after optimization. """ obs = samples_data['observation'] qf1_loss, qf2_loss = self._critic_objective(samples_data) zero_optim_grads(self._qf1_optimizer) qf1_loss.backward() self._qf1_optimizer.step() zero_optim_grads(self._qf2_optimizer) qf2_loss.backward() self._qf2_optimizer.step() action_dists = self.policy(obs)[0] new_actions_pre_tanh, new_actions = ( action_dists.rsample_with_pre_tanh_value()) log_pi_new_actions = action_dists.log_prob( value=new_actions, pre_tanh_value=new_actions_pre_tanh) policy_loss = self._actor_objective(samples_data, new_actions, log_pi_new_actions) policy_loss += self._caps_regularization_objective( action_dists, samples_data) zero_optim_grads(self._policy_optimizer) policy_loss.backward() self._policy_optimizer.step() if self._use_automatic_entropy_tuning: alpha_loss = self._temperature_objective(log_pi_new_actions, samples_data) zero_optim_grads(self._alpha_optimizer) alpha_loss.backward() self._alpha_optimizer.step() return policy_loss, qf1_loss, qf2_loss
def _train_value_function(self, obs, returns): r"""Train the value function. Args: obs (torch.Tensor): Observation from the environment with shape :math:`(N, O*)`. returns (torch.Tensor): Acquired returns with shape :math:`(N, )`. Returns: torch.Tensor: Calculated mean scalar value of value function loss (float). """ # pylint: disable=protected-access zero_optim_grads(self._vf_optimizer._optimizer) loss = self._value_function.compute_loss(obs, returns) loss.backward() self._vf_optimizer.step() return loss
def _optimize_policy(self, samples_data, grad_step_timer): """Perform algorithm optimization. Args: samples_data (dict): Processed batch data. grad_step_timer (int): Iteration number of the gradient time taken in the env. Returns: float: Loss predicted by the q networks (critic networks). float: Q value (min) predicted by one of the target q networks. float: Q value (min) predicted by one of the current q networks. float: Loss predicted by the policy (action network). """ rewards = samples_data['rewards'].to(global_device()).reshape(-1, 1) terminals = samples_data['terminals'].to(global_device()).reshape( -1, 1) actions = samples_data['actions'].to(global_device()) observations = samples_data['observations'].to(global_device()) next_observations = samples_data['next_observations'].to( global_device()) next_inputs = next_observations inputs = observations with torch.no_grad(): # Select action according to policy and add clipped noise noise = (torch.randn_like(actions) * self._policy_noise).clamp( -self._policy_noise_clip, self._policy_noise_clip) next_actions = (self._target_policy(next_inputs) + noise).clamp( -self._max_action, self._max_action) # Compute the target Q value target_Q1 = self._target_qf_1(next_inputs, next_actions) target_Q2 = self._target_qf_2(next_inputs, next_actions) target_q = torch.min(target_Q1, target_Q2) target_Q = rewards * self._reward_scaling + ( 1. - terminals) * self._discount * target_q # Get current Q values current_Q1 = self._qf_1(inputs, actions) current_Q2 = self._qf_2(inputs, actions) current_Q = torch.min(current_Q1, current_Q2) # Compute critic loss critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss( current_Q2, target_Q) # Optimize critic zero_optim_grads(self._qf_optimizer_1) zero_optim_grads(self._qf_optimizer_2) critic_loss.backward() self._qf_optimizer_1.step() self._qf_optimizer_2.step() # Deplay policy updates if grad_step_timer % self._update_actor_interval == 0: # Compute actor loss actions = self.policy(inputs) self._actor_loss = -self._qf_1(inputs, actions).mean() # Optimize actor zero_optim_grads(self._policy_optimizer) self._actor_loss.backward() self._policy_optimizer.step() # update target networks self._update_network_parameters() return (critic_loss.detach(), target_Q, current_Q.detach(), self._actor_loss.detach())
def _optimize_policy(self, indices): """Perform algorithm optimizing. Args: indices (list): Tasks used for training. """ num_tasks = len(indices) context = self._sample_context(indices) # clear context and reset belief of policy self._policy.reset_belief(num_tasks=num_tasks) # data shape is (task, batch, feat) obs, actions, rewards, next_obs, terms = self._sample_data(indices) policy_outputs, task_z = self._policy(obs, context) new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:4] # flatten out the task dimension t, b, _ = obs.size() obs = obs.view(t * b, -1) actions = actions.view(t * b, -1) next_obs = next_obs.view(t * b, -1) # optimize qf and encoder networks q1_pred = self._qf1(torch.cat([obs, actions], dim=1), task_z) q2_pred = self._qf2(torch.cat([obs, actions], dim=1), task_z) v_pred = self._vf(obs, task_z.detach()) with torch.no_grad(): target_v_values = self.target_vf(next_obs, task_z) # KL constraint on z if probabilistic zero_optim_grads(self.context_optimizer) if self._use_information_bottleneck: kl_div = self._policy.compute_kl_div() kl_loss = self._kl_lambda * kl_div kl_loss.backward(retain_graph=True) zero_optim_grads(self.qf1_optimizer) zero_optim_grads(self.qf2_optimizer) rewards_flat = rewards.view(self._batch_size * num_tasks, -1) rewards_flat = rewards_flat * self._reward_scale terms_flat = terms.view(self._batch_size * num_tasks, -1) q_target = rewards_flat + ( 1. - terms_flat) * self._discount * target_v_values qf_loss = torch.mean((q1_pred - q_target)**2) + torch.mean( (q2_pred - q_target)**2) qf_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.step() self.context_optimizer.step() # compute min Q on the new actions q1 = self._qf1(torch.cat([obs, new_actions], dim=1), task_z.detach()) q2 = self._qf2(torch.cat([obs, new_actions], dim=1), task_z.detach()) min_q = torch.min(q1, q2) # optimize vf v_target = min_q - log_pi vf_loss = self.vf_criterion(v_pred, v_target.detach()) zero_optim_grads(self.vf_optimizer) vf_loss.backward() self.vf_optimizer.step() self._update_target_network() # optimize policy log_policy_target = min_q policy_loss = (log_pi - log_policy_target).mean() mean_reg_loss = self._policy_mean_reg_coeff * (policy_mean**2).mean() std_reg_loss = self._policy_std_reg_coeff * (policy_log_std**2).mean() pre_tanh_value = policy_outputs[-1] pre_activation_reg_loss = self._policy_pre_activation_coeff * ( (pre_tanh_value**2).sum(dim=1).mean()) policy_reg_loss = (mean_reg_loss + std_reg_loss + pre_activation_reg_loss) policy_loss = policy_loss + policy_reg_loss zero_optim_grads(self._policy_optimizer) policy_loss.backward() self._policy_optimizer.step()