def __call__( self, network: networks_lib.FeedForwardNetwork, params: networks_lib.Params, target_params: networks_lib.Params, batch: reverb.ReplaySample, key: networks_lib.PRNGKey, ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]: """Calculate a loss on a single batch of data.""" del key transitions: types.Transition = batch.data # Forward pass. q_tm1 = network.apply(params, transitions.observation) q_t = network.apply(target_params, transitions.next_observation) d_t = (transitions.discount * self.discount).astype(jnp.float32) # Compute Q-learning TD-error. batch_error = jax.vmap(rlax.q_learning) td_error = batch_error(q_tm1, transitions.action, transitions.reward, d_t, q_t) td_error = 0.5 * jnp.square(td_error) def select(qtm1, action): return qtm1[action] q_regularizer = jax.vmap(select)(q_tm1, transitions.action) loss = self.regularizer_coeff * jnp.mean(q_regularizer) + jnp.mean( td_error) extra = learning_lib.LossExtra(metrics={}) return loss, extra
def __call__( self, network: networks_lib.FeedForwardNetwork, params: networks_lib.Params, target_params: networks_lib.Params, batch: reverb.ReplaySample, key: networks_lib.PRNGKey, ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]: """Calculate a loss on a single batch of data.""" del key transitions: types.Transition = batch.data # Forward pass. q_tm1 = network.apply(params, transitions.observation) q_t = network.apply(target_params, transitions.next_observation) # Cast and clip rewards. d_t = (transitions.discount * self.discount).astype(jnp.float32) r_t = jnp.clip(transitions.reward, -self.max_abs_reward, self.max_abs_reward).astype(jnp.float32) # Compute Q-learning TD-error. batch_error = jax.vmap(rlax.q_learning) td_error = batch_error(q_tm1, transitions.action, r_t, d_t, q_t) batch_loss = jnp.square(td_error) loss = jnp.mean(batch_loss) extra = learning_lib.LossExtra(metrics={}) return loss, extra
def default_behavior_policy(network: networks_lib.FeedForwardNetwork, epsilon: float, params: networks_lib.Params, key: networks_lib.PRNGKey, observation: networks_lib.Observation): """Returns an action for the given observation.""" action_values = network.apply(params, observation) actions = rlax.epsilon_greedy(epsilon).sample(key, action_values) return actions.astype(jnp.int32)
def __call__( self, network: networks_lib.FeedForwardNetwork, params: networks_lib.Params, target_params: networks_lib.Params, batch: reverb.ReplaySample, key: networks_lib.PRNGKey, ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]: """Calculate a loss on a single batch of data.""" del key transitions: types.Transition = batch.data # Forward pass. q_online_s = network.apply(params, transitions.observation) action_one_hot = jax.nn.one_hot(transitions.action, q_online_s.shape[-1]) q_online_sa = jnp.sum(action_one_hot * q_online_s, axis=-1) q_target_s = network.apply(target_params, transitions.observation) q_target_next = network.apply(target_params, transitions.next_observation) # Cast and clip rewards. d_t = (transitions.discount * self.discount).astype(jnp.float32) r_t = jnp.clip(transitions.reward, -self.max_abs_reward, self.max_abs_reward).astype(jnp.float32) # Munchausen term : tau * log_pi(a|s) munchausen_term = self.entropy_temperature * jax.nn.log_softmax( q_target_s / self.entropy_temperature, axis=-1) munchausen_term_a = jnp.sum(action_one_hot * munchausen_term, axis=-1) munchausen_term_a = jnp.clip(munchausen_term_a, a_min=self.clip_value_min, a_max=0.) # Soft Bellman operator applied to q next_v = self.entropy_temperature * jax.nn.logsumexp( q_target_next / self.entropy_temperature, axis=-1) target_q = jax.lax.stop_gradient(r_t + self.munchausen_coefficient * munchausen_term_a + d_t * next_v) batch_loss = rlax.huber_loss(target_q - q_online_sa, self.huber_loss_parameter) loss = jnp.mean(batch_loss) extra = learning_lib.LossExtra(metrics={}) return loss, extra
def __call__( self, network: networks_lib.FeedForwardNetwork, params: networks_lib.Params, target_params: networks_lib.Params, batch: reverb.ReplaySample, key: networks_lib.PRNGKey, ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]: """Calculate a loss on a single batch of data.""" del key transitions: types.Transition = batch.data keys, probs, *_ = batch.info # Forward pass. _, logits_tm1, atoms_tm1 = network.apply(params, transitions.observation) _, logits_t, atoms_t = network.apply(target_params, transitions.next_observation) q_t_selector, _, _ = network.apply(params, transitions.next_observation) # Cast and clip rewards. d_t = (transitions.discount * self.discount).astype(jnp.float32) r_t = jnp.clip(transitions.reward, -self.max_abs_reward, self.max_abs_reward).astype(jnp.float32) # Compute categorical double Q-learning loss. batch_loss_fn = jax.vmap(rlax.categorical_double_q_learning, in_axes=(None, 0, 0, 0, 0, None, 0, 0)) batch_loss = batch_loss_fn(atoms_tm1, logits_tm1, transitions.action, r_t, d_t, atoms_t, logits_t, q_t_selector) # Importance weighting. importance_weights = (1. / probs).astype(jnp.float32) importance_weights **= self.importance_sampling_exponent importance_weights /= jnp.max(importance_weights) # Reweight. loss = jnp.mean(importance_weights * batch_loss) # [] reverb_update = learning_lib.ReverbUpdate( keys=keys, priorities=jnp.abs(batch_loss).astype(jnp.float64)) extra = learning_lib.LossExtra(metrics={}, reverb_update=reverb_update) return loss, extra
def __call__( self, network: networks_lib.FeedForwardNetwork, params: networks_lib.Params, target_params: networks_lib.Params, batch: reverb.ReplaySample, key: jnp.DeviceArray, ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]: """Calculate a loss on a single batch of data.""" del key transitions: types.Transition = batch.data keys, probs, *_ = batch.info # Forward pass. q_tm1 = network.apply(params, transitions.observation) q_t_value = network.apply(target_params, transitions.next_observation) q_t_selector = network.apply(params, transitions.next_observation) # Cast and clip rewards. d_t = (transitions.discount * self.discount).astype(jnp.float32) r_t = jnp.clip(transitions.reward, -self.max_abs_reward, self.max_abs_reward).astype(jnp.float32) # Compute double Q-learning n-step TD-error. batch_error = jax.vmap(rlax.double_q_learning) td_error = batch_error(q_tm1, transitions.action, r_t, d_t, q_t_value, q_t_selector) batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter) # Importance weighting. importance_weights = (1. / probs).astype(jnp.float32) importance_weights **= self.importance_sampling_exponent importance_weights /= jnp.max(importance_weights) # Reweight. loss = jnp.mean(importance_weights * batch_loss) # [] reverb_update = learning_lib.ReverbUpdate( keys=keys, priorities=jnp.abs(td_error).astype(jnp.float64)) extra = learning_lib.LossExtra(metrics={}, reverb_update=reverb_update) return loss, extra
def __call__( self, network: networks_lib.FeedForwardNetwork, params: networks_lib.Params, target_params: networks_lib.Params, batch: reverb.ReplaySample, key: networks_lib.PRNGKey, ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]: """Calculate a loss on a single batch of data.""" del key transitions: types.Transition = batch.data dist_q_tm1 = network.apply(params, transitions.observation)['q_dist'] dist_q_target_t = network.apply(target_params, transitions.next_observation)['q_dist'] # Swap distribution and action dimension, since # rlax.quantile_q_learning expects it that way. dist_q_tm1 = jnp.swapaxes(dist_q_tm1, 1, 2) dist_q_target_t = jnp.swapaxes(dist_q_target_t, 1, 2) quantiles = ((jnp.arange(self.num_atoms, dtype=jnp.float32) + 0.5) / self.num_atoms) batch_quantile_q_learning = jax.vmap(rlax.quantile_q_learning, in_axes=(0, None, 0, 0, 0, 0, 0, None)) losses = batch_quantile_q_learning( dist_q_tm1, quantiles, transitions.action, transitions.reward, transitions.discount, dist_q_target_t, # No double Q-learning here. dist_q_target_t, self.huber_param, ) loss = jnp.mean(losses) extra = learning_lib.LossExtra(metrics={'mean_loss': loss}) return loss, extra
def __init__(self, unroll: networks_lib.FeedForwardNetwork, initial_state: networks_lib.FeedForwardNetwork, batch_size: int, random_key: networks_lib.PRNGKey, burn_in_length: int, discount: float, importance_sampling_exponent: float, max_priority_weight: float, target_update_period: int, iterator: Iterator[reverb.ReplaySample], optimizer: optax.GradientTransformation, bootstrap_n: int = 5, tx_pair: rlax.TxPair = rlax.SIGNED_HYPERBOLIC_PAIR, clip_rewards: bool = False, max_abs_reward: float = 1., use_core_state: bool = True, prefetch_size: int = 2, replay_client: Optional[reverb.Client] = None, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None): """Initializes the learner.""" random_key, key_initial_1, key_initial_2 = jax.random.split( random_key, 3) initial_state_params = initial_state.init(key_initial_1, batch_size) initial_state = initial_state.apply(initial_state_params, key_initial_2, batch_size) def loss( params: networks_lib.Params, target_params: networks_lib.Params, key_grad: networks_lib.PRNGKey, sample: reverb.ReplaySample ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: """Computes mean transformed N-step loss for a batch of sequences.""" # Convert sample data to sequence-major format [T, B, ...]. data = utils.batch_to_sequence(sample.data) # Get core state & warm it up on observations for a burn-in period. if use_core_state: # Replay core state. online_state = jax.tree_map(lambda x: x[0], data.extras['core_state']) else: online_state = initial_state target_state = online_state # Maybe burn the core state in. if burn_in_length: burn_obs = jax.tree_map(lambda x: x[:burn_in_length], data.observation) key_grad, key1, key2 = jax.random.split(key_grad, 3) _, online_state = unroll.apply(params, key1, burn_obs, online_state) _, target_state = unroll.apply(target_params, key2, burn_obs, target_state) # Only get data to learn on from after the end of the burn in period. data = jax.tree_map(lambda seq: seq[burn_in_length:], data) # Unroll on sequences to get online and target Q-Values. key1, key2 = jax.random.split(key_grad) online_q, _ = unroll.apply(params, key1, data.observation, online_state) target_q, _ = unroll.apply(target_params, key2, data.observation, target_state) # Get value-selector actions from online Q-values for double Q-learning. selector_actions = jnp.argmax(online_q, axis=-1) # Preprocess discounts & rewards. discounts = (data.discount * discount).astype(online_q.dtype) rewards = data.reward if clip_rewards: rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward) rewards = rewards.astype(online_q.dtype) # Get N-step transformed TD error and loss. batch_td_error_fn = jax.vmap(functools.partial( rlax.transformed_n_step_q_learning, n=bootstrap_n, tx_pair=tx_pair), in_axes=1, out_axes=1) # TODO(b/183945808): when this bug is fixed, truncations of actions, # rewards, and discounts will no longer be necessary. batch_td_error = batch_td_error_fn(online_q[:-1], data.action[:-1], target_q[1:], selector_actions[1:], rewards[:-1], discounts[:-1]) batch_loss = 0.5 * jnp.square(batch_td_error).sum(axis=0) # Importance weighting. probs = sample.info.probability importance_weights = (1. / (probs + 1e-6)).astype(online_q.dtype) importance_weights **= importance_sampling_exponent importance_weights /= jnp.max(importance_weights) mean_loss = jnp.mean(importance_weights * batch_loss) # Calculate priorities as a mixture of max and mean sequence errors. abs_td_error = jnp.abs(batch_td_error).astype(online_q.dtype) max_priority = max_priority_weight * jnp.max(abs_td_error, axis=0) mean_priority = (1 - max_priority_weight) * jnp.mean(abs_td_error, axis=0) priorities = (max_priority + mean_priority) return mean_loss, priorities def sgd_step( state: TrainingState, samples: reverb.ReplaySample ) -> Tuple[TrainingState, jnp.ndarray, Dict[str, jnp.ndarray]]: """Performs an update step, averaging over pmap replicas.""" # Compute loss and gradients. grad_fn = jax.value_and_grad(loss, has_aux=True) key, key_grad = jax.random.split(state.random_key) (loss_value, priorities), gradients = grad_fn(state.params, state.target_params, key_grad, samples) # Average gradients over pmap replicas before optimizer update. gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME) # Apply optimizer updates. updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) # Periodically update target networks. steps = state.steps + 1 target_params = rlax.periodic_update(new_params, state.target_params, steps, self._target_update_period) new_state = TrainingState(params=new_params, target_params=target_params, opt_state=new_opt_state, steps=steps, random_key=key) return new_state, priorities, {'loss': loss_value} def update_priorities(keys_and_priorities: Tuple[jnp.ndarray, jnp.ndarray]): keys, priorities = keys_and_priorities keys, priorities = tree.map_structure( # Fetch array and combine device and batch dimensions. lambda x: utils.fetch_devicearray(x).reshape( (-1, ) + x.shape[2:]), (keys, priorities)) replay_client.mutate_priorities( # pytype: disable=attribute-error table=adders.DEFAULT_PRIORITY_TABLE, updates=dict(zip(keys, priorities))) # Internalise components, hyperparameters, logger, counter, and methods. self._replay_client = replay_client self._target_update_period = target_update_period self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray, time_delta=1.) self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME) self._async_priority_updater = async_utils.AsyncExecutor( update_priorities) # Initialise and internalise training state (parameters/optimiser state). random_key, key_init = jax.random.split(random_key) initial_params = unroll.init(key_init, initial_state) opt_state = optimizer.init(initial_params) state = TrainingState(params=initial_params, target_params=initial_params, opt_state=opt_state, steps=jnp.array(0), random_key=random_key) # Replicate parameters. self._state = utils.replicate_in_all_devices(state) # Shard multiple inputs with on-device prefetching. # We split samples in two outputs, the keys which need to be kept on-host # since int64 arrays are not supported in TPUs, and the entire sample # separately so it can be sent to the sgd_step method. def split_sample( sample: reverb.ReplaySample) -> utils.PrefetchingSplit: return utils.PrefetchingSplit(host=sample.info.key, device=sample) self._prefetched_iterator = utils.sharded_prefetch( iterator, buffer_size=prefetch_size, num_threads=jax.local_device_count(), split_fn=split_sample)