Ejemplo n.º 1
0
    def __call__(
        self,
        network: networks_lib.FeedForwardNetwork,
        params: networks_lib.Params,
        target_params: networks_lib.Params,
        batch: reverb.ReplaySample,
        key: networks_lib.PRNGKey,
    ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
        """Calculate a loss on a single batch of data."""
        del key
        transitions: types.Transition = batch.data

        # Forward pass.
        q_tm1 = network.apply(params, transitions.observation)
        q_t = network.apply(target_params, transitions.next_observation)

        d_t = (transitions.discount * self.discount).astype(jnp.float32)

        # Compute Q-learning TD-error.
        batch_error = jax.vmap(rlax.q_learning)
        td_error = batch_error(q_tm1, transitions.action, transitions.reward,
                               d_t, q_t)
        td_error = 0.5 * jnp.square(td_error)

        def select(qtm1, action):
            return qtm1[action]

        q_regularizer = jax.vmap(select)(q_tm1, transitions.action)

        loss = self.regularizer_coeff * jnp.mean(q_regularizer) + jnp.mean(
            td_error)
        extra = learning_lib.LossExtra(metrics={})
        return loss, extra
Ejemplo n.º 2
0
    def __call__(
        self,
        network: networks_lib.FeedForwardNetwork,
        params: networks_lib.Params,
        target_params: networks_lib.Params,
        batch: reverb.ReplaySample,
        key: networks_lib.PRNGKey,
    ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
        """Calculate a loss on a single batch of data."""
        del key
        transitions: types.Transition = batch.data

        # Forward pass.
        q_tm1 = network.apply(params, transitions.observation)
        q_t = network.apply(target_params, transitions.next_observation)

        # Cast and clip rewards.
        d_t = (transitions.discount * self.discount).astype(jnp.float32)
        r_t = jnp.clip(transitions.reward, -self.max_abs_reward,
                       self.max_abs_reward).astype(jnp.float32)

        # Compute Q-learning TD-error.
        batch_error = jax.vmap(rlax.q_learning)
        td_error = batch_error(q_tm1, transitions.action, r_t, d_t, q_t)
        batch_loss = jnp.square(td_error)

        loss = jnp.mean(batch_loss)
        extra = learning_lib.LossExtra(metrics={})
        return loss, extra
Ejemplo n.º 3
0
def default_behavior_policy(network: networks_lib.FeedForwardNetwork,
                            epsilon: float, params: networks_lib.Params,
                            key: networks_lib.PRNGKey,
                            observation: networks_lib.Observation):
    """Returns an action for the given observation."""
    action_values = network.apply(params, observation)
    actions = rlax.epsilon_greedy(epsilon).sample(key, action_values)
    return actions.astype(jnp.int32)
Ejemplo n.º 4
0
    def __call__(
        self,
        network: networks_lib.FeedForwardNetwork,
        params: networks_lib.Params,
        target_params: networks_lib.Params,
        batch: reverb.ReplaySample,
        key: networks_lib.PRNGKey,
    ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
        """Calculate a loss on a single batch of data."""
        del key
        transitions: types.Transition = batch.data

        # Forward pass.
        q_online_s = network.apply(params, transitions.observation)
        action_one_hot = jax.nn.one_hot(transitions.action,
                                        q_online_s.shape[-1])
        q_online_sa = jnp.sum(action_one_hot * q_online_s, axis=-1)
        q_target_s = network.apply(target_params, transitions.observation)
        q_target_next = network.apply(target_params,
                                      transitions.next_observation)

        # Cast and clip rewards.
        d_t = (transitions.discount * self.discount).astype(jnp.float32)
        r_t = jnp.clip(transitions.reward, -self.max_abs_reward,
                       self.max_abs_reward).astype(jnp.float32)

        # Munchausen term : tau * log_pi(a|s)
        munchausen_term = self.entropy_temperature * jax.nn.log_softmax(
            q_target_s / self.entropy_temperature, axis=-1)
        munchausen_term_a = jnp.sum(action_one_hot * munchausen_term, axis=-1)
        munchausen_term_a = jnp.clip(munchausen_term_a,
                                     a_min=self.clip_value_min,
                                     a_max=0.)

        # Soft Bellman operator applied to q
        next_v = self.entropy_temperature * jax.nn.logsumexp(
            q_target_next / self.entropy_temperature, axis=-1)
        target_q = jax.lax.stop_gradient(r_t + self.munchausen_coefficient *
                                         munchausen_term_a + d_t * next_v)

        batch_loss = rlax.huber_loss(target_q - q_online_sa,
                                     self.huber_loss_parameter)
        loss = jnp.mean(batch_loss)

        extra = learning_lib.LossExtra(metrics={})
        return loss, extra
Ejemplo n.º 5
0
    def __call__(
        self,
        network: networks_lib.FeedForwardNetwork,
        params: networks_lib.Params,
        target_params: networks_lib.Params,
        batch: reverb.ReplaySample,
        key: networks_lib.PRNGKey,
    ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
        """Calculate a loss on a single batch of data."""
        del key
        transitions: types.Transition = batch.data
        keys, probs, *_ = batch.info

        # Forward pass.
        _, logits_tm1, atoms_tm1 = network.apply(params,
                                                 transitions.observation)
        _, logits_t, atoms_t = network.apply(target_params,
                                             transitions.next_observation)
        q_t_selector, _, _ = network.apply(params,
                                           transitions.next_observation)

        # Cast and clip rewards.
        d_t = (transitions.discount * self.discount).astype(jnp.float32)
        r_t = jnp.clip(transitions.reward, -self.max_abs_reward,
                       self.max_abs_reward).astype(jnp.float32)

        # Compute categorical double Q-learning loss.
        batch_loss_fn = jax.vmap(rlax.categorical_double_q_learning,
                                 in_axes=(None, 0, 0, 0, 0, None, 0, 0))
        batch_loss = batch_loss_fn(atoms_tm1, logits_tm1, transitions.action,
                                   r_t, d_t, atoms_t, logits_t, q_t_selector)

        # Importance weighting.
        importance_weights = (1. / probs).astype(jnp.float32)
        importance_weights **= self.importance_sampling_exponent
        importance_weights /= jnp.max(importance_weights)

        # Reweight.
        loss = jnp.mean(importance_weights * batch_loss)  # []
        reverb_update = learning_lib.ReverbUpdate(
            keys=keys, priorities=jnp.abs(batch_loss).astype(jnp.float64))
        extra = learning_lib.LossExtra(metrics={}, reverb_update=reverb_update)
        return loss, extra
Ejemplo n.º 6
0
  def __call__(
      self,
      network: networks_lib.FeedForwardNetwork,
      params: networks_lib.Params,
      target_params: networks_lib.Params,
      batch: reverb.ReplaySample,
      key: jnp.DeviceArray,
  ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
    """Calculate a loss on a single batch of data."""
    del key
    transitions: types.Transition = batch.data
    keys, probs, *_ = batch.info

    # Forward pass.
    q_tm1 = network.apply(params, transitions.observation)
    q_t_value = network.apply(target_params, transitions.next_observation)
    q_t_selector = network.apply(params, transitions.next_observation)

    # Cast and clip rewards.
    d_t = (transitions.discount * self.discount).astype(jnp.float32)
    r_t = jnp.clip(transitions.reward, -self.max_abs_reward,
                   self.max_abs_reward).astype(jnp.float32)

    # Compute double Q-learning n-step TD-error.
    batch_error = jax.vmap(rlax.double_q_learning)
    td_error = batch_error(q_tm1, transitions.action, r_t, d_t, q_t_value,
                           q_t_selector)
    batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter)

    # Importance weighting.
    importance_weights = (1. / probs).astype(jnp.float32)
    importance_weights **= self.importance_sampling_exponent
    importance_weights /= jnp.max(importance_weights)

    # Reweight.
    loss = jnp.mean(importance_weights * batch_loss)  # []
    reverb_update = learning_lib.ReverbUpdate(
        keys=keys, priorities=jnp.abs(td_error).astype(jnp.float64))
    extra = learning_lib.LossExtra(metrics={}, reverb_update=reverb_update)
    return loss, extra
Ejemplo n.º 7
0
 def __call__(
     self,
     network: networks_lib.FeedForwardNetwork,
     params: networks_lib.Params,
     target_params: networks_lib.Params,
     batch: reverb.ReplaySample,
     key: networks_lib.PRNGKey,
 ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
     """Calculate a loss on a single batch of data."""
     del key
     transitions: types.Transition = batch.data
     dist_q_tm1 = network.apply(params, transitions.observation)['q_dist']
     dist_q_target_t = network.apply(target_params,
                                     transitions.next_observation)['q_dist']
     # Swap distribution and action dimension, since
     # rlax.quantile_q_learning expects it that way.
     dist_q_tm1 = jnp.swapaxes(dist_q_tm1, 1, 2)
     dist_q_target_t = jnp.swapaxes(dist_q_target_t, 1, 2)
     quantiles = ((jnp.arange(self.num_atoms, dtype=jnp.float32) + 0.5) /
                  self.num_atoms)
     batch_quantile_q_learning = jax.vmap(rlax.quantile_q_learning,
                                          in_axes=(0, None, 0, 0, 0, 0, 0,
                                                   None))
     losses = batch_quantile_q_learning(
         dist_q_tm1,
         quantiles,
         transitions.action,
         transitions.reward,
         transitions.discount,
         dist_q_target_t,  # No double Q-learning here.
         dist_q_target_t,
         self.huber_param,
     )
     loss = jnp.mean(losses)
     extra = learning_lib.LossExtra(metrics={'mean_loss': loss})
     return loss, extra
Ejemplo n.º 8
0
    def __init__(self,
                 spec: specs.EnvironmentSpec,
                 networks: networks_lib.FeedForwardNetwork,
                 rng: networks_lib.PRNGKey,
                 config: ars_config.ARSConfig,
                 iterator: Iterator[reverb.ReplaySample],
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None):

        self._config = config
        self._lock = threading.Lock()

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger(
            'learner',
            asynchronous=True,
            serialize_fn=utils.fetch_devicearray,
            steps_key=self._counter.get_steps_key())

        # Iterator on demonstration transitions.
        self._iterator = iterator

        if self._config.normalize_observations:
            normalizer_params = running_statistics.init_state(
                spec.observations)
            self._normalizer_update_fn = running_statistics.update
        else:
            normalizer_params = ()
            self._normalizer_update_fn = lambda a, b: a

        rng1, rng2, tmp = jax.random.split(rng, 3)
        # Create initial state.
        self._training_state = TrainingState(
            key=rng1,
            policy_params=networks.init(tmp),
            normalizer_params=normalizer_params,
            training_iteration=0)
        self._evaluation_state = EvaluationState(
            key=rng2,
            evaluation_queue=collections.deque(),
            received_results={},
            noises=[])

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None
Ejemplo n.º 9
0
    def __init__(self,
                 network: networks_lib.FeedForwardNetwork,
                 obs_spec: specs.Array,
                 optimizer: optax.GradientTransformation,
                 random_key: networks_lib.PRNGKey,
                 dataset: tf.data.Dataset,
                 loss_fn: LossFn = _sparse_categorical_cross_entropy,
                 counter: counting.Counter = None,
                 logger: loggers.Logger = None):
        """Initializes the learner."""
        def loss(params: networks_lib.Params,
                 sample: reverb.ReplaySample) -> jnp.ndarray:
            # Pull out the data needed for updates.
            o_tm1, a_tm1, r_t, d_t, o_t = sample.data
            del r_t, d_t, o_t
            logits = network.apply(params, o_tm1)
            return jnp.mean(loss_fn(a_tm1, logits))

        def sgd_step(
            state: TrainingState, sample: reverb.ReplaySample
        ) -> Tuple[TrainingState, Dict[str, jnp.DeviceArray]]:
            """Do a step of SGD."""
            grad_fn = jax.value_and_grad(loss)
            loss_value, gradients = grad_fn(state.params, sample)
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optax.apply_updates(state.params, updates)

            steps = state.steps + 1

            new_state = TrainingState(params=new_params,
                                      opt_state=new_opt_state,
                                      steps=steps)

            # Compute the global norm of the gradients for logging.
            global_gradient_norm = optax.global_norm(gradients)
            fetches = {
                'loss': loss_value,
                'gradient_norm': global_gradient_norm
            }

            return new_state, fetches

        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        # Get an iterator over the dataset.
        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types
        # TODO(b/155086959): Fix type stubs and remove.

        # Initialise parameters and optimiser state.
        initial_params = network.init(
            random_key, utils.add_batch_dim(utils.zeros_like(obs_spec)))
        initial_opt_state = optimizer.init(initial_params)

        self._state = TrainingState(params=initial_params,
                                    opt_state=initial_opt_state,
                                    steps=0)

        self._sgd_step = jax.jit(sgd_step)
Ejemplo n.º 10
0
    def __init__(self,
                 unroll: networks_lib.FeedForwardNetwork,
                 initial_state: networks_lib.FeedForwardNetwork,
                 batch_size: int,
                 random_key: networks_lib.PRNGKey,
                 burn_in_length: int,
                 discount: float,
                 importance_sampling_exponent: float,
                 max_priority_weight: float,
                 target_update_period: int,
                 iterator: Iterator[reverb.ReplaySample],
                 optimizer: optax.GradientTransformation,
                 bootstrap_n: int = 5,
                 tx_pair: rlax.TxPair = rlax.SIGNED_HYPERBOLIC_PAIR,
                 clip_rewards: bool = False,
                 max_abs_reward: float = 1.,
                 use_core_state: bool = True,
                 prefetch_size: int = 2,
                 replay_client: Optional[reverb.Client] = None,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None):
        """Initializes the learner."""

        random_key, key_initial_1, key_initial_2 = jax.random.split(
            random_key, 3)
        initial_state_params = initial_state.init(key_initial_1, batch_size)
        initial_state = initial_state.apply(initial_state_params,
                                            key_initial_2, batch_size)

        def loss(
            params: networks_lib.Params, target_params: networks_lib.Params,
            key_grad: networks_lib.PRNGKey, sample: reverb.ReplaySample
        ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
            """Computes mean transformed N-step loss for a batch of sequences."""

            # Convert sample data to sequence-major format [T, B, ...].
            data = utils.batch_to_sequence(sample.data)

            # Get core state & warm it up on observations for a burn-in period.
            if use_core_state:
                # Replay core state.
                online_state = jax.tree_map(lambda x: x[0],
                                            data.extras['core_state'])
            else:
                online_state = initial_state
            target_state = online_state

            # Maybe burn the core state in.
            if burn_in_length:
                burn_obs = jax.tree_map(lambda x: x[:burn_in_length],
                                        data.observation)
                key_grad, key1, key2 = jax.random.split(key_grad, 3)
                _, online_state = unroll.apply(params, key1, burn_obs,
                                               online_state)
                _, target_state = unroll.apply(target_params, key2, burn_obs,
                                               target_state)

            # Only get data to learn on from after the end of the burn in period.
            data = jax.tree_map(lambda seq: seq[burn_in_length:], data)

            # Unroll on sequences to get online and target Q-Values.
            key1, key2 = jax.random.split(key_grad)
            online_q, _ = unroll.apply(params, key1, data.observation,
                                       online_state)
            target_q, _ = unroll.apply(target_params, key2, data.observation,
                                       target_state)

            # Get value-selector actions from online Q-values for double Q-learning.
            selector_actions = jnp.argmax(online_q, axis=-1)
            # Preprocess discounts & rewards.
            discounts = (data.discount * discount).astype(online_q.dtype)
            rewards = data.reward
            if clip_rewards:
                rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward)
            rewards = rewards.astype(online_q.dtype)

            # Get N-step transformed TD error and loss.
            batch_td_error_fn = jax.vmap(functools.partial(
                rlax.transformed_n_step_q_learning,
                n=bootstrap_n,
                tx_pair=tx_pair),
                                         in_axes=1,
                                         out_axes=1)
            # TODO(b/183945808): when this bug is fixed, truncations of actions,
            # rewards, and discounts will no longer be necessary.
            batch_td_error = batch_td_error_fn(online_q[:-1], data.action[:-1],
                                               target_q[1:],
                                               selector_actions[1:],
                                               rewards[:-1], discounts[:-1])
            batch_loss = 0.5 * jnp.square(batch_td_error).sum(axis=0)

            # Importance weighting.
            probs = sample.info.probability
            importance_weights = (1. / (probs + 1e-6)).astype(online_q.dtype)
            importance_weights **= importance_sampling_exponent
            importance_weights /= jnp.max(importance_weights)
            mean_loss = jnp.mean(importance_weights * batch_loss)

            # Calculate priorities as a mixture of max and mean sequence errors.
            abs_td_error = jnp.abs(batch_td_error).astype(online_q.dtype)
            max_priority = max_priority_weight * jnp.max(abs_td_error, axis=0)
            mean_priority = (1 - max_priority_weight) * jnp.mean(abs_td_error,
                                                                 axis=0)
            priorities = (max_priority + mean_priority)

            return mean_loss, priorities

        def sgd_step(
            state: TrainingState, samples: reverb.ReplaySample
        ) -> Tuple[TrainingState, jnp.ndarray, Dict[str, jnp.ndarray]]:
            """Performs an update step, averaging over pmap replicas."""

            # Compute loss and gradients.
            grad_fn = jax.value_and_grad(loss, has_aux=True)
            key, key_grad = jax.random.split(state.random_key)
            (loss_value,
             priorities), gradients = grad_fn(state.params,
                                              state.target_params, key_grad,
                                              samples)

            # Average gradients over pmap replicas before optimizer update.
            gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME)

            # Apply optimizer updates.
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optax.apply_updates(state.params, updates)

            # Periodically update target networks.
            steps = state.steps + 1
            target_params = rlax.periodic_update(new_params,
                                                 state.target_params, steps,
                                                 self._target_update_period)

            new_state = TrainingState(params=new_params,
                                      target_params=target_params,
                                      opt_state=new_opt_state,
                                      steps=steps,
                                      random_key=key)
            return new_state, priorities, {'loss': loss_value}

        def update_priorities(keys_and_priorities: Tuple[jnp.ndarray,
                                                         jnp.ndarray]):
            keys, priorities = keys_and_priorities
            keys, priorities = tree.map_structure(
                # Fetch array and combine device and batch dimensions.
                lambda x: utils.fetch_devicearray(x).reshape(
                    (-1, ) + x.shape[2:]),
                (keys, priorities))
            replay_client.mutate_priorities(  # pytype: disable=attribute-error
                table=adders.DEFAULT_PRIORITY_TABLE,
                updates=dict(zip(keys, priorities)))

        # Internalise components, hyperparameters, logger, counter, and methods.
        self._replay_client = replay_client
        self._target_update_period = target_update_period
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger(
            'learner',
            asynchronous=True,
            serialize_fn=utils.fetch_devicearray,
            time_delta=1.)

        self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME)
        self._async_priority_updater = async_utils.AsyncExecutor(
            update_priorities)

        # Initialise and internalise training state (parameters/optimiser state).
        random_key, key_init = jax.random.split(random_key)
        initial_params = unroll.init(key_init, initial_state)
        opt_state = optimizer.init(initial_params)

        state = TrainingState(params=initial_params,
                              target_params=initial_params,
                              opt_state=opt_state,
                              steps=jnp.array(0),
                              random_key=random_key)
        # Replicate parameters.
        self._state = utils.replicate_in_all_devices(state)

        # Shard multiple inputs with on-device prefetching.
        # We split samples in two outputs, the keys which need to be kept on-host
        # since int64 arrays are not supported in TPUs, and the entire sample
        # separately so it can be sent to the sgd_step method.
        def split_sample(
                sample: reverb.ReplaySample) -> utils.PrefetchingSplit:
            return utils.PrefetchingSplit(host=sample.info.key, device=sample)

        self._prefetched_iterator = utils.sharded_prefetch(
            iterator,
            buffer_size=prefetch_size,
            num_threads=jax.local_device_count(),
            split_fn=split_sample)
Ejemplo n.º 11
0
  def __init__(self,
               policy_network: networks_lib.FeedForwardNetwork,
               critic_network: networks_lib.FeedForwardNetwork,
               random_key: networks_lib.PRNGKey,
               discount: float,
               target_update_period: int,
               iterator: Iterator[reverb.ReplaySample],
               policy_optimizer: Optional[optax.GradientTransformation] = None,
               critic_optimizer: Optional[optax.GradientTransformation] = None,
               clipping: bool = True,
               counter: Optional[counting.Counter] = None,
               logger: Optional[loggers.Logger] = None,
               jit: bool = True,
               num_sgd_steps_per_step: int = 1):

    def critic_mean(
        critic_params: networks_lib.Params,
        observation: types.NestedArray,
        action: types.NestedArray,
    ) -> jnp.ndarray:
      # We add batch dimension to make sure batch concat in critic_network
      # works correctly.
      observation = utils.add_batch_dim(observation)
      action = utils.add_batch_dim(action)
      # Computes the mean action-value estimate.
      logits, atoms = critic_network.apply(critic_params, observation, action)
      logits = utils.squeeze_batch_dim(logits)
      probabilities = jax.nn.softmax(logits)
      return jnp.sum(probabilities * atoms, axis=-1)

    def policy_loss(
        policy_params: networks_lib.Params,
        critic_params: networks_lib.Params,
        o_t: types.NestedArray,
    ) -> jnp.ndarray:
      # Computes the discrete policy gradient loss.
      dpg_a_t = policy_network.apply(policy_params, o_t)
      grad_critic = jax.vmap(
          jax.grad(critic_mean, argnums=2), in_axes=(None, 0, 0))
      dq_da = grad_critic(critic_params, o_t, dpg_a_t)
      dqda_clipping = 1. if clipping else None
      batch_dpg_learning = jax.vmap(rlax.dpg_loss, in_axes=(0, 0, None))
      loss = batch_dpg_learning(dpg_a_t, dq_da, dqda_clipping)
      return jnp.mean(loss)

    def critic_loss(
        critic_params: networks_lib.Params,
        state: TrainingState,
        transition: types.Transition,
    ):
      # Computes the distributional critic loss.
      q_tm1, atoms_tm1 = critic_network.apply(critic_params,
                                              transition.observation,
                                              transition.action)
      a = policy_network.apply(state.target_policy_params,
                               transition.next_observation)
      q_t, atoms_t = critic_network.apply(state.target_critic_params,
                                          transition.next_observation, a)
      batch_td_learning = jax.vmap(
          rlax.categorical_td_learning, in_axes=(None, 0, 0, 0, None, 0))
      loss = batch_td_learning(atoms_tm1, q_tm1, transition.reward,
                               discount * transition.discount, atoms_t, q_t)
      return jnp.mean(loss)

    def sgd_step(
        state: TrainingState,
        transitions: types.Transition,
    ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:

      # TODO(jaslanides): Use a shared forward pass for efficiency.
      policy_loss_and_grad = jax.value_and_grad(policy_loss)
      critic_loss_and_grad = jax.value_and_grad(critic_loss)

      # Compute losses and their gradients.
      policy_loss_value, policy_gradients = policy_loss_and_grad(
          state.policy_params, state.critic_params,
          transitions.next_observation)
      critic_loss_value, critic_gradients = critic_loss_and_grad(
          state.critic_params, state, transitions)

      # Get optimizer updates and state.
      policy_updates, policy_opt_state = policy_optimizer.update(  # pytype: disable=attribute-error
          policy_gradients, state.policy_opt_state)
      critic_updates, critic_opt_state = critic_optimizer.update(  # pytype: disable=attribute-error
          critic_gradients, state.critic_opt_state)

      # Apply optimizer updates to parameters.
      policy_params = optax.apply_updates(state.policy_params, policy_updates)
      critic_params = optax.apply_updates(state.critic_params, critic_updates)

      steps = state.steps + 1

      # Periodically update target networks.
      target_policy_params, target_critic_params = optax.periodic_update(
          (policy_params, critic_params),
          (state.target_policy_params, state.target_critic_params), steps,
          self._target_update_period)

      new_state = TrainingState(
          policy_params=policy_params,
          critic_params=critic_params,
          target_policy_params=target_policy_params,
          target_critic_params=target_critic_params,
          policy_opt_state=policy_opt_state,
          critic_opt_state=critic_opt_state,
          steps=steps,
      )

      metrics = {
          'policy_loss': policy_loss_value,
          'critic_loss': critic_loss_value,
      }

      return new_state, metrics

    # General learner book-keeping and loggers.
    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.make_default_logger(
        'learner',
        asynchronous=True,
        serialize_fn=utils.fetch_devicearray,
        steps_key=self._counter.get_steps_key())

    # Necessary to track when to update target networks.
    self._target_update_period = target_update_period

    # Create prefetching dataset iterator.
    self._iterator = iterator

    # Maybe use the JIT compiler.
    sgd_step = utils.process_multiple_batches(sgd_step, num_sgd_steps_per_step)
    self._sgd_step = jax.jit(sgd_step) if jit else sgd_step

    # Create the network parameters and copy into the target network parameters.
    key_policy, key_critic = jax.random.split(random_key)
    initial_policy_params = policy_network.init(key_policy)
    initial_critic_params = critic_network.init(key_critic)
    initial_target_policy_params = initial_policy_params
    initial_target_critic_params = initial_critic_params

    # Create optimizers if they aren't given.
    critic_optimizer = critic_optimizer or optax.adam(1e-4)
    policy_optimizer = policy_optimizer or optax.adam(1e-4)

    # Initialize optimizers.
    initial_policy_opt_state = policy_optimizer.init(initial_policy_params)  # pytype: disable=attribute-error
    initial_critic_opt_state = critic_optimizer.init(initial_critic_params)  # pytype: disable=attribute-error

    # Create initial state.
    self._state = TrainingState(
        policy_params=initial_policy_params,
        target_policy_params=initial_target_policy_params,
        critic_params=initial_critic_params,
        target_critic_params=initial_target_critic_params,
        policy_opt_state=initial_policy_opt_state,
        critic_opt_state=initial_critic_opt_state,
        steps=0,
    )

    # Do not record timestamps until after the first learning step is done.
    # This is to avoid including the time it takes for actors to come online and
    # fill the replay buffer.
    self._timestamp = None
Ejemplo n.º 12
0
    def __init__(self,
                 network: networks_lib.FeedForwardNetwork,
                 random_key: networks_lib.PRNGKey,
                 loss_fn: losses.Loss,
                 optimizer: optax.GradientTransformation,
                 prefetching_iterator: Iterator[types.Transition],
                 num_sgd_steps_per_step: int,
                 loss_has_aux: bool = False,
                 logger: Optional[loggers.Logger] = None,
                 counter: Optional[counting.Counter] = None):
        """Behavior Cloning Learner.

    Args:
      network: Networks with signature for apply: (params, obs, is_training,
        key) -> jnp.ndarray and for init: (rng, is_training) -> params
      random_key: RNG key.
      loss_fn: BC loss to use.
      optimizer: Optax optimizer.
      prefetching_iterator: A sharded prefetching iterator as outputted from
        `acme.jax.utils.sharded_prefetch`. Please see the documentation for
        `sharded_prefetch` for more details.
      num_sgd_steps_per_step: Number of gradient updates per step.
      loss_has_aux: Whether the loss function returns auxiliary metrics as a
        second argument.
      logger: Logger.
      counter: Counter.
    """
        def sgd_step(
            state: TrainingState,
            transitions: types.Transition,
        ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:

            loss_and_grad = jax.value_and_grad(loss_fn,
                                               argnums=1,
                                               has_aux=loss_has_aux)

            # Compute losses and their gradients.
            key, key_input = jax.random.split(state.key)
            loss_result, gradients = loss_and_grad(network.apply,
                                                   state.policy_params,
                                                   key_input, transitions)

            # Combine the gradient across all devices (by taking their mean).
            gradients = jax.lax.pmean(gradients, axis_name=_PMAP_AXIS_NAME)

            # Compute and combine metrics across all devices.
            metrics = _create_loss_metrics(loss_has_aux, loss_result,
                                           gradients)
            metrics = jax.lax.pmean(metrics, axis_name=_PMAP_AXIS_NAME)

            policy_update, optimizer_state = optimizer.update(
                gradients, state.optimizer_state, state.policy_params)
            policy_params = optax.apply_updates(state.policy_params,
                                                policy_update)

            new_state = TrainingState(
                optimizer_state=optimizer_state,
                policy_params=policy_params,
                key=key,
                steps=state.steps + 1,
            )

            return new_state, metrics

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter(prefix='learner')
        self._logger = logger or loggers.make_default_logger(
            'learner',
            asynchronous=True,
            serialize_fn=utils.fetch_devicearray,
            steps_key=self._counter.get_steps_key())

        # Split the input batch to `num_sgd_steps_per_step` minibatches in order
        # to achieve better performance on accelerators.
        sgd_step = utils.process_multiple_batches(sgd_step,
                                                  num_sgd_steps_per_step)
        self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME)

        random_key, init_key = jax.random.split(random_key)
        policy_params = network.init(init_key)
        optimizer_state = optimizer.init(policy_params)

        # Create initial state.
        state = TrainingState(
            optimizer_state=optimizer_state,
            policy_params=policy_params,
            key=random_key,
            steps=0,
        )
        self._state = utils.replicate_in_all_devices(state)

        self._timestamp = None

        self._prefetching_iterator = prefetching_iterator
Ejemplo n.º 13
0
  def __init__(self,
               network: networks_lib.FeedForwardNetwork,
               random_key: networks_lib.PRNGKey,
               loss_fn: losses.Loss,
               optimizer: optax.GradientTransformation,
               demonstrations: Iterator[types.Transition],
               num_sgd_steps_per_step: int,
               logger: Optional[loggers.Logger] = None,
               counter: Optional[counting.Counter] = None):
    def sgd_step(
        state: TrainingState,
        transitions: types.Transition,
    ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:

      loss_and_grad = jax.value_and_grad(loss_fn, argnums=1)

      # Compute losses and their gradients.
      key, key_input = jax.random.split(state.key)
      loss_value, gradients = loss_and_grad(network.apply, state.policy_params,
                                            key_input, transitions)

      policy_update, optimizer_state = optimizer.update(gradients,
                                                        state.optimizer_state)
      policy_params = optax.apply_updates(state.policy_params, policy_update)

      new_state = TrainingState(
          optimizer_state=optimizer_state,
          policy_params=policy_params,
          key=key,
          steps=state.steps + 1,
      )
      metrics = {
          'loss': loss_value,
          'gradient_norm': optax.global_norm(gradients)
      }

      return new_state, metrics

    # General learner book-keeping and loggers.
    self._counter = counter or counting.Counter(prefix='learner')
    self._logger = logger or loggers.make_default_logger(
        'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray)

    # Iterator on demonstration transitions.
    self._demonstrations = demonstrations

    # Split the input batch to `num_sgd_steps_per_step` minibatches in order
    # to achieve better performance on accelerators.
    self._sgd_step = jax.jit(utils.process_multiple_batches(
        sgd_step, num_sgd_steps_per_step))

    random_key, init_key = jax.random.split(random_key)
    policy_params = network.init(init_key)
    optimizer_state = optimizer.init(policy_params)

    # Create initial state.
    self._state = TrainingState(
        optimizer_state=optimizer_state,
        policy_params=policy_params,
        key=random_key,
        steps=0,
    )

    self._timestamp = None