示例#1
0
    def objective_func(self, params, state, hyperparams, rng, transition_batch,
                       Adv):
        rngs = hk.PRNGSequence(rng)

        # get distribution params from function approximator
        S = self.pi.observation_preprocessor(next(rngs), transition_batch.S)
        dist_params, state_new = self.pi.function(params, state, next(rngs), S,
                                                  True)

        # compute objective: q(s, a_greedy)
        S = self.q_targ.observation_preprocessor(next(rngs),
                                                 transition_batch.S)
        A = self.pi.proba_dist.mode(dist_params)
        log_pi = self.pi.proba_dist.log_proba(dist_params, A)
        params_q, state_q = hyperparams['q']['params'], hyperparams['q'][
            'function_state']
        Q, _ = self.q_targ.function_type1(params_q, state_q, next(rngs), S, A,
                                          True)

        # clip importance weights to reduce variance
        W = jnp.clip(transition_batch.W, 0.1, 10.)

        # the objective
        chex.assert_equal_shape([W, Q])
        chex.assert_rank([W, Q], 1)
        objective = W * Q

        return jnp.mean(objective), (dist_params, log_pi, state_new)
示例#2
0
def test_hist_params_transform_shape():

    X = rng.randn(100)

    X_u, _ = get_hist_params(X, support_extension=10, precision=50, alpha=1e-5)

    chex.assert_equal_shape([X_u, X])
示例#3
0
文件: loss.py 项目: ksachdeva/optax
def cosine_distance(
    predictions: chex.Array,
    targets: chex.Array,
    epsilon: float = 0.,
) -> chex.Array:
    r"""Computes the cosine distance between targets and predictions.

  The cosine **distance**, implemented here, measures the **dissimilarity**
  of two vectors as the opposite of cosine **similarity**: `1 - cos(\theta)`.

  References:
    [Wikipedia, 2021](https://en.wikipedia.org/wiki/Cosine_similarity)

  Args:
    predictions: The predicted vector.
    targets: Ground truth target vector.
    epsilon: minimum norm for terms in the denominator of the cosine similarity.

  Returns:
    cosine similarity values.
  """
    chex.assert_equal_shape([targets, predictions])
    chex.assert_type([targets, predictions], float)
    # cosine distance = 1 - cosine similarity.
    return 1. - cosine_similarity(predictions, targets, epsilon)
示例#4
0
def test_mixturegausscdf_bijector_shape(n_samples, n_features, n_components):

    params_rng, data_rng = jax.random.split(KEY, 2)

    x = jax.random.normal(data_rng, shape=(n_samples, n_features))

    # create layer
    init_func = MixtureGaussianCDF(n_components=n_components)

    # create layer
    params, forward_f, inverse_f = init_func(rng=params_rng,
                                             n_features=n_features)

    # forward transformation
    z, log_abs_det = forward_f(params, x)

    # checks
    chex.assert_equal_shape([z, x])
    chex.assert_shape(log_abs_det, (n_samples, ))

    # forward transformation
    x_approx, log_abs_det = inverse_f(params, z)

    # checks
    chex.assert_equal_shape([x, x_approx])
示例#5
0
def test_composite_shape(n_samples, n_features):

    params_rng, data_rng = jax.random.split(KEY, 2)

    x = jax.random.normal(data_rng, shape=(n_samples, n_features))
    # create layer
    init_func = CompositeTransform(
        [MixtureLogisticCDF(n_components=5), Logit(), HouseHolder(n_reflections=2),]
    )

    # create layer
    params, forward_f, inverse_f = init_func(rng=params_rng, n_features=n_features)

    # forward transformation
    z, log_abs_det = forward_f(params, x)

    # checks
    chex.assert_equal_shape([z, x])
    chex.assert_shape(log_abs_det, (n_samples,))

    # forward transformation
    x_approx, log_abs_det = inverse_f(params, z)

    # checks
    chex.assert_equal_shape([x_approx, x])
示例#6
0
    def sample(self, batch_size=32):
        r"""

        Get a batch of transitions to be used for bootstrapped updates.

        Parameters
        ----------
        batch_size : positive int, optional

            The desired batch size of the sample.

        Returns
        -------
        transitions : TransitionBatch

            A :class:`TransitionBatch <coax.reward_tracing.TransitionBatch>` object.

        """
        idx = self._sumtree.sample(n=batch_size)
        P = self._sumtree.values[
            idx] / self._sumtree.root_value  # prioritized, biased propensities
        W = onp.power(P * len(self),
                      -self.beta)  # inverse propensity weights (β≈1)
        W /= W.max(
        )  # for stability, ensure only down-weighting (see sec. 3.4 of arxiv:1511.05952)
        transition_batch = _concatenate_leaves(self._storage[idx])
        chex.assert_equal_shape([transition_batch.W, W])
        transition_batch.W *= W
        return transition_batch
示例#7
0
    def add(self, transition_batch, Adv):
        r"""

        Add a transition to the experience replay buffer.

        Parameters
        ----------
        transition_batch : TransitionBatch

            A :class:`TransitionBatch <coax.reward_tracing.TransitionBatch>` object.

        Adv : ndarray

            A batch of advantages, used to construct the priorities :math:`p_i`.

        """
        if not isinstance(transition_batch, TransitionBatch):
            raise TypeError(
                f"transition_batch must be a TransitionBatch, got: {type(transition_batch)}"
            )

        transition_batch.idx = self._index + onp.arange(
            transition_batch.batch_size)
        idx = transition_batch.idx % self.capacity  # wrap around
        chex.assert_equal_shape([idx, Adv])
        self._storage[idx] = list(transition_batch.to_singles())
        self._sumtree.set_values(
            idx, onp.power(onp.abs(Adv) + self.epsilon, self.alpha))
        self._index += transition_batch.batch_size
示例#8
0
def kl_divergence_with_logits(p_logits=None, q_logits=None, temperature=1.):
    """Compute the KL between two categorical distributions from their logits.

  Args:
    p_logits: [..., dim] array with logits for the first distribution.
    q_logits: [..., dim] array with logits for the second distribution.
    temperature: the temperature for the softmax distribution, defaults at 1.

  Returns:
    an array of KL divergence terms taken over the last axis.
  """
    chex.assert_type([p_logits, q_logits], float)
    chex.assert_equal_shape([p_logits, q_logits])

    p_logits /= temperature
    q_logits /= temperature

    p = jax.nn.softmax(p_logits)

    log_p = jax.nn.log_softmax(p_logits)
    log_q = jax.nn.log_softmax(q_logits)
    kl = jnp.sum(p * (log_p - log_q), axis=-1)

    ## KL divergence should be positive, this helps with numerical stability
    loss = jax.nn.relu(kl)

    return loss
示例#9
0
    def objective_func(self, params, state, hyperparams, rng, transition_batch,
                       Adv):
        rngs = hk.PRNGSequence(rng)

        # get distribution params from function approximator
        S = self.pi.observation_preprocessor(next(rngs), transition_batch.S)
        dist_params, state_new = self.pi.function(params, state, next(rngs), S,
                                                  True)

        # compute probability ratios
        A = self.pi.proba_dist.preprocess_variate(next(rngs),
                                                  transition_batch.A)
        log_pi = self.pi.proba_dist.log_proba(dist_params, A)
        ratio = jnp.exp(log_pi - transition_batch.logP)  # π_new / π_old
        ratio_clip = jnp.clip(ratio, 1 - hyperparams['epsilon'],
                              1 + hyperparams['epsilon'])

        # clip importance weights to reduce variance
        W = jnp.clip(transition_batch.W, 0.1, 10.)

        # ppo-clip objective
        chex.assert_equal_shape([W, Adv, ratio, ratio_clip])
        chex.assert_rank([W, Adv, ratio, ratio_clip], 1)
        objective = W * jnp.minimum(Adv * ratio, Adv * ratio_clip)

        # also pass auxiliary data to avoid multiple forward passes
        return jnp.mean(objective), (dist_params, log_pi, state_new)
示例#10
0
def test_logit_shape(n_samples, n_features):

    params_rng, data_rng = jax.random.split(KEY, 2)

    x = jax.random.uniform(data_rng, shape=(n_samples, n_features))

    # create layer
    init_func = Logit(eps=1e-5, temperature=1)

    # create layer
    params, forward_f, inverse_f = init_func(
        rng=params_rng,
        n_features=n_features,
    )

    # forward transformation
    z, log_abs_det = forward_f(params, x)

    # checks
    chex.assert_equal_shape([z, x])
    chex.assert_shape(log_abs_det, (n_samples, ))

    # inverse transformation
    x_approx, log_abs_det = inverse_f(params, z)

    # checks
    chex.assert_equal_shape([x_approx, x])
示例#11
0
def test_conv1x1ortho_shape(n_channels, hw, n_samples, n_reflections):

    params_rng, data_rng = jax.random.split(KEY, 2)

    x = jax.random.normal(data_rng, shape=(n_samples, hw[0], hw[1], n_channels))

    # create layer
    init_func = Conv1x1Householder(n_channels=n_channels, n_reflections=n_reflections)

    # create layer
    params, forward_f, inverse_f = init_func(rng=params_rng, n_features=n_channels)

    # forward transformation
    z, log_abs_det = forward_f(params, x)

    # print(z.shape, log_abs_det.shape)

    # checks
    chex.assert_equal_shape([z, x])
    chex.assert_shape(np.atleast_1d(log_abs_det), (n_samples,))

    # inverse transformation
    x_approx, log_abs_det = inverse_f(params, z)

    # checks
    chex.assert_equal_shape([x_approx, x])
    chex.assert_shape(np.atleast_1d(log_abs_det), (n_samples,))
示例#12
0
def n_step_bootstrapped_returns(r_t: Array, discount_t: Array, v_t: Array,
                                n: int) -> Array:
    """Computes strided n-step bootstrapped return targets over a sequence.

  The returns are computed in a backwards fashion according to the equation:

     Gₜ = rₜ₊₁ + γₜ₊₁ * (rₜ₊₂ + γₜ₊₂ * (... * (rₜ₊ₙ + γₜ₊ₙ * vₜ₊ₙ ))),

  Args:
    r_t: rewards at times [1, ..., T].
    discount_t: discounts at times [1, ..., T].
    v_t: state or state-action values to bootstrap from at time [1, ...., T]
    n: number of steps over which to accumulate reward before bootstrapping.

  Returns:
    estimated bootstrapped returns at times [1, ...., T]
  """
    chex.assert_rank([r_t, discount_t, v_t], 1)
    chex.assert_type([r_t, discount_t, v_t], float)
    chex.assert_equal_shape([r_t, discount_t, v_t])
    seq_len = r_t.shape[0]

    # Pad end of reward and discount sequences with 0 and 1 respectively.
    r_t = jnp.concatenate([r_t, jnp.zeros(n - 1)])
    discount_t = jnp.concatenate([discount_t, jnp.ones(n - 1)])

    # Shift bootstrap values by n and pad end of sequence with last value v_t[-1].
    pad_size = min(n - 1, seq_len)
    targets = jnp.concatenate([v_t[n - 1:], jnp.array([v_t[-1]] * pad_size)])

    # Work backwards to compute discounted, bootstrapped n-step returns.
    for i in reversed(range(n)):
        targets = r_t[i:i + seq_len] + discount_t[i:i + seq_len] * targets
    return targets
示例#13
0
def test_reshape_shape_forward(filters):

    n_dims = (1, 28, 28, 1)
    new_shape = _get_new_shapes(28, 28, 1, filters)
    params_rng, data_rng = jax.random.split(KEY, 2)

    x = jax.random.uniform(data_rng, shape=n_dims)

    # create layer
    init_func = Squeeze(filter_shape=filters, collapse=None, return_outputs=True)

    # create layer
    z_, params, forward_f, inverse_f = init_func(rng=params_rng, shape=n_dims, inputs=x)

    # forward transformation
    z, log_abs_det = forward_f(params, x)

    # checks
    chex.assert_tree_all_close(z, z_)
    chex.assert_equal_shape([z, log_abs_det, z_])
    chex.assert_rank(z, 4)
    chex.assert_equal(z.shape[1:], new_shape)

    # inverse transformation
    x_approx, log_abs_det = inverse_f(params, z)

    # checks
    chex.assert_equal_shape([x_approx, x])
    chex.assert_tree_all_close(x_approx, x)
示例#14
0
def test_reshape_shape_collapse(filters, collapse):

    n_dims = (1, 28, 28, 1)
    params_rng, data_rng = jax.random.split(KEY, 2)

    x = jax.random.uniform(data_rng, shape=n_dims)

    # create layer
    init_func = Squeeze(filter_shape=filters, collapse=collapse)

    # create layer
    params, forward_f, inverse_f = init_func(rng=params_rng, shape=n_dims,)

    # forward transformation
    z, log_abs_det = forward_f(params, x)

    # checks
    chex.assert_equal_shape([z, log_abs_det])
    chex.assert_rank(z, 2)

    # inverse transformation
    x_approx, log_abs_det = inverse_f(params, z)

    # checks
    chex.assert_equal_shape([x_approx, x])
    chex.assert_tree_all_close(x_approx, x)
示例#15
0
    def update(self, idx, Adv):
        r"""

        Update the priority weights of transitions previously added to the buffer.

        Parameters
        ----------
        idx : 1d array of ints

            The identifiers of the transitions to be updated.

        Adv : ndarray

            The corresponding updated advantages.

        """
        idx = onp.asarray(idx, dtype='int32')
        Adv = onp.asarray(Adv, dtype='float32')
        chex.assert_equal_shape([idx, Adv])
        chex.assert_rank([idx, Adv], 1)

        idx_lookup = idx % self.capacity  # wrap around
        new_values = onp.where(
            _get_transition_batch_idx(
                self._storage[idx_lookup]) == idx,  # only update if ids match
            onp.power(onp.abs(Adv) + self.epsilon, self.alpha),
            self._sumtree.values[idx_lookup])
        self._sumtree.set_values(idx_lookup, new_values)
示例#16
0
    def test_sample_and_log_prob(self, mu, sigma, sample_shape, base_dist):
        base = base_dist(mu, sigma)
        bijector = inverse.Inverse(tfb.Scale(2))
        dist = transformed.Transformed(base, bijector)

        def sample_and_log_prob_fn(seed, sample_shape):
            return dist.sample_and_log_prob(seed=seed,
                                            sample_shape=sample_shape)

        samples, log_prob = self.variant(sample_and_log_prob_fn,
                                         ignore_argnums=(1, ),
                                         static_argnums=(1, ))(self.seed,
                                                               sample_shape)
        expected_samples = bijector.forward(
            base.sample(seed=self.seed, sample_shape=sample_shape))

        tfp_bijector = tfb.Invert(tfb.Scale(2))
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector)
        tfp_samples = tfp_dist.sample(seed=self.seed,
                                      sample_shape=sample_shape)
        tfp_log_prob = tfp_dist.log_prob(samples)

        chex.assert_equal_shape([samples, tfp_samples])
        np.testing.assert_allclose(log_prob, tfp_log_prob, rtol=RTOL)
        np.testing.assert_allclose(samples, expected_samples, rtol=RTOL)
示例#17
0
  def _test_sample_and_log_prob(
      self,
      dist_args: Tuple[Any, ...] = (),
      dist_kwargs: Optional[Dict[str, Any]] = None,
      tfp_dist_args: Optional[Tuple[Any, ...]] = None,
      tfp_dist_kwargs: Optional[Dict[str, Any]] = None,
      sample_shape: Union[int, Tuple[int, ...]] = (),
      assertion_fn: Callable[[Any, Any], None] = np.testing.assert_allclose):
    """Tests sample and log prob."""
    if tfp_dist_args is None:
      tfp_dist_args = dist_args
    if tfp_dist_kwargs is None:
      tfp_dist_kwargs = dist_kwargs
    dist = self.distrax_cls(*dist_args, **dist_kwargs)
    sample_and_log_prob_fn = (
        lambda k: dist.sample_and_log_prob(seed=k, sample_shape=sample_shape))
    log_prob_fn = dist.log_prob
    if hasattr(self, 'variant'):
      sample_and_log_prob_fn = self.variant(sample_and_log_prob_fn)
      log_prob_fn = self.variant(dist.log_prob)
    samples, log_prob = sample_and_log_prob_fn(self.key)

    tfp_dist = self.tfp_cls(*tfp_dist_args, **tfp_dist_kwargs)
    tfp_samples = tfp_dist.sample(sample_shape=sample_shape,
                                  seed=self.key)
    tfp_log_prob = tfp_dist.log_prob(samples)

    chex.assert_equal_shape([samples, tfp_samples])
    assertion_fn(log_prob, tfp_log_prob)
    assertion_fn(log_prob, log_prob_fn(samples))
示例#18
0
文件: loss.py 项目: ksachdeva/optax
def softmax_cross_entropy(
    logits: chex.Array,
    labels: chex.Array,
) -> chex.Array:
    """Computes the softmax cross entropy between sets of logits and labels.

  Measures the probability error in discrete classification tasks in which
  the classes are mutually exclusive (each entry is in exactly one class).
  For example, each CIFAR-10 image is labeled with one and only one label:
  an image can be a dog or a truck, but not both.

  References:
    [Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html)

  Args:
    logits: unnormalized log probabilities.
    labels: a valid probability distribution (non-negative, sum to 1), e.g a
      one hot encoding of which class is the correct one for each input.

  Returns:
    the cross entropy loss.
  """
    chex.assert_equal_shape([logits, labels])
    chex.assert_type([logits, labels], float)
    return -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1)
示例#19
0
文件: loss.py 项目: ksachdeva/optax
def cosine_similarity(
    predictions: chex.Array,
    targets: chex.Array,
    epsilon: float = 0.,
) -> chex.Array:
    r"""Computes the cosine similarity between targets and predictions.

  The cosine **similarity** is a measure of similarity between vectors defined
  as the cosine of the angle between them, which is also the inner product of
  those vectors normalized to have unit norm.

  References:
    [Wikipedia, 2021](https://en.wikipedia.org/wiki/Cosine_similarity)

  Args:
    predictions: The predicted vector.
    targets: Ground truth target vector.
    epsilon: minimum norm for terms in the denominator of the cosine similarity.

  Returns:
    cosine similarity values.
  """
    chex.assert_equal_shape([targets, predictions])
    chex.assert_type([targets, predictions], float)
    # vectorize norm fn, to treat all dimensions except the last as batch dims.
    batched_norm_fn = jnp.vectorize(utils.safe_norm,
                                    signature='(k)->()',
                                    excluded={1})
    # normalise the last dimension of targets and predictions.
    unit_targets = targets / jnp.expand_dims(batched_norm_fn(targets, epsilon),
                                             axis=-1)
    unit_predictions = predictions / jnp.expand_dims(
        batched_norm_fn(predictions, epsilon), axis=-1)
    # return cosine similarity.
    return jnp.sum(unit_targets * unit_predictions, axis=-1)
示例#20
0
def general_off_policy_returns_from_action_values(
    q_t: Array,
    a_t: Array,
    r_t: Array,
    discount_t: Array,
    c_t: Array,
    pi_t: Array,
    stop_target_gradients: bool = False,
) -> Array:
    """Calculates targets for various off-policy correction algorithms.

  Given a window of experience of length `K`, generated by a behaviour policy μ,
  for each time-step `t` we can estimate the return `G_t` from that step
  onwards, under some target policy π, using the rewards in the trajectory, the
  actions selected by μ and the action-values under π, according to equation:

    Gₜ = rₜ₊₁ + γₜ₊₁ * (E[q(aₜ₊₁)] - cₜ * q(aₜ₊₁) + cₜ * Gₜ₊₁),

  where, depending on the choice of `c_t`, the algorithm implements:

    Importance Sampling             c_t = π(x_t, a_t) / μ(x_t, a_t),
    Harutyunyan's et al. Q(lambda)  c_t = λ,
    Precup's et al. Tree-Backup     c_t = π(x_t, a_t),
    Munos' et al. Retrace           c_t = λ min(1, π(x_t, a_t) / μ(x_t, a_t)).

  See "Safe and Efficient Off-Policy Reinforcement Learning" by Munos et al.
  (https://arxiv.org/abs/1606.02647).

  Args:
    q_t: Q-values at times [1, ..., K - 1].
    a_t: action index at times [1, ..., K - 1].
    r_t: reward at times [1, ..., K - 1].
    discount_t: discount at times [1, ..., K - 1].
    c_t: importance weights at times [1, ..., K - 1].
    pi_t: target policy probs at times [1, ..., K - 1].
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.

  Returns:
    Off-policy estimates of the generalized returns from states visited at times
    [0, ..., K - 1].
  """
    chex.assert_rank([q_t, a_t, r_t, discount_t, c_t, pi_t],
                     [2, 1, 1, 1, 1, 2])
    chex.assert_type([q_t, a_t, r_t, discount_t, c_t, pi_t],
                     [float, int, float, float, float, float])
    chex.assert_equal_shape(
        [q_t[..., 0], a_t, r_t, discount_t, c_t, pi_t[..., 0]])

    # Get the expected values and the values of actually selected actions.
    exp_q_t = (pi_t * q_t).sum(axis=-1)
    # The generalized returns are independent of Q-values and cs at the final
    # state.
    q_a_t = base.batched_index(q_t, a_t)[:-1]
    c_t = c_t[:-1]

    return general_off_policy_returns_from_q_and_v(q_a_t, exp_q_t, r_t,
                                                   discount_t, c_t,
                                                   stop_target_gradients)
示例#21
0
def n_step_bootstrapped_returns(
    r_t: Array,
    discount_t: Array,
    v_t: Array,
    n: int,
    lambda_t: Numeric = 1.,
    stop_target_gradients: bool = False,
) -> Array:
    """Computes strided n-step bootstrapped return targets over a sequence.

  The returns are computed according to the below equation iterated `n` times:

     Gₜ = rₜ₊₁ + γₜ₊₁ [(1 - λₜ₊₁) vₜ₊₁ + λₜ₊₁ Gₜ₊₁].

  When lambda_t == 1. (default), this reduces to

     Gₜ = rₜ₊₁ + γₜ₊₁ * (rₜ₊₂ + γₜ₊₂ * (... * (rₜ₊ₙ + γₜ₊ₙ * vₜ₊ₙ ))).

  Args:
    r_t: rewards at times [1, ..., T].
    discount_t: discounts at times [1, ..., T].
    v_t: state or state-action values to bootstrap from at time [1, ...., T].
    n: number of steps over which to accumulate reward before bootstrapping.
    lambda_t: lambdas at times [1, ..., T]. Shape is [], or [T-1].
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.

  Returns:
    estimated bootstrapped returns at times [0, ...., T-1]
  """
    chex.assert_rank([r_t, discount_t, v_t, lambda_t], [1, 1, 1, {0, 1}])
    chex.assert_type([r_t, discount_t, v_t, lambda_t], float)
    chex.assert_equal_shape([r_t, discount_t, v_t])
    seq_len = r_t.shape[0]

    # Maybe change scalar lambda to an array.
    lambda_t = jnp.ones_like(discount_t) * lambda_t

    # Shift bootstrap values by n and pad end of sequence with last value v_t[-1].
    pad_size = min(n - 1, seq_len)
    targets = jnp.concatenate([v_t[n - 1:], jnp.array([v_t[-1]] * pad_size)])

    # Pad sequences. Shape is now (T + n - 1,).
    r_t = jnp.concatenate([r_t, jnp.zeros(n - 1)])
    discount_t = jnp.concatenate([discount_t, jnp.ones(n - 1)])
    lambda_t = jnp.concatenate([lambda_t, jnp.ones(n - 1)])
    v_t = jnp.concatenate([v_t, jnp.array([v_t[-1]] * (n - 1))])

    # Work backwards to compute n-step returns.
    for i in reversed(range(n)):
        r_ = r_t[i:i + seq_len]
        discount_ = discount_t[i:i + seq_len]
        lambda_ = lambda_t[i:i + seq_len]
        v_ = v_t[i:i + seq_len]
        targets = r_ + discount_ * ((1. - lambda_) * v_ + lambda_ * targets)

    return jax.lax.select(stop_target_gradients,
                          jax.lax.stop_gradient(targets), targets)
示例#22
0
def leaky_vtrace(v_tm1: Array,
                 v_t: Array,
                 r_t: Array,
                 discount_t: Array,
                 rho_tm1: Array,
                 alpha_: float = 1.0,
                 lambda_: Numeric = 1.0,
                 clip_rho_threshold: float = 1.0,
                 stop_target_gradients: bool = True):
    """Calculates Leaky V-Trace errors from importance weights.

  Leaky-Vtrace is a combination of Importance sampling and V-trace, where the
  degree of mixing is controlled by a scalar `alpha` (that may be meta-learnt).

  See "Self-Tuning Deep Reinforcement Learning"
  by Zahavy et al. (https://arxiv.org/abs/2002.12928)

  Args:
    v_tm1: values at time t-1.
    v_t: values at time t.
    r_t: reward at time t.
    discount_t: discount at time t.
    rho_tm1: importance weights at time t-1.
    alpha_: mixing parameter for Importance Sampling and V-trace.
    lambda_: mixing parameter; a scalar or a vector for timesteps t.
    clip_rho_threshold: clip threshold for importance weights.
    stop_target_gradients: whether or not to apply stop gradient to targets.

  Returns:
    Leaky V-Trace error.
  """
    chex.assert_rank([v_tm1, v_t, r_t, discount_t, rho_tm1, lambda_],
                     [1, 1, 1, 1, 1, {0, 1}])
    chex.assert_type([v_tm1, v_t, r_t, discount_t, rho_tm1, lambda_],
                     [float, float, float, float, float, float])
    chex.assert_equal_shape([v_tm1, v_t, r_t, discount_t, rho_tm1])

    # Mix clipped and unclipped importance sampling ratios.
    c_tm1 = (
        (1 - alpha_) * rho_tm1 + alpha_ * jnp.minimum(1.0, rho_tm1)) * lambda_
    clipped_rhos_tm1 = ((1 - alpha_) * rho_tm1 +
                        alpha_ * jnp.minimum(clip_rho_threshold, rho_tm1))

    # Compute the temporal difference errors.
    td_errors = clipped_rhos_tm1 * (r_t + discount_t * v_t - v_tm1)

    # Work backwards computing the td-errors.
    err = 0.0
    errors = []
    for i in reversed(range(v_t.shape[0])):
        err = td_errors[i] + discount_t[i] * c_tm1[i] * err
        errors.insert(0, err)

    # Return errors, maybe disabling gradient flow through bootstrap targets.
    return jax.lax.select(
        stop_target_gradients,
        jax.lax.stop_gradient(jnp.array(errors) + v_tm1) - v_tm1,
        jnp.array(errors))
示例#23
0
文件: mpo_ops.py 项目: deepmind/rlax
def get_top_k_weights(
    top_k_fraction: float,
    restarting_weights: Array,
    scaled_advantages: Array,
    axis_name: Optional[str] = None,
    use_stop_gradient: bool = True,
):
  """Get the weights for the top top_k_fraction of advantages.

  Args:
    top_k_fraction: The fraction of weights to use.
    restarting_weights: Restarting weights, shape E*, 0 means that this step is
      the start of a new episode and we ignore losses at this step because the
      agent cannot influence these.
    scaled_advantages: The advantages for each example (shape E*), scaled by
      temperature.
    axis_name: Optional axis name for `pmap`. If `None`, computations are
      performed locally on each device.
    use_stop_gradient: bool indicating whether or not to apply stop gradient.

  Returns:
    Weights for the top top_k_fraction of advantages
  """
  chex.assert_equal_shape([scaled_advantages, restarting_weights])
  chex.assert_type([scaled_advantages, restarting_weights], float)

  if not 0.0 < top_k_fraction <= 1.0:
    raise ValueError(
        f"`top_k_fraction` must be in (0, 1], got {top_k_fraction}")
  logging.info("[vmpo_e_step] top_k_fraction: %f", top_k_fraction)

  if top_k_fraction < 1.0:
    # Don't include the restarting samples in the determination of top-k.
    valid_scaled_advantages = scaled_advantages - (
        1.0 - restarting_weights) * _INFINITY
    # Determine the minimum top-k value across all devices,
    if axis_name:
      all_valid_scaled_advantages = jax.lax.all_gather(
          valid_scaled_advantages, axis_name=axis_name)
    else:
      all_valid_scaled_advantages = valid_scaled_advantages
    top_k = int(top_k_fraction * jnp.size(all_valid_scaled_advantages))
    if top_k == 0:
      raise ValueError(
          "top_k_fraction too low to get any valid scaled advantages.")
    # TODO(b/160450251): Use jnp.partition(all_valid_scaled_advantages, top_k)
    #   when this is implemented in jax.
    top_k_min = jnp.sort(jnp.reshape(all_valid_scaled_advantages, [-1]))[-top_k]
    # Fold the top-k into the restarting weights.
    top_k_weights = jnp.greater_equal(valid_scaled_advantages,
                                      top_k_min).astype(jnp.float32)
    top_k_weights = jax.lax.select(
        use_stop_gradient, jax.lax.stop_gradient(top_k_weights), top_k_weights)
    top_k_restarting_weights = restarting_weights * top_k_weights
  else:
    top_k_restarting_weights = restarting_weights

  return top_k_restarting_weights
示例#24
0
def importance_corrected_td_errors(
    r_t: Array,
    discount_t: Array,
    rho_tm1: Array,
    lambda_: Array,
    values: Array,
) -> Array:
    """Computes the multistep td errors with per decision importance sampling.

  Given a trajectory of length `T+1`, generated under some policy π, for each
  time-step `t` we can estimate a multistep temporal difference error δₜ(ρ,λ),
  by combining rewards, discounts, and state values, according to a mixing
  parameter `λ` and importance sampling ratios ρₜ = π(aₜ|sₜ) / μ(aₜ|sₜ):

    td-errorₜ = ρₜ δₜ(ρ,λ)
    δₜ(ρ,λ) = δₜ + ρₜ₊₁ λₜ₊₁ γₜ₊₁ δₜ₊₁(ρ,λ),

  where δₜ = rₜ₊₁ + γₜ₊₁ vₜ₊₁ - vₜ is the one step, temporal difference error
  for the agent's state value estimates. This is equivalent to computing
  the λ-return with λₜ = ρₜ (e.g. using the `lambda_returns` function from
  above), and then computing errors as  td-errorₜ = ρₜ(Gₜ - vₜ).

  See "A new Q(λ) with interim forward view and Monte Carlo equivalence"
  by Sutton et al. (http://proceedings.mlr.press/v32/sutton14.html).

  Args:
    r_t: sequence of rewards rₜ for timesteps t in [1, T].
    discount_t: sequence of discounts γₜ for timesteps t in [1, T].
    rho_tm1: sequence of importance ratios for all timesteps t in [0, T-1].
    lambda_: mixing parameter; scalar or have per timestep values in [1, T].
    values: sequence of state values under π for all timesteps t in [0, T].

  Returns:
    Off-policy estimates of the multistep td errors.
  """
    chex.assert_rank([r_t, discount_t, rho_tm1, values], [1, 1, 1, 1])
    chex.assert_type([r_t, discount_t, rho_tm1, values], float)
    chex.assert_equal_shape([r_t, discount_t, rho_tm1, values[1:]])

    v_tm1 = values[:-1]  # Predictions to compute errors for.
    v_t = values[1:]  # Values for bootstrapping.
    rho_t = jnp.concatenate(
        (rho_tm1[1:], jnp.array([1.])))  # Unused dummy value.
    lambda_ = jnp.ones_like(
        discount_t) * lambda_  # If scalar, make into vector.

    # Compute the one step temporal difference errors.
    one_step_delta = r_t + discount_t * v_t - v_tm1

    # Work backwards to compute `delta_{T-1}`, ..., `delta_0`.
    delta, errors = 0.0, []
    for i in jnp.arange(one_step_delta.shape[0] - 1, -1, -1):
        delta = one_step_delta[
            i] + discount_t[i] * rho_t[i] * lambda_[i] * delta
        errors.insert(0, delta)

    return rho_tm1 * jnp.array(errors)
示例#25
0
def vtrace_td_error_and_advantage(
    v_tm1: Array,
    v_t: Array,
    r_t: Array,
    discount_t: Array,
    rho_tm1: Array,
    lambda_: Numeric = 1.0,
    clip_rho_threshold: float = 1.0,
    clip_pg_rho_threshold: float = 1.0,
    stop_target_gradients: bool = True,
) -> VTraceOutput:
    """Calculates V-Trace errors and PG advantage from importance weights.

  This functions computes the TD-errors and policy gradient Advantage terms
  as used by the IMPALA distributed actor-critic agent.

  See "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor
  Learner Architectures" by Espeholt et al. (https://arxiv.org/abs/1802.01561)

  Args:
    v_tm1: values at time t-1.
    v_t: values at time t.
    r_t: reward at time t.
    discount_t: discount at time t.
    rho_tm1: importance weights at time t-1.
    lambda_: mixing parameter; a scalar or a vector for timesteps t.
    clip_rho_threshold: clip threshold for importance ratios.
    clip_pg_rho_threshold: clip threshold for policy gradient importance ratios.
    stop_target_gradients: whether or not to apply stop gradient to targets.

  Returns:
    a tuple of V-Trace error, policy gradient advantage, and estimated Q-values.
  """
    chex.assert_rank([v_tm1, v_t, r_t, discount_t, rho_tm1, lambda_],
                     [1, 1, 1, 1, 1, {0, 1}])
    chex.assert_type([v_tm1, v_t, r_t, discount_t, rho_tm1, lambda_],
                     [float, float, float, float, float, float])
    chex.assert_equal_shape([v_tm1, v_t, r_t, discount_t, rho_tm1])

    # If scalar make into vector.
    lambda_ = jnp.ones_like(discount_t) * lambda_

    errors = vtrace(v_tm1, v_t, r_t, discount_t, rho_tm1, lambda_,
                    clip_rho_threshold, stop_target_gradients)
    targets_tm1 = errors + v_tm1
    q_bootstrap = jnp.concatenate([
        lambda_[:-1] * targets_tm1[1:] + (1 - lambda_[:-1]) * v_tm1[1:],
        v_t[-1:],
    ],
                                  axis=0)
    q_estimate = r_t + discount_t * q_bootstrap
    clipped_pg_rho_tm1 = jnp.minimum(clip_pg_rho_threshold, rho_tm1)
    pg_advantages = clipped_pg_rho_tm1 * (q_estimate - v_tm1)
    return VTraceOutput(errors=errors,
                        pg_advantage=pg_advantages,
                        q_estimate=q_estimate)
示例#26
0
        def loss_func(params, target_params, state, target_state, rng,
                      transition_batch):
            rngs = hk.PRNGSequence(rng)
            S = self.q.observation_preprocessor(next(rngs), transition_batch.S)
            A = self.q.action_preprocessor(next(rngs), transition_batch.A)
            W = jnp.clip(transition_batch.W, 0.1,
                         10.)  # clip importance weights to reduce variance

            metrics = {}
            # regularization term
            if self.policy_regularizer is None:
                regularizer = 0.
            else:
                regularizer, regularizer_metrics = self.policy_regularizer.batch_eval(
                    target_params['reg'], target_params['reg_hparams'],
                    target_state['reg'], next(rngs), transition_batch)
                metrics.update({
                    f'{self.__class__.__name__}/{k}': v
                    for k, v in regularizer_metrics.items()
                })

            Q, state_new = self.q.function_type1(params, state, next(rngs), S,
                                                 A, True)
            G = self.target_func(target_params, target_state, next(rngs),
                                 transition_batch)
            # flip sign (typical example: regularizer = -beta * entropy)
            G -= regularizer
            loss = self.loss_function(G, Q, W)

            dLoss_dQ = jax.grad(self.loss_function, argnums=1)
            td_error = -Q.shape[0] * dLoss_dQ(
                G, Q)  # e.g. (G - Q) if loss function is MSE

            # target-network estimate (is this worth computing?)
            Q_targ_list = []
            qs = list(
                zip(self.q_targ_list, target_params['q_targ'],
                    target_state['q_targ']))
            for q, pm, st in qs:
                Q_targ, _ = q.function_type1(pm, st, next(rngs), S, A, False)
                assert Q_targ.ndim == 1, f"bad shape: {Q_targ.shape}"
                Q_targ_list.append(Q_targ)
            Q_targ_list = jnp.stack(Q_targ_list, axis=-1)
            assert Q_targ_list.ndim == 2, f"bad shape: {Q_targ_list.shape}"
            Q_targ = jnp.min(Q_targ_list, axis=-1)

            chex.assert_equal_shape([td_error, W, Q_targ])
            metrics.update({
                f'{self.__class__.__name__}/loss':
                loss,
                f'{self.__class__.__name__}/td_error':
                jnp.mean(W * td_error),
                f'{self.__class__.__name__}/td_error_targ':
                jnp.mean(-dLoss_dQ(Q, Q_targ, W)),
            })
            return loss, (td_error, state_new, metrics)
示例#27
0
def vtrace(
    v_tm1: Array,
    v_t: Array,
    r_t: Array,
    discount_t: Array,
    rho_tm1: Array,
    lambda_: Numeric = 1.0,
    clip_rho_threshold: float = 1.0,
    stop_target_gradients: bool = True,
) -> Array:
    """Calculates V-Trace errors from importance weights.

  V-trace computes TD-errors from multistep trajectories by applying
  off-policy corrections based on clipped importance sampling ratios.

  See "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor
  Learner Architectures" by Espeholt et al. (https://arxiv.org/abs/1802.01561).

  Args:
    v_tm1: values at time t-1.
    v_t: values at time t.
    r_t: reward at time t.
    discount_t: discount at time t.
    rho_tm1: importance sampling ratios at time t-1.
    lambda_: mixing parameter; a scalar or a vector for timesteps t.
    clip_rho_threshold: clip threshold for importance weights.
    stop_target_gradients: whether or not to apply stop gradient to targets.

  Returns:
    V-Trace error.
  """
    chex.assert_rank([v_tm1, v_t, r_t, discount_t, rho_tm1, lambda_],
                     [1, 1, 1, 1, 1, {0, 1}])
    chex.assert_type([v_tm1, v_t, r_t, discount_t, rho_tm1, lambda_],
                     [float, float, float, float, float, float])
    chex.assert_equal_shape([v_tm1, v_t, r_t, discount_t, rho_tm1])

    # Clip importance sampling ratios.
    c_tm1 = jnp.minimum(1.0, rho_tm1) * lambda_
    clipped_rhos_tm1 = jnp.minimum(clip_rho_threshold, rho_tm1)

    # Compute the temporal difference errors.
    td_errors = clipped_rhos_tm1 * (r_t + discount_t * v_t - v_tm1)

    # Work backwards computing the td-errors.
    err = 0.0
    errors = []
    for i in reversed(range(v_t.shape[0])):
        err = td_errors[i] + discount_t[i] * c_tm1[i] * err
        errors.insert(0, err)

    # Return errors, maybe disabling gradient flow through bootstrap targets.
    return jax.lax.select(
        stop_target_gradients,
        jax.lax.stop_gradient(jnp.array(errors) + v_tm1) - v_tm1,
        jnp.array(errors))
示例#28
0
def unitwise_clip(g_norm: chex.Array,
                  max_norm: chex.Array,
                  grad: chex.Array,
                  div_eps: float = 1e-6) -> chex.Array:
  """Applies gradient clipping unit-wise."""
  # This little max(., div_eps) is distinct from the normal eps and just
  # prevents division by zero. It technically should be impossible to engage.
  clipped_grad = grad * (max_norm / jnp.maximum(g_norm, div_eps))
  chex.assert_equal_shape((g_norm, max_norm, grad, clipped_grad))
  return jnp.where(g_norm < max_norm, grad, clipped_grad)
示例#29
0
def leaky_vtrace_td_error_and_advantage(
    v_tm1: chex.Array,
    v_t: chex.Array,
    r_t: chex.Array,
    discount_t: chex.Array,
    rho_tm1: chex.Array,
    alpha: float = 1.0,
    lambda_: float = 1.0,
    clip_rho_threshold: float = 1.0,
    clip_pg_rho_threshold: float = 1.0,
    stop_target_gradients: bool = True,
) -> VTraceOutput:
    """Calculates V-Trace errors and PG advantage from importance weights.

  This functions computes the TD-errors and policy gradient Advantage terms
  as used by the IMPALA distributed actor-critic agent.

  See "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor
  Learner Architectures" by Espeholt et al. (https://arxiv.org/abs/1802.01561)

  Args:
    v_tm1: values at time t-1.
    v_t: values at time t.
    r_t: reward at time t.
    discount_t: discount at time t.
    rho_tm1: importance weights at time t.
    alpha: mixing the clipped importance sampling weights with unclipped ones.
    lambda_: scalar mixing parameter lambda.
    clip_rho_threshold: clip threshold for importance ratios.
    clip_pg_rho_threshold: clip threshold for policy gradient importance ratios.
    stop_target_gradients: whether or not to apply stop gradient to targets.

  Returns:
    a tuple of V-Trace error, policy gradient advantage, and estimated Q-values.
  """
    chex.assert_rank([v_tm1, v_t, r_t, discount_t, rho_tm1], 1)
    chex.assert_type([v_tm1, v_t, r_t, discount_t, rho_tm1], float)
    chex.assert_equal_shape([v_tm1, v_t, r_t, discount_t, rho_tm1])

    errors = leaky_vtrace(v_tm1, v_t, r_t, discount_t, rho_tm1, alpha, lambda_,
                          clip_rho_threshold, stop_target_gradients)
    targets_tm1 = errors + v_tm1
    q_bootstrap = jnp.concatenate([
        lambda_ * targets_tm1[1:] + (1 - lambda_) * v_tm1[1:],
        v_t[-1:],
    ],
                                  axis=0)
    q_estimate = r_t + discount_t * q_bootstrap
    clipped_pg_rho_tm1 = ((1 - alpha) * rho_tm1 +
                          alpha * jnp.minimum(clip_pg_rho_threshold, rho_tm1))
    pg_advantages = clipped_pg_rho_tm1 * (q_estimate - v_tm1)
    return VTraceOutput(errors=errors,
                        pg_advantage=pg_advantages,
                        q_estimate=q_estimate)
示例#30
0
def general_off_policy_returns_from_q_and_v(
    q_t: Array,
    v_t: Array,
    r_t: Array,
    discount_t: Array,
    c_t: Array,
    stop_target_gradients: bool = False,
) -> Array:
    """Calculates targets for various off-policy evaluation algorithms.

  Given a window of experience of length `K+1`, generated by a behaviour policy
  μ, for each time-step `t` we can estimate the return `G_t` from that step
  onwards, under some target policy π, using the rewards in the trajectory, the
  values under π of states and actions selected by μ, according to equation:

    Gₜ = rₜ₊₁ + γₜ₊₁ * (vₜ₊₁ - cₜ₊₁ * q(aₜ₊₁) + cₜ₊₁* Gₜ₊₁),

  where, depending on the choice of `c_t`, the algorithm implements:

    Importance Sampling             c_t = π(x_t, a_t) / μ(x_t, a_t),
    Harutyunyan's et al. Q(lambda)  c_t = λ,
    Precup's et al. Tree-Backup     c_t = π(x_t, a_t),
    Munos' et al. Retrace           c_t = λ min(1, π(x_t, a_t) / μ(x_t, a_t)).

  See "Safe and Efficient Off-Policy Reinforcement Learning" by Munos et al.
  (https://arxiv.org/abs/1606.02647).

  Args:
    q_t: Q-values under π of actions executed by μ at times [1, ..., K - 1].
    v_t: Values under π at times [1, ..., K].
    r_t: rewards at times [1, ..., K].
    discount_t: discounts at times [1, ..., K].
    c_t: weights at times [1, ..., K - 1].
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.

  Returns:
    Off-policy estimates of the generalized returns from states visited at times
    [0, ..., K - 1].
  """
    chex.assert_rank([q_t, v_t, r_t, discount_t, c_t], 1)
    chex.assert_type([q_t, v_t, r_t, discount_t, c_t], float)
    chex.assert_equal_shape([q_t, v_t[:-1], r_t[:-1], discount_t[:-1], c_t])

    # Work backwards to compute `G_K-1`, ..., `G_1`, `G_0`.
    g = r_t[-1] + discount_t[-1] * v_t[-1]  # G_K-1.
    returns = [g]
    for i in reversed(range(q_t.shape[0])):  # [K - 2, ..., 0]
        g = r_t[i] + discount_t[i] * (v_t[i] - c_t[i] * q_t[i] + c_t[i] * g)
        returns.insert(0, g)

    return jax.lax.select(stop_target_gradients,
                          jax.lax.stop_gradient(jnp.array(returns)),
                          jnp.array(returns))