def apply(n, dx, dz): x = Symbol.x(shape=(n, dx), real=True) W_Q = Symbol("W^Q", shape=(dx, dz), real=True) W_K = Symbol("W^K", shape=(dx, dz), real=True) W_V = Symbol("W^V", shape=(dx, dz), real=True) Q = Symbol.Q(definition=x @ W_Q) K = Symbol.K(definition=x @ W_K) i = Symbol.i(integer=True) j = Symbol.j(integer=True) k = Symbol.k(integer=True, positive=True) w_K = Symbol("w^K", shape=(2 * k + 1, dz), real=True) w_V = Symbol("w^V", shape=(2 * k + 1, dz), real=True) a_K = Symbol("a^K", definition=LAMBDA[j:n, i:n](w_K[k + clip(j - i, -k, k)])) a_V = Symbol("a^V", definition=LAMBDA[j:n, i:n](w_V[k + clip(j - i, -k, k)])) e = Symbol.e(definition=(Q @ K.T + LAMBDA[i:n](Q[i] @ a_K[i].T)) / sqrt(dz)) α = Symbol.α(definition=softmax(e)) z = Symbol.z(shape=(n, dz), definition=α @ (x @ W_V) + LAMBDA[i:n](α[i] @ a_V[i])) return Contains(k + clip(j - i, -k, k), Interval(0, 2 * k, integer=True)), Equality(z[i], Sum[j:n](α[i, j] * (x[j] @ W_V + a_V[i, j]))), Equality(e[i, j], (x[i] @ W_Q @ (x[j] @ W_K + a_K[i, j])) / sqrt(dz))
def apply(n, d): Q = Symbol.Q(shape=(n, d), real=True) K = Symbol.K(shape=(n, d), real=True) V = Symbol.V(shape=(n, d), real=True) S = Symbol.S(shape=(n, d), definition=softmax(Q @ K.T / sympy.sqrt(d)) @ V) return Equality(S[0], softmax(Q[0] @ K.T / sympy.sqrt(d)) @ V)
def predefined_symbols(n): x = Symbol.x(shape=(oo, ), integer=True, nonnegative=True) t = Symbol.t(integer=True) Q = Symbol.Q(definition=LAMBDA[t:n + 1](conditionset( x[:n + 1], Equality(x[:n + 1].set_comprehension(), Interval(0, n, integer=True)) & Equality(x[n], t)))) j = Symbol.j(integer=True) i = Symbol.i(integer=True) w = Symbol.w(definition=LAMBDA[j:n + 1, i:n + 1](Swap(n + 1, i, j))) return Q, w, x