def test_local_reshape_dimshuffle():

    reshape_dimshuffle = out2in(local_reshape_dimshuffle)

    x = tensor.matrix('x')

    y = x.dimshuffle('x', 0, 'x', 1)
    out = tensor.reshape(y, (1, x.shape[0] * x.shape[1], 1))

    g = FunctionGraph([x], [out])
    reshape_dimshuffle(g)

    topo = g.toposort()
    assert any([not isinstance(x, DimShuffle) for x in topo])
def test_local_alloc_dimshuffle():

    alloc_dimshuffle = out2in(local_alloc_dimshuffle)

    x = tensor.vector('x')
    m = tensor.iscalar('m')

    y = x.dimshuffle('x', 0)
    out = tensor.alloc(y, m, 1, x.shape[0])

    g = FunctionGraph([x, m], [out])
    alloc_dimshuffle(g)

    topo = g.toposort()
    assert any([not isinstance(x, DimShuffle) for x in topo])
def test_local_dimshuffle_subtensor():

    dimshuffle_subtensor = out2in(local_dimshuffle_subtensor)

    x = tensor.tensor4('x')
    x = tensor.patternbroadcast(x, (False, True, False, False))
    i = tensor.iscalar('i')

    out = x[:, :, 10:30, ::i].dimshuffle(0,2,3)

    g = FunctionGraph([x,i], [out])
    dimshuffle_subtensor(g)

    topo = g.toposort()
    assert any([not isinstance(x, DimShuffle) for x in topo])
def test_local_dimshuffle_alloc():

    reshape_dimshuffle = out2in(local_dimshuffle_alloc)

    x = tensor.vector('x')

    out = tensor.alloc(x, 3, 2).dimshuffle('x', 'x', 0, 1)

    g = FunctionGraph([x], [out])
    reshape_dimshuffle(g)

    l = theano.gof.PerformLinker()
    l.accept(g)
    f = l.make_function()

    assert f([3, 4]).ndim == 4

    topo = g.toposort()
    assert any([not isinstance(x, DimShuffle) for x in topo])
def test_local_dimshuffle_subtensor():

    dimshuffle_subtensor = out2in(local_dimshuffle_subtensor)

    x = tensor.dtensor4('x')
    x = tensor.patternbroadcast(x, (False, True, False, False))
    i = tensor.iscalar('i')

    out = x[:, :, 10:30, ::i].dimshuffle(0, 2, 3)

    g = FunctionGraph([x, i], [out])
    dimshuffle_subtensor(g)

    topo = g.toposort()
    assert any([not isinstance(x, DimShuffle) for x in topo])

    # Test dimshuffle remove dimensions the subtensor don't "see".
    x = tensor.tensor(broadcastable=(False, True, False), dtype='float64')
    out = x[i].dimshuffle(1)

    g = FunctionGraph([x, i], [out])
    dimshuffle_subtensor(g)

    topo = g.toposort()
    assert any([not isinstance(x, DimShuffle) for x in topo])

    # Test dimshuffle remove dimensions the subtensor don't "see" but
    # have in between dimensions.
    x = tensor.tensor(broadcastable=(False, True, False, True),
                      dtype='float64')
    out = x[i].dimshuffle(1)

    f = theano.function([x, i], out)

    topo = f.maker.fgraph.toposort()
    assert any([not isinstance(x, DimShuffle) for x in topo])
    assert f(np.random.rand(5, 1, 4, 1), 2).shape == (4,)

    # Test a corner case that had Theano return a bug.
    x = tensor.dtensor4('x')
    x = tensor.patternbroadcast(x, (False, True, False, False))

    assert x[:,:, 0:3, ::-1].dimshuffle(0,2,3).eval({x: np.ones((5, 1, 6, 7))}).shape == (5, 3, 7)
Example #6
0
def test_local_dimshuffle_subtensor():

    dimshuffle_subtensor = out2in(local_dimshuffle_subtensor)

    x = tensor.dtensor4("x")
    x = tensor.patternbroadcast(x, (False, True, False, False))
    i = tensor.iscalar("i")

    out = x[:, :, 10:30, ::i].dimshuffle(0, 2, 3)

    g = FunctionGraph([x, i], [out])
    dimshuffle_subtensor(g)

    topo = g.toposort()
    assert any([not isinstance(x, DimShuffle) for x in topo])

    # Test dimshuffle remove dimensions the subtensor don't "see".
    x = tensor.tensor(broadcastable=(False, True, False), dtype="float64")
    out = x[i].dimshuffle(1)

    g = FunctionGraph([x, i], [out])
    dimshuffle_subtensor(g)

    topo = g.toposort()
    assert any([not isinstance(x, DimShuffle) for x in topo])

    # Test dimshuffle remove dimensions the subtensor don't "see" but
    # have in between dimensions.
    x = tensor.tensor(broadcastable=(False, True, False, True), dtype="float64")
    out = x[i].dimshuffle(1)

    f = theano.function([x, i], out)

    topo = f.maker.fgraph.toposort()
    assert any([not isinstance(x, DimShuffle) for x in topo])
    assert f(numpy.random.rand(5, 1, 4, 1), 2).shape == (4,)
Example #7
0
 def test_constant_cache_error(self):
     v = theano.tensor.constant(1)
     assert v.cached
     with pytest.raises(CachedConstantError):
         FunctionGraph([], [v + 1], clone=False)
 def test_straightforward(self):
     x, y, z = inputs()
     e = mul(add(x, y), div_proxy(x, y))
     g = FunctionGraph([x, y], [e])
     fn = gof.DualLinker().accept(g).make_function()
     assert fn(1.0, 2.0) == 1.5
Example #9
0
def jax_funcify_Scan(op):
    inner_fg = FunctionGraph(op.inputs, op.outputs)
    jax_tt_inner_func = jax_funcify(inner_fg)

    def scan(*outer_inputs):
        scan_args = ScanArgs(
            list(outer_inputs), [None] * op.n_outs, op.inputs, op.outputs, op.info
        )

        # `outer_inputs` is a list with the following composite form:
        # [n_steps]
        # + outer_in_seqs
        # + outer_in_mit_mot
        # + outer_in_mit_sot
        # + outer_in_sit_sot
        # + outer_in_shared
        # + outer_in_nit_sot
        # + outer_in_non_seqs
        n_steps = scan_args.n_steps
        seqs = scan_args.outer_in_seqs

        # TODO: mit_mots
        mit_mot_in_slices = []

        mit_sot_in_slices = []
        for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot):
            neg_taps = [abs(t) for t in tap if t < 0]
            pos_taps = [abs(t) for t in tap if t > 0]
            max_neg = max(neg_taps) if neg_taps else 0
            max_pos = max(pos_taps) if pos_taps else 0
            init_slice = seq[: max_neg + max_pos]
            mit_sot_in_slices.append(init_slice)

        sit_sot_in_slices = [seq[0] for seq in scan_args.outer_in_sit_sot]

        init_carry = (
            mit_mot_in_slices,
            mit_sot_in_slices,
            sit_sot_in_slices,
            scan_args.outer_in_shared,
            scan_args.outer_in_non_seqs,
        )

        def jax_args_to_inner_scan(op, carry, x):
            # `carry` contains all inner-output taps, non_seqs, and shared
            # terms
            (
                inner_in_mit_mot,
                inner_in_mit_sot,
                inner_in_sit_sot,
                inner_in_shared,
                inner_in_non_seqs,
            ) = carry

            # `x` contains the in_seqs
            inner_in_seqs = x

            # `inner_scan_inputs` is a list with the following composite form:
            # inner_in_seqs
            # + sum(inner_in_mit_mot, [])
            # + sum(inner_in_mit_sot, [])
            # + inner_in_sit_sot
            # + inner_in_shared
            # + inner_in_non_seqs
            inner_in_mit_sot_flatten = []
            for array, index in zip(inner_in_mit_sot, scan_args.mit_sot_in_slices):
                inner_in_mit_sot_flatten.extend(array[jnp.array(index)])

            inner_scan_inputs = sum(
                [
                    inner_in_seqs,
                    inner_in_mit_mot,
                    inner_in_mit_sot_flatten,
                    inner_in_sit_sot,
                    inner_in_shared,
                    inner_in_non_seqs,
                ],
                [],
            )

            return inner_scan_inputs

        def inner_scan_outs_to_jax_outs(
            op,
            old_carry,
            inner_scan_outs,
        ):
            (
                inner_in_mit_mot,
                inner_in_mit_sot,
                inner_in_sit_sot,
                inner_in_shared,
                inner_in_non_seqs,
            ) = old_carry

            def update_mit_sot(mit_sot, new_val):
                return jnp.concatenate([mit_sot[1:], new_val[None, ...]], axis=0)

            inner_out_mit_sot = [
                update_mit_sot(mit_sot, new_val)
                for mit_sot, new_val in zip(inner_in_mit_sot, inner_scan_outs)
            ]

            # This should contain all inner-output taps, non_seqs, and shared
            # terms
            if not inner_in_sit_sot:
                inner_out_sit_sot = []
            else:
                inner_out_sit_sot = inner_scan_outs
            new_carry = (
                inner_in_mit_mot,
                inner_out_mit_sot,
                inner_out_sit_sot,
                inner_in_shared,
                inner_in_non_seqs,
            )

            return new_carry

        def jax_inner_func(carry, x):
            inner_args = jax_args_to_inner_scan(op, carry, x)
            inner_scan_outs = [fn(*inner_args) for fn in jax_tt_inner_func]
            new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs)
            return new_carry, inner_scan_outs

        _, scan_out = jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps)

        # We need to prepend the initial values so that the JAX output will
        # match the raw `Scan` `Op` output and, thus, work with a downstream
        # `Subtensor` `Op` introduced by the `scan` helper function.
        def append_scan_out(scan_in_part, scan_out_part):
            return jnp.concatenate([scan_in_part[:-n_steps], scan_out_part], axis=0)

        if scan_args.outer_in_mit_sot:
            scan_out_final = [
                append_scan_out(init, out)
                for init, out in zip(scan_args.outer_in_mit_sot, scan_out)
            ]
        elif scan_args.outer_in_sit_sot:
            scan_out_final = [
                append_scan_out(init, out)
                for init, out in zip(scan_args.outer_in_sit_sot, scan_out)
            ]

        if len(scan_out_final) == 1:
            scan_out_final = scan_out_final[0]
        return scan_out_final

    return scan
 def test_neq(self):
     x, y, z = inputs()
     fn = gof.DualLinker().accept(FunctionGraph(
         [x, y], [neq(x, y)])).make_function()
     for a, b in ((3., 9), (3, 0.9), (3, 3)):
         self.assertTrue(fn(a, b) == (a != b))
 def test_xor(self):
     x, y, z = ints('xyz')
     fn = gof.DualLinker().accept(FunctionGraph([x, y],
                                                [x ^ y])).make_function()
     for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)):
         self.assertTrue(fn(a, b) == (a ^ b), (a, b))
Example #12
0
def jax_funcify_Scan(op):
    inner_fg = FunctionGraph(op.inputs, op.outputs)
    jax_tt_inner_func = jax_funcify(inner_fg)

    def scan(*outer_inputs):
        scan_args = ScanArgs(
            outer_inputs, [None] * op.n_outs, op.inputs, op.outputs, op.info
        )

        # `outer_inputs` is a list with the following composite form:
        # [n_steps]
        # + outer_in_seqs
        # + outer_in_mit_mot
        # + outer_in_mit_sot
        # + outer_in_sit_sot
        # + outer_in_shared
        # + outer_in_nit_sot
        # + outer_in_non_seqs
        n_steps = scan_args.n_steps
        seqs = scan_args.outer_in_seqs

        n_non_seqs = len(scan_args.outer_in_non_seqs)

        # TODO: sit_sots
        mit_sot_in_slices = []
        for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot):
            neg_taps = [abs(t) for t in tap if t < 0]
            pos_taps = [abs(t) for t in tap if t > 0]
            max_neg = max(neg_taps) if neg_taps else 0
            max_pos = max(pos_taps) if pos_taps else 0
            init_slice = seq[: max_neg + max_pos]
            mit_sot_in_slices.append(init_slice)

        init_carry = [mit_sot_in_slices, scan_args.outer_in_non_seqs]

        def jax_args_to_inner_scan(op, carry, x):
            # `carry` contains all inner-output taps, non_seqs, and shared
            # terms
            (
                inner_in_mit_mot,
                inner_in_mit_sot,
                inner_in_sit_sot,
                inner_in_shared,
                inner_in_non_seqs,
            ) = carry

            # `x` contains the in_seqs
            inner_in_seqs = x

            # `inner_scan_inputs` is a list with the following composite form:
            # inner_in_seqs
            # + sum(inner_in_mit_mot, [])
            # + sum(inner_in_mit_sot, [])
            # + inner_in_sit_sot
            # + inner_in_shared
            # + inner_in_non_seqs
            inner_scan_inputs = [
                inner_in_seqs,
                inner_in_mit_mot,
                inner_in_mit_sot,
                inner_in_sit_sot,
                inner_in_non_seqs,
            ]

            raise NotImplementedError()
            return inner_scan_inputs

        def inner_scan_outs_to_jax_outs(
            op,
            old_carry,
            inner_scan_outs,
        ):
            # `inner_scan_outs` is a list with the following
            # composite form:
            # outer_out_mit_mot
            # + outer_out_mit_sot
            # + outer_out_sit_sot
            # + outer_out_nit_sot
            # + outer_out_shared
            # + cond
            (
                outer_out_mit_mot,
                outer_out_mit_sot,
                outer_out_sit_sot,
                outer_out_nit_sot,
                outer_out_shared,
                cond,
            ) = inner_scan_outs
            outer_out_non_seqs = old_carry[:-n_non_seqs]

            # This should contain all inner-output taps, non_seqs, and shared
            # terms
            carry = [
                outer_out_mit_mot,
                outer_out_mit_sot,
                outer_out_sit_sot,
                outer_out_shared,
                outer_out_non_seqs,
            ]
            # This should contain all inner-outputs that produce
            # outer-outputs
            y = []

            raise NotImplementedError()
            return (carry, y)

        def jax_inner_func(carry, x):
            inner_args = jax_args_to_inner_scan(op, carry, x)
            inner_scan_outs = jax_tt_inner_func(*inner_args)
            new_carry, y = inner_scan_outs_to_jax_outs(op, inner_scan_outs)
            return new_carry, y

        return jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps)

    return scan
Example #13
0
def test_mul_add_div_proxy():
    x, y, z = floats("xyz")
    e = mul(add(x, y), div_proxy(x, y))
    g = FunctionGraph([x, y], [e])
    fn = gof.DualLinker().accept(g).make_function()
    assert fn(1.0, 2.0) == 1.5
Example #14
0
 def test_or(self):
     x, y, z = ints("xyz")
     fn = gof.DualLinker().accept(FunctionGraph([x, y], [x | y])).make_function()
     for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)):
         assert fn(a, b) == (a | b), (a, b)
Example #15
0
 def test_neq(self):
     x, y, z = floats("xyz")
     fn = gof.DualLinker().accept(FunctionGraph([x, y], [neq(x, y)])).make_function()
     for a, b in ((3.0, 9), (3, 0.9), (3, 3)):
         assert fn(a, b) == (a != b)
Example #16
0
 def test_clone(self):
     v = theano.tensor.constant(1)
     assert v.cached
     FunctionGraph([], [v + 1])