def Max(axis=-1, keepdims=False): """Returns a layer that applies max along one tensor axis. Args: axis: Axis along which values are grouped for computing maximum. keepdims: If `True`, keep the resulting size 1 axis as a separate tensor axis; else, remove that axis. """ return Fn('Max', lambda x: jnp.max(x, axis, keepdims=keepdims))
def _aggregate_values(self, values, aggregate_max, act_log_probs): if self._q_value: if aggregate_max: values = jnp.max(values, axis=1) elif self._sample_all_discrete_actions: values = jnp.sum(values * jnp.exp(act_log_probs), axis=1) else: values = jnp.mean(values, axis=1) return np.array(values) # Move the values to CPU.
def value_batches_stream(self): """Use the RLTask self._task to create inputs to the value model.""" max_slice_length = self._max_slice_length min_slice_length = 1 for np_trajectory in self._task.trajectory_batch_stream( self._value_batch_size, max_slice_length=max_slice_length, min_slice_length=min_slice_length, margin=0, epochs=self._replay_epochs, ): values_target = self._run_value_model( np_trajectory.observations, use_eval_model=True) if self._double_dqn: values = self._run_value_model( np_trajectory.observations, use_eval_model=False ) index_max = np.argmax(values, axis=-1) ind_0, ind_1 = np.indices(index_max.shape) values_max = values_target[ind_0, ind_1, index_max] else: values_max = np.array(jnp.max(values_target, axis=-1)) # The advantage_estimator returns # gamma^n_steps * values_max(s_{i + n_steps}) + discounted_rewards # - values_max(s_i) # hence we have to add values_max(s_i) in order to get n-step returns: # gamma^n_steps * values_max(s_{i + n_steps}) + discounted_rewards # Notice, that in DQN the tensor values_max[:, :-self._margin] # is the same as values_max[:, :-1]. n_step_returns = values_max[:, :-self._margin] + \ self._advantage_estimator( rewards=np_trajectory.rewards, returns=np_trajectory.returns, values=values_max, dones=np_trajectory.dones ) length = n_step_returns.shape[1] target_returns = n_step_returns[:, :length] inputs = np_trajectory.observations[:, :length, :] yield ( # Inputs are observations # (batch, length, obs) inputs, # the max indices will be needed to compute the loss np_trajectory.actions[:, :length], # index_max, # Targets: computed returns. # target_returns, we expect here shapes such as # (batch, length, num_actions) target_returns / self._value_network_scale, # TODO(henrykm): mask has the shape (batch, max_slice_length) # that is it misses the action dimension; the preferred format # would be np_trajectory.mask[:, :length, :] but for now we pass: np.ones(shape=target_returns.shape), )
def value_batches_stream(self): """Use the RLTask self._task to create inputs to the value model.""" max_slice_length = self._max_slice_length min_slice_length = 1 for np_trajectory in self._task.trajectory_batch_stream( self._value_batch_size, max_slice_length=max_slice_length, min_slice_length=min_slice_length, margin=0, epochs=self._replay_epochs, ): values = self._run_value_model(np_trajectory.observations) values_max = np.array(jnp.max(values, axis=-1)) adv = self._advantage_estimator( rewards=np_trajectory.rewards, returns=np_trajectory.returns, values=values_max, dones=np_trajectory.dones, ) length = adv.shape[1] values = values[:, :length, :] indices_max = (np.arange(values.shape[0]), np.arange(values.shape[1]), np.argmax(values, axis=-1)) # TODO(henrykm): change it to fastmath instead of jax.ops target_returns = jax.ops.index_add(values, indices_max, adv) inputs = np_trajectory.observations[:, :length, :] yield ( # Inputs are observations # (batch, length, obs) inputs, # Targets: computed returns. # target_returns, we expect here shapes such as # (batch, length, num_actions) target_returns / self._value_network_scale, # TODO(henrykm): mask has the shape (batch, max_slice_length) # that is it misses the action dimension; the preferred format # would be np_trajectory.mask[:, :length, :] but for now we pass: np.ones(shape=target_returns.shape))
def _aggregate_values(self, values, aggregate, act_log_probs): # Normalize the Q-values before aggragetion, so it can adapt to the scale # of the returns. This does not affect mean and max aggregation. scale = 1 epsilon = 1e-5 if self._q_value_normalization == 'std': scale = jnp.std(values) + epsilon elif self._q_value_normalization == 'abs': scale = jnp.mean(jnp.abs(values - jnp.mean(values))) + epsilon values /= scale temp = self._q_value_temperature if self._q_value: assert values.shape[:2] == (self._value_batch_size, self._q_value_n_samples) if aggregate == 'max': # max_a Q(s, a) values = jnp.max(values, axis=1) elif aggregate == 'softmax': # sum_a (Q(s, a) * w(s, a)) # where w(s, .) = softmax (Q(s, .) / T) weights = tl.Softmax(axis=1)(values / temp) values = jnp.sum(values * weights, axis=1) elif aggregate == 'logsumexp': # log(mean_a exp(Q(s, a) / T)) * T n = values.shape[1] values = (fastmath.logsumexp(values / temp, axis=1) - jnp.log(n)) * temp else: assert aggregate == 'mean' # mean_a Q(s, a) if self._sample_all_discrete_actions: values = jnp.sum(values * jnp.exp(act_log_probs), axis=1) else: values = jnp.mean(values, axis=1) # Re-scale the Q-values after aggregation. values *= scale return np.array(values) # Move the values to CPU.