def __init__(self, to: base.Logger): """Initializes the logger. Args: to: A `Logger` object to which the current object will forward its results when `write` is called. """ self._to = to self._async_worker = async_utils.AsyncExecutor(self._to.write, queue_size=5)
def __init__(self, network: networks.QNetwork, obs_spec: specs.Array, discount: float, importance_sampling_exponent: float, target_update_period: int, iterator: Iterator[reverb.ReplaySample], optimizer: optix.InitUpdate, rng: hk.PRNGSequence, max_abs_reward: float = 1., huber_loss_parameter: float = 1., replay_client: reverb.Client = None, counter: counting.Counter = None, logger: loggers.Logger = None): """Initializes the learner.""" # Transform network into a pure function. network = hk.transform(network) def loss(params: hk.Params, target_params: hk.Params, sample: reverb.ReplaySample): o_tm1, a_tm1, r_t, d_t, o_t = sample.data keys, probs = sample.info[:2] # Forward pass. q_tm1 = network.apply(params, o_tm1) q_t_value = network.apply(target_params, o_t) q_t_selector = network.apply(params, o_t) # Cast and clip rewards. d_t = (d_t * discount).astype(jnp.float32) r_t = jnp.clip(r_t, -max_abs_reward, max_abs_reward).astype(jnp.float32) # Compute double Q-learning n-step TD-error. batch_error = jax.vmap(rlax.double_q_learning) td_error = batch_error(q_tm1, a_tm1, r_t, d_t, q_t_value, q_t_selector) batch_loss = rlax.huber_loss(td_error, huber_loss_parameter) # Importance weighting. importance_weights = (1. / probs).astype(jnp.float32) importance_weights **= importance_sampling_exponent importance_weights /= jnp.max(importance_weights) # Reweight. mean_loss = jnp.mean(importance_weights * batch_loss) # [] priorities = jnp.abs(td_error).astype(jnp.float64) return mean_loss, (keys, priorities) def sgd_step( state: TrainingState, samples: reverb.ReplaySample ) -> Tuple[TrainingState, LearnerOutputs]: grad_fn = jax.grad(loss, has_aux=True) gradients, (keys, priorities) = grad_fn(state.params, state.target_params, samples) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optix.apply_updates(state.params, updates) new_state = TrainingState(params=new_params, target_params=state.target_params, opt_state=new_opt_state, step=state.step + 1) outputs = LearnerOutputs(keys=keys, priorities=priorities) return new_state, outputs def update_priorities(outputs: LearnerOutputs): for key, priority in zip(outputs.keys, outputs.priorities): replay_client.mutate_priorities( table=adders.DEFAULT_PRIORITY_TABLE, updates={key: priority}) # Internalise agent components (replay buffer, networks, optimizer). self._replay_client = replay_client self._iterator = utils.prefetch(iterator) # Internalise the hyperparameters. self._target_update_period = target_update_period # Internalise logging/counting objects. self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Initialise parameters and optimiser state. initial_params = network.init( next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec))) initial_target_params = network.init( next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec))) initial_opt_state = optimizer.init(initial_params) self._state = TrainingState(params=initial_params, target_params=initial_target_params, opt_state=initial_opt_state, step=0) self._forward = jax.jit(network.apply) self._sgd_step = jax.jit(sgd_step) self._async_priority_updater = async_utils.AsyncExecutor( update_priorities)
def __init__(self, network: networks_lib.FeedForwardNetwork, loss_fn: LossFn, optimizer: optax.GradientTransformation, data_iterator: Iterator[reverb.ReplaySample], target_update_period: int, random_key: networks_lib.PRNGKey, replay_client: Optional[reverb.Client] = None, replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, num_sgd_steps_per_step: int = 1): """Initialize the SGD learner.""" self.network = network # Internalize the loss_fn with network. self._loss = jax.jit(functools.partial(loss_fn, self.network)) # SGD performs the loss, optimizer update and periodic target net update. def sgd_step( state: TrainingState, batch: reverb.ReplaySample) -> Tuple[TrainingState, LossExtra]: next_rng_key, rng_key = jax.random.split(state.rng_key) # Implements one SGD step of the loss and updates training state (loss, extra), grads = jax.value_and_grad( self._loss, has_aux=True)(state.params, state.target_params, batch, rng_key) extra.metrics.update({'total_loss': loss}) # Apply the optimizer updates updates, new_opt_state = optimizer.update(grads, state.opt_state) new_params = optax.apply_updates(state.params, updates) # Periodically update target networks. steps = state.steps + 1 target_params = rlax.periodic_update(new_params, state.target_params, steps, target_update_period) new_training_state = TrainingState(new_params, target_params, new_opt_state, steps, next_rng_key) return new_training_state, extra def postprocess_aux(extra: LossExtra) -> LossExtra: reverb_update = jax.tree_map( lambda a: jnp.reshape(a, (-1, *a.shape[2:])), extra.reverb_update) return extra._replace(metrics=jax.tree_map(jnp.mean, extra.metrics), reverb_update=reverb_update) self._num_sgd_steps_per_step = num_sgd_steps_per_step sgd_step = utils.process_multiple_batches(sgd_step, num_sgd_steps_per_step, postprocess_aux) self._sgd_step = jax.jit(sgd_step) # Internalise agent components self._data_iterator = utils.prefetch(data_iterator) self._target_update_period = target_update_period self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None # Initialize the network parameters key_params, key_target, key_state = jax.random.split(random_key, 3) initial_params = self.network.init(key_params) initial_target_params = self.network.init(key_target) self._state = TrainingState( params=initial_params, target_params=initial_target_params, opt_state=optimizer.init(initial_params), steps=0, rng_key=key_state, ) # Update replay priorities def update_priorities(reverb_update: ReverbUpdate) -> None: if replay_client is None: return keys, priorities = tree.map_structure( utils.fetch_devicearray, (reverb_update.keys, reverb_update.priorities)) replay_client.mutate_priorities(table=replay_table_name, updates=dict(zip(keys, priorities))) self._replay_client = replay_client self._async_priority_updater = async_utils.AsyncExecutor( update_priorities)
def __init__(self, unroll: networks_lib.FeedForwardNetwork, initial_state: networks_lib.FeedForwardNetwork, batch_size: int, random_key: networks_lib.PRNGKey, burn_in_length: int, discount: float, importance_sampling_exponent: float, max_priority_weight: float, target_update_period: int, iterator: Iterator[reverb.ReplaySample], optimizer: optax.GradientTransformation, bootstrap_n: int = 5, tx_pair: rlax.TxPair = rlax.SIGNED_HYPERBOLIC_PAIR, clip_rewards: bool = False, max_abs_reward: float = 1., use_core_state: bool = True, prefetch_size: int = 2, replay_client: Optional[reverb.Client] = None, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None): """Initializes the learner.""" random_key, key_initial_1, key_initial_2 = jax.random.split( random_key, 3) initial_state_params = initial_state.init(key_initial_1, batch_size) initial_state = initial_state.apply(initial_state_params, key_initial_2, batch_size) def loss( params: networks_lib.Params, target_params: networks_lib.Params, key_grad: networks_lib.PRNGKey, sample: reverb.ReplaySample ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: """Computes mean transformed N-step loss for a batch of sequences.""" # Convert sample data to sequence-major format [T, B, ...]. data = utils.batch_to_sequence(sample.data) # Get core state & warm it up on observations for a burn-in period. if use_core_state: # Replay core state. online_state = jax.tree_map(lambda x: x[0], data.extras['core_state']) else: online_state = initial_state target_state = online_state # Maybe burn the core state in. if burn_in_length: burn_obs = jax.tree_map(lambda x: x[:burn_in_length], data.observation) key_grad, key1, key2 = jax.random.split(key_grad, 3) _, online_state = unroll.apply(params, key1, burn_obs, online_state) _, target_state = unroll.apply(target_params, key2, burn_obs, target_state) # Only get data to learn on from after the end of the burn in period. data = jax.tree_map(lambda seq: seq[burn_in_length:], data) # Unroll on sequences to get online and target Q-Values. key1, key2 = jax.random.split(key_grad) online_q, _ = unroll.apply(params, key1, data.observation, online_state) target_q, _ = unroll.apply(target_params, key2, data.observation, target_state) # Get value-selector actions from online Q-values for double Q-learning. selector_actions = jnp.argmax(online_q, axis=-1) # Preprocess discounts & rewards. discounts = (data.discount * discount).astype(online_q.dtype) rewards = data.reward if clip_rewards: rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward) rewards = rewards.astype(online_q.dtype) # Get N-step transformed TD error and loss. batch_td_error_fn = jax.vmap(functools.partial( rlax.transformed_n_step_q_learning, n=bootstrap_n, tx_pair=tx_pair), in_axes=1, out_axes=1) # TODO(b/183945808): when this bug is fixed, truncations of actions, # rewards, and discounts will no longer be necessary. batch_td_error = batch_td_error_fn(online_q[:-1], data.action[:-1], target_q[1:], selector_actions[1:], rewards[:-1], discounts[:-1]) batch_loss = 0.5 * jnp.square(batch_td_error).sum(axis=0) # Importance weighting. probs = sample.info.probability importance_weights = (1. / (probs + 1e-6)).astype(online_q.dtype) importance_weights **= importance_sampling_exponent importance_weights /= jnp.max(importance_weights) mean_loss = jnp.mean(importance_weights * batch_loss) # Calculate priorities as a mixture of max and mean sequence errors. abs_td_error = jnp.abs(batch_td_error).astype(online_q.dtype) max_priority = max_priority_weight * jnp.max(abs_td_error, axis=0) mean_priority = (1 - max_priority_weight) * jnp.mean(abs_td_error, axis=0) priorities = (max_priority + mean_priority) return mean_loss, priorities def sgd_step( state: TrainingState, samples: reverb.ReplaySample ) -> Tuple[TrainingState, jnp.ndarray, Dict[str, jnp.ndarray]]: """Performs an update step, averaging over pmap replicas.""" # Compute loss and gradients. grad_fn = jax.value_and_grad(loss, has_aux=True) key, key_grad = jax.random.split(state.random_key) (loss_value, priorities), gradients = grad_fn(state.params, state.target_params, key_grad, samples) # Average gradients over pmap replicas before optimizer update. gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME) # Apply optimizer updates. updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) # Periodically update target networks. steps = state.steps + 1 target_params = rlax.periodic_update(new_params, state.target_params, steps, self._target_update_period) new_state = TrainingState(params=new_params, target_params=target_params, opt_state=new_opt_state, steps=steps, random_key=key) return new_state, priorities, {'loss': loss_value} def update_priorities(keys_and_priorities: Tuple[jnp.ndarray, jnp.ndarray]): keys, priorities = keys_and_priorities keys, priorities = tree.map_structure( # Fetch array and combine device and batch dimensions. lambda x: utils.fetch_devicearray(x).reshape( (-1, ) + x.shape[2:]), (keys, priorities)) replay_client.mutate_priorities( # pytype: disable=attribute-error table=adders.DEFAULT_PRIORITY_TABLE, updates=dict(zip(keys, priorities))) # Internalise components, hyperparameters, logger, counter, and methods. self._replay_client = replay_client self._target_update_period = target_update_period self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray, time_delta=1.) self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME) self._async_priority_updater = async_utils.AsyncExecutor( update_priorities) # Initialise and internalise training state (parameters/optimiser state). random_key, key_init = jax.random.split(random_key) initial_params = unroll.init(key_init, initial_state) opt_state = optimizer.init(initial_params) state = TrainingState(params=initial_params, target_params=initial_params, opt_state=opt_state, steps=jnp.array(0), random_key=random_key) # Replicate parameters. self._state = utils.replicate_in_all_devices(state) # Shard multiple inputs with on-device prefetching. # We split samples in two outputs, the keys which need to be kept on-host # since int64 arrays are not supported in TPUs, and the entire sample # separately so it can be sent to the sgd_step method. def split_sample( sample: reverb.ReplaySample) -> utils.PrefetchingSplit: return utils.PrefetchingSplit(host=sample.info.key, device=sample) self._prefetched_iterator = utils.sharded_prefetch( iterator, buffer_size=prefetch_size, num_threads=jax.local_device_count(), split_fn=split_sample)
def __init__(self, network: networks_lib.FeedForwardNetwork, obs_spec: specs.Array, loss_fn: LossFn, optimizer: optax.GradientTransformation, data_iterator: Iterator[reverb.ReplaySample], target_update_period: int, random_key: networks_lib.PRNGKey, replay_client: Optional[reverb.Client] = None, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None): """Initialize the SGD learner.""" self.network = network # Internalize the loss_fn with network. self._loss = jax.jit(functools.partial(loss_fn, self.network)) # SGD performs the loss, optimizer update and periodic target net update. def sgd_step( state: TrainingState, batch: reverb.ReplaySample) -> Tuple[TrainingState, LossExtra]: next_rng_key, rng_key = jax.random.split(state.rng_key) # Implements one SGD step of the loss and updates training state (loss, extra), grads = jax.value_and_grad( self._loss, has_aux=True)(state.params, state.target_params, batch, rng_key) extra.metrics.update({'total_loss': loss}) # Apply the optimizer updates updates, new_opt_state = optimizer.update(grads, state.opt_state) new_params = optax.apply_updates(state.params, updates) # Periodically update target networks. steps = state.steps + 1 target_params = rlax.periodic_update(new_params, state.target_params, steps, target_update_period) new_training_state = TrainingState(new_params, target_params, new_opt_state, steps, next_rng_key) return new_training_state, extra self._sgd_step = jax.jit(sgd_step) # Internalise agent components self._data_iterator = utils.prefetch(data_iterator) self._target_update_period = target_update_period self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Initialize the network parameters dummy_obs = utils.add_batch_dim(utils.zeros_like(obs_spec)) key_params, key_target, key_state = jax.random.split(random_key, 3) initial_params = self.network.init(key_params, dummy_obs) initial_target_params = self.network.init(key_target, dummy_obs) self._state = TrainingState( params=initial_params, target_params=initial_target_params, opt_state=optimizer.init(initial_params), steps=0, rng_key=key_state, ) # Update replay priorities def update_priorities(reverb_update: Optional[ReverbUpdate]) -> None: if reverb_update is None or replay_client is None: return else: replay_client.mutate_priorities( table=adders.DEFAULT_PRIORITY_TABLE, updates=dict( zip(reverb_update.keys, reverb_update.priorities))) self._replay_client = replay_client self._async_priority_updater = async_utils.AsyncExecutor( update_priorities)