Beispiel #1
0
    def _train_once(self, itr):
        """Perform one iteration of training.

        Args:
            itr (int): Iteration number.

        """
        for grad_step_timer in range(self._grad_steps_per_env_step):
            if (self._replay_buffer.n_transitions_stored >=
                    self._min_buffer_size):
                # Sample from buffer
                samples = self._replay_buffer.sample_transitions(
                    self._buffer_batch_size)
                samples = dict_np_to_torch(samples)

                # Optimize
                qf_loss, y, q, policy_loss = torch_to_np(
                    self._optimize_policy(samples, grad_step_timer))

                self._episode_policy_losses.append(policy_loss)
                self._episode_qf_losses.append(qf_loss)
                self._epoch_ys.append(y)
                self._epoch_qs.append(q)

        if itr % self._steps_per_epoch == 0:
            logger.log('Training finished')
            epoch = itr // self._steps_per_epoch

            if (self._replay_buffer.n_transitions_stored >=
                    self._min_buffer_size):
                tabular.record('Epoch', epoch)
                self._log_statistics()
Beispiel #2
0
    def optimize_policy(self, samples_data):
        """Perform algorithm optimizing.

        Args:
            samples_data (dict): Processed batch data.

        Returns:
            action_loss: Loss of action predicted by the policy network.
            qval_loss: Loss of Q-value predicted by the Q-network.
            ys: y_s.
            qval: Q-value predicted by the Q-network.

        """
        transitions = dict_np_to_torch(samples_data)

        observations = transitions['observations']
        rewards = transitions['rewards'].reshape(-1, 1)
        actions = transitions['actions']
        next_observations = transitions['next_observations']
        terminals = transitions['terminals'].reshape(-1, 1)

        next_inputs = next_observations
        inputs = observations
        with torch.no_grad():
            next_actions = self._target_policy(next_inputs)
            target_qvals = self._target_qf(next_inputs, next_actions)

        clip_range = (-self._clip_return,
                      0. if self._clip_pos_returns else self._clip_return)

        y_target = rewards + (1.0 - terminals) * self._discount * target_qvals
        y_target = torch.clamp(y_target, clip_range[0], clip_range[1])

        # optimize critic
        qval = self._qf(inputs, actions)
        qf_loss = torch.nn.MSELoss()
        qval_loss = qf_loss(qval, y_target)
        self._qf_optimizer.zero_grad()
        qval_loss.backward()
        self._qf_optimizer.step()

        # optimize actor
        actions = self.policy(inputs)
        action_loss = -1 * self._qf(inputs, actions).mean()
        self._policy_optimizer.zero_grad()
        action_loss.backward()
        self._policy_optimizer.step()

        # update target networks
        self.update_target()
        return (qval_loss.detach(), y_target, qval.detach(),
                action_loss.detach())
Beispiel #3
0
    def train_once(self, itr=None, paths=None):
        """Complete 1 training iteration of SAC.

        Args:
            itr (int): Iteration number. This argument is deprecated.
            paths (list[dict]): A list of collected paths.
                This argument is deprecated.

        Returns:
            torch.Tensor: loss from actor/policy network after optimization.
            torch.Tensor: loss from 1st q-function after optimization.
            torch.Tensor: loss from 2nd q-function after optimization.

        """
        del itr
        del paths
        if self.replay_buffer.n_transitions_stored >= self._min_buffer_size:
            samples = self.replay_buffer.sample_transitions(
                self._buffer_batch_size)
            samples = dict_np_to_torch(samples)
            policy_loss, qf1_loss, qf2_loss = self.optimize_policy(samples)
            self._update_targets()

        return policy_loss, qf1_loss, qf2_loss
Beispiel #4
0
def test_dict_np_to_torch():
    """Test if dict whose values are tensors can be converted to np arrays."""
    dic = {'a': np.zeros(1), 'b': np.ones(1)}
    dict_np_to_torch(dic)
    for tensor in dic.values():
        assert isinstance(tensor, torch.Tensor)