Beispiel #1
0
    def __init__(
        self,
        networks: impala_networks.IMPALANetworks,
        iterator: Iterator[reverb.ReplaySample],
        optimizer: optax.GradientTransformation,
        random_key: networks_lib.PRNGKey,
        discount: float = 0.99,
        entropy_cost: float = 0.,
        baseline_cost: float = 1.,
        max_abs_reward: float = np.inf,
        counter: Optional[counting.Counter] = None,
        logger: Optional[loggers.Logger] = None,
        devices: Optional[Sequence[jax.xla.Device]] = None,
        prefetch_size: int = 2,
        num_prefetch_threads: Optional[int] = None,
    ):
        local_devices = jax.local_devices()
        process_id = jax.process_index()
        logging.info('Learner process id: %s. Devices passed: %s', process_id,
                     devices)
        logging.info('Learner process id: %s. Local devices from JAX API: %s',
                     process_id, local_devices)
        self._devices = devices or local_devices
        self._local_devices = [d for d in self._devices if d in local_devices]

        loss_fn = losses.impala_loss(networks.unroll_fn,
                                     discount=discount,
                                     max_abs_reward=max_abs_reward,
                                     baseline_cost=baseline_cost,
                                     entropy_cost=entropy_cost)

        @jax.jit
        def sgd_step(
            state: TrainingState, sample: reverb.ReplaySample
        ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:
            """Computes an SGD step, returning new state and metrics for logging."""

            # Compute gradients.
            grad_fn = jax.value_and_grad(loss_fn)
            loss_value, gradients = grad_fn(state.params, sample)

            # Average gradients over pmap replicas before optimizer update.
            gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME)

            # Apply updates.
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optax.apply_updates(state.params, updates)

            metrics = {
                'loss': loss_value,
                'param_norm': optax.global_norm(new_params),
                'param_updates_norm': optax.global_norm(updates),
            }

            new_state = TrainingState(params=new_params,
                                      opt_state=new_opt_state)

            return new_state, metrics

        def make_initial_state(key: jnp.ndarray) -> TrainingState:
            """Initialises the training state (parameters and optimiser state)."""
            key, key_initial_state = jax.random.split(key)
            # Note: parameters do not depend on the batch size, so initial_state below
            # does not need a batch dimension.
            params = networks.initial_state_init_fn(key_initial_state)
            # TODO(jferret): as it stands, we do not yet support
            # training the initial state params.
            initial_state = networks.initial_state_fn(params)

            initial_params = networks.unroll_init_fn(key, initial_state)
            initial_opt_state = optimizer.init(initial_params)
            return TrainingState(params=initial_params,
                                 opt_state=initial_opt_state)

        # Initialise training state (parameters and optimiser state).
        state = make_initial_state(random_key)
        self._state = utils.replicate_in_all_devices(state,
                                                     self._local_devices)

        if num_prefetch_threads is None:
            num_prefetch_threads = len(self._local_devices)
        self._prefetched_iterator = utils.sharded_prefetch(
            iterator,
            buffer_size=prefetch_size,
            devices=self._local_devices,
            num_threads=num_prefetch_threads,
        )

        self._sgd_step = jax.pmap(sgd_step,
                                  axis_name=_PMAP_AXIS_NAME,
                                  devices=self._devices)

        # Set up logging/counting.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger('learner')
Beispiel #2
0
  def __init__(
      self,
      obs_spec: specs.Array,
      unroll_fn: networks_lib.PolicyValueRNN,
      initial_state_fn: Callable[[], hk.LSTMState],
      iterator: Iterator[reverb.ReplaySample],
      optimizer: optax.GradientTransformation,
      random_key: networks_lib.PRNGKey,
      discount: float = 0.99,
      entropy_cost: float = 0.,
      baseline_cost: float = 1.,
      max_abs_reward: float = np.inf,
      counter: counting.Counter = None,
      logger: loggers.Logger = None,
      devices: Optional[Sequence[jax.xla.Device]] = None,
      prefetch_size: int = 2,
      num_prefetch_threads: Optional[int] = None,
  ):

    self._devices = devices or jax.local_devices()

    # Transform into pure functions.
    unroll_fn = hk.without_apply_rng(hk.transform(unroll_fn, apply_rng=True))
    initial_state_fn = hk.without_apply_rng(
        hk.transform(initial_state_fn, apply_rng=True))

    loss_fn = losses.impala_loss(
        unroll_fn,
        discount=discount,
        max_abs_reward=max_abs_reward,
        baseline_cost=baseline_cost,
        entropy_cost=entropy_cost)

    @jax.jit
    def sgd_step(
        state: TrainingState, sample: reverb.ReplaySample
    ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:
      """Computes an SGD step, returning new state and metrics for logging."""

      # Compute gradients.
      grad_fn = jax.value_and_grad(loss_fn)
      loss_value, gradients = grad_fn(state.params, sample)

      # Average gradients over pmap replicas before optimizer update.
      gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME)

      # Apply updates.
      updates, new_opt_state = optimizer.update(gradients, state.opt_state)
      new_params = optax.apply_updates(state.params, updates)

      metrics = {
          'loss': loss_value,
      }

      new_state = TrainingState(params=new_params, opt_state=new_opt_state)

      return new_state, metrics

    def make_initial_state(key: jnp.ndarray) -> TrainingState:
      """Initialises the training state (parameters and optimiser state)."""
      dummy_obs = utils.zeros_like(obs_spec)
      dummy_obs = utils.add_batch_dim(dummy_obs)  # Dummy 'sequence' dim.
      initial_state = initial_state_fn.apply(None)
      initial_params = unroll_fn.init(key, dummy_obs, initial_state)
      initial_opt_state = optimizer.init(initial_params)
      return TrainingState(params=initial_params, opt_state=initial_opt_state)

    # Initialise training state (parameters and optimiser state).
    state = make_initial_state(random_key)
    self._state = utils.replicate_in_all_devices(state, self._devices)

    if num_prefetch_threads is None:
      num_prefetch_threads = len(self._devices)
    self._prefetched_iterator = utils.sharded_prefetch(
        iterator,
        buffer_size=prefetch_size,
        devices=devices,
        num_threads=num_prefetch_threads,
    )

    self._sgd_step = jax.pmap(
        sgd_step, axis_name=_PMAP_AXIS_NAME, devices=self._devices)

    # Set up logging/counting.
    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.make_default_logger('learner')
Beispiel #3
0
    def __init__(
        self,
        obs_spec: specs.Array,
        unroll_fn: networks.PolicyValueRNN,
        initial_state_fn: Callable[[], hk.LSTMState],
        iterator: Iterator[reverb.ReplaySample],
        optimizer: optax.GradientTransformation,
        rng: hk.PRNGSequence,
        discount: float = 0.99,
        entropy_cost: float = 0.,
        baseline_cost: float = 1.,
        max_abs_reward: float = np.inf,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
    ):

        # Transform into pure functions.
        unroll_fn = hk.without_apply_rng(
            hk.transform(unroll_fn, apply_rng=True))
        initial_state_fn = hk.without_apply_rng(
            hk.transform(initial_state_fn, apply_rng=True))

        loss_fn = losses.impala_loss(unroll_fn,
                                     discount=discount,
                                     max_abs_reward=max_abs_reward,
                                     baseline_cost=baseline_cost,
                                     entropy_cost=entropy_cost)

        @jax.jit
        def sgd_step(
            state: TrainingState, sample: reverb.ReplaySample
        ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:
            """Computes an SGD step, returning new state and metrics for logging."""

            # Compute gradients.
            grad_fn = jax.value_and_grad(loss_fn)
            loss_value, gradients = grad_fn(state.params, sample)

            # Apply updates.
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optax.apply_updates(state.params, updates)

            metrics = {
                'loss': loss_value,
            }

            new_state = TrainingState(params=new_params,
                                      opt_state=new_opt_state)

            return new_state, metrics

        def make_initial_state(key: jnp.ndarray) -> TrainingState:
            """Initialises the training state (parameters and optimiser state)."""
            dummy_obs = utils.zeros_like(obs_spec)
            dummy_obs = utils.add_batch_dim(dummy_obs)  # Dummy 'sequence' dim.
            initial_state = initial_state_fn.apply(None)
            initial_params = unroll_fn.init(key, dummy_obs, initial_state)
            initial_opt_state = optimizer.init(initial_params)
            return TrainingState(params=initial_params,
                                 opt_state=initial_opt_state)

        # Initialise training state (parameters and optimiser state).
        self._state = make_initial_state(next(rng))

        # Internalise iterator.
        self._iterator = iterator
        self._sgd_step = sgd_step

        # Set up logging/counting.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger('learner')