def test_factory_rollout_buffer_from_state_wrong_type():
    # Assign
    buffer = RolloutBuffer(batch_size=5, buffer_size=20)
    state = buffer.get_state()
    state.type = "WrongType"

    # Act
    with pytest.raises(ValueError):
        BufferFactory.from_state(state=state)
Ejemplo n.º 2
0
def test_rollout_buffer_length():
    # Assign
    buffer_size = 10
    buffer = RolloutBuffer(batch_size=5, buffer_size=buffer_size)

    # Act
    for (state, action, reward, next_state, done) in generate_sample_SARS(buffer_size+1):
        buffer.add(state=state, action=action, reward=reward, next_state=next_state, done=done)

    # Assert
    assert len(buffer) == buffer_size
Ejemplo n.º 3
0
def test_rollout_buffer_get_state_without_data():
    # Assign
    buffer = RolloutBuffer(batch_size=5, buffer_size=20)

    # Act
    state = buffer.get_state()

    # Assert
    assert state.type == RolloutBuffer.type
    assert state.buffer_size == 20
    assert state.batch_size == 5
    assert state.data is None
def test_factory_rollout_buffer_from_state_without_data():
    # Assign
    buffer = RolloutBuffer(batch_size=5, buffer_size=20)
    state = buffer.get_state()

    # Act
    new_buffer = BufferFactory.from_state(state=state)

    # Assert
    assert new_buffer == buffer
    assert new_buffer.buffer_size == state.buffer_size
    assert new_buffer.batch_size == state.batch_size
    assert len(new_buffer.data) == 0
Ejemplo n.º 5
0
def test_rollout_buffer_from_state_with_data():
    # Assign
    buffer = RolloutBuffer(batch_size=5, buffer_size=20)
    buffer = populate_buffer(buffer, 30)
    state = buffer.get_state()

    # Act
    new_buffer = RolloutBuffer.from_state(state=state)

    # Assert
    assert new_buffer == buffer
    assert new_buffer.buffer_size == state.buffer_size
    assert new_buffer.batch_size == state.batch_size
    assert new_buffer.data == state.data
    assert len(buffer.data) == state.buffer_size
Ejemplo n.º 6
0
def test_rollout_buffer_sample_batch_equal_buffer():
    # Assign
    buffer_size = batch_size = 20
    buffer = RolloutBuffer(batch_size=batch_size, buffer_size=buffer_size)

    # Act
    for (state, action, reward, next_state, done) in generate_sample_SARS(buffer_size+1):
        buffer.add(state=state, action=action, reward=reward, next_state=next_state, done=done)

    # Assert
    num_samples = 0
    for samples in buffer.sample():
        num_samples += 1
        for value in samples.values():
            assert len(value) == batch_size
    assert num_samples == 1
Ejemplo n.º 7
0
def test_rollout_buffer_get_state_with_data():
    # Assign
    buffer = RolloutBuffer(batch_size=5, buffer_size=20)
    sample_experience = Experience(state=[0, 1], action=[0], reward=0)
    for _ in range(25):
        buffer.add(**sample_experience.data)

    # Act
    state = buffer.get_state()

    # Assert
    assert state.type == RolloutBuffer.type
    assert state.buffer_size == 20
    assert state.batch_size == 5
    assert state.data is not None
    assert len(state.data) == 20

    for data in state.data:
        assert data == sample_experience
Ejemplo n.º 8
0
 def from_state(state: BufferState) -> BufferBase:
     if state.type == ReplayBuffer.type:
         return ReplayBuffer.from_state(state)
     elif state.type == PERBuffer.type:
         return PERBuffer.from_state(state)
     elif state.type == NStepBuffer.type:
         return NStepBuffer.from_state(state)
     elif state.type == RolloutBuffer.type:
         return RolloutBuffer.from_state(state)
     else:
         raise ValueError(f"Buffer state contains unsupported buffer type: '{state.type}'")
Ejemplo n.º 9
0
def test_rollout_buffer_travers_buffer_twice():
    # Assign
    batch_size = 10
    buffer_size = 30
    buffer = RolloutBuffer(batch_size=batch_size, buffer_size=buffer_size)

    # Act
    reward = -1
    for (state, action, _, next_state, done) in generate_sample_SARS(buffer_size):
        reward += 1
        buffer.add(state=state, action=action, reward=reward, next_state=next_state, done=done)

    # Assert
    num_samples = 0

    # First pass
    for idx, samples in enumerate(buffer.sample()):
        num_samples += 1
        rewards = samples['reward']
        assert rewards == list(range(idx*10, (idx+1)*10))

    # Second pass
    for idx, samples in enumerate(buffer.sample()):
        num_samples += 1
        rewards = samples['reward']
        assert rewards == list(range(idx*10, (idx+1)*10))

    assert num_samples == 6  # 2 * ceil(buffer_size / batch_size)
Ejemplo n.º 10
0
def test_rollout_buffer_size_not_multiple_of_minibatch():
    # Assign
    batch_size = 10
    buffer_size = 55
    buffer = RolloutBuffer(batch_size=batch_size, buffer_size=buffer_size)

    # Act
    reward = -1
    for (state, action, _, next_state, done) in generate_sample_SARS(buffer_size):
        reward += 1
        buffer.add(state=state, action=action, reward=reward, next_state=next_state, done=done)

    # Assert
    num_samples = 0
    for idx, samples in enumerate(buffer.sample()):
        num_samples += 1
        rewards = samples['reward']
        if idx != 5:
            assert len(rewards) == batch_size
            assert rewards == list(range(idx*10, (idx+1)*10))
        else:
            assert len(rewards) == 5
            assert rewards == [50, 51, 52, 53, 54]
    assert num_samples == 6  # ceil(buffer_size / batch_size)
Ejemplo n.º 11
0
def test_rollout_buffer_clear_buffer():
    # Assign
    batch_size = 10
    buffer_size = 30
    buffer = RolloutBuffer(batch_size=batch_size, buffer_size=buffer_size)

    # Act
    reward = -1
    for (state, action, _, next_state, done) in generate_sample_SARS(buffer_size):
        reward += 1
        buffer.add(state=state, action=action, reward=reward, next_state=next_state, done=done)

    # Assert
    for idx, samples in enumerate(buffer.sample()):
        rewards = samples['reward']
        assert rewards == list(range(idx*10, (idx+1)*10))

    buffer.clear()
    assert len(buffer) == 0
Ejemplo n.º 12
0
    def __init__(self, state_size: int, action_size: int, **kwargs):
        """
        Parameters:
            state_size: Number of input dimensions.
            action_size: Number of output dimensions
            hidden_layers: (default: (100, 100) ) Tuple defining hidden dimensions in fully connected nets.
            is_discrete: (default: False) Whether return discrete action.
            kl_div: (default: False) Whether to use KL divergence in loss.
            using_gae: (default: True) Whether to use General Advantage Estimator.
            gae_lambda: (default: 0.96) Value of \lambda in GAE.
            actor_lr: (default: 0.0003) Learning rate for the actor (policy).
            critic_lr: (default: 0.001) Learning rate for the critic (value function).
            actor_betas: (default: (0.9, 0.999) Adam's betas for actor optimizer.
            critic_betas: (default: (0.9, 0.999) Adam's betas for critic optimizer.
            gamma: (default: 0.99) Discount value.
            ppo_ratio_clip: (default: 0.25) Policy ratio clipping value.
            num_epochs: (default: 1) Number of time to learn from samples.
            rollout_length: (default: 48) Number of actions to take before update.
            batch_size: (default: rollout_length) Number of samples used in learning.
            actor_number_updates: (default: 10) Number of times policy losses are propagated.
            critic_number_updates: (default: 10) Number of times value losses are propagated.
            entropy_weight: (default: 0.005) Weight of the entropy term in the loss.
            value_loss_weight: (default: 0.005) Weight of the entropy term in the loss.

        """
        super().__init__(**kwargs)

        self.device = self._register_param(
            kwargs, "device", DEVICE)  # Default device is CUDA if available

        self.state_size = state_size
        self.action_size = action_size
        self.hidden_layers = to_numbers_seq(
            self._register_param(kwargs, "hidden_layers", (100, 100)))
        self.iteration = 0

        self.is_discrete = bool(
            self._register_param(kwargs, "is_discrete", False))
        self.using_gae = bool(self._register_param(kwargs, "using_gae", True))
        self.gae_lambda = float(
            self._register_param(kwargs, "gae_lambda", 0.96))

        self.actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-4))
        self.actor_betas: Tuple[float, float] = to_numbers_seq(
            self._register_param(kwargs, 'actor_betas', (0.9, 0.999)))
        self.critic_lr = float(self._register_param(kwargs, 'critic_lr', 1e-3))
        self.critic_betas: Tuple[float, float] = to_numbers_seq(
            self._register_param(kwargs, 'critic_betas', (0.9, 0.999)))
        self.gamma = float(self._register_param(kwargs, "gamma", 0.99))
        self.ppo_ratio_clip = float(
            self._register_param(kwargs, "ppo_ratio_clip", 0.25))

        self.using_kl_div = bool(
            self._register_param(kwargs, "using_kl_div", False))
        self.kl_beta = float(self._register_param(kwargs, 'kl_beta', 0.1))
        self.target_kl = float(self._register_param(kwargs, "target_kl", 0.01))
        self.kl_div = float('inf')

        self.num_workers = int(self._register_param(kwargs, "num_workers", 1))
        self.num_epochs = int(self._register_param(kwargs, "num_epochs", 1))
        self.rollout_length = int(
            self._register_param(kwargs, "rollout_length",
                                 48))  # "Much less than the episode length"
        self.batch_size = int(
            self._register_param(kwargs, "batch_size", self.rollout_length))
        self.actor_number_updates = int(
            self._register_param(kwargs, "actor_number_updates", 10))
        self.critic_number_updates = int(
            self._register_param(kwargs, "critic_number_updates", 10))
        self.entropy_weight = float(
            self._register_param(kwargs, "entropy_weight", 0.5))
        self.value_loss_weight = float(
            self._register_param(kwargs, "value_loss_weight", 1.0))

        self.local_memory_buffer = {}

        self.action_scale = float(
            self._register_param(kwargs, "action_scale", 1))
        self.action_min = float(self._register_param(kwargs, "action_min", -1))
        self.action_max = float(self._register_param(kwargs, "action_max", 1))
        self.max_grad_norm_actor = float(
            self._register_param(kwargs, "max_grad_norm_actor", 100.0))
        self.max_grad_norm_critic = float(
            self._register_param(kwargs, "max_grad_norm_critic", 100.0))

        if kwargs.get("simple_policy", False):
            self.policy = MultivariateGaussianPolicySimple(
                self.action_size, **kwargs)
        else:
            self.policy = MultivariateGaussianPolicy(self.action_size,
                                                     device=self.device)

        self.buffer = RolloutBuffer(batch_size=self.batch_size,
                                    buffer_size=self.rollout_length)
        self.actor = ActorBody(state_size,
                               self.policy.param_dim * action_size,
                               gate_out=torch.tanh,
                               hidden_layers=self.hidden_layers,
                               device=self.device)
        self.critic = ActorBody(state_size,
                                1,
                                gate_out=None,
                                hidden_layers=self.hidden_layers,
                                device=self.device)
        self.actor_params = list(self.actor.parameters()) + list(
            self.policy.parameters())
        self.critic_params = list(self.critic.parameters())

        self.actor_opt = optim.Adam(self.actor_params,
                                    lr=self.actor_lr,
                                    betas=self.actor_betas)
        self.critic_opt = optim.Adam(self.critic_params,
                                     lr=self.critic_lr,
                                     betas=self.critic_betas)
        self._loss_actor = float('nan')
        self._loss_critic = float('nan')
        self._metrics: Dict[str, float] = {}
Ejemplo n.º 13
0
class PPOAgent(AgentBase):
    """
    Proximal Policy Optimization (PPO) [1] is an online policy gradient method
    that could be considered as an implementation-wise simplified version of
    the Trust Region Policy Optimization (TRPO).


    [1] "Proximal Policy Optimization Algorithms" (2017) by J. Schulman, F. Wolski,
        P. Dhariwal, A. Radford, O. Klimov. https://arxiv.org/abs/1707.06347
    """

    name = "PPO"

    def __init__(self, state_size: int, action_size: int, **kwargs):
        """
        Parameters:
            state_size: Number of input dimensions.
            action_size: Number of output dimensions
            hidden_layers: (default: (100, 100) ) Tuple defining hidden dimensions in fully connected nets.
            is_discrete: (default: False) Whether return discrete action.
            kl_div: (default: False) Whether to use KL divergence in loss.
            using_gae: (default: True) Whether to use General Advantage Estimator.
            gae_lambda: (default: 0.96) Value of \lambda in GAE.
            actor_lr: (default: 0.0003) Learning rate for the actor (policy).
            critic_lr: (default: 0.001) Learning rate for the critic (value function).
            actor_betas: (default: (0.9, 0.999) Adam's betas for actor optimizer.
            critic_betas: (default: (0.9, 0.999) Adam's betas for critic optimizer.
            gamma: (default: 0.99) Discount value.
            ppo_ratio_clip: (default: 0.25) Policy ratio clipping value.
            num_epochs: (default: 1) Number of time to learn from samples.
            rollout_length: (default: 48) Number of actions to take before update.
            batch_size: (default: rollout_length) Number of samples used in learning.
            actor_number_updates: (default: 10) Number of times policy losses are propagated.
            critic_number_updates: (default: 10) Number of times value losses are propagated.
            entropy_weight: (default: 0.005) Weight of the entropy term in the loss.
            value_loss_weight: (default: 0.005) Weight of the entropy term in the loss.

        """
        super().__init__(**kwargs)

        self.device = self._register_param(
            kwargs, "device", DEVICE)  # Default device is CUDA if available

        self.state_size = state_size
        self.action_size = action_size
        self.hidden_layers = to_numbers_seq(
            self._register_param(kwargs, "hidden_layers", (100, 100)))
        self.iteration = 0

        self.is_discrete = bool(
            self._register_param(kwargs, "is_discrete", False))
        self.using_gae = bool(self._register_param(kwargs, "using_gae", True))
        self.gae_lambda = float(
            self._register_param(kwargs, "gae_lambda", 0.96))

        self.actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-4))
        self.actor_betas: Tuple[float, float] = to_numbers_seq(
            self._register_param(kwargs, 'actor_betas', (0.9, 0.999)))
        self.critic_lr = float(self._register_param(kwargs, 'critic_lr', 1e-3))
        self.critic_betas: Tuple[float, float] = to_numbers_seq(
            self._register_param(kwargs, 'critic_betas', (0.9, 0.999)))
        self.gamma = float(self._register_param(kwargs, "gamma", 0.99))
        self.ppo_ratio_clip = float(
            self._register_param(kwargs, "ppo_ratio_clip", 0.25))

        self.using_kl_div = bool(
            self._register_param(kwargs, "using_kl_div", False))
        self.kl_beta = float(self._register_param(kwargs, 'kl_beta', 0.1))
        self.target_kl = float(self._register_param(kwargs, "target_kl", 0.01))
        self.kl_div = float('inf')

        self.num_workers = int(self._register_param(kwargs, "num_workers", 1))
        self.num_epochs = int(self._register_param(kwargs, "num_epochs", 1))
        self.rollout_length = int(
            self._register_param(kwargs, "rollout_length",
                                 48))  # "Much less than the episode length"
        self.batch_size = int(
            self._register_param(kwargs, "batch_size", self.rollout_length))
        self.actor_number_updates = int(
            self._register_param(kwargs, "actor_number_updates", 10))
        self.critic_number_updates = int(
            self._register_param(kwargs, "critic_number_updates", 10))
        self.entropy_weight = float(
            self._register_param(kwargs, "entropy_weight", 0.5))
        self.value_loss_weight = float(
            self._register_param(kwargs, "value_loss_weight", 1.0))

        self.local_memory_buffer = {}

        self.action_scale = float(
            self._register_param(kwargs, "action_scale", 1))
        self.action_min = float(self._register_param(kwargs, "action_min", -1))
        self.action_max = float(self._register_param(kwargs, "action_max", 1))
        self.max_grad_norm_actor = float(
            self._register_param(kwargs, "max_grad_norm_actor", 100.0))
        self.max_grad_norm_critic = float(
            self._register_param(kwargs, "max_grad_norm_critic", 100.0))

        if kwargs.get("simple_policy", False):
            self.policy = MultivariateGaussianPolicySimple(
                self.action_size, **kwargs)
        else:
            self.policy = MultivariateGaussianPolicy(self.action_size,
                                                     device=self.device)

        self.buffer = RolloutBuffer(batch_size=self.batch_size,
                                    buffer_size=self.rollout_length)
        self.actor = ActorBody(state_size,
                               self.policy.param_dim * action_size,
                               gate_out=torch.tanh,
                               hidden_layers=self.hidden_layers,
                               device=self.device)
        self.critic = ActorBody(state_size,
                                1,
                                gate_out=None,
                                hidden_layers=self.hidden_layers,
                                device=self.device)
        self.actor_params = list(self.actor.parameters()) + list(
            self.policy.parameters())
        self.critic_params = list(self.critic.parameters())

        self.actor_opt = optim.Adam(self.actor_params,
                                    lr=self.actor_lr,
                                    betas=self.actor_betas)
        self.critic_opt = optim.Adam(self.critic_params,
                                     lr=self.critic_lr,
                                     betas=self.critic_betas)
        self._loss_actor = float('nan')
        self._loss_critic = float('nan')
        self._metrics: Dict[str, float] = {}

    @property
    def loss(self) -> Dict[str, float]:
        return {'actor': self._loss_actor, 'critic': self._loss_critic}

    @loss.setter
    def loss(self, value):
        if isinstance(value, dict):
            self._loss_actor = value['actor']
            self._loss_critic = value['critic']
        else:
            self._loss_actor = value
            self._loss_critic = value

    def __eq__(self, o: object) -> bool:
        return super().__eq__(o) \
            and self._config == o._config \
            and self.buffer == o.buffer \
            and self.get_network_state() == o.get_network_state()  # TODO @dawid: Currently net isn't compared properly

    def __clear_memory(self):
        self.buffer.clear()

    @torch.no_grad()
    def act(self, state, epsilon: float = 0.):
        actions = []
        logprobs = []
        values = []
        state = to_tensor(state).view(self.num_workers,
                                      self.state_size).float().to(self.device)
        for worker in range(self.num_workers):
            actor_est = self.actor.act(state[worker].unsqueeze(0))
            assert not torch.any(torch.isnan(actor_est))

            dist = self.policy(actor_est)
            action = dist.sample()
            value = self.critic.act(
                state[worker].unsqueeze(0))  # Shape: (1, 1)
            logprob = self.policy.log_prob(dist, action)  # Shape: (1,)
            values.append(value)
            logprobs.append(logprob)

            if self.is_discrete:  # *Technically* it's the max of Softmax but that's monotonic.
                action = int(torch.argmax(action))
            else:
                action = torch.clamp(action * self.action_scale,
                                     self.action_min, self.action_max)
                action = action.cpu().numpy().flatten().tolist()
            actions.append(action)

        self.local_memory_buffer['value'] = torch.cat(values)
        self.local_memory_buffer['logprob'] = torch.stack(logprobs)
        assert len(actions) == self.num_workers
        return actions if self.num_workers > 1 else actions[0]

    def step(self, state, action, reward, next_state, done, **kwargs):
        self.iteration += 1

        self.buffer.add(
            state=torch.tensor(state).reshape(self.num_workers,
                                              self.state_size).float(),
            action=torch.tensor(action).reshape(self.num_workers,
                                                self.action_size).float(),
            reward=torch.tensor(reward).reshape(self.num_workers, 1),
            done=torch.tensor(done).reshape(self.num_workers, 1),
            logprob=self.local_memory_buffer['logprob'].reshape(
                self.num_workers, 1),
            value=self.local_memory_buffer['value'].reshape(
                self.num_workers, 1),
        )

        if self.iteration % self.rollout_length == 0:
            self.train()
            self.__clear_memory()

    def train(self):
        """
        Main loop that initiates the training.
        """
        experiences = self.buffer.all_samples()
        rewards = to_tensor(experiences['reward']).to(self.device)
        dones = to_tensor(experiences['done']).type(torch.int).to(self.device)
        states = to_tensor(experiences['state']).to(self.device)
        actions = to_tensor(experiences['action']).to(self.device)
        values = to_tensor(experiences['value']).to(self.device)
        logprobs = to_tensor(experiences['logprob']).to(self.device)
        assert rewards.shape == dones.shape == values.shape == logprobs.shape
        assert states.shape == (
            self.rollout_length, self.num_workers,
            self.state_size), f"Wrong states shape: {states.shape}"
        assert actions.shape == (
            self.rollout_length, self.num_workers,
            self.action_size), f"Wrong action shape: {actions.shape}"

        with torch.no_grad():
            if self.using_gae:
                next_value = self.critic.act(states[-1])
                advantages = compute_gae(rewards, dones, values, next_value,
                                         self.gamma, self.gae_lambda)
                advantages = normalize(advantages)
                returns = advantages + values
                # returns = normalize(advantages + values)
                assert advantages.shape == returns.shape == values.shape
            else:
                returns = revert_norm_returns(rewards, dones, self.gamma)
                returns = returns.float()
                advantages = normalize(returns - values)
                assert advantages.shape == returns.shape == values.shape

        for _ in range(self.num_epochs):
            idx = 0
            self.kl_div = 0
            while idx < self.rollout_length:
                _states = states[idx:idx + self.batch_size].view(
                    -1, self.state_size).detach()
                _actions = actions[idx:idx + self.batch_size].view(
                    -1, self.action_size).detach()
                _logprobs = logprobs[idx:idx + self.batch_size].view(
                    -1, 1).detach()
                _returns = returns[idx:idx + self.batch_size].view(-1,
                                                                   1).detach()
                _advantages = advantages[idx:idx + self.batch_size].view(
                    -1, 1).detach()
                idx += self.batch_size
                self.learn(
                    (_states, _actions, _logprobs, _returns, _advantages))

            self.kl_div = abs(
                self.kl_div) / (self.actor_number_updates *
                                self.rollout_length / self.batch_size)
            if self.kl_div > self.target_kl * 1.75:
                self.kl_beta = min(2 * self.kl_beta, 1e2)  # Max 100
            if self.kl_div < self.target_kl / 1.75:
                self.kl_beta = max(0.5 * self.kl_beta, 1e-6)  # Min 0.000001
            self._metrics['policy/kl_beta'] = self.kl_beta

    def compute_policy_loss(self, samples):
        states, actions, old_log_probs, _, advantages = samples

        actor_est = self.actor(states)
        dist = self.policy(actor_est)

        entropy = dist.entropy()
        new_log_probs = self.policy.log_prob(dist, actions).view(-1, 1)
        assert new_log_probs.shape == old_log_probs.shape

        r_theta = (new_log_probs - old_log_probs).exp()
        r_theta_clip = torch.clamp(r_theta, 1.0 - self.ppo_ratio_clip,
                                   1.0 + self.ppo_ratio_clip)
        assert r_theta.shape == r_theta_clip.shape

        # KL = E[log(P/Q)] = sum_{P}( P * log(P/Q) ) -- \approx --> avg_{P}( log(P) - log(Q) )
        approx_kl_div = (old_log_probs - new_log_probs).mean().item()
        if self.using_kl_div:
            # Ratio threshold for updates is 1.75 (although it should be configurable)
            policy_loss = -torch.mean(
                r_theta * advantages) + self.kl_beta * approx_kl_div
        else:
            joint_theta_adv = torch.stack(
                (r_theta * advantages, r_theta_clip * advantages))
            assert joint_theta_adv.shape[0] == 2
            policy_loss = -torch.amin(joint_theta_adv, dim=0).mean()
        entropy_loss = -self.entropy_weight * entropy.mean()

        loss = policy_loss + entropy_loss
        self._metrics['policy/kl_div'] = approx_kl_div
        self._metrics['policy/policy_ratio'] = float(r_theta.mean())
        self._metrics['policy/policy_ratio_clip_mean'] = float(
            r_theta_clip.mean())
        return loss, approx_kl_div

    def compute_value_loss(self, samples):
        states, _, _, returns, _ = samples
        values = self.critic(states)
        self._metrics['value/value_mean'] = values.mean()
        self._metrics['value/value_std'] = values.std()
        return F.mse_loss(values, returns)

    def learn(self, samples):
        self._loss_actor = 0.

        for _ in range(self.actor_number_updates):
            self.actor_opt.zero_grad()
            loss_actor, kl_div = self.compute_policy_loss(samples)
            self.kl_div += kl_div
            if kl_div > 1.5 * self.target_kl:
                # Early break
                # print(f"Iter: {i:02} Early break")
                break
            loss_actor.backward()
            nn.utils.clip_grad_norm_(self.actor_params,
                                     self.max_grad_norm_actor)
            self.actor_opt.step()
            self._loss_actor = loss_actor.item()

        for _ in range(self.critic_number_updates):
            self.critic_opt.zero_grad()
            loss_critic = self.compute_value_loss(samples)
            loss_critic.backward()
            nn.utils.clip_grad_norm_(self.critic_params,
                                     self.max_grad_norm_critic)
            self.critic_opt.step()
            self._loss_critic = float(loss_critic.item())

    def log_metrics(self,
                    data_logger: DataLogger,
                    step: int,
                    full_log: bool = False):
        data_logger.log_value("loss/actor", self._loss_actor, step)
        data_logger.log_value("loss/critic", self._loss_critic, step)
        for metric_name, metric_value in self._metrics.items():
            data_logger.log_value(metric_name, metric_value, step)

        policy_params = {
            str(i): v
            for i, v in enumerate(
                itertools.chain.from_iterable(self.policy.parameters()))
        }
        data_logger.log_values_dict("policy/param", policy_params, step)

        if full_log:
            for idx, layer in enumerate(self.actor.layers):
                if hasattr(layer, "weight"):
                    data_logger.create_histogram(f"actor/layer_weights_{idx}",
                                                 layer.weight, step)
                if hasattr(layer, "bias") and layer.bias is not None:
                    data_logger.create_histogram(f"actor/layer_bias_{idx}",
                                                 layer.bias, step)

            for idx, layer in enumerate(self.critic.layers):
                if hasattr(layer, "weight"):
                    data_logger.create_histogram(f"critic/layer_weights_{idx}",
                                                 layer.weight, step)
                if hasattr(layer, "bias") and layer.bias is not None:
                    data_logger.create_histogram(f"critic/layer_bias_{idx}",
                                                 layer.bias, step)

    def get_state(self) -> AgentState:
        return AgentState(model=self.name,
                          state_space=self.state_size,
                          action_space=self.action_size,
                          config=self._config,
                          buffer=copy.deepcopy(self.buffer.get_state()),
                          network=copy.deepcopy(self.get_network_state()))

    def get_network_state(self) -> NetworkState:
        return NetworkState(net=dict(
            policy=self.policy.state_dict(),
            actor=self.actor.state_dict(),
            critic=self.critic.state_dict(),
        ))

    def set_buffer(self, buffer_state: BufferState) -> None:
        self.buffer = BufferFactory.from_state(buffer_state)

    def set_network(self, network_state: NetworkState) -> None:
        self.policy.load_state_dict(network_state.net['policy'])
        self.actor.load_state_dict(network_state.net['actor'])
        self.critic.load_state_dict(network_state.net['critic'])

    @staticmethod
    def from_state(state: AgentState) -> AgentBase:
        agent = PPOAgent(state.state_space, state.action_space, **state.config)
        if state.network is not None:
            agent.set_network(state.network)
        if state.buffer is not None:
            agent.set_buffer(state.buffer)
        return agent

    def save_state(self, path: str):
        agent_state = self.get_state()
        torch.save(agent_state, path)

    def load_state(self, path: str):
        agent_state = torch.load(path)
        self._config = agent_state.get('config', {})
        self.__dict__.update(**self._config)

        self.policy.load_state_dict(agent_state['policy'])
        self.actor.load_state_dict(agent_state['actor'])
        self.critic.load_state_dict(agent_state['critic'])