def __init__( self, preprocessor: processors.Processor, sample_network_input: IqnInputs, network: parts.Network, optimizer: optax.GradientTransformation, transition_accumulator: Any, replay: replay_lib.TransitionReplay, batch_size: int, exploration_epsilon: Callable[[int], float], min_replay_capacity_fraction: float, learn_period: int, target_network_update_period: int, huber_param: float, tau_samples_policy: int, tau_samples_s_tm1: int, tau_samples_s_t: int, rng_key: parts.PRNGKey, ): self._preprocessor = preprocessor self._replay = replay self._transition_accumulator = transition_accumulator self._batch_size = batch_size self._exploration_epsilon = exploration_epsilon self._min_replay_capacity = min_replay_capacity_fraction * replay.capacity self._learn_period = learn_period self._target_network_update_period = target_network_update_period # Initialize network parameters and optimizer. self._rng_key, network_rng_key = jax.random.split(rng_key) self._online_params = network.init( network_rng_key, jax.tree_map(lambda x: x[None, ...], sample_network_input)) self._target_params = self._online_params self._opt_state = optimizer.init(self._online_params) # Other agent state: last action, frame count, etc. self._action = None self._frame_t = -1 # Current frame index. # Define jitted loss, update, and policy functions here instead of as # class methods, to emphasize that these are meant to be pure functions # and should not access the agent object's state via `self`. def loss_fn(online_params, target_params, transitions, rng_key): """Calculates loss given network parameters and transitions.""" # Sample tau values for q_tm1, q_t_selector, q_t. batch_size = self._batch_size rng_key, *sample_keys = jax.random.split(rng_key, 4) tau_tm1 = _sample_tau(sample_keys[0], (batch_size, tau_samples_s_tm1)) tau_t_selector = _sample_tau(sample_keys[1], (batch_size, tau_samples_policy)) tau_t = _sample_tau(sample_keys[2], (batch_size, tau_samples_s_t)) # Compute Q value distributions. _, *apply_keys = jax.random.split(rng_key, 4) dist_q_tm1 = network.apply(online_params, apply_keys[0], IqnInputs(transitions.s_tm1, tau_tm1)).q_dist dist_q_t_selector = network.apply( target_params, apply_keys[1], IqnInputs(transitions.s_t, tau_t_selector)).q_dist dist_q_target_t = network.apply(target_params, apply_keys[2], IqnInputs(transitions.s_t, tau_t)).q_dist losses = _batch_quantile_q_learning( dist_q_tm1, tau_tm1, transitions.a_tm1, transitions.r_t, transitions.discount_t, dist_q_t_selector, dist_q_target_t, huber_param, ) assert losses.shape == (self._batch_size,) loss = jnp.mean(losses) return loss def update(rng_key, opt_state, online_params, target_params, transitions): """Computes learning update from batch of replay transitions.""" rng_key, update_key = jax.random.split(rng_key) d_loss_d_params = jax.grad(loss_fn)(online_params, target_params, transitions, update_key) updates, new_opt_state = optimizer.update(d_loss_d_params, opt_state) new_online_params = optax.apply_updates(online_params, updates) return rng_key, new_opt_state, new_online_params self._update = jax.jit(update) def select_action(rng_key, network_params, s_t, exploration_epsilon): """Samples action from eps-greedy policy wrt Q-values at given state.""" rng_key, sample_key, apply_key, policy_key = jax.random.split(rng_key, 4) tau_t = _sample_tau(sample_key, (1, tau_samples_policy)) q_t = network.apply(network_params, apply_key, IqnInputs(s_t[None, ...], tau_t)).q_values[0] a_t = rlax.epsilon_greedy().sample(policy_key, q_t, exploration_epsilon) return rng_key, a_t self._select_action = jax.jit(select_action)
def __init__( self, preprocessor: processors.Processor, sample_network_input: jnp.ndarray, network: parts.Network, optimizer: optax.GradientTransformation, transition_accumulator: Any, replay: replay_lib.TransitionReplay, shaping_function, mask_probability: float, num_heads: int, batch_size: int, exploration_epsilon: Callable[[int], float], min_replay_capacity_fraction: float, learn_period: int, target_network_update_period: int, grad_error_bound: float, rng_key: parts.PRNGKey, ): self._preprocessor = preprocessor self._replay = replay self._transition_accumulator = transition_accumulator self._mask_probabilities = jnp.array( [mask_probability, 1 - mask_probability]) self._num_heads = num_heads self._batch_size = batch_size self._exploration_epsilon = exploration_epsilon self._min_replay_capacity = min_replay_capacity_fraction * replay.capacity self._learn_period = learn_period self._target_network_update_period = target_network_update_period # Initialize network parameters and optimizer. self._rng_key, network_rng_key = jax.random.split(rng_key) self._online_params = network.init(network_rng_key, sample_network_input[None, ...]) self._target_params = self._online_params self._opt_state = optimizer.init(self._online_params) # Other agent state: last action, frame count, etc. self._action = None self._frame_t = -1 # Current frame index. # Define jitted loss, update, and policy functions here instead of as # class methods, to emphasize that these are meant to be pure functions # and should not access the agent object's state via `self`. def loss_fn(online_params, target_params, transitions, rng_key): """Calculates loss given network parameters and transitions.""" _, online_key, target_key, shaping_key = jax.random.split( rng_key, 4) q_tm1 = network.apply(online_params, online_key, transitions.s_tm1).multi_head_output q_target_t = network.apply(target_params, target_key, transitions.s_t).multi_head_output # batch by num_heads -> batch by num_heads by num_actions mask = jnp.einsum('ij,k->ijk', transitions.mask_t, jnp.ones(q_tm1.shape[-1])) masked_q = jnp.multiply(mask, q_tm1) masked_q_target = jnp.multiply(mask, q_target_t) flattened_q = jnp.reshape(q_tm1, (-1, q_tm1.shape[-1])) flattened_q_target = jnp.reshape(q_target_t, (-1, q_target_t.shape[-1])) # compute shaping function F(s, a, s') shaped_rewards = shaping_function(q_target_t, transitions, shaping_key) repeated_actions = jnp.repeat(transitions.a_tm1, num_heads) repeated_rewards = jnp.repeat(shaped_rewards, num_heads) repeated_discounts = jnp.repeat(transitions.discount_t, num_heads) td_errors = _batch_q_learning( flattened_q, repeated_actions, repeated_rewards, repeated_discounts, flattened_q_target, ) td_errors = rlax.clip_gradient(td_errors, -grad_error_bound, grad_error_bound) losses = rlax.l2_loss(td_errors) assert losses.shape == (self._batch_size * num_heads, ) loss = jnp.mean(losses) return loss def update(rng_key, opt_state, online_params, target_params, transitions): """Computes learning update from batch of replay transitions.""" rng_key, update_key = jax.random.split(rng_key) d_loss_d_params = jax.grad(loss_fn)(online_params, target_params, transitions, update_key) updates, new_opt_state = optimizer.update(d_loss_d_params, opt_state) new_online_params = optax.apply_updates(online_params, updates) return rng_key, new_opt_state, new_online_params self._update = jax.jit(update) def select_action(rng_key, network_params, s_t, exploration_epsilon): """Samples action from eps-greedy policy wrt Q-values at given state.""" rng_key, apply_key, policy_key = jax.random.split(rng_key, 3) q_t = network.apply(network_params, apply_key, s_t[None, ...]).random_head_q_value[0] a_t = rlax.epsilon_greedy().sample(policy_key, q_t, exploration_epsilon) return rng_key, a_t self._select_action = jax.jit(select_action)
def __init__( self, preprocessor: processors.Processor, sample_network_input: jnp.ndarray, network: parts.Network, support: jnp.ndarray, optimizer: optax.GradientTransformation, transition_accumulator: Any, replay: replay_lib.TransitionReplay, batch_size: int, exploration_epsilon: Callable[[int], float], min_replay_capacity_fraction: float, learn_period: int, target_network_update_period: int, rng_key: parts.PRNGKey, ): self._preprocessor = preprocessor self._replay = replay self._transition_accumulator = transition_accumulator self._batch_size = batch_size self._exploration_epsilon = exploration_epsilon self._min_replay_capacity = min_replay_capacity_fraction * replay.capacity self._learn_period = learn_period self._target_network_update_period = target_network_update_period # Initialize network parameters and optimizer. self._rng_key, network_rng_key = jax.random.split(rng_key) self._online_params = network.init(network_rng_key, sample_network_input[None, ...]) self._target_params = self._online_params self._opt_state = optimizer.init(self._online_params) # Other agent state: last action, frame count, etc. self._action = None self._frame_t = -1 # Current frame index. self._statistics = {'state_value': np.nan} # Define jitted loss, update, and policy functions here instead of as # class methods, to emphasize that these are meant to be pure functions # and should not access the agent object's state via `self`. def loss_fn(online_params, target_params, transitions, rng_key): """Calculates loss given network parameters and transitions.""" _, online_key, target_key = jax.random.split(rng_key, 3) logits_q_tm1 = network.apply(online_params, online_key, transitions.s_tm1).q_logits logits_target_q_t = network.apply(target_params, target_key, transitions.s_t).q_logits losses = _batch_categorical_q_learning( support, logits_q_tm1, transitions.a_tm1, transitions.r_t, transitions.discount_t, support, logits_target_q_t, ) chex.assert_shape(losses, (self._batch_size, )) loss = jnp.mean(losses) return loss def update(rng_key, opt_state, online_params, target_params, transitions): """Computes learning update from batch of replay transitions.""" rng_key, update_key = jax.random.split(rng_key) d_loss_d_params = jax.grad(loss_fn)(online_params, target_params, transitions, update_key) updates, new_opt_state = optimizer.update(d_loss_d_params, opt_state) new_online_params = optax.apply_updates(online_params, updates) return rng_key, new_opt_state, new_online_params self._update = jax.jit(update) def select_action(rng_key, network_params, s_t, exploration_epsilon): """Samples action from eps-greedy policy wrt Q-values at given state.""" rng_key, apply_key, policy_key = jax.random.split(rng_key, 3) q_t = network.apply(network_params, apply_key, s_t[None, ...]).q_values[0] a_t = rlax.epsilon_greedy().sample(policy_key, q_t, exploration_epsilon) v_t = jnp.max(q_t, axis=-1) return rng_key, a_t, v_t self._select_action = jax.jit(select_action)
def __init__( self, preprocessor: processors.Processor, sample_network_input: jnp.ndarray, network: parts.Network, optimizer: optax.GradientTransformation, transition_accumulator: replay_lib.TransitionAccumulator, replay: replay_lib.PrioritizedTransitionReplay, batch_size: int, exploration_epsilon: Callable[[int], float], min_replay_capacity_fraction: float, learn_period: int, target_network_update_period: int, grad_error_bound: float, rng_key: parts.PRNGKey, ): self._preprocessor = preprocessor self._replay = replay self._transition_accumulator = transition_accumulator self._batch_size = batch_size self._exploration_epsilon = exploration_epsilon self._min_replay_capacity = min_replay_capacity_fraction * replay.capacity self._learn_period = learn_period self._target_network_update_period = target_network_update_period # Initialize network parameters and optimizer. self._rng_key, network_rng_key = jax.random.split(rng_key) self._online_params = network.init(network_rng_key, sample_network_input[None, ...]) self._target_params = self._online_params self._opt_state = optimizer.init(self._online_params) # Other agent state: last action, frame count, etc. self._action = None self._frame_t = -1 # Current frame index. self._max_seen_priority = 1. # Define jitted loss, update, and policy functions here instead of as # class methods, to emphasize that these are meant to be pure functions # and should not access the agent object's state via `self`. def loss_fn(online_params, target_params, transitions, weights, rng_key): """Calculates loss given network parameters and transitions.""" _, *apply_keys = jax.random.split(rng_key, 4) q_tm1 = network.apply(online_params, apply_keys[0], transitions.s_tm1).q_values q_t = network.apply(online_params, apply_keys[1], transitions.s_t).q_values q_target_t = network.apply(target_params, apply_keys[2], transitions.s_t).q_values td_errors = _batch_double_q_learning( q_tm1, transitions.a_tm1, transitions.r_t, transitions.discount_t, q_target_t, q_t, ) td_errors = rlax.clip_gradient(td_errors, -grad_error_bound, grad_error_bound) losses = rlax.l2_loss(td_errors) assert losses.shape == (self._batch_size, ) == weights.shape # This is not the same as using a huber loss and multiplying by weights. loss = jnp.mean(losses * weights) return loss, td_errors def update(rng_key, opt_state, online_params, target_params, transitions, weights): """Computes learning update from batch of replay transitions.""" rng_key, update_key = jax.random.split(rng_key) d_loss_d_params, td_errors = jax.grad(loss_fn, has_aux=True)( online_params, target_params, transitions, weights, update_key) updates, new_opt_state = optimizer.update(d_loss_d_params, opt_state) new_online_params = optax.apply_updates(online_params, updates) return rng_key, new_opt_state, new_online_params, td_errors self._update = jax.jit(update) def select_action(rng_key, network_params, s_t, exploration_epsilon): """Samples action from eps-greedy policy wrt Q-values at given state.""" rng_key, apply_key, policy_key = jax.random.split(rng_key, 3) q_t = network.apply(network_params, apply_key, s_t[None, ...]).q_values[0] a_t = rlax.epsilon_greedy().sample(policy_key, q_t, exploration_epsilon) return rng_key, a_t self._select_action = jax.jit(select_action)