Пример #1
0
def lstm_recursive(x, *limits):
    (W, ), (Wh, ), (b, ), (t, ) = limits
    hc = lstm[W, Wh, b, t - 1](x)

    xt = x[t]
    h = Indexed(hc, 0)
    c = Indexed(hc, 1)
    d = h.shape[-1]

    Wi = W[:, :d]
    Wf = W[:, d:2 * d]
    Wc = W[:, 2 * d:3 * d]
    Wo = W[:, -d:]

    Whi = Wh[:, :d]
    Whf = Wh[:, d:2 * d]
    Whc = Wh[:, 2 * d:3 * d]
    Who = Wh[:, -d:]

    bi = b[:d]
    bf = b[d:2 * d]
    bc = b[2 * d:3 * d]
    bo = b[-d:]

    i = sigmoid(xt @ Wi + h @ Whi + bi)
    f = sigmoid(xt @ Wf + h @ Whf + bf)
    c = f * c + i * tanh(xt @ Wc + h @ Whc + bc)
    o = sigmoid(xt @ Wo + h @ Who + bo)

    return Piecewise((BlockMatrix(o * tanh(c), c), t > 0),
                     (ZeroMatrix(*hc.shape), True))
Пример #2
0
def test_Indexed_constructor():
    i, j = symbols('i j', integer=True)
    A = Indexed('A', i, j)
    assert A == Indexed(Symbol('A'), i, j)
    assert A == Indexed(IndexedBase('A'), i, j)
    raises(TypeError, lambda: Indexed(A, i, j))
    raises(IndexException, lambda: Indexed("A"))
    assert A.free_symbols == {A, A.base.label, i, j}
Пример #3
0
def test_Indexed_shape_precedence():
    i, j = symbols('i j', integer=True)
    o, p = symbols('o p', integer=True)
    n, m = symbols('n m', integer=True)
    a = IndexedBase('a', shape=(o, p))
    assert a.shape == Tuple(o, p)
    assert Indexed(a, Idx(i, m),
                   Idx(j, n)).ranges == [Tuple(0, m - 1),
                                         Tuple(0, n - 1)]
    assert Indexed(a, Idx(i, m), Idx(j, n)).shape == Tuple(o, p)
    assert Indexed(a, Idx(i, m),
                   Idx(j)).ranges == [Tuple(0, m - 1),
                                      Tuple(None, None)]
    assert Indexed(a, Idx(i, m), Idx(j)).shape == Tuple(o, p)
Пример #4
0
def test_issue_12533():
    d = IndexedBase('d')
    assert IndexedBase(range(5)) == Range(0, 5, 1)
    assert d[0].subs(Symbol("d"), range(5)) == 0
    assert d[0].subs(d, range(5)) == 0
    assert d[1].subs(d, range(5)) == 1
    assert Indexed(Range(5), 2) == 2
Пример #5
0
def test_Indexed_properties():
    i, j = symbols('i j', integer=True)
    A = Indexed('A', i, j)
    assert A.name == 'A[i, j]'
    assert A.rank == 2
    assert A.indices == (i, j)
    assert A.base == IndexedBase('A')
    assert A.ranges == [None, None]
    raises(IndexException, lambda: A.shape)

    n, m = symbols('n m', integer=True)
    assert Indexed('A', Idx(i, m),
                   Idx(j, n)).ranges == [Tuple(0, m - 1),
                                         Tuple(0, n - 1)]
    assert Indexed('A', Idx(i, m), Idx(j, n)).shape == Tuple(m, n)
    raises(IndexException, lambda: Indexed("A", Idx(i, m), Idx(j)).shape)
Пример #6
0
def test_JointRV():
    x1, x2 = (Indexed('x', i) for i in (1, 2))
    pdf = exp(-x1**2/2 + x1 - x2**2/2 - S.Half)/(2*pi)
    X = JointRV('x', pdf)
    assert density(X)(1, 2) == exp(-2)/(2*pi)
    assert isinstance(X.pspace.distribution, JointDistributionHandmade)
    assert marginal_distribution(X, 0)(2) == sqrt(2)*exp(Rational(-1, 2))/(2*sqrt(pi))
Пример #7
0
def test_IndexedBase_sugar():
    i, j = symbols('i j', integer=True)
    a = symbols('a')
    A1 = Indexed(a, i, j)
    A2 = IndexedBase(a)
    assert A1 == A2[i, j]
    assert A1 == A2[(i, j)]
    assert A1 == A2[[i, j]]
    assert A1 == A2[Tuple(i, j)]
    assert all(a.is_Integer for a in A2[1, 0].args[1:])
Пример #8
0
 def _marginal_distribution(self, indices, sym):
     sym = ImmutableMatrix([Indexed(sym, i) for i in indices])
     _mu, _sigma = self.mu, self.sigma
     k = self.mu.shape[0]
     for i in range(k):
         if i not in indices:
             _mu = _mu.row_del(i)
             _sigma = _sigma.col_del(i)
             _sigma = _sigma.row_del(i)
     return Lambda(tuple(sym), S.One/sqrt((2*pi)**(len(_mu))*det(_sigma))*exp(
         Rational(-1, 2)*(_mu - sym).transpose()*(_sigma.inv()*\
             (_mu - sym)))[0])
Пример #9
0
def test_IndexedBase_assumptions():
    i = Symbol('i', integer=True)
    a = Symbol('a')
    A = IndexedBase(a, positive=True)
    for c in (A, A[i]):
        assert c.is_real
        assert c.is_complex
        assert not c.is_imaginary
        assert c.is_nonnegative
        assert c.is_nonzero
        assert c.is_commutative
        assert log(exp(c)) == c

    assert A != IndexedBase(a)
    assert A == IndexedBase(a, positive=True, real=True)
    assert A[i] != Indexed(a, i)
Пример #10
0
def test_issue_12283():
    x = symbols('x')
    X = RandomSymbol(x)
    Y = RandomSymbol('Y')
    Z = RandomMatrixSymbol('Z', 2, 1)
    W = RandomMatrixSymbol('W', 2, 1)
    RI = RandomIndexedSymbol(Indexed('RI', 3))
    assert pspace(Z) == PSpace()
    assert pspace(RI) == PSpace()
    assert pspace(X) == PSpace()
    assert E(X) == Expectation(X)
    assert P(Y > 3) == Probability(Y > 3)
    assert variance(X) == Variance(X)
    assert variance(RI) == Variance(RI)
    assert covariance(X, Y) == Covariance(X, Y)
    assert covariance(W, Z) == Covariance(W, Z)
Пример #11
0
def test_sample_seed():
    x1, x2 = (Indexed('x', i) for i in (1, 2))
    pdf = exp(-x1**2 / 2 + x1 - x2**2 / 2 - S.Half) / (2 * pi)
    X = JointRV('x', pdf)

    libraries = ['scipy', 'numpy', 'pymc3']
    for lib in libraries:
        try:
            imported_lib = import_module(lib)
            if imported_lib:
                s0, s1, s2 = [], [], []
                s0 = sample(X, size=10, library=lib, seed=0)
                s1 = sample(X, size=10, library=lib, seed=0)
                s2 = sample(X, size=10, library=lib, seed=1)
                assert all(s0 == s1)
                assert all(s1 != s2)
        except NotImplementedError:
            continue
Пример #12
0
def test_not_interable():
    i, j = symbols('i j', integer=True)
    A = Indexed('A', i, i + j)
    assert not iterable(A)
Пример #13
0
def test_complex_indices():
    i, j = symbols('i j', integer=True)
    A = Indexed('A', i, i + j)
    assert A.rank == 2
    assert A.indices == (i, i + j)
Пример #14
0
def test_Lambda():
    e = Lambda(x, x**2)
    assert e(4) == 16
    assert e(x) == x**2
    assert e(y) == y**2

    assert Lambda((), 42)() == 42
    assert unchanged(Lambda, (), 42)
    assert Lambda((), 42) != Lambda((), 43)
    assert Lambda((), f(x))() == f(x)
    assert Lambda((), 42).nargs == FiniteSet(0)

    assert unchanged(Lambda, (x,), x**2)
    assert Lambda(x, x**2) == Lambda((x,), x**2)
    assert Lambda(x, x**2) != Lambda(x, x**2 + 1)
    assert Lambda((x, y), x**y) != Lambda((y, x), y**x)
    assert Lambda((x, y), x**y) != Lambda((x, y), y**x)

    assert Lambda((x, y), x**y)(x, y) == x**y
    assert Lambda((x, y), x**y)(3, 3) == 3**3
    assert Lambda((x, y), x**y)(x, 3) == x**3
    assert Lambda((x, y), x**y)(3, y) == 3**y
    assert Lambda(x, f(x))(x) == f(x)
    assert Lambda(x, x**2)(e(x)) == x**4
    assert e(e(x)) == x**4

    x1, x2 = (Indexed('x', i) for i in (1, 2))
    assert Lambda((x1, x2), x1 + x2)(x, y) == x + y

    assert Lambda((x, y), x + y).nargs == FiniteSet(2)

    p = x, y, z, t
    assert Lambda(p, t*(x + y + z))(*p) == t * (x + y + z)

    eq = Lambda(x, 2*x) + Lambda(y, 2*y)
    assert eq != 2*Lambda(x, 2*x)
    assert eq.as_dummy() == 2*Lambda(x, 2*x).as_dummy()
    assert Lambda(x, 2*x) not in [ Lambda(x, x) ]
    raises(BadSignatureError, lambda: Lambda(1, x))
    assert Lambda(x, 1)(1) is S.One

    raises(BadSignatureError, lambda: Lambda((x, x), x + 2))
    raises(BadSignatureError, lambda: Lambda(((x, x), y), x))
    raises(BadSignatureError, lambda: Lambda(((y, x), x), x))
    raises(BadSignatureError, lambda: Lambda(((y, 1), 2), x))

    with warns_deprecated_sympy():
        assert Lambda([x, y], x+y) == Lambda((x, y), x+y)

    flam = Lambda(((x, y),), x + y)
    assert flam((2, 3)) == 5
    flam = Lambda(((x, y), z), x + y + z)
    assert flam((2, 3), 1) == 6
    flam = Lambda((((x, y), z),), x + y + z)
    assert flam(((2, 3), 1)) == 6
    raises(BadArgumentsError, lambda: flam(1, 2, 3))
    flam = Lambda( (x,), (x, x))
    assert flam(1,) == (1, 1)
    assert flam((1,)) == ((1,), (1,))
    flam = Lambda( ((x,),), (x, x))
    raises(BadArgumentsError, lambda: flam(1))
    assert flam((1,)) == (1, 1)

    # Previously TypeError was raised so this is potentially needed for
    # backwards compatibility.
    assert issubclass(BadSignatureError, TypeError)
    assert issubclass(BadArgumentsError, TypeError)

    # These are tested to see they don't raise:
    hash(Lambda(x, 2*x))
    hash(Lambda(x, x))  # IdentityFunction subclass
Пример #15
0
def LSTMCell(x, *weights):
    (W, ), (Wh, ), (b, ) = weights
    n = x.shape[0]
    t = Symbol.t(integer=True)
    return Lamda[t:n](Indexed(lstm[W, Wh, b, t](x), 1))
Пример #16
0
def test_sympy__tensor__indexed__Indexed():
    from sympy.tensor.indexed import Indexed, Idx
    assert _test_args(Indexed('A', Idx('i'), Idx('j')))
Пример #17
0
def test_Indexed_func_args():
    i, j = symbols('i j', integer=True)
    a = symbols('a')
    A = Indexed(a, i, j)
    assert A == A.func(*A.args)
Пример #18
0
def test_Idx_limits():
    i = symbols('i', cls=Idx)
    r = Indexed('r', i)

    assert SeqFormula(r, (i, 0, 5))[:] == [r.subs(i, j) for j in range(6)]
    assert SeqPer((1, 2), (i, 0, 5))[:] == [1, 2, 1, 2, 1, 2]
Пример #19
0
def test_MarginalDistribution():
    a1, p1, p2 = symbols('a1 p1 p2', positive=True)
    C = Multinomial('C', 2, p1, p2)
    B = MultivariateBeta('B', a1, C[0])
    MGR = MarginalDistribution(B, (C[0], ))
    mgrc = Mul(
        Symbol('B'),
        Piecewise(
            ExprCondPair(
                Mul(
                    Integer(2),
                    Pow(Symbol('p1', positive=True),
                        Indexed(IndexedBase(Symbol('C')), Integer(0))),
                    Pow(Symbol('p2', positive=True),
                        Indexed(IndexedBase(Symbol('C')), Integer(1))),
                    Pow(
                        factorial(Indexed(IndexedBase(Symbol('C')),
                                          Integer(0))), Integer(-1)),
                    Pow(
                        factorial(Indexed(IndexedBase(Symbol('C')),
                                          Integer(1))), Integer(-1))),
                Eq(
                    Add(Indexed(IndexedBase(Symbol('C')), Integer(0)),
                        Indexed(IndexedBase(Symbol('C')), Integer(1))),
                    Integer(2))), ExprCondPair(Integer(0), True)),
        Pow(gamma(Symbol('a1', positive=True)), Integer(-1)),
        gamma(
            Add(Symbol('a1', positive=True),
                Indexed(IndexedBase(Symbol('C')), Integer(0)))),
        Pow(gamma(Indexed(IndexedBase(Symbol('C')), Integer(0))), Integer(-1)),
        Pow(Indexed(IndexedBase(Symbol('B')), Integer(0)),
            Add(Symbol('a1', positive=True), Integer(-1))),
        Pow(Indexed(IndexedBase(Symbol('B')), Integer(1)),
            Add(Indexed(IndexedBase(Symbol('C')), Integer(0)), Integer(-1))))
    assert MGR(C) == mgrc