Пример #1
0
def categorical_q_learning(
    q_atoms_tm1: Array,
    q_logits_tm1: Array,
    a_tm1: Numeric,
    r_t: Numeric,
    discount_t: Numeric,
    q_atoms_t: Array,
    q_logits_t: Array,
    stop_target_gradients: bool = True,
) -> Numeric:
    """Implements Q-learning for categorical Q distributions.

  See "A Distributional Perspective on Reinforcement Learning", by
    Bellemere, Dabney and Munos (https://arxiv.org/pdf/1707.06887.pdf).

  Args:
    q_atoms_tm1: atoms of Q distribution at time t-1.
    q_logits_tm1: logits of Q distribution at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_atoms_t: atoms of Q distribution at time t.
    q_logits_t: logits of Q distribution at time t.
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.

  Returns:
    Categorical Q-learning loss (i.e. temporal difference error).
  """
    chex.assert_rank([
        q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t,
        q_logits_t
    ], [1, 2, 0, 0, 0, 1, 2])
    chex.assert_type([
        q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t,
        q_logits_t
    ], [float, float, int, float, float, float, float])

    # Scale and shift time-t distribution atoms by discount and reward.
    target_z = r_t + discount_t * q_atoms_t

    # Convert logits to distribution, then find greedy action in state s_t.
    q_t_probs = jax.nn.softmax(q_logits_t)
    q_t_mean = jnp.sum(q_t_probs * q_atoms_t[jnp.newaxis, :], axis=1)
    pi_t = jnp.argmax(q_t_mean)

    # Compute distribution for greedy action.
    p_target_z = q_t_probs[pi_t]

    # Project using the Cramer distance and maybe stop gradient flow to targets.
    target = categorical_l2_project(target_z, p_target_z, q_atoms_tm1)
    target = jax.lax.select(stop_target_gradients,
                            jax.lax.stop_gradient(target), target)

    # Compute loss (i.e. temporal difference error).
    logit_qa_tm1 = q_logits_tm1[a_tm1]
    return distributions.categorical_cross_entropy(labels=target,
                                                   logits=logit_qa_tm1)
Пример #2
0
def categorical_double_q_learning(
    q_atoms_tm1: Array,
    q_logits_tm1: Array,
    a_tm1: Numeric,
    r_t: Numeric,
    discount_t: Numeric,
    q_atoms_t: Array,
    q_logits_t: Array,
    q_t_selector: Array,
    stop_target_gradients: bool = True,
) -> Numeric:
    """Implements double Q-learning for categorical Q distributions.

  See "A Distributional Perspective on Reinforcement Learning", by
    Bellemere, Dabney and Munos (https://arxiv.org/pdf/1707.06887.pdf)
  and "Double Q-learning" by van Hasselt.
  (https://papers.nips.cc/paper/3964-double-q-learning.pdf).

  Args:
    q_atoms_tm1: atoms of Q distribution at time t-1.
    q_logits_tm1: logits of Q distribution at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_atoms_t: atoms of Q distribution at time t.
    q_logits_t: logits of Q distribution at time t.
    q_t_selector: selector Q-values at time t.
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.

  Returns:
    Categorical double Q-learning loss (i.e. temporal difference error).
  """
    chex.assert_rank([
        q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t,
        q_logits_t, q_t_selector
    ], [1, 2, 0, 0, 0, 1, 2, 1])
    chex.assert_type([
        q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t,
        q_logits_t, q_t_selector
    ], [float, float, int, float, float, float, float, float])

    # Scale and shift time-t distribution atoms by discount and reward.
    target_z = r_t + discount_t * q_atoms_t

    # Select logits for greedy action in state s_t and convert to distribution.
    p_target_z = jax.nn.softmax(q_logits_t[q_t_selector.argmax()])

    # Project using the Cramer distance and maybe stop gradient flow to targets.
    target = categorical_l2_project(target_z, p_target_z, q_atoms_tm1)
    target = jax.lax.select(stop_target_gradients,
                            jax.lax.stop_gradient(target), target)

    # Compute loss (i.e. temporal difference error).
    logit_qa_tm1 = q_logits_tm1[a_tm1]
    return distributions.categorical_cross_entropy(labels=target,
                                                   logits=logit_qa_tm1)
Пример #3
0
def categorical_q_learning(
    q_atoms_tm1: ArrayLike,
    q_logits_tm1: ArrayLike,
    a_tm1: ArrayOrScalar,
    r_t: ArrayOrScalar,
    discount_t: ArrayOrScalar,
    q_atoms_t: ArrayLike,
    q_logits_t: ArrayLike,
) -> ArrayOrScalar:
    """Implements Q-learning for categorical Q distributions.

  See "A Distributional Perspective on Reinforcement Learning", by
    Bellemere, Dabney and Munos (https://arxiv.org/pdf/1707.06887.pdf).

  Args:
    q_atoms_tm1: atoms of Q distribution at time t-1.
    q_logits_tm1: logits of Q distribution at time t-1.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    q_atoms_t: atoms of Q distribution at time t.
    q_logits_t: logits of Q distribution at time t.

  Returns:
    Categorical Q-learning loss (i.e. temporal difference error).
  """
    base.rank_assert([
        q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t,
        q_logits_t
    ], [1, 2, 0, 0, 0, 1, 2])
    base.type_assert([
        q_atoms_tm1, q_logits_tm1, a_tm1, r_t, discount_t, q_atoms_t,
        q_logits_t
    ], [float, float, int, float, float, float, float])

    # Scale and shift time-t distribution atoms by discount and reward.
    target_z = r_t + discount_t * q_atoms_t

    # Convert logits to distribution, then find greedy action in state s_t.
    q_t_probs = jax.nn.softmax(q_logits_t)
    q_t_mean = jnp.sum(q_t_probs * q_atoms_t[jnp.newaxis, :], axis=1)
    pi_t = jnp.argmax(q_t_mean)

    # Compute distribution for greedy action.
    p_target_z = q_t_probs[pi_t]

    # Project using the Cramer distance.
    target = jax.lax.stop_gradient(
        _categorical_l2_project(target_z, p_target_z, q_atoms_tm1))

    # Compute loss (i.e. temporal difference error).
    logit_qa_tm1 = q_logits_tm1[a_tm1]
    return distributions.categorical_cross_entropy(labels=target,
                                                   logits=logit_qa_tm1)
Пример #4
0
def categorical_td_learning(
    v_atoms_tm1: Array,
    v_logits_tm1: Array,
    r_t: Numeric,
    discount_t: Numeric,
    v_atoms_t: Array,
    v_logits_t: Array,
    stop_target_gradients: bool = True,
) -> Numeric:
    """Implements TD-learning for categorical value distributions.

  See "A Distributional Perspective on Reinforcement Learning", by
    Bellemere, Dabney and Munos (https://arxiv.org/pdf/1707.06887.pdf).

  Args:
    v_atoms_tm1: atoms of V distribution at time t-1.
    v_logits_tm1: logits of V distribution at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    v_atoms_t: atoms of V distribution at time t.
    v_logits_t: logits of V distribution at time t.
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.

  Returns:
    Categorical Q learning loss (i.e. temporal difference error).
  """
    chex.assert_rank(
        [v_atoms_tm1, v_logits_tm1, r_t, discount_t, v_atoms_t, v_logits_t],
        [1, 1, 0, 0, 1, 1])
    chex.assert_type(
        [v_atoms_tm1, v_logits_tm1, r_t, discount_t, v_atoms_t, v_logits_t],
        [float, float, float, float, float, float])

    # Scale and shift time-t distribution atoms by discount and reward.
    target_z = r_t + discount_t * v_atoms_t

    # Convert logits to distribution.
    v_t_probs = jax.nn.softmax(v_logits_t)

    # Project using the Cramer distance and maybe stop gradient flow to targets.
    target = categorical_l2_project(target_z, v_t_probs, v_atoms_tm1)
    target = jax.lax.select(stop_target_gradients,
                            jax.lax.stop_gradient(target), target)

    # Compute loss (i.e. temporal difference error).
    return distributions.categorical_cross_entropy(labels=target,
                                                   logits=v_logits_tm1)
Пример #5
0
def categorical_td_learning(
    v_atoms_tm1: ArrayLike,
    v_logits_tm1: ArrayLike,
    r_t: ArrayLike,
    discount_t: ArrayLike,
    v_atoms_t: ArrayLike,
    v_logits_t: ArrayLike
) -> ArrayLike:
  """Implements TD-learning for categorical value distributions.

  See "A Distributional Perspective on Reinforcement Learning", by
    Bellemere, Dabney and Munos (https://arxiv.org/pdf/1707.06887.pdf).

  Args:
    v_atoms_tm1: atoms of V distribution at time t-1.
    v_logits_tm1: logits of V distribution at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    v_atoms_t: atoms of V distribution at time t.
    v_logits_t: logits of V distribution at time t.

  Returns:
    Categorical Q learning loss (i.e. temporal difference error).
  """
  base.rank_assert(
      [v_atoms_tm1, v_logits_tm1, r_t, discount_t, v_atoms_t, v_logits_t],
      [1, 1, 0, 0, 1, 1])
  base.type_assert(
      [v_atoms_tm1, v_logits_tm1, r_t, discount_t, v_atoms_t, v_logits_t],
      [float, float, float, float, float, float])

  # Scale and shift time-t distribution atoms by discount and reward.
  target_z = r_t + discount_t * v_atoms_t

  # Convert logits to distribution.
  v_t_probs = jax.nn.softmax(v_logits_t)

  # Project using the Cramer distance.
  target = jax.lax.stop_gradient(
      _categorical_l2_project(target_z, v_t_probs, v_atoms_tm1))

  # Compute loss (i.e. temporal difference error).
  return distributions.categorical_cross_entropy(
      labels=target, logits=v_logits_tm1)