예제 #1
0
    def test_not(self):
        x, y, z = ints("xyz")
        fn = gof.DualLinker().accept(FunctionGraph([x, y], [invert(x)])).make_function()
        for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)):
            assert fn(a, b) == ~a, (a,)

        x, y, z = ints("xyz")
        fn = gof.DualLinker().accept(FunctionGraph([x, y], [~x])).make_function()
        for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)):
            assert fn(a, b) == ~a, (a,)
예제 #2
0
    def test_and(self):
        x, y, z = ints('xyz')
        fn = gof.DualLinker().accept(FunctionGraph([x, y], [and_(x, y)])).make_function()
        for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)):
            self.assertTrue(fn(a, b) == (a & b), (a, b))

        x, y, z = ints('xyz')
        fn = gof.DualLinker().accept(FunctionGraph([x, y], [x & y])).make_function()
        for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)):
            self.assertTrue(fn(a, b) == (a & b), (a, b))
예제 #3
0
    def test_and(self):
        x, y, z = ints("xyz")
        fn = DualLinker().accept(FunctionGraph([x, y],
                                               [and_(x, y)])).make_function()
        for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)):
            assert fn(a, b) == (a & b), (a, b)

        x, y, z = ints("xyz")
        fn = DualLinker().accept(FunctionGraph([x, y],
                                               [x & y])).make_function()
        for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)):
            assert fn(a, b) == (a & b), (a, b)
예제 #4
0
 def merge(self):
     from theano.gof import FunctionGraph
     from theano.gof.opt import merge_optimizer
     fg = FunctionGraph(self.inputs + self.shared_variables,
                        self.outputs,
                        clone=True)
     merge_optimizer.optimize(fg)
     return ComputationGraph(fg.outputs)
예제 #5
0
def test_local_dimshuffle_subtensor():

    dimshuffle_subtensor = out2in(local_dimshuffle_subtensor)

    x = tensor.dtensor4("x")
    x = tensor.patternbroadcast(x, (False, True, False, False))
    i = tensor.iscalar("i")

    out = x[:, :, 10:30, ::i].dimshuffle(0, 2, 3)

    g = FunctionGraph([x, i], [out])
    dimshuffle_subtensor(g)

    topo = g.toposort()
    assert any([not isinstance(x, DimShuffle) for x in topo])

    # Test dimshuffle remove dimensions the subtensor don't "see".
    x = tensor.tensor(broadcastable=(False, True, False), dtype="float64")
    out = x[i].dimshuffle(1)

    g = FunctionGraph([x, i], [out])
    dimshuffle_subtensor(g)

    topo = g.toposort()
    assert any([not isinstance(x, DimShuffle) for x in topo])

    # Test dimshuffle remove dimensions the subtensor don't "see" but
    # have in between dimensions.
    x = tensor.tensor(broadcastable=(False, True, False, True),
                      dtype="float64")
    out = x[i].dimshuffle(1)

    f = theano.function([x, i], out)

    topo = f.maker.fgraph.toposort()
    assert any([not isinstance(x, DimShuffle) for x in topo])
    assert f(np.random.rand(5, 1, 4, 1), 2).shape == (4, )

    # Test a corner case that had Theano return a bug.
    x = tensor.dtensor4("x")
    x = tensor.patternbroadcast(x, (False, True, False, False))

    assert x[:, :, 0:3, ::-1].dimshuffle(0, 2, 3).eval({
        x: np.ones((5, 1, 6, 7))
    }).shape == (5, 3, 7)
예제 #6
0
파일: test_basic.py 프로젝트: yubow/Theano
 def test_straightforward(self):
     x, y, z = inputs()
     e = mul(add(x, y), div_proxy(x, y))
     C = Composite([x, y], [e])
     c = C.make_node(x, y)
     # print c.c_code(['x', 'y'], ['z'], dict(id = 0))
     g = FunctionGraph([x, y], [c.out])
     fn = gof.DualLinker().accept(g).make_function()
     assert fn(1.0, 2.0) == 1.5
예제 #7
0
파일: test_basic.py 프로젝트: yubow/Theano
 def test_with_constants(self):
     x, y, z = inputs()
     e = mul(add(70.0, y), div_proxy(x, y))
     C = Composite([x, y], [e])
     c = C.make_node(x, y)
     assert "70.0" in c.op.c_code(c, 'dummy', ['x', 'y'], ['z'], dict(id=0))
     # print c.c_code(['x', 'y'], ['z'], dict(id = 0))
     g = FunctionGraph([x, y], [c.out])
     fn = gof.DualLinker().accept(g).make_function()
     assert fn(1.0, 2.0) == 36.0
예제 #8
0
 def test_with_constants(self):
     x, y, z = floats("xyz")
     e = mul(add(70.0, y), true_div(x, y))
     C = Composite([x, y], [e])
     c = C.make_node(x, y)
     assert "70.0" in c.op.c_code(c, "dummy", ["x", "y"], ["z"], dict(id=0))
     # print c.c_code(['x', 'y'], ['z'], dict(id = 0))
     g = FunctionGraph([x, y], [c.out])
     fn = DualLinker().accept(g).make_function()
     assert fn(1.0, 2.0) == 36.0
예제 #9
0
파일: test_basic.py 프로젝트: yubow/Theano
 def test_many_outputs(self):
     x, y, z = inputs()
     e0 = x + y + z
     e1 = x + y * z
     e2 = x / y
     C = Composite([x, y, z], [e0, e1, e2])
     c = C.make_node(x, y, z)
     # print c.c_code(['x', 'y', 'z'], ['out0', 'out1', 'out2'], dict(id = 0))
     g = FunctionGraph([x, y, z], c.outputs)
     fn = gof.DualLinker().accept(g).make_function()
     assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5]
예제 #10
0
    def tes_mod(self):
        # We add this test as not all language and C implementation give the same
        # sign to the result. This check that the c_code of `Mod` is implemented
        # as Python. That is what we want.

        x, y = ints('xy')
        fn = gof.DualLinker().accept(FunctionGraph([x, y],
                                                   [x % y])).make_function()
        for a, b in ((0, 1), (1, 1), (0, -1), (1, -1), (-1, -1), (1, 2), (-1,
                                                                          2),
                     (1, -2), (-1, -2), (5, 3), (-5, 3), (5, -3), (-5, -3)):
            self.assertTrue(fn(a, b) == a % b, (a, ))
def test_local_reshape_dimshuffle():

    reshape_dimshuffle = out2in(local_reshape_dimshuffle)

    x = tensor.matrix("x")

    y = x.dimshuffle("x", 0, "x", 1)
    out = tensor.reshape(y, (1, x.shape[0] * x.shape[1], 1))

    g = FunctionGraph([x], [out])
    reshape_dimshuffle(g)

    topo = g.toposort()
    assert any([not isinstance(x, DimShuffle) for x in topo])
def test_local_dimshuffle_subtensor():

    dimshuffle_subtensor = out2in(local_dimshuffle_subtensor)

    x = tensor.dtensor4('x')
    x = tensor.patternbroadcast(x, (False, True, False, False))
    i = tensor.iscalar('i')

    out = x[:, :, 10:30, ::i].dimshuffle(0, 2, 3)

    g = FunctionGraph([x, i], [out])
    dimshuffle_subtensor(g)

    topo = g.toposort()
    assert any([not isinstance(x, DimShuffle) for x in topo])

    # Test dimshuffle remove dimensions the subtensor don't "see".
    x = tensor.tensor(broadcastable=(False, True, False), dtype='float64')
    out = x[i].dimshuffle(1)

    g = FunctionGraph([x, i], [out])
    dimshuffle_subtensor(g)

    topo = g.toposort()
    assert any([not isinstance(x, DimShuffle) for x in topo])

    # Test dimshuffle remove dimensions the subtensor don't "see" but
    # have in between dimensions.
    x = tensor.tensor(broadcastable=(False, True, False, True),
                      dtype='float64')
    out = x[i].dimshuffle(1)

    f = theano.function([x, i], out)

    topo = f.maker.fgraph.toposort()
    assert any([not isinstance(x, DimShuffle) for x in topo])
    assert f(np.random.rand(5, 1, 4, 1), 2).shape == (4, )
def test_local_alloc_dimshuffle():

    alloc_dimshuffle = out2in(local_alloc_dimshuffle)

    x = tensor.vector("x")
    m = tensor.iscalar("m")

    y = x.dimshuffle("x", 0)
    out = tensor.alloc(y, m, 1, x.shape[0])

    g = FunctionGraph([x, m], [out])
    alloc_dimshuffle(g)

    topo = g.toposort()
    assert any([not isinstance(x, DimShuffle) for x in topo])
예제 #14
0
    def test_clip_grad(self):
        # This is testing for the issue #633
        x, y = floats('xy')
        a = theano.tensor.clip(x, y, x)
        g = theano.gradient.grad(a, x)
        fn = gof.DualLinker().accept(FunctionGraph([x, y],
                                                   [g])).make_function()

        # Test the other way around as well
        a2 = theano.tensor.clip(x, x, y)
        g2 = theano.gradient.grad(a2, x)
        fn2 = gof.DualLinker().accept(FunctionGraph([x, y],
                                                    [g2])).make_function()

        # Test for the equal case too .
        a3 = theano.tensor.clip(x, x, x)
        g3 = theano.gradient.grad(a3, x)
        fn3 = gof.DualLinker().accept(FunctionGraph([x], [g3])).make_function()

        rng = np.random.RandomState(utt.fetch_seed())

        ntests = 50
        for i in xrange(ntests):
            xval = rng.rand(1)
            # To ensure that the min < x .
            yval_mn = rng.rand(1) - 1.0

            # To ensure that the max > x.
            yval_mx = rng.rand(1) + 1.0

            aval = fn(xval, yval_mn)
            aval2 = fn2(xval, yval_mx)
            aval3 = fn3(xval)
            self.assertTrue(aval == 1.)
            self.assertTrue(aval2 == 1.)
            self.assertTrue(aval3 == 1.)
def test_local_dimshuffle_alloc():

    reshape_dimshuffle = out2in(local_dimshuffle_alloc)

    x = tensor.vector("x")

    out = tensor.alloc(x, 3, 2).dimshuffle("x", "x", 0, 1)

    g = FunctionGraph([x], [out])
    reshape_dimshuffle(g)

    l = PerformLinker()
    l.accept(g)
    f = l.make_function()

    assert f([3, 4]).ndim == 4

    topo = g.toposort()
    assert any([not isinstance(x, DimShuffle) for x in topo])
예제 #16
0
    def test_composite_printing(self):
        x, y, z = floats('xyz')
        e0 = x + y + z
        e1 = x + y * z
        e2 = x / y
        e3 = x // 5
        e4 = -x
        e5 = x - y
        e6 = x**y + (-z)
        e7 = x % 3
        C = Composite([x, y, z], [e0, e1, e2, e3, e4, e5, e6, e7])
        c = C.make_node(x, y, z)
        g = FunctionGraph([x, y, z], c.outputs)
        gof.DualLinker().accept(g).make_function()

        assert str(g) == ('[*1 -> Composite{((i0 + i1) + i2),'
                          ' (i0 + (i1 * i2)), (i0 / i1), '
                          '(i0 // Constant{5}), '
                          '(-i0), (i0 - i1), ((i0 ** i1) + (-i2)),'
                          ' (i0 % Constant{3})}(x, y, z), '
                          '*1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7]')
예제 #17
0
def jax_funcify_Scan(op):
    inner_fg = FunctionGraph(op.inputs, op.outputs)
    jax_tt_inner_func = jax_funcify(inner_fg)

    def scan(*outer_inputs):
        scan_args = ScanArgs(outer_inputs, [None] * op.n_outs, op.inputs,
                             op.outputs, op.info)

        # `outer_inputs` is a list with the following composite form:
        # [n_steps]
        # + outer_in_seqs
        # + outer_in_mit_mot
        # + outer_in_mit_sot
        # + outer_in_sit_sot
        # + outer_in_shared
        # + outer_in_nit_sot
        # + outer_in_non_seqs
        n_steps = scan_args.n_steps
        seqs = scan_args.outer_in_seqs

        n_non_seqs = len(scan_args.outer_in_non_seqs)

        # TODO: sit_sots
        mit_sot_in_slices = []
        for tap, seq in zip(scan_args.mit_sot_in_slices,
                            scan_args.outer_in_mit_sot):
            neg_taps = [abs(t) for t in tap if t < 0]
            pos_taps = [abs(t) for t in tap if t > 0]
            max_neg = max(neg_taps) if neg_taps else 0
            max_pos = max(pos_taps) if pos_taps else 0
            init_slice = seq[:max_neg + max_pos]
            mit_sot_in_slices.append(init_slice)

        init_carry = [mit_sot_in_slices, scan_args.outer_in_non_seqs]

        def jax_args_to_inner_scan(op, carry, x):
            # `carry` contains all inner-output taps, non_seqs, and shared
            # terms
            (
                inner_in_mit_mot,
                inner_in_mit_sot,
                inner_in_sit_sot,
                inner_in_shared,
                inner_in_non_seqs,
            ) = carry

            # `x` contains the in_seqs
            inner_in_seqs = x

            # `inner_scan_inputs` is a list with the following composite form:
            # inner_in_seqs
            # + sum(inner_in_mit_mot, [])
            # + sum(inner_in_mit_sot, [])
            # + inner_in_sit_sot
            # + inner_in_shared
            # + inner_in_non_seqs
            inner_scan_inputs = [
                inner_in_seqs,
                inner_in_mit_mot,
                inner_in_mit_sot,
                inner_in_sit_sot,
                inner_in_non_seqs,
            ]

            raise NotImplementedError()
            return inner_scan_inputs

        def inner_scan_outs_to_jax_outs(
            op,
            old_carry,
            inner_scan_outs,
        ):
            # `inner_scan_outs` is a list with the following
            # composite form:
            # outer_out_mit_mot
            # + outer_out_mit_sot
            # + outer_out_sit_sot
            # + outer_out_nit_sot
            # + outer_out_shared
            # + cond
            (
                outer_out_mit_mot,
                outer_out_mit_sot,
                outer_out_sit_sot,
                outer_out_nit_sot,
                outer_out_shared,
                cond,
            ) = inner_scan_outs
            outer_out_non_seqs = old_carry[:-n_non_seqs]

            # This should contain all inner-output taps, non_seqs, and shared
            # terms
            carry = [
                outer_out_mit_mot,
                outer_out_mit_sot,
                outer_out_sit_sot,
                outer_out_shared,
                outer_out_non_seqs,
            ]
            # This should contain all inner-outputs that produce
            # outer-outputs
            y = []

            raise NotImplementedError()
            return (carry, y)

        def jax_inner_func(carry, x):
            inner_args = jax_args_to_inner_scan(op, carry, x)
            inner_scan_outs = jax_tt_inner_func(*inner_args)
            new_carry, y = inner_scan_outs_to_jax_outs(op, inner_scan_outs)
            return new_carry, y

        return jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps)

    return scan
예제 #18
0
파일: test_basic.py 프로젝트: yubow/Theano
 def test_straightforward(self):
     x, y, z = inputs()
     e = mul(add(x, y), div_proxy(x, y))
     g = FunctionGraph([x, y], [e])
     fn = gof.DualLinker().accept(g).make_function()
     assert fn(1.0, 2.0) == 1.5
예제 #19
0
파일: test_basic.py 프로젝트: yubow/Theano
 def test_neq(self):
     x, y, z = inputs()
     fn = gof.DualLinker().accept(FunctionGraph(
         [x, y], [neq(x, y)])).make_function()
     for a, b in ((3., 9), (3, 0.9), (3, 3)):
         self.assertTrue(fn(a, b) == (a != b))
예제 #20
0
 def test_clone(self):
     v = theano.tensor.constant(1)
     assert v.cached
     FunctionGraph([], [v + 1])
예제 #21
0
 def test_xor(self):
     x, y, z = ints("xyz")
     fn = gof.DualLinker().accept(FunctionGraph([x, y], [x ^ y])).make_function()
     for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)):
         assert fn(a, b) == (a ^ b), (a, b)
예제 #22
0
 def test_ge(self):
     x, y, z = floats("xyz")
     fn = DualLinker().accept(FunctionGraph([x, y],
                                            [x >= y])).make_function()
     for a, b in ((3.0, 9), (3, 0.9), (3, 3)):
         assert fn(a, b) == (a >= b)
예제 #23
0
def test_mul_add_div_proxy():
    x, y, z = floats("xyz")
    e = mul(add(x, y), div_proxy(x, y))
    g = FunctionGraph([x, y], [e])
    fn = gof.DualLinker().accept(g).make_function()
    assert fn(1.0, 2.0) == 1.5
예제 #24
0
 def test_neq(self):
     x, y, z = floats("xyz")
     fn = gof.DualLinker().accept(FunctionGraph([x, y], [neq(x, y)])).make_function()
     for a, b in ((3.0, 9), (3, 0.9), (3, 3)):
         assert fn(a, b) == (a != b)
예제 #25
0
def jax_funcify_Scan(op):
    inner_fg = FunctionGraph(op.inputs, op.outputs)
    jax_tt_inner_func = jax_funcify(inner_fg)

    def scan(*outer_inputs):
        scan_args = ScanArgs(list(outer_inputs), [None] * op.n_outs, op.inputs,
                             op.outputs, op.info)

        # `outer_inputs` is a list with the following composite form:
        # [n_steps]
        # + outer_in_seqs
        # + outer_in_mit_mot
        # + outer_in_mit_sot
        # + outer_in_sit_sot
        # + outer_in_shared
        # + outer_in_nit_sot
        # + outer_in_non_seqs
        n_steps = scan_args.n_steps
        seqs = scan_args.outer_in_seqs

        # TODO: mit_mots
        mit_mot_in_slices = []

        mit_sot_in_slices = []
        for tap, seq in zip(scan_args.mit_sot_in_slices,
                            scan_args.outer_in_mit_sot):
            neg_taps = [abs(t) for t in tap if t < 0]
            pos_taps = [abs(t) for t in tap if t > 0]
            max_neg = max(neg_taps) if neg_taps else 0
            max_pos = max(pos_taps) if pos_taps else 0
            init_slice = seq[:max_neg + max_pos]
            mit_sot_in_slices.append(init_slice)

        sit_sot_in_slices = [seq[0] for seq in scan_args.outer_in_sit_sot]

        init_carry = (
            mit_mot_in_slices,
            mit_sot_in_slices,
            sit_sot_in_slices,
            scan_args.outer_in_shared,
            scan_args.outer_in_non_seqs,
        )

        def jax_args_to_inner_scan(op, carry, x):
            # `carry` contains all inner-output taps, non_seqs, and shared
            # terms
            (
                inner_in_mit_mot,
                inner_in_mit_sot,
                inner_in_sit_sot,
                inner_in_shared,
                inner_in_non_seqs,
            ) = carry

            # `x` contains the in_seqs
            inner_in_seqs = x

            # `inner_scan_inputs` is a list with the following composite form:
            # inner_in_seqs
            # + sum(inner_in_mit_mot, [])
            # + sum(inner_in_mit_sot, [])
            # + inner_in_sit_sot
            # + inner_in_shared
            # + inner_in_non_seqs
            inner_in_mit_sot_flatten = []
            for array, index in zip(inner_in_mit_sot,
                                    scan_args.mit_sot_in_slices):
                inner_in_mit_sot_flatten.extend(array[jnp.array(index)])

            inner_scan_inputs = sum(
                [
                    inner_in_seqs,
                    inner_in_mit_mot,
                    inner_in_mit_sot_flatten,
                    inner_in_sit_sot,
                    inner_in_shared,
                    inner_in_non_seqs,
                ],
                [],
            )

            return inner_scan_inputs

        def inner_scan_outs_to_jax_outs(
            op,
            old_carry,
            inner_scan_outs,
        ):
            (
                inner_in_mit_mot,
                inner_in_mit_sot,
                inner_in_sit_sot,
                inner_in_shared,
                inner_in_non_seqs,
            ) = old_carry

            def update_mit_sot(mit_sot, new_val):
                return jnp.concatenate([mit_sot[1:], new_val[None, ...]],
                                       axis=0)

            inner_out_mit_sot = [
                update_mit_sot(mit_sot, new_val)
                for mit_sot, new_val in zip(inner_in_mit_sot, inner_scan_outs)
            ]

            # This should contain all inner-output taps, non_seqs, and shared
            # terms
            if not inner_in_sit_sot:
                inner_out_sit_sot = []
            else:
                inner_out_sit_sot = inner_scan_outs
            new_carry = (
                inner_in_mit_mot,
                inner_out_mit_sot,
                inner_out_sit_sot,
                inner_in_shared,
                inner_in_non_seqs,
            )

            return new_carry

        def jax_inner_func(carry, x):
            inner_args = jax_args_to_inner_scan(op, carry, x)
            inner_scan_outs = [fn(*inner_args) for fn in jax_tt_inner_func]
            new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs)
            return new_carry, inner_scan_outs

        _, scan_out = jax.lax.scan(jax_inner_func,
                                   init_carry,
                                   seqs,
                                   length=n_steps)

        # We need to prepend the initial values so that the JAX output will
        # match the raw `Scan` `Op` output and, thus, work with a downstream
        # `Subtensor` `Op` introduced by the `scan` helper function.
        def append_scan_out(scan_in_part, scan_out_part):
            return jnp.concatenate([scan_in_part[:-n_steps], scan_out_part],
                                   axis=0)

        if scan_args.outer_in_mit_sot:
            scan_out_final = [
                append_scan_out(init, out)
                for init, out in zip(scan_args.outer_in_mit_sot, scan_out)
            ]
        elif scan_args.outer_in_sit_sot:
            scan_out_final = [
                append_scan_out(init, out)
                for init, out in zip(scan_args.outer_in_sit_sot, scan_out)
            ]

        if len(scan_out_final) == 1:
            scan_out_final = scan_out_final[0]
        return scan_out_final

    return scan
예제 #26
0
 def test_eq(self):
     x, y, z = inputs()
     fn = gof.DualLinker().accept(FunctionGraph([x, y], [eq(x, y)])).make_function()
     for a, b in ((3.0, 9), (3, 0.9), (3, 3)):
         assert fn(a, b) == (a == b)
예제 #27
0
 def test_constant_cache_error(self):
     v = theano.tensor.constant(1)
     assert v.cached
     with pytest.raises(CachedConstantError):
         FunctionGraph([], [v + 1], clone=False)