Example #1
0
 def ClipFraction(dist_inputs, actions, old_log_probs):
     """Probability Ratio Mean from the PPO algorithm."""
     probs_ratio = rl_layers.ProbsRatio(
         dist_inputs,
         actions,
         old_log_probs,
         log_prob_fun=self._policy_dist.log_prob)
     return jnp.mean(jnp.abs(probs_ratio - 1) > self._epsilon)
Example #2
0
    def forward(self, inputs):
        gamma, beta, epsilon_l = self.weights

        epsilon = self._init_epsilon
        if epsilon_l is not base.EMPTY_WEIGHTS:
            epsilon += jnp.abs(epsilon_l[0])

        # Omit B and C
        axis = tuple(range(1, len(jnp.shape(inputs)) - 1))
        # (B, 1, 1, C)
        nu2 = jnp.mean(inputs**2, axis=axis, keepdims=True)
        # (B, W, H, C)
        xhat = inputs / jnp.sqrt(nu2 + epsilon)

        return gamma * xhat + beta
Example #3
0
    def _aggregate_values(self, values, aggregate, act_log_probs):
        # Normalize the Q-values before aggragetion, so it can adapt to the scale
        # of the returns. This does not affect mean and max aggregation.
        scale = 1
        epsilon = 1e-5
        if self._q_value_normalization == 'std':
            scale = jnp.std(values) + epsilon
        elif self._q_value_normalization == 'abs':
            scale = jnp.mean(jnp.abs(values - jnp.mean(values))) + epsilon
        values /= scale

        temp = self._q_value_temperature
        if self._q_value:
            assert values.shape[:2] == (self._value_batch_size,
                                        self._q_value_n_samples)
            if aggregate == 'max':
                # max_a Q(s, a)
                values = jnp.max(values, axis=1)
            elif aggregate == 'softmax':
                # sum_a (Q(s, a) * w(s, a))
                # where w(s, .) = softmax (Q(s, .) / T)
                weights = tl.Softmax(axis=1)(values / temp)
                values = jnp.sum(values * weights, axis=1)
            elif aggregate == 'logsumexp':
                # log(mean_a exp(Q(s, a) / T)) * T
                n = values.shape[1]
                values = (fastmath.logsumexp(values / temp, axis=1) -
                          jnp.log(n)) * temp
            else:
                assert aggregate == 'mean'
                # mean_a Q(s, a)
                if self._sample_all_discrete_actions:
                    values = jnp.sum(values * jnp.exp(act_log_probs), axis=1)
                else:
                    values = jnp.mean(values, axis=1)

        # Re-scale the Q-values after aggregation.
        values *= scale
        return np.array(values)  # Move the values to CPU.
Example #4
0
    def favor(query, key, value):
        query_prime = relu(query) + numerical_stabilizer
        key_prime = relu(key) + numerical_stabilizer
        prefix_sum_tensor_shape = (key.shape[0], key.shape[-1],
                                   value.shape[-1])
        t_slice_shape = (key.shape[0], key.shape[-1])
        init_prefix_sum_value_numerator = np.zeros(prefix_sum_tensor_shape)
        init_prefix_sum_value_denominator = np.zeros(t_slice_shape)

        w = favor_numerator(init_prefix_sum_value_numerator, precision,
                            np.moveaxis(query_prime, 1, 0),
                            np.moveaxis(key_prime, 1, 0),
                            np.moveaxis(value, 1, 0))
        r = favor_denominator(init_prefix_sum_value_denominator, precision,
                              np.moveaxis(query_prime, 1, 0),
                              np.moveaxis(key_prime, 1, 0))
        w = np.moveaxis(w, 0, 1)
        r = np.moveaxis(r, 0, 1)

        r = r + 2 * numerical_stabilizer * (np.abs(r) <= numerical_stabilizer)
        r = np.reciprocal(r)
        r = np.expand_dims(r, len(r.shape))
        renormalized_attention = w * r
        return renormalized_attention
Example #5
0
 def loss(values, targets, weights):
   return jnp.sum(jnp.abs(values - targets) * weights) / jnp.sum(weights)
Example #6
0
def SaturationCost(x, limit=0.9):
  return jnp.minimum(0, jnp.abs(x) - limit)
Example #7
0
    def test_lsh_and_pure_lsh_self_attention_equivalence(self):
        # Given the same weight matrices and random numbers, do these produce the
        # same output.
        with fastmath.use_backend(fastmath.Backend.JAX):
            n_heads = 4
            d_head = 4
            d_model = n_heads * d_head
            pure_lsh_layer = efficient_attention.PureLSHSelfAttention(
                n_heads=n_heads,
                d_qk=d_head,
                d_v=d_head,
                causal=True,
                masked=False,
                chunk_len=8,
                n_chunks_before=1,
                n_chunks_after=0,
                n_hashes=4,
                n_buckets=8,
                use_reference_code=False,
                attention_dropout=0.0,
                use_python_loop=True,
                bias=False,
                mode='train')
            lsh_layer = efficient_attention.LSHSelfAttention(
                n_heads=n_heads,
                d_qk=d_head,
                d_v=d_head,
                causal=True,
                masked=False,
                chunk_len=8,
                n_chunks_before=1,
                n_chunks_after=0,
                n_hashes=4,
                n_buckets=8,
                use_reference_code=False,
                attention_dropout=0.0,
                use_python_loop=True,
                mode='train')

            batch, seqlen = 3, 32
            input_shape = (batch, seqlen, d_model)

            x = jax.random.uniform(jax.random.PRNGKey(0),
                                   input_shape,
                                   dtype=jnp.float32)
            lsh_layer_input = x

            call_rng = jax.random.PRNGKey(42)

            lsh_layer_weights, lsh_layer_state = lsh_layer.init(
                shapes.signature(lsh_layer_input))
            lsh_layer.rng = call_rng
            lsh_layer_output = lsh_layer(lsh_layer_input)

            # Shapes are: (n_heads, d_model, d_head), (n_heads, d_model, d_head),
            # (n_heads, d_head, d_model)
            # Abbreviated as - hmn, hmn, hnm
            w_qk, w_v, w_o = lsh_layer_weights

            qk = jnp.einsum('blm,hmn->bhln', x, w_qk)
            qk = qk.reshape((-1, qk.shape[2], qk.shape[3]))

            v = jnp.einsum('blm,hmn->bhln', x, w_v)
            v = v.reshape((-1, v.shape[2], v.shape[3]))

            pure_lsh_layer_input = (qk, v)
            _, _ = pure_lsh_layer.init(shapes.signature(pure_lsh_layer_input))
            pure_lsh_layer.rng = call_rng
            pure_lsh_layer.state = lsh_layer_state
            pure_lsh_layer_output = pure_lsh_layer(pure_lsh_layer_input)

            # b*h,l,n
            pure_lsh_layer_output = pure_lsh_layer_output.reshape(
                (batch, -1) + pure_lsh_layer_output.shape[1:])
            pure_lsh_layer_output_projected = (jnp.einsum(
                'bhld,hdm->blm', pure_lsh_layer_output, w_o))

            diff = pure_lsh_layer_output_projected - lsh_layer_output
            avg_diff = jnp.sum(jnp.abs(diff)) / jnp.sum(jnp.ones_like(diff))

            self.assertLess(avg_diff, 1e-5)