Пример #1
0
  def __init__(
      self,
      obs_spec: specs.Array,
      unroll_fn: networks_lib.PolicyValueRNN,
      initial_state_fn: Callable[[], hk.LSTMState],
      iterator: Iterator[reverb.ReplaySample],
      optimizer: optax.GradientTransformation,
      random_key: networks_lib.PRNGKey,
      discount: float = 0.99,
      entropy_cost: float = 0.,
      baseline_cost: float = 1.,
      max_abs_reward: float = np.inf,
      counter: counting.Counter = None,
      logger: loggers.Logger = None,
      devices: Optional[Sequence[jax.xla.Device]] = None,
      prefetch_size: int = 2,
      num_prefetch_threads: Optional[int] = None,
  ):

    self._devices = devices or jax.local_devices()

    # Transform into pure functions.
    unroll_fn = hk.without_apply_rng(hk.transform(unroll_fn, apply_rng=True))
    initial_state_fn = hk.without_apply_rng(
        hk.transform(initial_state_fn, apply_rng=True))

    loss_fn = losses.impala_loss(
        unroll_fn,
        discount=discount,
        max_abs_reward=max_abs_reward,
        baseline_cost=baseline_cost,
        entropy_cost=entropy_cost)

    @jax.jit
    def sgd_step(
        state: TrainingState, sample: reverb.ReplaySample
    ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:
      """Computes an SGD step, returning new state and metrics for logging."""

      # Compute gradients.
      grad_fn = jax.value_and_grad(loss_fn)
      loss_value, gradients = grad_fn(state.params, sample)

      # Average gradients over pmap replicas before optimizer update.
      gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME)

      # Apply updates.
      updates, new_opt_state = optimizer.update(gradients, state.opt_state)
      new_params = optax.apply_updates(state.params, updates)

      metrics = {
          'loss': loss_value,
      }

      new_state = TrainingState(params=new_params, opt_state=new_opt_state)

      return new_state, metrics

    def make_initial_state(key: jnp.ndarray) -> TrainingState:
      """Initialises the training state (parameters and optimiser state)."""
      dummy_obs = utils.zeros_like(obs_spec)
      dummy_obs = utils.add_batch_dim(dummy_obs)  # Dummy 'sequence' dim.
      initial_state = initial_state_fn.apply(None)
      initial_params = unroll_fn.init(key, dummy_obs, initial_state)
      initial_opt_state = optimizer.init(initial_params)
      return TrainingState(params=initial_params, opt_state=initial_opt_state)

    # Initialise training state (parameters and optimiser state).
    state = make_initial_state(random_key)
    self._state = utils.replicate_in_all_devices(state, self._devices)

    if num_prefetch_threads is None:
      num_prefetch_threads = len(self._devices)
    self._prefetched_iterator = utils.sharded_prefetch(
        iterator,
        buffer_size=prefetch_size,
        devices=devices,
        num_threads=num_prefetch_threads,
    )

    self._sgd_step = jax.pmap(
        sgd_step, axis_name=_PMAP_AXIS_NAME, devices=self._devices)

    # Set up logging/counting.
    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.make_default_logger('learner')
Пример #2
0
 def restore(self, state: TrainingState):
   self._state = utils.replicate_in_all_devices(state, self._devices)
Пример #3
0
    def __init__(self,
                 network: networks_lib.FeedForwardNetwork,
                 random_key: networks_lib.PRNGKey,
                 loss_fn: losses.Loss,
                 optimizer: optax.GradientTransformation,
                 prefetching_iterator: Iterator[types.Transition],
                 num_sgd_steps_per_step: int,
                 loss_has_aux: bool = False,
                 logger: Optional[loggers.Logger] = None,
                 counter: Optional[counting.Counter] = None):
        """Behavior Cloning Learner.

    Args:
      network: Networks with signature for apply: (params, obs, is_training,
        key) -> jnp.ndarray and for init: (rng, is_training) -> params
      random_key: RNG key.
      loss_fn: BC loss to use.
      optimizer: Optax optimizer.
      prefetching_iterator: A sharded prefetching iterator as outputted from
        `acme.jax.utils.sharded_prefetch`. Please see the documentation for
        `sharded_prefetch` for more details.
      num_sgd_steps_per_step: Number of gradient updates per step.
      loss_has_aux: Whether the loss function returns auxiliary metrics as a
        second argument.
      logger: Logger.
      counter: Counter.
    """
        def sgd_step(
            state: TrainingState,
            transitions: types.Transition,
        ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:

            loss_and_grad = jax.value_and_grad(loss_fn,
                                               argnums=1,
                                               has_aux=loss_has_aux)

            # Compute losses and their gradients.
            key, key_input = jax.random.split(state.key)
            loss_result, gradients = loss_and_grad(network.apply,
                                                   state.policy_params,
                                                   key_input, transitions)

            # Combine the gradient across all devices (by taking their mean).
            gradients = jax.lax.pmean(gradients, axis_name=_PMAP_AXIS_NAME)

            # Compute and combine metrics across all devices.
            metrics = _create_loss_metrics(loss_has_aux, loss_result,
                                           gradients)
            metrics = jax.lax.pmean(metrics, axis_name=_PMAP_AXIS_NAME)

            policy_update, optimizer_state = optimizer.update(
                gradients, state.optimizer_state, state.policy_params)
            policy_params = optax.apply_updates(state.policy_params,
                                                policy_update)

            new_state = TrainingState(
                optimizer_state=optimizer_state,
                policy_params=policy_params,
                key=key,
                steps=state.steps + 1,
            )

            return new_state, metrics

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter(prefix='learner')
        self._logger = logger or loggers.make_default_logger(
            'learner',
            asynchronous=True,
            serialize_fn=utils.fetch_devicearray,
            steps_key=self._counter.get_steps_key())

        # Split the input batch to `num_sgd_steps_per_step` minibatches in order
        # to achieve better performance on accelerators.
        sgd_step = utils.process_multiple_batches(sgd_step,
                                                  num_sgd_steps_per_step)
        self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME)

        random_key, init_key = jax.random.split(random_key)
        policy_params = network.init(init_key)
        optimizer_state = optimizer.init(policy_params)

        # Create initial state.
        state = TrainingState(
            optimizer_state=optimizer_state,
            policy_params=policy_params,
            key=random_key,
            steps=0,
        )
        self._state = utils.replicate_in_all_devices(state)

        self._timestamp = None

        self._prefetching_iterator = prefetching_iterator
Пример #4
0
    def __init__(self,
                 unroll: networks_lib.FeedForwardNetwork,
                 initial_state: networks_lib.FeedForwardNetwork,
                 batch_size: int,
                 random_key: networks_lib.PRNGKey,
                 burn_in_length: int,
                 discount: float,
                 importance_sampling_exponent: float,
                 max_priority_weight: float,
                 target_update_period: int,
                 iterator: Iterator[reverb.ReplaySample],
                 optimizer: optax.GradientTransformation,
                 bootstrap_n: int = 5,
                 tx_pair: rlax.TxPair = rlax.SIGNED_HYPERBOLIC_PAIR,
                 clip_rewards: bool = False,
                 max_abs_reward: float = 1.,
                 use_core_state: bool = True,
                 prefetch_size: int = 2,
                 replay_client: Optional[reverb.Client] = None,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None):
        """Initializes the learner."""

        random_key, key_initial_1, key_initial_2 = jax.random.split(
            random_key, 3)
        initial_state_params = initial_state.init(key_initial_1, batch_size)
        initial_state = initial_state.apply(initial_state_params,
                                            key_initial_2, batch_size)

        def loss(
            params: networks_lib.Params, target_params: networks_lib.Params,
            key_grad: networks_lib.PRNGKey, sample: reverb.ReplaySample
        ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
            """Computes mean transformed N-step loss for a batch of sequences."""

            # Convert sample data to sequence-major format [T, B, ...].
            data = utils.batch_to_sequence(sample.data)

            # Get core state & warm it up on observations for a burn-in period.
            if use_core_state:
                # Replay core state.
                online_state = jax.tree_map(lambda x: x[0],
                                            data.extras['core_state'])
            else:
                online_state = initial_state
            target_state = online_state

            # Maybe burn the core state in.
            if burn_in_length:
                burn_obs = jax.tree_map(lambda x: x[:burn_in_length],
                                        data.observation)
                key_grad, key1, key2 = jax.random.split(key_grad, 3)
                _, online_state = unroll.apply(params, key1, burn_obs,
                                               online_state)
                _, target_state = unroll.apply(target_params, key2, burn_obs,
                                               target_state)

            # Only get data to learn on from after the end of the burn in period.
            data = jax.tree_map(lambda seq: seq[burn_in_length:], data)

            # Unroll on sequences to get online and target Q-Values.
            key1, key2 = jax.random.split(key_grad)
            online_q, _ = unroll.apply(params, key1, data.observation,
                                       online_state)
            target_q, _ = unroll.apply(target_params, key2, data.observation,
                                       target_state)

            # Get value-selector actions from online Q-values for double Q-learning.
            selector_actions = jnp.argmax(online_q, axis=-1)
            # Preprocess discounts & rewards.
            discounts = (data.discount * discount).astype(online_q.dtype)
            rewards = data.reward
            if clip_rewards:
                rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward)
            rewards = rewards.astype(online_q.dtype)

            # Get N-step transformed TD error and loss.
            batch_td_error_fn = jax.vmap(functools.partial(
                rlax.transformed_n_step_q_learning,
                n=bootstrap_n,
                tx_pair=tx_pair),
                                         in_axes=1,
                                         out_axes=1)
            # TODO(b/183945808): when this bug is fixed, truncations of actions,
            # rewards, and discounts will no longer be necessary.
            batch_td_error = batch_td_error_fn(online_q[:-1], data.action[:-1],
                                               target_q[1:],
                                               selector_actions[1:],
                                               rewards[:-1], discounts[:-1])
            batch_loss = 0.5 * jnp.square(batch_td_error).sum(axis=0)

            # Importance weighting.
            probs = sample.info.probability
            importance_weights = (1. / (probs + 1e-6)).astype(online_q.dtype)
            importance_weights **= importance_sampling_exponent
            importance_weights /= jnp.max(importance_weights)
            mean_loss = jnp.mean(importance_weights * batch_loss)

            # Calculate priorities as a mixture of max and mean sequence errors.
            abs_td_error = jnp.abs(batch_td_error).astype(online_q.dtype)
            max_priority = max_priority_weight * jnp.max(abs_td_error, axis=0)
            mean_priority = (1 - max_priority_weight) * jnp.mean(abs_td_error,
                                                                 axis=0)
            priorities = (max_priority + mean_priority)

            return mean_loss, priorities

        def sgd_step(
            state: TrainingState, samples: reverb.ReplaySample
        ) -> Tuple[TrainingState, jnp.ndarray, Dict[str, jnp.ndarray]]:
            """Performs an update step, averaging over pmap replicas."""

            # Compute loss and gradients.
            grad_fn = jax.value_and_grad(loss, has_aux=True)
            key, key_grad = jax.random.split(state.random_key)
            (loss_value,
             priorities), gradients = grad_fn(state.params,
                                              state.target_params, key_grad,
                                              samples)

            # Average gradients over pmap replicas before optimizer update.
            gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME)

            # Apply optimizer updates.
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optax.apply_updates(state.params, updates)

            # Periodically update target networks.
            steps = state.steps + 1
            target_params = rlax.periodic_update(new_params,
                                                 state.target_params, steps,
                                                 self._target_update_period)

            new_state = TrainingState(params=new_params,
                                      target_params=target_params,
                                      opt_state=new_opt_state,
                                      steps=steps,
                                      random_key=key)
            return new_state, priorities, {'loss': loss_value}

        def update_priorities(keys_and_priorities: Tuple[jnp.ndarray,
                                                         jnp.ndarray]):
            keys, priorities = keys_and_priorities
            keys, priorities = tree.map_structure(
                # Fetch array and combine device and batch dimensions.
                lambda x: utils.fetch_devicearray(x).reshape(
                    (-1, ) + x.shape[2:]),
                (keys, priorities))
            replay_client.mutate_priorities(  # pytype: disable=attribute-error
                table=adders.DEFAULT_PRIORITY_TABLE,
                updates=dict(zip(keys, priorities)))

        # Internalise components, hyperparameters, logger, counter, and methods.
        self._replay_client = replay_client
        self._target_update_period = target_update_period
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger(
            'learner',
            asynchronous=True,
            serialize_fn=utils.fetch_devicearray,
            time_delta=1.)

        self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME)
        self._async_priority_updater = async_utils.AsyncExecutor(
            update_priorities)

        # Initialise and internalise training state (parameters/optimiser state).
        random_key, key_init = jax.random.split(random_key)
        initial_params = unroll.init(key_init, initial_state)
        opt_state = optimizer.init(initial_params)

        state = TrainingState(params=initial_params,
                              target_params=initial_params,
                              opt_state=opt_state,
                              steps=jnp.array(0),
                              random_key=random_key)
        # Replicate parameters.
        self._state = utils.replicate_in_all_devices(state)

        # Shard multiple inputs with on-device prefetching.
        # We split samples in two outputs, the keys which need to be kept on-host
        # since int64 arrays are not supported in TPUs, and the entire sample
        # separately so it can be sent to the sgd_step method.
        def split_sample(
                sample: reverb.ReplaySample) -> utils.PrefetchingSplit:
            return utils.PrefetchingSplit(host=sample.info.key, device=sample)

        self._prefetched_iterator = utils.sharded_prefetch(
            iterator,
            buffer_size=prefetch_size,
            num_threads=jax.local_device_count(),
            split_fn=split_sample)
Пример #5
0
    def __init__(
        self,
        networks: impala_networks.IMPALANetworks,
        iterator: Iterator[reverb.ReplaySample],
        optimizer: optax.GradientTransformation,
        random_key: networks_lib.PRNGKey,
        discount: float = 0.99,
        entropy_cost: float = 0.,
        baseline_cost: float = 1.,
        max_abs_reward: float = np.inf,
        counter: Optional[counting.Counter] = None,
        logger: Optional[loggers.Logger] = None,
        devices: Optional[Sequence[jax.xla.Device]] = None,
        prefetch_size: int = 2,
        num_prefetch_threads: Optional[int] = None,
    ):
        local_devices = jax.local_devices()
        process_id = jax.process_index()
        logging.info('Learner process id: %s. Devices passed: %s', process_id,
                     devices)
        logging.info('Learner process id: %s. Local devices from JAX API: %s',
                     process_id, local_devices)
        self._devices = devices or local_devices
        self._local_devices = [d for d in self._devices if d in local_devices]

        loss_fn = losses.impala_loss(networks.unroll_fn,
                                     discount=discount,
                                     max_abs_reward=max_abs_reward,
                                     baseline_cost=baseline_cost,
                                     entropy_cost=entropy_cost)

        @jax.jit
        def sgd_step(
            state: TrainingState, sample: reverb.ReplaySample
        ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:
            """Computes an SGD step, returning new state and metrics for logging."""

            # Compute gradients.
            grad_fn = jax.value_and_grad(loss_fn)
            loss_value, gradients = grad_fn(state.params, sample)

            # Average gradients over pmap replicas before optimizer update.
            gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME)

            # Apply updates.
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optax.apply_updates(state.params, updates)

            metrics = {
                'loss': loss_value,
                'param_norm': optax.global_norm(new_params),
                'param_updates_norm': optax.global_norm(updates),
            }

            new_state = TrainingState(params=new_params,
                                      opt_state=new_opt_state)

            return new_state, metrics

        def make_initial_state(key: jnp.ndarray) -> TrainingState:
            """Initialises the training state (parameters and optimiser state)."""
            key, key_initial_state = jax.random.split(key)
            # Note: parameters do not depend on the batch size, so initial_state below
            # does not need a batch dimension.
            params = networks.initial_state_init_fn(key_initial_state)
            # TODO(jferret): as it stands, we do not yet support
            # training the initial state params.
            initial_state = networks.initial_state_fn(params)

            initial_params = networks.unroll_init_fn(key, initial_state)
            initial_opt_state = optimizer.init(initial_params)
            return TrainingState(params=initial_params,
                                 opt_state=initial_opt_state)

        # Initialise training state (parameters and optimiser state).
        state = make_initial_state(random_key)
        self._state = utils.replicate_in_all_devices(state,
                                                     self._local_devices)

        if num_prefetch_threads is None:
            num_prefetch_threads = len(self._local_devices)
        self._prefetched_iterator = utils.sharded_prefetch(
            iterator,
            buffer_size=prefetch_size,
            devices=self._local_devices,
            num_threads=num_prefetch_threads,
        )

        self._sgd_step = jax.pmap(sgd_step,
                                  axis_name=_PMAP_AXIS_NAME,
                                  devices=self._devices)

        # Set up logging/counting.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger('learner')