示例#1
0
def test_local_sampling_dot_csr():
    if not theano.config.cxx:
        raise SkipTest("G++ not available, so we need to skip this test.")
    mode = theano.compile.mode.get_default_mode()
    mode = mode.including("specialize", "local_sampling_dot_csr")

    for sp_format in ['csr']:  # Not implemented for other format
        inputs = [
            tensor.matrix(),
            tensor.matrix(),
            getattr(theano.sparse, sp_format + '_matrix')()
        ]

        f = theano.function(inputs, sparse.sampling_dot(*inputs), mode=mode)

        if theano.config.blas.ldflags:
            assert not any(
                isinstance(node.op, sparse.SamplingDot)
                for node in f.maker.fgraph.toposort())
        else:
            # SamplingDotCSR's C implementation needs blas, so it should not
            # be inserted
            assert not any(
                isinstance(node.op, sparse.opt.SamplingDotCSR)
                for node in f.maker.fgraph.toposort())
示例#2
0
文件: test_opt.py 项目: npinto/Theano
def test_local_sampling_dot_csr():
    mode = theano.compile.mode.get_default_mode()
    mode = mode.including("specialize", "local_sampling_dot_csr")

    for sp_format in ['csr']:  # Not implemented for other format
        inputs = [tensor.matrix(),
                  tensor.matrix(),
                  getattr(theano.sparse, sp_format + '_matrix')()]

        f = theano.function(inputs,
                            sparse.sampling_dot(*inputs),
                            mode=mode)

        assert not any(isinstance(node.op, sparse.SamplingDot) for node
                       in f.maker.fgraph.toposort())
示例#3
0
def test_local_sampling_dot_csr():
    if not theano.config.cxx:
        raise SkipTest("G++ not available, so we need to skip this test.")
    mode = theano.compile.mode.get_default_mode()
    mode = mode.including("specialize", "local_sampling_dot_csr")

    for sp_format in ['csr']:  # Not implemented for other format
        inputs = [tensor.matrix(),
                  tensor.matrix(),
                  getattr(theano.sparse, sp_format + '_matrix')()]

        f = theano.function(inputs,
                            sparse.sampling_dot(*inputs),
                            mode=mode)

        if theano.config.blas.ldflags:
            assert not any(isinstance(node.op, sparse.SamplingDot) for node
                       in f.maker.fgraph.toposort())
        else:
            # SamplingDotCSR's C implementation needs blas, so it should not
            # be inserted
            assert not any(isinstance(node.op, sparse.opt.SamplingDotCSR) for node
                       in f.maker.fgraph.toposort())