Пример #1
0
from sympy.functions.elementary.miscellaneous import Min, Max
from sympy.core.function import Function


def clip(a, a_min, a_max):
    return Min(a_max, Max(a, a_min))


clip = Function.clip(shape=(), eval=clip)


from axiom.utility import plausible
from sympy.core.relational import Equality
from sympy.functions.elementary.exponential import softmax
from sympy import Symbol
from sympy.functions.elementary.miscellaneous import sqrt, Min, Max
from sympy.matrices.expressions.matmul import MatMul
from sympy.concrete.summations import Sum
from sympy.concrete.expr_with_limits import LAMBDA
from sympy.core.function import Function
from sympy.sets.contains import Contains
from sympy.sets.sets import Interval

clip = Function.clip(nargs=(3,), shape=(), eval=lambda a, a_min, a_max: Min(a_max, Max(a, a_min)))

@plausible
def apply(n, dx, dz):
    x = Symbol.x(shape=(n, dx), real=True)
    W_Q = Symbol("W^Q", shape=(dx, dz), real=True)
    W_K = Symbol("W^K", shape=(dx, dz), real=True)
    W_V = Symbol("W^V", shape=(dx, dz), real=True)
    
    Q = Symbol.Q(definition=x @ W_Q)
    K = Symbol.K(definition=x @ W_K)
    
    i = Symbol.i(integer=True)
    j = Symbol.j(integer=True)
    
    k = Symbol.k(integer=True, positive=True)
    w_K = Symbol("w^K", shape=(2 * k + 1, dz), real=True)
    w_V = Symbol("w^V", shape=(2 * k + 1, dz), real=True)