Exemplo n.º 1
0
def prove(Eq):
    n = Symbol.n(domain=Interval(2, oo, integer=True))
    S = Symbol.S(dtype=dtype.integer * n)

    x = Symbol.x(**S.element_symbol().dtype.dict)

    i = Symbol.i(integer=True)
    j = Symbol.j(integer=True)

    w = Symbol.w(integer=True,
                 shape=(n, n, n, n),
                 definition=LAMBDA[j:n, i:n](Swap(n, i, j)))

    k = Symbol.k(integer=True)

    given = ForAll[x:S](Contains(LAMBDA[k:n](x[(w[i, j] @ LAMBDA[k:n](k))[k]]),
                                 S))

    Eq.P_definition, Eq.w_definition, Eq.swap, Eq.axiom = apply(given)

    Eq << algebre.matrix.elementary.swap.identity.apply(x, w)

    Eq << Eq.swap.subs(Eq[-1])

    Eq << swapn.permutation.apply(Eq[-1])
Exemplo n.º 2
0
def apply(n, d):
    Q = Symbol.Q(shape=(n, d), real=True)
    K = Symbol.K(shape=(n, d), real=True)
    V = Symbol.V(shape=(n, d), real=True)

    S = Symbol.S(shape=(n, d), definition=softmax(Q @ K.T / sympy.sqrt(d)) @ V)

    return Equality(S[0], softmax(Q[0] @ K.T / sympy.sqrt(d)) @ V)
Exemplo n.º 3
0
def prove(Eq):
    S = Symbol.S(dtype=dtype.integer)
    e = Symbol.e(integer=True)

    Eq << apply(NotContains(e, S))

    Eq << sets.notcontains.imply.equality.emptyset.apply(Eq[0])

    Eq << sets.equality.imply.equality.given.emptyset.complement.apply(Eq[-1])
Exemplo n.º 4
0
def prove(Eq):
    n = Symbol.n(domain=Interval(2, oo, integer=True))
    S = Symbol.S(dtype=dtype.integer * n)

    x = Symbol.x(**S.element_symbol().dtype.dict)

    i = Symbol.i(integer=True)
    j = Symbol.j(integer=True)

    given = [
        ForAll[j:1:n - 1, x:S](Contains(
            LAMBDA[i:n](Piecewise((x[0], Equality(i, j)),
                                  (x[j], Equality(i, 0)), (x[i], True))), S)),
        ForAll[x:S](Equality(abs(x.set_comprehension()), n))
    ]

    Eq << apply(given)

    Eq << discrete.combinatorics.permutation.adjacent.swap2.general.apply(
        Eq[0])

    Eq.permutation = discrete.combinatorics.permutation.adjacent.swapn.permutation.apply(
        Eq[-1])

    Eq << Eq.permutation.limits[0][1].this.definition

    Eq << discrete.combinatorics.permutation.factorial.definition.apply(n)

    Eq << Eq[-1].this.lhs.arg.limits_subs(Eq[-1].lhs.arg.variable,
                                          Eq[-2].rhs.variable)

    Eq <<= Eq[-1] & Eq[-2].abs()

    F = Function.F(nargs=(), dtype=dtype.integer * n)
    F.eval = lambda e: conditionset(x, Equality(x.set_comprehension(), e), S)

    e = Symbol.e(dtype=dtype.integer)
    Eq << Subset(F(e), S, plausible=True)
    Eq << Eq[-1].this.lhs.definition

    Eq << sets.subset.forall.imply.forall.apply(Eq[-1], Eq.permutation)

    Eq.forall_x = ForAll(Contains(Eq[-1].lhs, F(e)),
                         *Eq[-1].limits,
                         plausible=True)

    Eq << Eq.forall_x.definition.split()

    P = Eq[-1].limits[0][1]
    Eq << sets.imply.conditionset.apply(P)
    Eq << Eq[-1].apply(sets.equality.imply.equality.permutation, x)

    Eq.equality_e = Eq[-3] & Eq[-1]

    Eq << sets.imply.conditionset.apply(F(e)).reversed
Exemplo n.º 5
0
def prove(Eq): 
    n = Symbol.n(domain=Interval(2, oo, integer=True))
    S = Symbol.S(dtype=dtype.integer * n)    
    
    x = Symbol.x(**S.element_symbol().dtype.dict)
    
    i = Symbol.i(integer=True)
    j = Symbol.j(integer=True)    
    
    w = Symbol.w(integer=True, shape=(n, n, n, n), definition=LAMBDA[j:n, i:n](Swap(n, i, j)))
    
    given = ForAll[x:S](Contains(w[0, j] @ x, S))
    
    Eq << apply(given)
    
    Eq.given_i = given.subs(j, i)    
    
    Eq << given.subs(x, Eq.given_i.function.lhs)
    
    Eq << (Eq.given_i & Eq[-1]).split()[-1]
    
    Eq << Eq.given_i.subs(x, Eq[-1].function.lhs)
    
    Eq.final_statement = (Eq[-2] & Eq[-1]).split()[0]
    
    Eq << swap2.equality.apply(n, w)
    
    Eq << Eq[-1] @ x
    
    Eq << Eq[-1].forall((Eq[-1].limits[0].args[1].args[1].arg,))
    
    Eq.i_complement = Eq.final_statement.subs(Eq[-1])
    
    Eq.plausible = ForAll(Contains(w[i, j] @ x, S), (x, S), (j, Interval(1, n - 1, integer=True)), plausible=True)    
    
    Eq << Eq.plausible.bisect(i.set, wrt=j)
    
    Eq.i_complement, Eq.i_intersection = Eq[-1].split()
    
    Eq << sets.imply.equality.intersection.apply(i, Interval(1, n - 1, integer=True))
    
    Eq << Eq.i_intersection.this.limits[1].subs(Eq[-1])
    
    Eq << Eq[-1].subs(w[i, i].equality_defined())
    
    Eq << (Eq.i_complement & Eq.i_intersection)
    
    Eq << elementary.swap.transpose.apply(w).subs(j, 0)
    Eq << Eq.given_i.subs(Eq[-1].reversed)
    
    Eq << (Eq[-1] & Eq.plausible)
Exemplo n.º 6
0
def prove(Eq):
    S = Symbol.S(dtype=dtype.integer)

    Eq << apply(Equality(abs(S), 1))

    Eq << StrictGreaterThan(abs(S), 0, plausible=True)

    Eq << Eq[-1].subs(Eq[0])

    Eq << sets.strict_greater_than.imply.inequality.apply(Eq[-1])

    Eq << sets.inequality.imply.exists.contains.apply(Eq[-1])

    Eq << Eq[-1].apply(sets.contains.imply.subset, simplify=False)

    Eq << Eq[-1].apply(sets.subset.equality.imply.equality, Eq[0])
Exemplo n.º 7
0
def prove(Eq):
    n = Symbol.n(domain=Interval(2, oo, integer=True))
    S = Symbol.S(dtype=dtype.integer * n)

    x = Symbol.x(**S.element_symbol().dtype.dict)

    i = Symbol.i(integer=True)
    j = Symbol.j(integer=True)

    given = ForAll(
        Contains(
            LAMBDA[i:n](Piecewise((x[0], Equality(i, j)),
                                  (x[j], Equality(i, 0)), (x[i], True))), S),
        (j, 1, n - 1), (x, S))

    Eq << apply(given)

    w = Eq[0].lhs.base

    Eq << swap1.utility.apply(x, w[0])

    Eq << Eq[-1].reference(*Eq[-1].limits)

    Eq.given = Eq[1].subs(Eq[-1].reversed)

    Eq << axiom.algebre.matrix.elementary.swap.identity.apply(x, w)

    Eq << Eq[-1].subs(Eq[-1].rhs.args[0].indices[0], 0)

    Eq << Eq[-1].this.lhs.limits_subs(Eq[-1].lhs.variable, i)

    Eq << Eq[-1].this.lhs.function.indices[0].args[1].limits_subs(
        Eq[-1].lhs.function.indices[0].args[1].variable, i)

    Eq << Eq[-1].subs(Eq[-1].rhs.args[0].indices[1], j)

    Eq.given = Eq.given.subs(Eq[-1])

    Eq << Eq.given.limits_swap()

    Eq << ForAll[x:S](Eq[-1].function.subs(j, 0), plausible=True)

    Eq << Eq[-1].subs(w[0, 0].this.definition)

    Eq <<= Eq[-1] & Eq[-2]

    Eq << combinatorics.permutation.adjacent.swap2.contains.apply(Eq[-1])
Exemplo n.º 8
0
def prove(Eq):
    n = Symbol.n(domain=Interval(2, oo, integer=True))
    S = Symbol.S(dtype=dtype.integer * n)

    x = Symbol.x(**S.element_symbol().dtype.dict)

    i = Symbol.i(integer=True)
    j = Symbol.j(integer=True)
    k = Symbol.k(integer=True)

    e = Symbol.e(dtype=dtype.integer, given=True)

    p = Symbol.p(shape=(oo, ), integer=True, nonnegative=True)

    P = Symbol.P(dtype=dtype.integer * n,
                 definition=conditionset(
                     p[:n],
                     Equality(p[:n].set_comprehension(),
                              Interval(0, n - 1, integer=True))))

    Eq << apply(ForAll[x:S](Equality(x.set_comprehension(), e)),
                ForAll[x:S, p[:n]:P](Contains(LAMBDA[k:n](x[p[k]]), S)),
                Equality(abs(e), n))
Exemplo n.º 9
0
def prove(Eq):
    n = Symbol.n(domain=Interval(2, oo, integer=True))
    S = Symbol.S(dtype=dtype.integer * n, given=True)

    x = Symbol.x(shape=(oo, ), integer=True)

    i = Symbol.i(integer=True)
    j = Symbol.j(integer=True)

    w = Symbol.w(definition=LAMBDA[j:n, i:n](Swap(n, i, j)))

    given = ForAll[x[:n]:S](Contains(w[i, j] @ x[:n], S))

    Eq.P_definition, Eq.w_definition, Eq.swap, Eq.axiom = apply(given)

    Eq << factorization.apply(n)

    *_, b_i = Eq[-1].rhs.args[1].function.args
    b, _i = b_i.args
    Eq << Eq.w_definition.subs(j, b[_i]).subs(i, _i)

    Eq << Eq[-2].subs(Eq[-1].reversed)

    k = Eq.axiom.lhs.variable
    Eq << Eq[-1][k]

    Eq << Eq[-1].this.function.function.rhs.args[0].limits_subs(_i, k)

    Eq << swapn.utility.apply(x[:n], b[:n], w)

    Eq << Eq[-1].subs(Eq[-2].reversed)

    Eq.plausible = Eq.axiom.subs(Eq[-1])

    Eq << swapn.mat_product.apply(Eq.swap.T, n, b)

    Eq << Eq.plausible.this.function.as_ForAll()
Exemplo n.º 10
0
def prove(Eq): 
    n = Symbol.n(domain=Interval(2, oo, integer=True))
    S = Symbol.S(dtype=dtype.integer * n, given=True)    
    
    x = Symbol.x(shape=(n,), integer=True)
    
    i = Symbol.i(integer=True)
    j = Symbol.j(integer=True)    
    
    w = Symbol.w(definition=LAMBDA[j:n, i:n](Swap(n, i, j)))
    
    given = ForAll[x[:n]:S](Contains(x[:n] @ w[i, j], S))
    
    Eq.w_definition, Eq.swap, Eq.mat_product = apply(given)
    
    i, _, m_1 = Eq.mat_product.function.lhs.args[1].limits[0]
    m = m_1 + 1
    
    b = Eq.mat_product.function.lhs.args[1].function.indices[1].base
    
    Eq << Eq.mat_product.subs(m, 0)
    Eq << Eq.mat_product.subs(m, m + 1)
    
    Eq << Eq[-1].function.lhs.args[1].this.bisect(Slice[-1:])
    
    Eq << x @ Eq[-1]
    
    Eq << Eq.swap.subs(i, m).subs(j, b[m])
    
    Eq << Eq[-1].subs(x, Eq[2].rhs.func(*Eq[2].rhs.args[:2]))
    
    Eq << Eq[-1].forall((x, S))
    
    Eq << (Eq[-1] & Eq.mat_product).split()
    
    Eq << Eq[-1].subs(Eq[2].reversed)