def apply(n, dx, dz):
    x = Symbol.x(shape=(n, dx), real=True)
    W_Q = Symbol("W^Q", shape=(dx, dz), real=True)
    W_K = Symbol("W^K", shape=(dx, dz), real=True)
    W_V = Symbol("W^V", shape=(dx, dz), real=True)
    
    Q = Symbol.Q(definition=x @ W_Q)
    K = Symbol.K(definition=x @ W_K)
    
    i = Symbol.i(integer=True)
    j = Symbol.j(integer=True)
    
    k = Symbol.k(integer=True, positive=True)
    w_K = Symbol("w^K", shape=(2 * k + 1, dz), real=True)
    w_V = Symbol("w^V", shape=(2 * k + 1, dz), real=True)
    
    a_K = Symbol("a^K", definition=LAMBDA[j:n, i:n](w_K[k + clip(j - i, -k, k)]))
    a_V = Symbol("a^V", definition=LAMBDA[j:n, i:n](w_V[k + clip(j - i, -k, k)]))
    
    e = Symbol.e(definition=(Q @ K.T + LAMBDA[i:n](Q[i] @ a_K[i].T)) / sqrt(dz))
    α = Symbol.α(definition=softmax(e))
    
    z = Symbol.z(shape=(n, dz), definition=α @ (x @ W_V) + LAMBDA[i:n](α[i] @ a_V[i]))
    
    return Contains(k + clip(j - i, -k, k), Interval(0, 2 * k, integer=True)), Equality(z[i], Sum[j:n](α[i, j] * (x[j] @ W_V + a_V[i, j]))), Equality(e[i, j], (x[i] @ W_Q @ (x[j] @ W_K + a_K[i, j])) / sqrt(dz))
Esempio n. 2
0
def apply(n, d):
    Q = Symbol.Q(shape=(n, d), real=True)
    K = Symbol.K(shape=(n, d), real=True)
    V = Symbol.V(shape=(n, d), real=True)

    S = Symbol.S(shape=(n, d), definition=softmax(Q @ K.T / sympy.sqrt(d)) @ V)

    return Equality(S[0], softmax(Q[0] @ K.T / sympy.sqrt(d)) @ V)
Esempio n. 3
0
def predefined_symbols(n):
    x = Symbol.x(shape=(oo, ), integer=True, nonnegative=True)
    t = Symbol.t(integer=True)
    Q = Symbol.Q(definition=LAMBDA[t:n + 1](conditionset(
        x[:n + 1],
        Equality(x[:n + 1].set_comprehension(), Interval(0, n, integer=True))
        & Equality(x[n], t))))
    j = Symbol.j(integer=True)
    i = Symbol.i(integer=True)
    w = Symbol.w(definition=LAMBDA[j:n + 1, i:n + 1](Swap(n + 1, i, j)))

    return Q, w, x