def ClipFraction(dist_inputs, actions, old_log_probs): """Probability Ratio Mean from the PPO algorithm.""" probs_ratio = rl_layers.ProbsRatio( dist_inputs, actions, old_log_probs, log_prob_fun=self._policy_dist.log_prob) return jnp.mean(jnp.abs(probs_ratio - 1) > self._epsilon)
def forward(self, inputs): gamma, beta, epsilon_l = self.weights epsilon = self._init_epsilon if epsilon_l is not base.EMPTY_WEIGHTS: epsilon += jnp.abs(epsilon_l[0]) # Omit B and C axis = tuple(range(1, len(jnp.shape(inputs)) - 1)) # (B, 1, 1, C) nu2 = jnp.mean(inputs**2, axis=axis, keepdims=True) # (B, W, H, C) xhat = inputs / jnp.sqrt(nu2 + epsilon) return gamma * xhat + beta
def _aggregate_values(self, values, aggregate, act_log_probs): # Normalize the Q-values before aggragetion, so it can adapt to the scale # of the returns. This does not affect mean and max aggregation. scale = 1 epsilon = 1e-5 if self._q_value_normalization == 'std': scale = jnp.std(values) + epsilon elif self._q_value_normalization == 'abs': scale = jnp.mean(jnp.abs(values - jnp.mean(values))) + epsilon values /= scale temp = self._q_value_temperature if self._q_value: assert values.shape[:2] == (self._value_batch_size, self._q_value_n_samples) if aggregate == 'max': # max_a Q(s, a) values = jnp.max(values, axis=1) elif aggregate == 'softmax': # sum_a (Q(s, a) * w(s, a)) # where w(s, .) = softmax (Q(s, .) / T) weights = tl.Softmax(axis=1)(values / temp) values = jnp.sum(values * weights, axis=1) elif aggregate == 'logsumexp': # log(mean_a exp(Q(s, a) / T)) * T n = values.shape[1] values = (fastmath.logsumexp(values / temp, axis=1) - jnp.log(n)) * temp else: assert aggregate == 'mean' # mean_a Q(s, a) if self._sample_all_discrete_actions: values = jnp.sum(values * jnp.exp(act_log_probs), axis=1) else: values = jnp.mean(values, axis=1) # Re-scale the Q-values after aggregation. values *= scale return np.array(values) # Move the values to CPU.
def favor(query, key, value): query_prime = relu(query) + numerical_stabilizer key_prime = relu(key) + numerical_stabilizer prefix_sum_tensor_shape = (key.shape[0], key.shape[-1], value.shape[-1]) t_slice_shape = (key.shape[0], key.shape[-1]) init_prefix_sum_value_numerator = np.zeros(prefix_sum_tensor_shape) init_prefix_sum_value_denominator = np.zeros(t_slice_shape) w = favor_numerator(init_prefix_sum_value_numerator, precision, np.moveaxis(query_prime, 1, 0), np.moveaxis(key_prime, 1, 0), np.moveaxis(value, 1, 0)) r = favor_denominator(init_prefix_sum_value_denominator, precision, np.moveaxis(query_prime, 1, 0), np.moveaxis(key_prime, 1, 0)) w = np.moveaxis(w, 0, 1) r = np.moveaxis(r, 0, 1) r = r + 2 * numerical_stabilizer * (np.abs(r) <= numerical_stabilizer) r = np.reciprocal(r) r = np.expand_dims(r, len(r.shape)) renormalized_attention = w * r return renormalized_attention
def loss(values, targets, weights): return jnp.sum(jnp.abs(values - targets) * weights) / jnp.sum(weights)
def SaturationCost(x, limit=0.9): return jnp.minimum(0, jnp.abs(x) - limit)
def test_lsh_and_pure_lsh_self_attention_equivalence(self): # Given the same weight matrices and random numbers, do these produce the # same output. with fastmath.use_backend(fastmath.Backend.JAX): n_heads = 4 d_head = 4 d_model = n_heads * d_head pure_lsh_layer = efficient_attention.PureLSHSelfAttention( n_heads=n_heads, d_qk=d_head, d_v=d_head, causal=True, masked=False, chunk_len=8, n_chunks_before=1, n_chunks_after=0, n_hashes=4, n_buckets=8, use_reference_code=False, attention_dropout=0.0, use_python_loop=True, bias=False, mode='train') lsh_layer = efficient_attention.LSHSelfAttention( n_heads=n_heads, d_qk=d_head, d_v=d_head, causal=True, masked=False, chunk_len=8, n_chunks_before=1, n_chunks_after=0, n_hashes=4, n_buckets=8, use_reference_code=False, attention_dropout=0.0, use_python_loop=True, mode='train') batch, seqlen = 3, 32 input_shape = (batch, seqlen, d_model) x = jax.random.uniform(jax.random.PRNGKey(0), input_shape, dtype=jnp.float32) lsh_layer_input = x call_rng = jax.random.PRNGKey(42) lsh_layer_weights, lsh_layer_state = lsh_layer.init( shapes.signature(lsh_layer_input)) lsh_layer.rng = call_rng lsh_layer_output = lsh_layer(lsh_layer_input) # Shapes are: (n_heads, d_model, d_head), (n_heads, d_model, d_head), # (n_heads, d_head, d_model) # Abbreviated as - hmn, hmn, hnm w_qk, w_v, w_o = lsh_layer_weights qk = jnp.einsum('blm,hmn->bhln', x, w_qk) qk = qk.reshape((-1, qk.shape[2], qk.shape[3])) v = jnp.einsum('blm,hmn->bhln', x, w_v) v = v.reshape((-1, v.shape[2], v.shape[3])) pure_lsh_layer_input = (qk, v) _, _ = pure_lsh_layer.init(shapes.signature(pure_lsh_layer_input)) pure_lsh_layer.rng = call_rng pure_lsh_layer.state = lsh_layer_state pure_lsh_layer_output = pure_lsh_layer(pure_lsh_layer_input) # b*h,l,n pure_lsh_layer_output = pure_lsh_layer_output.reshape( (batch, -1) + pure_lsh_layer_output.shape[1:]) pure_lsh_layer_output_projected = (jnp.einsum( 'bhld,hdm->blm', pure_lsh_layer_output, w_o)) diff = pure_lsh_layer_output_projected - lsh_layer_output avg_diff = jnp.sum(jnp.abs(diff)) / jnp.sum(jnp.ones_like(diff)) self.assertLess(avg_diff, 1e-5)