def __init__( self, networks: impala_networks.IMPALANetworks, iterator: Iterator[reverb.ReplaySample], optimizer: optax.GradientTransformation, random_key: networks_lib.PRNGKey, discount: float = 0.99, entropy_cost: float = 0., baseline_cost: float = 1., max_abs_reward: float = np.inf, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, devices: Optional[Sequence[jax.xla.Device]] = None, prefetch_size: int = 2, num_prefetch_threads: Optional[int] = None, ): local_devices = jax.local_devices() process_id = jax.process_index() logging.info('Learner process id: %s. Devices passed: %s', process_id, devices) logging.info('Learner process id: %s. Local devices from JAX API: %s', process_id, local_devices) self._devices = devices or local_devices self._local_devices = [d for d in self._devices if d in local_devices] loss_fn = losses.impala_loss(networks.unroll_fn, discount=discount, max_abs_reward=max_abs_reward, baseline_cost=baseline_cost, entropy_cost=entropy_cost) @jax.jit def sgd_step( state: TrainingState, sample: reverb.ReplaySample ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: """Computes an SGD step, returning new state and metrics for logging.""" # Compute gradients. grad_fn = jax.value_and_grad(loss_fn) loss_value, gradients = grad_fn(state.params, sample) # Average gradients over pmap replicas before optimizer update. gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME) # Apply updates. updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) metrics = { 'loss': loss_value, 'param_norm': optax.global_norm(new_params), 'param_updates_norm': optax.global_norm(updates), } new_state = TrainingState(params=new_params, opt_state=new_opt_state) return new_state, metrics def make_initial_state(key: jnp.ndarray) -> TrainingState: """Initialises the training state (parameters and optimiser state).""" key, key_initial_state = jax.random.split(key) # Note: parameters do not depend on the batch size, so initial_state below # does not need a batch dimension. params = networks.initial_state_init_fn(key_initial_state) # TODO(jferret): as it stands, we do not yet support # training the initial state params. initial_state = networks.initial_state_fn(params) initial_params = networks.unroll_init_fn(key, initial_state) initial_opt_state = optimizer.init(initial_params) return TrainingState(params=initial_params, opt_state=initial_opt_state) # Initialise training state (parameters and optimiser state). state = make_initial_state(random_key) self._state = utils.replicate_in_all_devices(state, self._local_devices) if num_prefetch_threads is None: num_prefetch_threads = len(self._local_devices) self._prefetched_iterator = utils.sharded_prefetch( iterator, buffer_size=prefetch_size, devices=self._local_devices, num_threads=num_prefetch_threads, ) self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME, devices=self._devices) # Set up logging/counting. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger('learner')
def __init__( self, obs_spec: specs.Array, unroll_fn: networks_lib.PolicyValueRNN, initial_state_fn: Callable[[], hk.LSTMState], iterator: Iterator[reverb.ReplaySample], optimizer: optax.GradientTransformation, random_key: networks_lib.PRNGKey, discount: float = 0.99, entropy_cost: float = 0., baseline_cost: float = 1., max_abs_reward: float = np.inf, counter: counting.Counter = None, logger: loggers.Logger = None, devices: Optional[Sequence[jax.xla.Device]] = None, prefetch_size: int = 2, num_prefetch_threads: Optional[int] = None, ): self._devices = devices or jax.local_devices() # Transform into pure functions. unroll_fn = hk.without_apply_rng(hk.transform(unroll_fn, apply_rng=True)) initial_state_fn = hk.without_apply_rng( hk.transform(initial_state_fn, apply_rng=True)) loss_fn = losses.impala_loss( unroll_fn, discount=discount, max_abs_reward=max_abs_reward, baseline_cost=baseline_cost, entropy_cost=entropy_cost) @jax.jit def sgd_step( state: TrainingState, sample: reverb.ReplaySample ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: """Computes an SGD step, returning new state and metrics for logging.""" # Compute gradients. grad_fn = jax.value_and_grad(loss_fn) loss_value, gradients = grad_fn(state.params, sample) # Average gradients over pmap replicas before optimizer update. gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME) # Apply updates. updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) metrics = { 'loss': loss_value, } new_state = TrainingState(params=new_params, opt_state=new_opt_state) return new_state, metrics def make_initial_state(key: jnp.ndarray) -> TrainingState: """Initialises the training state (parameters and optimiser state).""" dummy_obs = utils.zeros_like(obs_spec) dummy_obs = utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. initial_state = initial_state_fn.apply(None) initial_params = unroll_fn.init(key, dummy_obs, initial_state) initial_opt_state = optimizer.init(initial_params) return TrainingState(params=initial_params, opt_state=initial_opt_state) # Initialise training state (parameters and optimiser state). state = make_initial_state(random_key) self._state = utils.replicate_in_all_devices(state, self._devices) if num_prefetch_threads is None: num_prefetch_threads = len(self._devices) self._prefetched_iterator = utils.sharded_prefetch( iterator, buffer_size=prefetch_size, devices=devices, num_threads=num_prefetch_threads, ) self._sgd_step = jax.pmap( sgd_step, axis_name=_PMAP_AXIS_NAME, devices=self._devices) # Set up logging/counting. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger('learner')
def __init__( self, obs_spec: specs.Array, unroll_fn: networks.PolicyValueRNN, initial_state_fn: Callable[[], hk.LSTMState], iterator: Iterator[reverb.ReplaySample], optimizer: optax.GradientTransformation, rng: hk.PRNGSequence, discount: float = 0.99, entropy_cost: float = 0., baseline_cost: float = 1., max_abs_reward: float = np.inf, counter: counting.Counter = None, logger: loggers.Logger = None, ): # Transform into pure functions. unroll_fn = hk.without_apply_rng( hk.transform(unroll_fn, apply_rng=True)) initial_state_fn = hk.without_apply_rng( hk.transform(initial_state_fn, apply_rng=True)) loss_fn = losses.impala_loss(unroll_fn, discount=discount, max_abs_reward=max_abs_reward, baseline_cost=baseline_cost, entropy_cost=entropy_cost) @jax.jit def sgd_step( state: TrainingState, sample: reverb.ReplaySample ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: """Computes an SGD step, returning new state and metrics for logging.""" # Compute gradients. grad_fn = jax.value_and_grad(loss_fn) loss_value, gradients = grad_fn(state.params, sample) # Apply updates. updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) metrics = { 'loss': loss_value, } new_state = TrainingState(params=new_params, opt_state=new_opt_state) return new_state, metrics def make_initial_state(key: jnp.ndarray) -> TrainingState: """Initialises the training state (parameters and optimiser state).""" dummy_obs = utils.zeros_like(obs_spec) dummy_obs = utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. initial_state = initial_state_fn.apply(None) initial_params = unroll_fn.init(key, dummy_obs, initial_state) initial_opt_state = optimizer.init(initial_params) return TrainingState(params=initial_params, opt_state=initial_opt_state) # Initialise training state (parameters and optimiser state). self._state = make_initial_state(next(rng)) # Internalise iterator. self._iterator = iterator self._sgd_step = sgd_step # Set up logging/counting. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger('learner')