def optimize_policy(self, itr, samples_data): """Perform algorithm optimizing. Args: itr (int): Iteration number. samples_data (dict): Processed batch data. Returns: action_loss: Loss of action predicted by the policy network. qval_loss: Loss of Q-value predicted by the Q-network. ys: y_s. qval: Q-value predicted by the Q-network. """ transitions = tu.dict_np_to_torch(samples_data) observations = transitions['observation'] rewards = transitions['reward'] actions = transitions['action'] next_observations = transitions['next_observation'] terminals = transitions['terminal'] rewards = rewards.reshape(-1, 1) terminals = terminals.reshape(-1, 1) next_inputs = next_observations inputs = observations with torch.no_grad(): next_actions = self._target_policy(next_inputs) target_qvals = self._target_qf(next_inputs, next_actions) clip_range = (-self._clip_return, 0. if self._clip_pos_returns else self._clip_return) y_target = rewards + (1.0 - terminals) * self.discount * target_qvals y_target = torch.clamp(y_target, clip_range[0], clip_range[1]) # optimize critic qval = self.qf(inputs, actions) qf_loss = torch.nn.MSELoss() qval_loss = qf_loss(qval, y_target) self._qf_optimizer.zero_grad() qval_loss.backward() self._qf_optimizer.step() # optimize actor actions = self.policy(inputs) action_loss = -1 * self.qf(inputs, actions).mean() self._policy_optimizer.zero_grad() action_loss.backward() self._policy_optimizer.step() # update target networks self.update_target() return (qval_loss.detach(), y_target, qval.detach(), action_loss.detach())
def _learn_once(self, itr=None, paths=None): del itr del paths if self._buffer_prefilled: samples = self.replay_buffer.sample_transitions( self.buffer_batch_size) samples = tu.dict_np_to_torch(samples) policy_loss, qf1_loss, qf2_loss = self.optimize_policy(0, samples) self._update_targets() discriminator_loss = self.optimize_discriminator(samples) return policy_loss, qf1_loss, qf2_loss, discriminator_loss
def train_once(self, itr=None, paths=None): """Complete 1 training iteration of SAC. Args: itr (int): Iteration number. This argument is deprecated. paths (list[dict]): A list of collected paths. This argument is deprecated. Returns: torch.Tensor: loss from actor/policy network after optimization. torch.Tensor: loss from 1st q-function after optimization. torch.Tensor: loss from 2nd q-function after optimization. """ del itr del paths if self.replay_buffer.n_transitions_stored >= self.min_buffer_size: samples = self.replay_buffer.sample(self.buffer_batch_size) samples = tu.dict_np_to_torch(samples) policy_loss, qf1_loss, qf2_loss = self.optimize_policy(0, samples) self._update_targets() return policy_loss, qf1_loss, qf2_loss
def test_dict_np_to_torch(): """Test if dict whose values are tensors can be converted to np arrays.""" dic = {'a': np.zeros(1), 'b': np.ones(1)} tu.dict_np_to_torch(dic) for tensor in dic.values(): assert isinstance(tensor, torch.Tensor)