예제 #1
0
def test_MatrixSlice():
    n = sympy.Symbol('n', integer=True)
    X = sympy.MatrixSymbol('X', n, n)

    Y = X[1:2:3, 4:5:6]
    Yt = theano_code(Y)
    from theano.scalar import Scalar
    from theano import Constant

    s = Scalar('int64')
    assert tuple(Yt.owner.op.idx_list) == (slice(s, s, s), slice(s, s, s))
    assert Yt.owner.inputs[0] == theano_code(X)
    # == doesn't work in theano like it does in SymPy. You have to use
    # equals.
    assert [
        i.equals(j) for i, j in zip(Yt.owner.inputs[1:], [
            Constant(s, 1),
            Constant(s, 2),
            Constant(s, 3),
            Constant(s, 4),
            Constant(s, 5),
            Constant(s, 6),
        ])
    ]

    k = sympy.Symbol('k')
    kt = theano_code(k, dtypes={k: 'int32'})
    start, stop, step = 4, k, 2
    Y = X[start:stop:step]
    Yt = theano_code(Y, dtypes={n: 'int32', k: 'int32'})
예제 #2
0
def test_MatrixSlice():
    from theano import Constant

    cache = {}

    n = sy.Symbol('n', integer=True)
    X = sy.MatrixSymbol('X', n, n)

    Y = X[1:2:3, 4:5:6]
    Yt = theano_code_(Y, cache=cache)

    s = ts.Scalar('int64')
    assert tuple(Yt.owner.op.idx_list) == (slice(s, s, s), slice(s, s, s))
    assert Yt.owner.inputs[0] == theano_code_(X, cache=cache)
    # == doesn't work in theano like it does in SymPy. You have to use
    # equals.
    assert all(Yt.owner.inputs[i].equals(Constant(s, i)) for i in range(1, 7))

    k = sy.Symbol('k')
    kt = theano_code_(k, dtypes={k: 'int32'})
    start, stop, step = 4, k, 2
    Y = X[start:stop:step]
    Yt = theano_code_(Y, dtypes={n: 'int32', k: 'int32'})