def test_clinker_dups_inner(): # Testing that duplicates are allowed inside the graph x, y, z = inputs() e = add(mul(y, y), add(x, z)) lnk = CLinker().accept(Env([x, y, z], [e])) fn = lnk.make_function() assert fn(1.0, 2.0, 3.0) == 8.0
def test_clinker_not_used_inputs(): # Testing that unused inputs are allowed. x, y, z = inputs() e = add(x, y) lnk = CLinker().accept(Env([x, y, z], [e])) fn = lnk.make_function() assert fn(2.0, 1.5, 1.0) == 3.5
def test_clinker_dups(): # Testing that duplicate inputs are allowed. x, y, z = inputs() e = add(x, x) lnk = CLinker().accept(Env([x, x], [e])) fn = lnk.make_function() assert fn(2.0, 2.0) == 4
def test_clinker_literal_inlining(): x, y, z = inputs() z = Constant(tdouble, 4.12345678) e = add(mul(add(x, y), div(x, y)), bad_sub(bad_sub(x, y), z)) lnk = CLinker().accept(Env([x, y], [e])) fn = lnk.make_function() assert abs(fn(2.0, 2.0) + 0.12345678) < 1e-9 code = lnk.code_gen() # print "=== Code generated ===" # print code assert "4.12345678" in code # we expect the number to be inlined
def test_duallinker_mismatch(): x, y, z = inputs() # bad_sub is correct in C but erroneous in Python e = bad_sub(mul(x, y), mul(y, z)) g = Env([x, y, z], [e]) lnk = DualLinker(checker=_my_checker).accept(g) fn = lnk.make_function() # good assert CLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0 # good assert OpWiseCLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0 # (purposely) wrong assert PerformLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -10.0 with pytest.raises(MyExc): # this runs OpWiseCLinker and PerformLinker in parallel and feeds # variables of matching operations to _my_checker to verify that they # are the same. fn(1.0, 2.0, 3.0)
import theano from theano import config, gof from theano.compile.function.types import Supervisor from theano.link.basic import PerformLinker from theano.link.c.basic import CLinker, OpWiseCLinker from theano.link.jax import JAXLinker from theano.link.vm import VMLinker _logger = logging.getLogger("theano.compile.mode") # If a string is passed as the linker argument in the constructor for # Mode, it will be used as the key to retrieve the real linker in this # dictionary predefined_linkers = { "py": PerformLinker(), # Use allow_gc Theano flag "c": CLinker(), # Don't support gc. so don't check allow_gc "c|py": OpWiseCLinker(), # Use allow_gc Theano flag "c|py_nogc": OpWiseCLinker(allow_gc=False), "vm": VMLinker(use_cloop=False), # Use allow_gc Theano flag "cvm": VMLinker(use_cloop=True), # Use allow_gc Theano flag "vm_nogc": VMLinker(allow_gc=False, use_cloop=False), "cvm_nogc": VMLinker(allow_gc=False, use_cloop=True), "jax": JAXLinker(), } def register_linker(name, linker): """Add a `Linker` which can be referred to by `name` in `Mode`.""" if name in predefined_linkers: raise ValueError(f"Linker name already taken: {name}") predefined_linkers[name] = linker
def test_shared_input_output(): # Test bug reported on the mailing list by Alberto Orlandi # https://groups.google.com/d/topic/theano-users/6dLaEqc2R6g/discussion # The shared variable is both an input and an output of the function. inc = theano.tensor.iscalar("inc") state = theano.shared(0) state.name = "state" linker = CLinker() mode = theano.Mode(linker=linker) f = theano.function([inc], state, updates=[(state, state + inc)], mode=mode) g = theano.function([inc], state, updates=[(state, state + inc)]) # Initial value f0 = f(0) g0 = g(0) assert f0 == g0 == 0, (f0, g0) # Increment state via f, returns the previous value. f2 = f(2) assert f2 == f0, (f2, f0) f0 = f(0) g0 = g(0) assert f0 == g0 == 2, (f0, g0) # Increment state via g, returns the previous value g3 = g(3) assert g3 == g0, (g3, g0) f0 = f(0) g0 = g(0) assert f0 == g0 == 5, (f0, g0) vstate = theano.shared(np.zeros(3, dtype="int32")) vstate.name = "vstate" fv = theano.function([inc], vstate, updates=[(vstate, vstate + inc)], mode=mode) gv = theano.function([inc], vstate, updates=[(vstate, vstate + inc)]) # Initial value fv0 = fv(0) gv0 = gv(0) assert np.all(fv0 == 0), fv0 assert np.all(gv0 == 0), gv0 # Increment state via f, returns the previous value. fv2 = fv(2) assert np.all(fv2 == fv0), (fv2, fv0) fv0 = fv(0) gv0 = gv(0) assert np.all(fv0 == 2), fv0 assert np.all(gv0 == 2), gv0 # Increment state via g, returns the previous value gv3 = gv(3) assert np.all(gv3 == gv0), (gv3, gv0) fv0 = fv(0) gv0 = gv(0) assert np.all(fv0 == 5), fv0 assert np.all(gv0 == 5), gv0
def test_clinker_single_node(): x, y, z = inputs() node = add.make_node(x, y) lnk = CLinker().accept(Env(node.inputs, node.outputs)) fn = lnk.make_function() assert fn(2.0, 7.0) == 9
def test_clinker_straightforward(): x, y, z = inputs() e = add(mul(add(x, y), div(x, y)), bad_sub(bad_sub(x, y), z)) lnk = CLinker().accept(Env([x, y, z], [e])) fn = lnk.make_function() assert fn(2.0, 2.0, 2.0) == 2.0