def __call__( self, network: networks_lib.FeedForwardNetwork, params: networks_lib.Params, target_params: networks_lib.Params, batch: reverb.ReplaySample, key: networks_lib.PRNGKey, ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]: """Calculate a loss on a single batch of data.""" del key transitions: types.Transition = batch.data # Forward pass. q_tm1 = network.apply(params, transitions.observation) q_t = network.apply(target_params, transitions.next_observation) # Cast and clip rewards. d_t = (transitions.discount * self.discount).astype(jnp.float32) r_t = jnp.clip(transitions.reward, -self.max_abs_reward, self.max_abs_reward).astype(jnp.float32) # Compute Q-learning TD-error. batch_error = jax.vmap(rlax.q_learning) td_error = batch_error(q_tm1, transitions.action, r_t, d_t, q_t) batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter) loss = jnp.mean(batch_loss) extra = learning_lib.LossExtra(metrics={}) return loss, extra
def loss(params: hk.Params, target_params: hk.Params, sample: reverb.ReplaySample): o_tm1, a_tm1, r_t, d_t, o_t = sample.data keys, probs = sample.info[:2] # Forward pass. q_tm1 = network.apply(params, o_tm1) q_t_value = network.apply(target_params, o_t) q_t_selector = network.apply(params, o_t) # Cast and clip rewards. d_t = (d_t * discount).astype(jnp.float32) r_t = jnp.clip(r_t, -max_abs_reward, max_abs_reward).astype(jnp.float32) # Compute double Q-learning n-step TD-error. batch_error = jax.vmap(rlax.double_q_learning) td_error = batch_error(q_tm1, a_tm1, r_t, d_t, q_t_value, q_t_selector) batch_loss = rlax.huber_loss(td_error, huber_loss_parameter) # Importance weighting. importance_weights = (1. / probs).astype(jnp.float32) importance_weights **= importance_sampling_exponent importance_weights /= jnp.max(importance_weights) # Reweight. mean_loss = jnp.mean(importance_weights * batch_loss) # [] priorities = jnp.abs(td_error).astype(jnp.float64) return mean_loss, (keys, priorities)
def __call__( self, network: networks_lib.FeedForwardNetwork, params: networks_lib.Params, target_params: networks_lib.Params, batch: reverb.ReplaySample, key: networks_lib.PRNGKey, ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]: """Calculate a loss on a single batch of data.""" del key transitions: types.Transition = batch.data # Forward pass. q_online_s = network.apply(params, transitions.observation) action_one_hot = jax.nn.one_hot(transitions.action, q_online_s.shape[-1]) q_online_sa = jnp.sum(action_one_hot * q_online_s, axis=-1) q_target_s = network.apply(target_params, transitions.observation) q_target_next = network.apply(target_params, transitions.next_observation) # Cast and clip rewards. d_t = (transitions.discount * self.discount).astype(jnp.float32) r_t = jnp.clip(transitions.reward, -self.max_abs_reward, self.max_abs_reward).astype(jnp.float32) # Munchausen term : tau * log_pi(a|s) munchausen_term = self.entropy_temperature * jax.nn.log_softmax( q_target_s / self.entropy_temperature, axis=-1) munchausen_term_a = jnp.sum(action_one_hot * munchausen_term, axis=-1) munchausen_term_a = jnp.clip(munchausen_term_a, a_min=self.clip_value_min, a_max=0.) # Soft Bellman operator applied to q next_v = self.entropy_temperature * jax.nn.logsumexp( q_target_next / self.entropy_temperature, axis=-1) target_q = jax.lax.stop_gradient(r_t + self.munchausen_coefficient * munchausen_term_a + d_t * next_v) batch_loss = rlax.huber_loss(target_q - q_online_sa, self.huber_loss_parameter) loss = jnp.mean(batch_loss) extra = learning_lib.LossExtra(metrics={}) return loss, extra
def __call__( self, network: networks_lib.FeedForwardNetwork, params: networks_lib.Params, target_params: networks_lib.Params, batch: reverb.ReplaySample, key: networks_lib.PRNGKey, ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]: """Calculate a loss on a single batch of data.""" transitions: types.Transition = batch.data keys, probs, *_ = batch.info # Forward pass. if self.stochastic_network: q_tm1 = network.apply(params, key, transitions.observation) q_t_value = network.apply(target_params, key, transitions.next_observation) q_t_selector = network.apply(params, key, transitions.next_observation) else: q_tm1 = network.apply(params, transitions.observation) q_t_value = network.apply(target_params, transitions.next_observation) q_t_selector = network.apply(params, transitions.next_observation) # Cast and clip rewards. d_t = (transitions.discount * self.discount).astype(jnp.float32) r_t = jnp.clip(transitions.reward, -self.max_abs_reward, self.max_abs_reward).astype(jnp.float32) # Compute double Q-learning n-step TD-error. batch_error = jax.vmap(rlax.double_q_learning) td_error = batch_error(q_tm1, transitions.action, r_t, d_t, q_t_value, q_t_selector) batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter) # Importance weighting. importance_weights = (1. / probs).astype(jnp.float32) importance_weights **= self.importance_sampling_exponent importance_weights /= jnp.max(importance_weights) # Reweight. loss = jnp.mean(importance_weights * batch_loss) # [] reverb_update = learning_lib.ReverbUpdate( keys=keys, priorities=jnp.abs(td_error).astype(jnp.float64)) extra = learning_lib.LossExtra(metrics={}, reverb_update=reverb_update) return loss, extra
def __call__( self, network: hk.Transformed, params: hk.Params, target_params: hk.Params, batch: reverb.ReplaySample, key: jnp.DeviceArray, ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]: """Calculate a loss on a single batch of data.""" del key o_tm1, a_tm1, r_t, d_t, o_t = batch.data keys, probs, *_ = batch.info # Forward pass. q_tm1 = network.apply(params, o_tm1) q_t_value = network.apply(target_params, o_t) q_t_selector = network.apply(params, o_t) # Cast and clip rewards. d_t = (d_t * self.discount).astype(jnp.float32) r_t = jnp.clip(r_t, -self.max_abs_reward, self.max_abs_reward).astype(jnp.float32) # Compute double Q-learning n-step TD-error. batch_error = jax.vmap(rlax.double_q_learning) td_error = batch_error(q_tm1, a_tm1, r_t, d_t, q_t_value, q_t_selector) batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter) # Importance weighting. importance_weights = (1. / probs).astype(jnp.float32) importance_weights **= self.importance_sampling_exponent importance_weights /= jnp.max(importance_weights) # Reweight. loss = jnp.mean(importance_weights * batch_loss) # [] reverb_update = learning_lib.ReverbUpdate( keys=keys, priorities=jnp.abs(td_error).astype(jnp.float64)) extra = learning_lib.LossExtra(metrics={}, reverb_update=reverb_update) return loss, extra
def __init__(self, player_id, state_representation_size, num_actions, hidden_layers_sizes=128, replay_buffer_capacity=10000, batch_size=128, replay_buffer_class=ReplayBuffer, learning_rate=0.01, update_target_network_every=1000, learn_every=10, discount_factor=1.0, min_buffer_size_to_learn=1000, epsilon_start=1.0, epsilon_end=0.1, epsilon_decay_duration=int(1e6), optimizer_str="sgd", loss_str="mse", huber_loss_parameter=1.0): """Initialize the DQN agent.""" # This call to locals() is used to store every argument used to initialize # the class instance, so it can be copied with no hyperparameter change. self._kwargs = locals() self.player_id = player_id self._num_actions = num_actions if isinstance(hidden_layers_sizes, int): hidden_layers_sizes = [hidden_layers_sizes] self._layer_sizes = hidden_layers_sizes self._batch_size = batch_size self._update_target_network_every = update_target_network_every self._learn_every = learn_every self._min_buffer_size_to_learn = min_buffer_size_to_learn self._discount_factor = discount_factor self.huber_loss_parameter = huber_loss_parameter self._epsilon_start = epsilon_start self._epsilon_end = epsilon_end self._epsilon_decay_duration = epsilon_decay_duration # TODO(author6) Allow for optional replay buffer config. if not isinstance(replay_buffer_capacity, int): raise ValueError("Replay buffer capacity not an integer.") self._replay_buffer = replay_buffer_class(replay_buffer_capacity) self._prev_timestep = None self._prev_action = None # Step counter to keep track of learning, eps decay and target network. self._step_counter = 0 # Keep track of the last training loss achieved in an update step. self._last_loss_value = None # Create the Q-network instances def network(x): mlp = hk.nets.MLP(self._layer_sizes + [num_actions]) return mlp(x) self.hk_network = hk.without_apply_rng(hk.transform(network)) self.hk_network_apply = jax.jit(self.hk_network.apply) rng = jax.random.PRNGKey(42) x = jnp.ones([1, state_representation_size]) self.params_q_network = self.hk_network.init(rng, x) self.params_target_q_network = self.hk_network.init(rng, x) if loss_str == "mse": self.loss_func = lambda x: jnp.mean(x**2) elif loss_str == "huber": # pylint: disable=g-long-lambda self.loss_func = lambda x: jnp.mean( rlax.huber_loss(x, self.huber_loss_parameter)) else: raise ValueError("Not implemented, choose from 'mse', 'huber'.") if optimizer_str == "adam": opt_init, opt_update = optax.chain( optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8), optax.scale(learning_rate)) elif optimizer_str == "sgd": opt_init, opt_update = optax.sgd(learning_rate) else: raise ValueError("Not implemented, choose from 'adam' and 'sgd'.") self._opt_update_fn = self._get_update_func(opt_update) self._opt_state = opt_init(self.params_q_network) self._loss_and_grad = jax.value_and_grad(self._loss, has_aux=False) self._jit_update = jax.jit(self.get_update())
def loss_fn(target, q_val): return huber_loss(target - q_val).mean()