Ejemplo n.º 1
0
    def __call__(
        self,
        network: networks_lib.FeedForwardNetwork,
        params: networks_lib.Params,
        target_params: networks_lib.Params,
        batch: reverb.ReplaySample,
        key: networks_lib.PRNGKey,
    ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
        """Calculate a loss on a single batch of data."""
        del key
        transitions: types.Transition = batch.data

        # Forward pass.
        q_tm1 = network.apply(params, transitions.observation)
        q_t = network.apply(target_params, transitions.next_observation)

        d_t = (transitions.discount * self.discount).astype(jnp.float32)

        # Compute Q-learning TD-error.
        batch_error = jax.vmap(rlax.q_learning)
        td_error = batch_error(q_tm1, transitions.action, transitions.reward,
                               d_t, q_t)
        td_error = 0.5 * jnp.square(td_error)

        def select(qtm1, action):
            return qtm1[action]

        q_regularizer = jax.vmap(select)(q_tm1, transitions.action)

        loss = self.regularizer_coeff * jnp.mean(q_regularizer) + jnp.mean(
            td_error)
        extra = learning_lib.LossExtra(metrics={})
        return loss, extra
Ejemplo n.º 2
0
    def __call__(
        self,
        network: networks_lib.FeedForwardNetwork,
        params: networks_lib.Params,
        target_params: networks_lib.Params,
        batch: reverb.ReplaySample,
        key: networks_lib.PRNGKey,
    ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
        """Calculate a loss on a single batch of data."""
        del key
        transitions: types.Transition = batch.data

        # Forward pass.
        q_tm1 = network.apply(params, transitions.observation)
        q_t = network.apply(target_params, transitions.next_observation)

        # Cast and clip rewards.
        d_t = (transitions.discount * self.discount).astype(jnp.float32)
        r_t = jnp.clip(transitions.reward, -self.max_abs_reward,
                       self.max_abs_reward).astype(jnp.float32)

        # Compute Q-learning TD-error.
        batch_error = jax.vmap(rlax.q_learning)
        td_error = batch_error(q_tm1, transitions.action, r_t, d_t, q_t)
        batch_loss = jnp.square(td_error)

        loss = jnp.mean(batch_loss)
        extra = learning_lib.LossExtra(metrics={})
        return loss, extra
Ejemplo n.º 3
0
def default_behavior_policy(network: networks_lib.FeedForwardNetwork,
                            epsilon: float, params: networks_lib.Params,
                            key: networks_lib.PRNGKey,
                            observation: networks_lib.Observation):
    """Returns an action for the given observation."""
    action_values = network.apply(params, observation)
    actions = rlax.epsilon_greedy(epsilon).sample(key, action_values)
    return actions.astype(jnp.int32)
Ejemplo n.º 4
0
    def __call__(
        self,
        network: networks_lib.FeedForwardNetwork,
        params: networks_lib.Params,
        target_params: networks_lib.Params,
        batch: reverb.ReplaySample,
        key: networks_lib.PRNGKey,
    ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
        """Calculate a loss on a single batch of data."""
        del key
        transitions: types.Transition = batch.data

        # Forward pass.
        q_online_s = network.apply(params, transitions.observation)
        action_one_hot = jax.nn.one_hot(transitions.action,
                                        q_online_s.shape[-1])
        q_online_sa = jnp.sum(action_one_hot * q_online_s, axis=-1)
        q_target_s = network.apply(target_params, transitions.observation)
        q_target_next = network.apply(target_params,
                                      transitions.next_observation)

        # Cast and clip rewards.
        d_t = (transitions.discount * self.discount).astype(jnp.float32)
        r_t = jnp.clip(transitions.reward, -self.max_abs_reward,
                       self.max_abs_reward).astype(jnp.float32)

        # Munchausen term : tau * log_pi(a|s)
        munchausen_term = self.entropy_temperature * jax.nn.log_softmax(
            q_target_s / self.entropy_temperature, axis=-1)
        munchausen_term_a = jnp.sum(action_one_hot * munchausen_term, axis=-1)
        munchausen_term_a = jnp.clip(munchausen_term_a,
                                     a_min=self.clip_value_min,
                                     a_max=0.)

        # Soft Bellman operator applied to q
        next_v = self.entropy_temperature * jax.nn.logsumexp(
            q_target_next / self.entropy_temperature, axis=-1)
        target_q = jax.lax.stop_gradient(r_t + self.munchausen_coefficient *
                                         munchausen_term_a + d_t * next_v)

        batch_loss = rlax.huber_loss(target_q - q_online_sa,
                                     self.huber_loss_parameter)
        loss = jnp.mean(batch_loss)

        extra = learning_lib.LossExtra(metrics={})
        return loss, extra
Ejemplo n.º 5
0
    def __call__(
        self,
        network: networks_lib.FeedForwardNetwork,
        params: networks_lib.Params,
        target_params: networks_lib.Params,
        batch: reverb.ReplaySample,
        key: networks_lib.PRNGKey,
    ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
        """Calculate a loss on a single batch of data."""
        del key
        transitions: types.Transition = batch.data
        keys, probs, *_ = batch.info

        # Forward pass.
        _, logits_tm1, atoms_tm1 = network.apply(params,
                                                 transitions.observation)
        _, logits_t, atoms_t = network.apply(target_params,
                                             transitions.next_observation)
        q_t_selector, _, _ = network.apply(params,
                                           transitions.next_observation)

        # Cast and clip rewards.
        d_t = (transitions.discount * self.discount).astype(jnp.float32)
        r_t = jnp.clip(transitions.reward, -self.max_abs_reward,
                       self.max_abs_reward).astype(jnp.float32)

        # Compute categorical double Q-learning loss.
        batch_loss_fn = jax.vmap(rlax.categorical_double_q_learning,
                                 in_axes=(None, 0, 0, 0, 0, None, 0, 0))
        batch_loss = batch_loss_fn(atoms_tm1, logits_tm1, transitions.action,
                                   r_t, d_t, atoms_t, logits_t, q_t_selector)

        # Importance weighting.
        importance_weights = (1. / probs).astype(jnp.float32)
        importance_weights **= self.importance_sampling_exponent
        importance_weights /= jnp.max(importance_weights)

        # Reweight.
        loss = jnp.mean(importance_weights * batch_loss)  # []
        reverb_update = learning_lib.ReverbUpdate(
            keys=keys, priorities=jnp.abs(batch_loss).astype(jnp.float64))
        extra = learning_lib.LossExtra(metrics={}, reverb_update=reverb_update)
        return loss, extra
Ejemplo n.º 6
0
  def __call__(
      self,
      network: networks_lib.FeedForwardNetwork,
      params: networks_lib.Params,
      target_params: networks_lib.Params,
      batch: reverb.ReplaySample,
      key: jnp.DeviceArray,
  ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
    """Calculate a loss on a single batch of data."""
    del key
    transitions: types.Transition = batch.data
    keys, probs, *_ = batch.info

    # Forward pass.
    q_tm1 = network.apply(params, transitions.observation)
    q_t_value = network.apply(target_params, transitions.next_observation)
    q_t_selector = network.apply(params, transitions.next_observation)

    # Cast and clip rewards.
    d_t = (transitions.discount * self.discount).astype(jnp.float32)
    r_t = jnp.clip(transitions.reward, -self.max_abs_reward,
                   self.max_abs_reward).astype(jnp.float32)

    # Compute double Q-learning n-step TD-error.
    batch_error = jax.vmap(rlax.double_q_learning)
    td_error = batch_error(q_tm1, transitions.action, r_t, d_t, q_t_value,
                           q_t_selector)
    batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter)

    # Importance weighting.
    importance_weights = (1. / probs).astype(jnp.float32)
    importance_weights **= self.importance_sampling_exponent
    importance_weights /= jnp.max(importance_weights)

    # Reweight.
    loss = jnp.mean(importance_weights * batch_loss)  # []
    reverb_update = learning_lib.ReverbUpdate(
        keys=keys, priorities=jnp.abs(td_error).astype(jnp.float64))
    extra = learning_lib.LossExtra(metrics={}, reverb_update=reverb_update)
    return loss, extra
Ejemplo n.º 7
0
 def __call__(
     self,
     network: networks_lib.FeedForwardNetwork,
     params: networks_lib.Params,
     target_params: networks_lib.Params,
     batch: reverb.ReplaySample,
     key: networks_lib.PRNGKey,
 ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
     """Calculate a loss on a single batch of data."""
     del key
     transitions: types.Transition = batch.data
     dist_q_tm1 = network.apply(params, transitions.observation)['q_dist']
     dist_q_target_t = network.apply(target_params,
                                     transitions.next_observation)['q_dist']
     # Swap distribution and action dimension, since
     # rlax.quantile_q_learning expects it that way.
     dist_q_tm1 = jnp.swapaxes(dist_q_tm1, 1, 2)
     dist_q_target_t = jnp.swapaxes(dist_q_target_t, 1, 2)
     quantiles = ((jnp.arange(self.num_atoms, dtype=jnp.float32) + 0.5) /
                  self.num_atoms)
     batch_quantile_q_learning = jax.vmap(rlax.quantile_q_learning,
                                          in_axes=(0, None, 0, 0, 0, 0, 0,
                                                   None))
     losses = batch_quantile_q_learning(
         dist_q_tm1,
         quantiles,
         transitions.action,
         transitions.reward,
         transitions.discount,
         dist_q_target_t,  # No double Q-learning here.
         dist_q_target_t,
         self.huber_param,
     )
     loss = jnp.mean(losses)
     extra = learning_lib.LossExtra(metrics={'mean_loss': loss})
     return loss, extra
Ejemplo n.º 8
0
    def __init__(self,
                 unroll: networks_lib.FeedForwardNetwork,
                 initial_state: networks_lib.FeedForwardNetwork,
                 batch_size: int,
                 random_key: networks_lib.PRNGKey,
                 burn_in_length: int,
                 discount: float,
                 importance_sampling_exponent: float,
                 max_priority_weight: float,
                 target_update_period: int,
                 iterator: Iterator[reverb.ReplaySample],
                 optimizer: optax.GradientTransformation,
                 bootstrap_n: int = 5,
                 tx_pair: rlax.TxPair = rlax.SIGNED_HYPERBOLIC_PAIR,
                 clip_rewards: bool = False,
                 max_abs_reward: float = 1.,
                 use_core_state: bool = True,
                 prefetch_size: int = 2,
                 replay_client: Optional[reverb.Client] = None,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None):
        """Initializes the learner."""

        random_key, key_initial_1, key_initial_2 = jax.random.split(
            random_key, 3)
        initial_state_params = initial_state.init(key_initial_1, batch_size)
        initial_state = initial_state.apply(initial_state_params,
                                            key_initial_2, batch_size)

        def loss(
            params: networks_lib.Params, target_params: networks_lib.Params,
            key_grad: networks_lib.PRNGKey, sample: reverb.ReplaySample
        ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
            """Computes mean transformed N-step loss for a batch of sequences."""

            # Convert sample data to sequence-major format [T, B, ...].
            data = utils.batch_to_sequence(sample.data)

            # Get core state & warm it up on observations for a burn-in period.
            if use_core_state:
                # Replay core state.
                online_state = jax.tree_map(lambda x: x[0],
                                            data.extras['core_state'])
            else:
                online_state = initial_state
            target_state = online_state

            # Maybe burn the core state in.
            if burn_in_length:
                burn_obs = jax.tree_map(lambda x: x[:burn_in_length],
                                        data.observation)
                key_grad, key1, key2 = jax.random.split(key_grad, 3)
                _, online_state = unroll.apply(params, key1, burn_obs,
                                               online_state)
                _, target_state = unroll.apply(target_params, key2, burn_obs,
                                               target_state)

            # Only get data to learn on from after the end of the burn in period.
            data = jax.tree_map(lambda seq: seq[burn_in_length:], data)

            # Unroll on sequences to get online and target Q-Values.
            key1, key2 = jax.random.split(key_grad)
            online_q, _ = unroll.apply(params, key1, data.observation,
                                       online_state)
            target_q, _ = unroll.apply(target_params, key2, data.observation,
                                       target_state)

            # Get value-selector actions from online Q-values for double Q-learning.
            selector_actions = jnp.argmax(online_q, axis=-1)
            # Preprocess discounts & rewards.
            discounts = (data.discount * discount).astype(online_q.dtype)
            rewards = data.reward
            if clip_rewards:
                rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward)
            rewards = rewards.astype(online_q.dtype)

            # Get N-step transformed TD error and loss.
            batch_td_error_fn = jax.vmap(functools.partial(
                rlax.transformed_n_step_q_learning,
                n=bootstrap_n,
                tx_pair=tx_pair),
                                         in_axes=1,
                                         out_axes=1)
            # TODO(b/183945808): when this bug is fixed, truncations of actions,
            # rewards, and discounts will no longer be necessary.
            batch_td_error = batch_td_error_fn(online_q[:-1], data.action[:-1],
                                               target_q[1:],
                                               selector_actions[1:],
                                               rewards[:-1], discounts[:-1])
            batch_loss = 0.5 * jnp.square(batch_td_error).sum(axis=0)

            # Importance weighting.
            probs = sample.info.probability
            importance_weights = (1. / (probs + 1e-6)).astype(online_q.dtype)
            importance_weights **= importance_sampling_exponent
            importance_weights /= jnp.max(importance_weights)
            mean_loss = jnp.mean(importance_weights * batch_loss)

            # Calculate priorities as a mixture of max and mean sequence errors.
            abs_td_error = jnp.abs(batch_td_error).astype(online_q.dtype)
            max_priority = max_priority_weight * jnp.max(abs_td_error, axis=0)
            mean_priority = (1 - max_priority_weight) * jnp.mean(abs_td_error,
                                                                 axis=0)
            priorities = (max_priority + mean_priority)

            return mean_loss, priorities

        def sgd_step(
            state: TrainingState, samples: reverb.ReplaySample
        ) -> Tuple[TrainingState, jnp.ndarray, Dict[str, jnp.ndarray]]:
            """Performs an update step, averaging over pmap replicas."""

            # Compute loss and gradients.
            grad_fn = jax.value_and_grad(loss, has_aux=True)
            key, key_grad = jax.random.split(state.random_key)
            (loss_value,
             priorities), gradients = grad_fn(state.params,
                                              state.target_params, key_grad,
                                              samples)

            # Average gradients over pmap replicas before optimizer update.
            gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME)

            # Apply optimizer updates.
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optax.apply_updates(state.params, updates)

            # Periodically update target networks.
            steps = state.steps + 1
            target_params = rlax.periodic_update(new_params,
                                                 state.target_params, steps,
                                                 self._target_update_period)

            new_state = TrainingState(params=new_params,
                                      target_params=target_params,
                                      opt_state=new_opt_state,
                                      steps=steps,
                                      random_key=key)
            return new_state, priorities, {'loss': loss_value}

        def update_priorities(keys_and_priorities: Tuple[jnp.ndarray,
                                                         jnp.ndarray]):
            keys, priorities = keys_and_priorities
            keys, priorities = tree.map_structure(
                # Fetch array and combine device and batch dimensions.
                lambda x: utils.fetch_devicearray(x).reshape(
                    (-1, ) + x.shape[2:]),
                (keys, priorities))
            replay_client.mutate_priorities(  # pytype: disable=attribute-error
                table=adders.DEFAULT_PRIORITY_TABLE,
                updates=dict(zip(keys, priorities)))

        # Internalise components, hyperparameters, logger, counter, and methods.
        self._replay_client = replay_client
        self._target_update_period = target_update_period
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger(
            'learner',
            asynchronous=True,
            serialize_fn=utils.fetch_devicearray,
            time_delta=1.)

        self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME)
        self._async_priority_updater = async_utils.AsyncExecutor(
            update_priorities)

        # Initialise and internalise training state (parameters/optimiser state).
        random_key, key_init = jax.random.split(random_key)
        initial_params = unroll.init(key_init, initial_state)
        opt_state = optimizer.init(initial_params)

        state = TrainingState(params=initial_params,
                              target_params=initial_params,
                              opt_state=opt_state,
                              steps=jnp.array(0),
                              random_key=random_key)
        # Replicate parameters.
        self._state = utils.replicate_in_all_devices(state)

        # Shard multiple inputs with on-device prefetching.
        # We split samples in two outputs, the keys which need to be kept on-host
        # since int64 arrays are not supported in TPUs, and the entire sample
        # separately so it can be sent to the sgd_step method.
        def split_sample(
                sample: reverb.ReplaySample) -> utils.PrefetchingSplit:
            return utils.PrefetchingSplit(host=sample.info.key, device=sample)

        self._prefetched_iterator = utils.sharded_prefetch(
            iterator,
            buffer_size=prefetch_size,
            num_threads=jax.local_device_count(),
            split_fn=split_sample)