Esempio n. 1
0
def critic_loss(
    batch: Batch,
    critic: Critic,
    gamma: float = 0.99,
) -> Tensor:
    """Computes loss for critic networks.

    Parameters
    ----------
    batch: (Batch) Experience sampled for training.
    critic: (base.Critic) Critic network to optimize.
    gamma: (float, optional) Discount factor.  Range: (0, 1).  Default: 0.99

    Returns
    -------
    Tensor:  Critic loss
    """
    states, actions, _, rewards, dones, _ = batch
    returns = torch.zeros_like(rewards)
    returns[:-1] = utils.discount_values(rewards, dones, gamma)[:-1]

    if isinstance(critic, DQNCritic):
        values = critic(states)[range(len(actions)), actions.long()]
        return (values - returns.unsqueeze(1)).pow(2).mean()
    else:
        return (critic(states, actions) - returns).pow(2).mean()
Esempio n. 2
0
def actor_loss(
    batch: Batch,
    actor: Actor,
    gamma: float = 0.99,
) -> Tensor:
    """Computes loss for the actor network.

    Parameters
    ----------
    batch: (Batch) Experience sampled for training.
    actor: (base.Actor) Actor (policy) network to optimize.
    gamma: (float, optional) Discount factor.  Range: (0, 1).  Default: 0.99

    Returns
    -------
    Tensor:  Actor loss
    """
    states, actions, rewards, dones = batch
    values = utils.discount_values(rewards, dones, gamma).to(rewards.device)
    values = (values - values.mean()) / values.std()
    _, logprobs = actor(states, actions)
    if logprobs.ndim > 1:
        logprobs = logprobs[range(len(actions)), actions.long()]

    return -(logprobs * values).mean()
Esempio n. 3
0
def actor_loss(
    batch: Batch,
    actor: Actor,
    critic: Critic,
    clip_ratio: float = 0.2,
    gamma: float = 0.99,
    lam: float = 0.97,
) -> Tuple[Tensor, float]:
    """Computes loss for the actor network, as well as the approximate
    KL-divergence (used for early stopping of each training update).

    Parameters
    ----------
    batch: (Batch) Experience sampled for training.
    actor: (base.Actor) Actor (policy) network to optimize.
    critic: (base.Critic) Critic network to optimize.
    gamma: (float, optional) Discount factor.  Range: (0, 1).  Default: 0.99
    clip_ratio: (float, optional) Hyperparameter for clipping in the policy
        objective.  Scales how much the policy is allowed change per
        training update.  Default: 0.2.
    lam: (float, optional) Hyperparameter for GAE-Lambda calaulation.
        Range: (0, 1).  Default: 0.97

    Returns
    -------
    (Tensor, float):  Actor loss, KL divergence
    """
    states, actions, old_logprobs, rewards, dones, next_states = batch
    with torch.no_grad():
        if isinstance(critic, DQNCritic):
            vals = critic(states)[range(len(actions)), actions.long()]
            next_act = actor(next_states)[0]
            next_vals = critic(next_states)[range(len(next_act)),
                                            next_act.long()]
        else:
            vals = critic(states, actions)
            next_act = actor(next_states)[0]
            next_vals = critic(next_states, next_act)

    # GAE-Lambda advantages
    deltas = rewards + gamma * next_vals - vals
    deltas = torch.where(dones > 1e-6, rewards, deltas)
    adv = utils.discount_values(deltas, dones, gamma * lam).to(deltas.device)
    adv = (adv - adv.mean()) / adv.std()

    _, logp = actor(states, actions)
    ratio = torch.exp(logp - old_logprobs)
    if ratio.ndim > 1:
        ratio = ratio[range(len(actions)), actions.long()]
    clip_adv = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) * adv
    loss_pi = -(torch.min(ratio * adv, clip_adv)).mean()
    approx_kl = (old_logprobs - logp).mean().item()

    return loss_pi, approx_kl
Esempio n. 4
0
File: a2c.py Progetto: fkodom/metis
def actor_loss(
    batch: Batch,
    actor: Actor,
    critic: Critic,
    gamma: float = 0.99,
    lam: float = 0.97,
) -> Tensor:
    """Computes loss for the actor network.

    Parameters
    ----------
    batch: (Batch) Experience sampled for training.
    actor: (base.Actor) Actor (policy) network to optimize.
    critic: (base.Critic) Critic network to optimize.
    gamma: (float, optional) Discount factor.  Range: (0, 1).  Default: 0.99
    lam: (float, optional) Hyperparameter for GAE-Lambda calaulation.
        Range: (0, 1).  Default: 0.97

    Returns
    -------
    (Tensor, float):  Actor loss, KL divergence
    """
    states, actions, old_logprobs, rewards, dones, next_states = batch
    with torch.no_grad():
        if isinstance(critic, DQNCritic):
            values = critic(states)[range(len(actions)), actions.long()]
            next_act = actor(next_states)[0]
            next_values = critic(next_states)[range(len(next_act)), next_act.long()]
        else:
            values = critic(states, actions)
            next_act = actor(next_states)[0]
            next_values = critic(next_states, next_act)

    # GAE-Lambda advantages
    deltas = rewards + gamma * next_values - values
    deltas = torch.where(dones > 1e-6, rewards, deltas)
    advantages = utils.discount_values(deltas, dones, gamma * lam).to(deltas.device)
    advantages = (advantages - advantages.mean()) / advantages.std()

    _, logprobs = actor(states, actions)
    if logprobs.ndim > 1:
        logprobs = logprobs[range(len(actions)), actions.long()]

    return -(logprobs * advantages).mean()