Esempio n. 1
0
    def train(self):
        """
        Main loop that initiates the training.
        """
        experiences = self.buffer.all_samples()
        rewards = to_tensor(experiences['reward']).to(self.device)
        dones = to_tensor(experiences['done']).type(torch.int).to(self.device)
        states = to_tensor(experiences['state']).to(self.device)
        actions = to_tensor(experiences['action']).to(self.device)
        values = to_tensor(experiences['value']).to(self.device)
        logprobs = to_tensor(experiences['logprob']).to(self.device)
        assert rewards.shape == dones.shape == values.shape == logprobs.shape
        assert states.shape == (
            self.rollout_length, self.num_workers,
            self.state_size), f"Wrong states shape: {states.shape}"
        assert actions.shape == (
            self.rollout_length, self.num_workers,
            self.action_size), f"Wrong action shape: {actions.shape}"

        with torch.no_grad():
            if self.using_gae:
                next_value = self.critic.act(states[-1])
                advantages = compute_gae(rewards, dones, values, next_value,
                                         self.gamma, self.gae_lambda)
                advantages = normalize(advantages)
                returns = advantages + values
                # returns = normalize(advantages + values)
                assert advantages.shape == returns.shape == values.shape
            else:
                returns = revert_norm_returns(rewards, dones, self.gamma)
                returns = returns.float()
                advantages = normalize(returns - values)
                assert advantages.shape == returns.shape == values.shape

        for _ in range(self.num_epochs):
            idx = 0
            self.kl_div = 0
            while idx < self.rollout_length:
                _states = states[idx:idx + self.batch_size].view(
                    -1, self.state_size).detach()
                _actions = actions[idx:idx + self.batch_size].view(
                    -1, self.action_size).detach()
                _logprobs = logprobs[idx:idx + self.batch_size].view(
                    -1, 1).detach()
                _returns = returns[idx:idx + self.batch_size].view(-1,
                                                                   1).detach()
                _advantages = advantages[idx:idx + self.batch_size].view(
                    -1, 1).detach()
                idx += self.batch_size
                self.learn(
                    (_states, _actions, _logprobs, _returns, _advantages))

            self.kl_div = abs(
                self.kl_div) / (self.actor_number_updates *
                                self.rollout_length / self.batch_size)
            if self.kl_div > self.target_kl * 1.75:
                self.kl_beta = min(2 * self.kl_beta, 1e2)  # Max 100
            if self.kl_div < self.target_kl / 1.75:
                self.kl_beta = max(0.5 * self.kl_beta, 1e-6)  # Min 0.000001
            self._metrics['policy/kl_beta'] = self.kl_beta
def test_normalize_2d():
    # Assign
    t1 = torch.arange(0, 26).float()
    t2 = torch.full(t1.shape, 10)
    test_t = torch.stack((t1, t2)).T  # Shape: (26, 2)

    expected_t1 = (t1 - t1.mean()) / t1.std()
    expected_t2 = torch.zeros(t2.shape)
    expected_t = torch.stack((expected_t1, expected_t2)).T  # Shape: (26, 2)

    # Act
    norm_t = normalize(test_t)
    norm_t_dim0 = normalize(test_t, dim=0)

    # Assert
    assert test_t.shape == norm_t.shape == norm_t_dim0.shape == (26, 2)
    assert torch.all(norm_t == expected_t)
    assert torch.all(norm_t == norm_t_dim0)
def test_normalize_1d():
    # Assign
    test_t = torch.arange(0, 26).float()
    expected_t = (test_t - test_t.mean()) / test_t.std()

    # Act
    norm_t = normalize(test_t)

    # Assert
    assert norm_t.shape == test_t.shape
    assert torch.all(norm_t == expected_t)
def test_normalize_2d_not_default_dim():
    # Assign
    t1 = torch.arange(0, 26).float()
    t2 = torch.full(t1.shape, 10)
    test_t = torch.stack((t1, t2))  # Shape: (2, 26)

    expected_t1 = (t1 - t1.mean()) / t1.std()
    expected_t2 = torch.zeros(t2.shape)
    expected_t = torch.stack((expected_t1, expected_t2))  # Shape: (2, 26)

    # Act
    norm_t = normalize(test_t, 1)

    # Assert
    assert norm_t.shape == test_t.shape == (2, 26)
    assert torch.all(norm_t == expected_t)