コード例 #1
0
ファイル: sparsity.py プロジェクト: viragumathe5/trax
  def favor(query, key, value, mask):
    query_prime = relu(query) + numerical_stabilizer
    key_prime = relu(key) + numerical_stabilizer
    mask_batch_1_length = jnp.reshape(
        mask, [key.shape[0] // n_heads, 1, key.shape[1]]).astype(jnp.float32)
    mask_heads = mask_batch_1_length + jnp.zeros((1, n_heads, 1))
    key_prime *= jnp.reshape(mask_heads, [key.shape[0], key.shape[1], 1])

    w = bidirectional_numerator(jnp.moveaxis(query_prime, 1, 0),
                                jnp.moveaxis(key_prime, 1, 0),
                                jnp.moveaxis(value, 1, 0))
    r = bidirectional_denominator(jnp.moveaxis(query_prime, 1, 0),
                                  jnp.moveaxis(key_prime, 1, 0))
    w = jnp.moveaxis(w, 0, 1)
    r = jnp.moveaxis(r, 0, 1)
    r = jnp.reciprocal(r)
    r = jnp.expand_dims(r, len(r.shape))
    renormalized_attention = w * r
    return renormalized_attention, mask
コード例 #2
0
ファイル: sparsity.py プロジェクト: piotrekp1/trax
    def favor(query, key, value):
        query_prime = relu(query) + numerical_stabilizer
        key_prime = relu(key) + numerical_stabilizer
        prefix_sum_tensor_shape = (key.shape[0], key.shape[-1],
                                   value.shape[-1])
        t_slice_shape = (key.shape[0], key.shape[-1])
        init_prefix_sum_value_numerator = jnp.zeros(prefix_sum_tensor_shape)
        init_prefix_sum_value_denominator = jnp.zeros(t_slice_shape)

        w = favor_numerator(init_prefix_sum_value_numerator, precision,
                            jnp.moveaxis(query_prime, 1, 0),
                            jnp.moveaxis(key_prime, 1, 0),
                            jnp.moveaxis(value, 1, 0))
        r = favor_denominator(init_prefix_sum_value_denominator, precision,
                              jnp.moveaxis(query_prime, 1, 0),
                              jnp.moveaxis(key_prime, 1, 0))
        w = jnp.moveaxis(w, 0, 1)
        r = jnp.moveaxis(r, 0, 1)
        r = jnp.reciprocal(r)
        r = jnp.expand_dims(r, len(r.shape))
        renormalized_attention = w * r
        return renormalized_attention