コード例 #1
0
ファイル: test_basic_ops.py プロジェクト: xsongx/Theano
def test_shape():
    x = GpuArrayType(dtype='float32', broadcastable=[False, False, False])()
    v = gpuarray.zeros((3, 4, 5), dtype='float32')
    f = theano.function([x], x.shape)
    topo = f.maker.fgraph.toposort()
    assert numpy.all(f(v) == (3, 4, 5))
    if theano.config.mode != 'FAST_COMPILE':
        assert len(topo) == 4
        assert isinstance(topo[0].op, T.opt.Shape_i)
        assert isinstance(topo[1].op, T.opt.Shape_i)
        assert isinstance(topo[2].op, T.opt.Shape_i)
        assert isinstance(topo[3].op, T.opt.MakeVector)
    mode = mode_with_gpu.excluding("local_shape_to_shape_i")
    f = theano.function([x], x.shape, mode=mode)
    topo = f.maker.fgraph.toposort()
    assert numpy.all(f(v) == (3, 4, 5))
    assert len(topo) == 1
    assert isinstance(topo[0].op, T.Shape)
コード例 #2
0
ファイル: test_basic_ops.py プロジェクト: Ambier/Theano
def test_shape():
    x = GpuArrayType(dtype='float32', broadcastable=[False, False, False])()
    v = gpuarray.zeros((3, 4, 5), dtype='float32')
    f = theano.function([x], x.shape)
    topo = f.maker.fgraph.toposort()
    assert numpy.all(f(v) == (3, 4, 5))
    if theano.config.mode != 'FAST_COMPILE':
        assert len(topo) == 4
        assert isinstance(topo[0].op, T.opt.Shape_i)
        assert isinstance(topo[1].op, T.opt.Shape_i)
        assert isinstance(topo[2].op, T.opt.Shape_i)
        assert isinstance(topo[3].op, T.opt.MakeVector)
    mode = mode_with_gpu.excluding("local_shape_to_shape_i")
    f = theano.function([x], x.shape, mode=mode)
    topo = f.maker.fgraph.toposort()
    assert numpy.all(f(v) == (3, 4, 5))
    assert len(topo) == 1
    assert isinstance(topo[0].op, T.Shape)