Beispiel #1
0
    def test_modes(self):
        # Then, build a mode with the same linker, and a modified optimizer
        default_mode = get_default_mode()
        modified_mode = default_mode.including("specialize")

        # The following line used to fail, with Python 2.4, in July 2012,
        # because an fgraph was associated to the default linker
        copy.deepcopy(modified_mode)

        # More straightforward test
        linker = get_default_mode().linker
        assert not hasattr(linker, "fgraph") or linker.fgraph is None
Beispiel #2
0
def test_alloc_inputs2():
    W1 = matrix()
    W2 = matrix()
    h0 = vector()

    def lambda_fn(W1, h, W2):
        return W1 * dot(h, W2)

    o, _ = scan(
        lambda_fn,
        sequences=at.zeros_like(W1),
        outputs_info=h0,
        non_sequences=[at.zeros_like(W2)],
        n_steps=5,
    )

    f = function([h0, W1, W2], o, mode=get_default_mode().including("scan"))
    scan_node = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)][0]

    assert (
        len(
            [
                x
                for x in scan_node.op.fn.maker.fgraph.toposort()
                if isinstance(x.op, Elemwise)
            ]
        )
        == 0
    )
Beispiel #3
0
def test_inner_replace_dot():
    """
    This tests that rewrites are applied to the inner-graph.
    In particular, BLAS-based rewrites that remove the original dot product.

    This was previously a test with a name that implied it was testing the
    `Scan` push-out rewrites, but it wasn't testing that at all, because the
    rewrites were never being applied.
    """
    W = matrix("W")
    h = matrix("h")

    mode = get_default_mode().including("scan")  # .excluding("BlasOpt")

    o, _ = scan(
        lambda hi, him1, W: (hi, dot(hi + him1, W)),
        outputs_info=[at.zeros([h.shape[1]]), None],
        sequences=[h],
        non_sequences=[W],
        mode=mode,
    )

    f = function([W, h], o, mode=mode)

    scan_nodes = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)]
    assert len(scan_nodes) == 1
    scan_op = scan_nodes[0].op
    assert not any(isinstance(n.op, Dot) for n in scan_op.fn.maker.fgraph.apply_nodes)
Beispiel #4
0
def test_local_sampling_dot_csr():
    mode = get_default_mode()
    mode = mode.including("specialize", "local_sampling_dot_csr")

    for sp_format in ["csr"]:  # Not implemented for other format
        inputs = [
            matrix(),
            matrix(),
            getattr(aesara.sparse, sp_format + "_matrix")(),
        ]

        f = aesara.function(inputs, sparse.sampling_dot(*inputs), mode=mode)

        if aesara.config.blas__ldflags:
            assert not any(
                isinstance(node.op, sparse.SamplingDot)
                for node in f.maker.fgraph.toposort()
            )
        else:
            # SamplingDotCSR's C implementation needs blas, so it should not
            # be inserted
            assert not any(
                isinstance(node.op, sparse.opt.SamplingDotCSR)
                for node in f.maker.fgraph.toposort()
            )
Beispiel #5
0
    def test_constant_output(self):
        # Test that if the output is a constant, we respect the aesara memory interface
        f = function([], aet.constant([4]))
        # print f.maker.fgraph.toposort()
        out = f()
        assert (out == 4).all()
        out[0] = 3
        out2 = f()
        # If the following 2 asserts fail it mean Aesara broke it's memory contract.
        assert out2 is not out
        assert (out2 == 4).all()

        # Test that if the output is a constant and borrow, we respect the aesara memory interface
        f = function([], Out(aet.constant([4]), borrow=True))
        # print f.maker.fgraph.toposort()
        out = f()
        assert (out == 4).all()
        out[0] = 3
        out2 = f()

        if isinstance(get_default_mode(), DebugMode):
            # In DebugMode, we don't implement optimization based on borrow on the output.
            assert (out2 == 4).all()
        else:
            assert out2 is out
            assert (out2 == 3).all()
Beispiel #6
0
def test_local_mul_s_v():
    mode = get_default_mode()
    mode = mode.including("specialize", "local_mul_s_v")

    for sp_format in ["csr"]:  # Not implemented for other format
        inputs = [getattr(aesara.sparse, sp_format + "_matrix")(), vector()]

        f = aesara.function(inputs, sparse.mul_s_v(*inputs), mode=mode)

        assert not any(
            isinstance(node.op, sparse.MulSV) for node in f.maker.fgraph.toposort()
        )
Beispiel #7
0
def test_local_dense_from_sparse_sparse_from_dense():
    mode = get_default_mode()
    mode = mode.including("local_dense_from_sparse_sparse_from_dense")

    m = matrix()
    for op in [aesara.sparse.csr_from_dense, aesara.sparse.csc_from_dense]:
        s = op(m)
        o = aesara.sparse.dense_from_sparse(s)
        f = aesara.function([m], o, mode=mode)
        # We should just have a deep copy.
        assert len(f.maker.fgraph.apply_nodes) == 1
        f([[1, 2], [3, 4]])
Beispiel #8
0
def test_local_mul_s_d():
    mode = get_default_mode()
    mode = mode.including("specialize", "local_mul_s_d")

    for sp_format in sparse.sparse_formats:
        inputs = [getattr(aesara.sparse, sp_format + "_matrix")(), matrix()]

        f = aesara.function(inputs, sparse.mul_s_d(*inputs), mode=mode)

        assert not any(
            isinstance(node.op, sparse.MulSD) for node in f.maker.fgraph.toposort()
        )
Beispiel #9
0
def test_local_csm_properties_csm():
    data = vector()
    indices, indptr, shape = (ivector(), ivector(), ivector())
    mode = get_default_mode()
    mode = mode.including("specialize", "local_csm_properties_csm")
    for CS, cast in [
        (sparse.CSC, sp.sparse.csc_matrix),
        (sparse.CSR, sp.sparse.csr_matrix),
    ]:
        f = aesara.function(
            [data, indices, indptr, shape],
            sparse.csm_properties(CS(data, indices, indptr, shape)),
            mode=mode,
        )
        assert not any(
            isinstance(node.op, (sparse.CSM, sparse.CSMProperties))
            for node in f.maker.fgraph.toposort())
        v = cast(random_lil((10, 40), config.floatX, 3))
        f(v.data, v.indices, v.indptr, v.shape)
Beispiel #10
0
def inplace_func(
    inputs,
    outputs,
    mode=None,
    allow_input_downcast=False,
    on_unused_input="raise",
    name=None,
):
    if mode is None:
        mode = get_default_mode()
    return function(
        inputs,
        outputs,
        mode=mode,
        allow_input_downcast=allow_input_downcast,
        accept_inplace=True,
        on_unused_input=on_unused_input,
        name=name,
    )
Beispiel #11
0
    def get_mode(self, excluding=None):
        """
        Return appropriate mode for the tests.

        :param excluding: List of optimizations to exclude.

        :return: The current default mode unless the `config.mode` option is
        set to 'FAST_COMPILE' (in which case it is replaced by the 'FAST_RUN'
        mode), without the optimizations specified in `excluding`.
        """
        if excluding is None:
            excluding = []
        m = config.mode
        if m == "FAST_COMPILE":
            mode = get_mode("FAST_RUN")
        else:
            mode = get_default_mode()
        if excluding:
            return mode.excluding(*excluding)
        else:
            return mode
Beispiel #12
0
def test_local_csm_grad_c():
    data = vector()
    indices, indptr, shape = (ivector(), ivector(), ivector())
    mode = get_default_mode()

    if aesara.config.mode == "FAST_COMPILE":
        mode = Mode(linker="c|py", optimizer="fast_compile")

    mode = mode.including("specialize", "local_csm_grad_c")
    for CS, cast in [
        (sparse.CSC, sp.sparse.csc_matrix),
        (sparse.CSR, sp.sparse.csr_matrix),
    ]:
        cost = aet_sum(sparse.DenseFromSparse()(CS(data, indices, indptr, shape)))
        f = aesara.function(
            [data, indices, indptr, shape], aesara.grad(cost, data), mode=mode
        )
        assert not any(
            isinstance(node.op, sparse.CSMGrad) for node in f.maker.fgraph.toposort()
        )
        v = cast(random_lil((10, 40), config.floatX, 3))
        f(v.data, v.indices, v.indptr, v.shape)
Beispiel #13
0
class TestSaveMem:
    mode = get_default_mode().including("scan_save_mem", "save_mem_new_scan")

    def test_save_mem(self):
        rng = np.random.default_rng(utt.fetch_seed())

        vW_in2 = asarrayX(rng.uniform(-0.5, 0.5, size=(2,)))
        vW = asarrayX(rng.uniform(-0.5, 0.5, size=(2, 2)))
        vWout = asarrayX(rng.uniform(-0.5, 0.5, size=(2,)))
        vW_in1 = asarrayX(rng.uniform(-0.5, 0.5, size=(2, 2)))
        v_u1 = asarrayX(rng.uniform(-0.5, 0.5, size=(8, 2)))
        v_u2 = asarrayX(rng.uniform(-0.5, 0.5, size=(8,)))
        v_x0 = asarrayX(rng.uniform(-0.5, 0.5, size=(2,)))
        v_y0 = asarrayX(rng.uniform(size=(3,)))

        W_in2 = shared(vW_in2, name="win2")
        W = shared(vW, name="w")
        W_out = shared(vWout, name="wout")
        W_in1 = matrix("win")
        u1 = matrix("u1")
        u2 = vector("u2")
        x0 = vector("x0")
        y0 = vector("y0")

        def f_rnn_cmpl(u1_t, u2_t, x_tm1, y_tm1, y_tm3, W_in1):
            return [
                y_tm3 + 1,
                dot(u1_t, W_in1) + u2_t * W_in2 + dot(x_tm1, W),
                y_tm1 + dot(x_tm1, W_out),
            ]

        _outputs, updates = scan(
            f_rnn_cmpl,
            [u1, u2],
            [None, dict(initial=x0), dict(initial=y0, taps=[-1, -3])],
            W_in1,
            n_steps=None,
            truncate_gradient=-1,
            go_backwards=False,
        )
        outputs = [_outputs[0][-1], _outputs[1][-1], _outputs[2][-1]]
        f4 = function(
            [u1, u2, x0, y0, W_in1],
            outputs,
            updates=updates,
            allow_input_downcast=True,
            mode=self.mode,
        )

        # compute the values in numpy
        v_x = np.zeros((8, 2), dtype=config.floatX)
        v_y = np.zeros((8,), dtype=config.floatX)
        v_x[0] = np.dot(v_u1[0], vW_in1) + v_u2[0] * vW_in2 + np.dot(v_x0, vW)
        v_y[0] = np.dot(v_x0, vWout) + v_y0[2]

        for i in range(1, 8):
            v_x[i] = np.dot(v_u1[i], vW_in1) + v_u2[i] * vW_in2 + np.dot(v_x[i - 1], vW)
            v_y[i] = np.dot(v_x[i - 1], vWout) + v_y[i - 1]

        (aesara_dump, aesara_x, aesara_y) = f4(v_u1, v_u2, v_x0, v_y0, vW_in1)

        utt.assert_allclose(aesara_x, v_x[-1:])
        utt.assert_allclose(aesara_y, v_y[-1:])

    def test_save_mem_reduced_number_of_steps(self):
        def f_rnn(u_t):
            return (
                u_t + 1.0,
                u_t + 2.0,
                u_t + 3.0,
                u_t + 4.0,
                u_t + 5.0,
                u_t + 6.0,
                u_t + 7.0,
            )

        u = vector("u")
        idx = iscalar("idx")
        jdx = iscalar("jdx")
        [x1, x2, x3, x4, x5, x6, x7], updates = scan(
            f_rnn, u, n_steps=None, truncate_gradient=-1, go_backwards=False
        )

        f2 = function(
            [u, idx, jdx],
            [x1[:2], x2[4], x3[idx], x4[:idx], x5[-10], x6[-jdx], x7[:-jdx]],
            updates=updates,
            allow_input_downcast=True,
            mode=self.mode,
        )
        # get random initial values
        rng = np.random.default_rng(utt.fetch_seed())
        v_u = rng.uniform(-5.0, 5.0, size=(20,))

        # compute the output in numpy
        tx1, tx2, tx3, tx4, tx5, tx6, tx7 = f2(v_u, 3, 15)

        utt.assert_allclose(tx1, v_u[:2] + 1.0)
        utt.assert_allclose(tx2, v_u[4] + 2.0)
        utt.assert_allclose(tx3, v_u[3] + 3.0)
        utt.assert_allclose(tx4, v_u[:3] + 4.0)
        utt.assert_allclose(tx5, v_u[-10] + 5.0)
        utt.assert_allclose(tx6, v_u[-15] + 6.0)
        utt.assert_allclose(tx7, v_u[:-15] + 7.0)

    def test_save_mem_store_steps(self):
        def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
            return (
                u_t + 1.0,
                u_t + 2.0,
                u_t + 3.0,
                u_t + 4.0,
                u_t + 5.0,
                u_t + 6.0,
                u_t + 7.0,
            )

        u = vector("u")
        x10 = vector("x10")
        x20 = scalar("x20")
        x30 = vector("x30")
        x40 = scalar("x40")
        [x1, x2, x3, x4, x5, x6, x7], updates = scan(
            f_rnn,
            u,
            [
                None,
                None,
                None,
                dict(initial=x10, taps=[-1, -2]),
                x20,
                dict(initial=x30, taps=[-1, -2]),
                x40,
            ],
            n_steps=None,
            truncate_gradient=-1,
            go_backwards=False,
        )

        f2 = function(
            [u, x10, x20, x30, x40],
            [x1[-7], x2[-3:-1], x3[-6:], x4[-1], x5[-1]],
            updates=updates,
            allow_input_downcast=True,
            mode=self.mode,
        )

        # get random initial values
        rng = np.random.default_rng(utt.fetch_seed())
        v_u = rng.uniform(-5.0, 5.0, size=(20,))

        # compute the output in numpy
        tx1, tx2, tx3, tx4, tx5 = f2(v_u, [0, 0], 0, [0, 0], 0)

        utt.assert_allclose(tx1, v_u[-7] + 1.0)
        utt.assert_allclose(tx2, v_u[-3:-1] + 2.0)
        utt.assert_allclose(tx3, v_u[-6:] + 3.0)
        utt.assert_allclose(tx4, v_u[-1] + 4.0)
        utt.assert_allclose(tx5, v_u[-1] + 5.0)

    def test_savemem_does_not_duplicate_number_of_scan_nodes(self):
        var = at.ones(())
        values, _ = scan(
            lambda x: ([x], (), until(x)),
            outputs_info=[var],
            n_steps=2,
        )

        tmp_fn = function([var], values, mode=self.mode)
        scan_nodes = [
            x for x in tmp_fn.maker.fgraph.toposort() if isinstance(x.op, Scan)
        ]
        assert len(scan_nodes) == 1

    def test_savemem_opt(self):
        y0 = shared(np.ones((2, 10)))
        [y1, y2], updates = scan(
            lambda y: [y, y],
            outputs_info=[dict(initial=y0, taps=[-2]), None],
            n_steps=5,
        )
        # TODO FIXME: Make this a real test and assert something.
        function([], y2.sum(), mode=self.mode)()

    def test_savemem_opt_0_step(self):
        """
        Test a case where the savemem optimization has the opportunity to
        lower the number of steps of a Scan to 0. It tests that the
        optimization doesn't do so since Scan nodes with 0
        steps are not currently supported and doing so would result in a
        crash during the function execution.
        """

        def inner_scan_step(x_t_t, h_tm1, w):
            return dot(h_tm1, w) + x_t_t

        def outer_scan_step(x_t, w):
            h, _ = scan(
                inner_scan_step,
                sequences=[x_t[1:]],
                outputs_info=[x_t[0]],
                non_sequences=[w],
                strict=True,
                name="the_inner_scan",
            )
            return h

        def get_outputs(x, w):
            features, _ = scan(
                outer_scan_step,
                sequences=[x],
                non_sequences=[w],
                strict=True,
                name="the_outer_scan",
            )

            return_val = grad(features.sum(), w)
            return return_val

        # Compile the aesara function
        x = tensor3("x")
        w = matrix("w")
        f = function(inputs=[x, w], outputs=get_outputs(x, w), mode=self.mode)

        # Test the function to ensure it returns valid results
        x_value = (
            np.random.default_rng(utt.fetch_seed())
            .random((2, 2, 3))
            .astype(config.floatX)
        )
        w_value = (
            np.random.default_rng(utt.fetch_seed()).random((3, 3)).astype(config.floatX)
        )
        expected_output = np.tile(x_value[:, 0].sum(0), (3, 1)).transpose()

        output = f(x_value, w_value)
        utt.assert_allclose(output, expected_output)

    @pytest.mark.skip(
        reason="The 'assertion' of this test relied on something that no longer exists "
    )
    def test_subtensor_multiple_slices(self):
        r"""
        This addresses a bug that happens when you have multiple subtensors
        on the output of `Scan`.  The bug requires the reshape to be produced,
        and it has something to do with how the `Subtensor`\s overlap.
        """

        def f_pow2(x_tm1):
            return 2 * x_tm1

        state = vector("state")
        n_steps = iscalar("nsteps")
        output, updates = scan(
            f_pow2,
            [],
            state,
            [],
            n_steps=n_steps,
            truncate_gradient=-1,
            go_backwards=False,
        )
        nw_shape = ivector("nw_shape")
        # Note that the output is reshaped to 3 dimensional tensor, and
        my_f = function(
            [state, n_steps, nw_shape],
            [reshape(output, nw_shape, ndim=3)[:-2], output[:-4]],
            updates=updates,
            allow_input_downcast=True,
        )
        nodes = [x for x in my_f.maker.fgraph.toposort() if isinstance(x.op, Scan)]
        # This assertion fails if savemem optimization failed on scan
        if config.mode != "FAST_COMPILE":
            assert nodes[0].op._scan_savemem_visited
        rng = np.random.default_rng(utt.fetch_seed())
        my_f(rng.uniform(size=(3,)), 4, np.int64([2, 2, 3]))
Beispiel #14
0
    _good_broadcast_unary_normal_float_no_complex_small_neg_range,
    _good_broadcast_unary_normal_no_complex,
    _grad_broadcast_unary_0_2_no_complex,
    _grad_broadcast_unary_abs1_no_complex,
    _grad_broadcast_unary_normal,
    _grad_broadcast_unary_normal_small_neg_range,
    check_floatX,
    copymod,
    makeBroadcastTester,
    rand_ranged,
    randint_ranged,
    upcast_int8_nfunc,
)

imported_scipy_special = False
mode_no_scipy = get_default_mode()
try:
    import scipy.special
    import scipy.stats

    imported_scipy_special = True
except ImportError:
    if config.mode == "FAST_COMPILE":
        mode_no_scipy = "FAST_RUN"


def scipy_special_gammau(k, x):
    return scipy.special.gammaincc(k, x) * scipy.special.gamma(k)


def scipy_special_gammal(k, x):
Beispiel #15
0
class TestPushOutDot:
    mode = get_default_mode().including("scan")

    def test_pushout_all(self):
        W1 = matrix("W1")
        W2 = matrix("W2")
        h0 = vector("h0")

        def lambda_fn(h, W1, W2):
            return dot(h, W1 + W2)

        o, _ = scan(lambda_fn, non_sequences=[h0, W1, W2], n_steps=5)

        f = function([h0, W1, W2], o, mode=self.mode)

        scan_nodes = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)]
        assert len(scan_nodes) == 0

        seed = utt.fetch_seed()
        rng = np.random.default_rng(seed)
        floatX = config.floatX
        v_h = np.array(rng.uniform(size=(2,)), dtype=floatX)
        v_W1 = np.array(rng.uniform(size=(2, 2)), dtype=floatX)
        v_W2 = np.array(rng.uniform(size=(2, 2)), dtype=floatX)

        v_out = np.dot(v_h, v_W1 + v_W2)
        sol = np.zeros((5, 2))
        # This line is here to make sol have the same shape as the output of
        # aesara. Note that what we ask aesara to do is to repeat the 2
        # elements vector v_out 5 times
        sol[:, :] = v_out
        utt.assert_allclose(sol, f(v_h, v_W1, v_W2))

    def test_pushout_while(self):
        """
        Ensure that the optimizations for Scan that push computation out of
        the Scan don't alter the result for 'as_while' scans.
        """

        W1 = matrix("W1")
        W2 = matrix("W2")
        step_indices = vector("step_indices")

        def lambda_fn(step_idx, W1, W2):
            until_condition = until(step_idx > 2)
            return dot(W1, W2), until_condition

        # Compile a function with the optimization
        o, _ = scan(
            lambda_fn, sequences=[step_indices, W1], non_sequences=[W2], n_steps=5
        )

        f = function([W1, W2, step_indices], o, mode=self.mode)

        # Compule an aesara function without the optimization
        o, _ = scan(
            lambda_fn,
            sequences=[step_indices, W1],
            non_sequences=[W2],
            n_steps=5,
            mode="FAST_COMPILE",
        )

        f_ref = function([W1, W2, step_indices], o, mode=self.mode)

        # Compare the results of the two implementations
        input_values = [
            np.random.default_rng(utt.fetch_seed()).random((5, 5)).astype("float32"),
            np.random.default_rng(utt.fetch_seed()).random((5, 5)).astype("float32"),
            np.arange(5).astype("float32"),
        ]

        out = f(*input_values)
        out_ref = f_ref(*input_values)
        utt.assert_allclose(out, out_ref)

    def test_pushout(self):
        W1 = matrix("W1")
        W2 = matrix("W2")
        h0 = vector("h0")

        def lambda_fn(h, W1, W2):
            return dot(h, W1 + W2)

        o, _ = scan(lambda_fn, outputs_info=h0, non_sequences=[W1, W2], n_steps=5)

        f = function([h0, W1, W2], o, mode=self.mode)

        scan_node = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)][0]
        assert (
            len(
                [
                    x
                    for x in scan_node.op.fn.maker.fgraph.toposort()
                    if isinstance(x.op, Elemwise)
                ]
            )
            == 0
        )

    def test_pushout_nomodif(self):
        inp = matrix("inp")

        def fn(i, i_tm1):
            return i + 10, i_tm1

        ([i_t, i_tm1], _) = scan(
            fn,
            sequences=[inp],
            outputs_info=[np.asarray([0.0, 0.0], config.floatX), None],
        )
        f = function([inp], [i_t, i_tm1])
        val = np.arange(10).reshape(5, 2).astype(config.floatX)
        ret = f(val)
        utt.assert_allclose(ret[0], val + 10)
        utt.assert_allclose(
            ret[1], [[0.0, 0.0], [10.0, 11.0], [12.0, 13.0], [14.0, 15.0], [16.0, 17.0]]
        )
Beispiel #16
0
class TestScanInplaceOptimizer:
    mode = get_default_mode().including("scan_make_inplace", "inplace")

    @utt.assertFailure_fast
    def test_simple_rnn(self):
        """Simple RNN; compute inplace version 1."""
        rng = np.random.default_rng(utt.fetch_seed())
        vW = asarrayX(np.random.uniform())
        vW_in = asarrayX(np.random.uniform())
        vu0 = asarrayX(rng.uniform(-5.0, 5.0, size=(3,)))
        vu1 = asarrayX(rng.uniform(-5.0, 5.0, size=(3,)))
        vu2 = asarrayX(rng.uniform(-5.0, 5.0, size=(3,)))
        vx0 = asarrayX(rng.uniform())
        vx1 = asarrayX(rng.uniform())

        u0 = vector("u0")
        u1 = vector("u1")
        u2 = vector("u2")
        mu0 = In(u0, mutable=False)
        mu1 = In(u1, mutable=True)
        mu2 = In(u2, mutable=True)
        x0 = scalar("x0")
        x1 = scalar("y0")
        W_in = shared(vW_in, "Win")
        W = shared(vW, "W")

        def f_rnn_shared(u0_t, u1_t, u2_t, x0_tm1, x1_tm1):
            return [
                u0_t * W_in + x0_tm1 * W + u1_t * u2_t,
                u0_t * W_in + x1_tm1 * W + u1_t + u2_t,
            ]

        outputs, updates = scan(
            f_rnn_shared,
            [u0, u1, u2],
            [dict(initial=x0, inplace=u2), dict(initial=x1, inplace=u1)],
            [],
            n_steps=None,
            truncate_gradient=-1,
            go_backwards=False,
            mode=self.mode,
        )

        f9 = function(
            [mu0, mu1, mu2, x0, x1],
            outputs,
            updates=updates,
            mode=self.mode,
            allow_input_downcast=True,
        )
        scan_node = [x for x in f9.maker.fgraph.toposort() if isinstance(x.op, Scan)]
        assert 0 in scan_node[0].op.destroy_map.keys()
        assert 1 in scan_node[0].op.destroy_map.keys()
        # compute output in numpy
        numpy_x0 = np.zeros((3,))
        numpy_x1 = np.zeros((3,))
        numpy_x0[0] = vu0[0] * vW_in + vx0 * vW + vu1[0] * vu2[0]
        numpy_x1[0] = vu0[0] * vW_in + vx1 * vW + vu1[0] + vu2[0]
        for i in range(1, 3):
            numpy_x0[i] = vu0[i] * vW_in + numpy_x0[i - 1] * vW + vu1[i] * vu2[i]
            numpy_x1[i] = vu0[i] * vW_in + numpy_x1[i - 1] * vW + vu1[i] + vu2[i]

        # note aesara computes inplace, so call function after numpy
        # equivalent is done
        (aesara_x0, aesara_x1) = f9(vu0, vu1, vu2, vx0, vx1)
        # assert that aesara does what it should
        utt.assert_allclose(aesara_x0, numpy_x0)
        utt.assert_allclose(aesara_x1, numpy_x1)

    @utt.assertFailure_fast
    def test_simple_rnn_2(self):
        """Simple RNN; compute inplace version 2."""
        rng = np.random.default_rng(utt.fetch_seed())
        vW = asarrayX(np.random.uniform())
        vW_in = asarrayX(np.random.uniform())
        vu0 = asarrayX(rng.uniform(-5.0, 5.0, size=(3,)))
        vu1 = asarrayX(rng.uniform(-5.0, 5.0, size=(4,)))
        vu2 = asarrayX(rng.uniform(-5.0, 5.0, size=(5,)))
        vx0 = asarrayX(rng.uniform())
        vx1 = asarrayX(rng.uniform())

        u0 = vector("u0")
        u1 = vector("u1")
        u2 = vector("u2")
        mu0 = In(u0, mutable=True)
        mu1 = In(u1, mutable=True)
        mu2 = In(u2, mutable=True)
        x0 = scalar("x0")
        x1 = scalar("y0")
        W_in = shared(vW_in, "Win")
        W = shared(vW, "W")

        def f_rnn_shared(u0_t, u1_t, u1_tp1, u2_tm1, u2_t, u2_tp1, x0_tm1, x1_tm1):
            return [
                u0_t * W_in + x0_tm1 * W + u1_t * u1_tp1,
                u0_t * W_in + x1_tm1 * W + u2_tm1 + u2_t + u2_tp1,
            ]

        outputs, updates = scan(
            f_rnn_shared,
            [u0, dict(input=u1, taps=[0, 1]), dict(input=u2, taps=[-1, 0, +1])],
            [dict(initial=x0), dict(initial=x1)],
            [],
            n_steps=None,
            truncate_gradient=-1,
            go_backwards=False,
            mode=self.mode,
        )
        f9 = function(
            [mu0, mu1, mu2, x0, x1],
            outputs,
            updates=updates,
            mode=self.mode,
            allow_input_downcast=True,
        )

        scan_node = [x for x in f9.maker.fgraph.toposort() if isinstance(x.op, Scan)]
        assert 0 in scan_node[0].op.destroy_map.keys()
        assert 1 in scan_node[0].op.destroy_map.keys()
        # compute output in numpy
        numpy_x0 = np.zeros((3,))
        numpy_x1 = np.zeros((3,))
        numpy_x0[0] = vu0[0] * vW_in + vx0 * vW + vu1[0] * vu1[1]
        numpy_x1[0] = vu0[0] * vW_in + vx1 * vW + vu2[0] + vu2[1] + vu2[2]
        for i in range(1, 3):
            numpy_x0[i] = vu0[i] * vW_in + numpy_x0[i - 1] * vW + vu1[i] * vu1[i + 1]
            numpy_x1[i] = (
                vu0[i] * vW_in + numpy_x1[i - 1] * vW + vu2[i] + vu2[i + 1] + vu2[i + 2]
            )

        # note aesara computes inplace, so call function after numpy
        # equivalent is done
        (aesara_x0, aesara_x1) = f9(vu0, vu1, vu2, vx0, vx1)
        # assert that aesara does what it should
        utt.assert_allclose(aesara_x0, numpy_x0)
        utt.assert_allclose(aesara_x1, numpy_x1)

    @utt.assertFailure_fast
    def test_inplace3(self):
        rng = np.random.default_rng(utt.fetch_seed())

        vx0 = asarrayX(rng.uniform())
        vx1 = asarrayX(rng.uniform())
        x0 = shared(vx0)
        x1 = shared(vx1)
        outputs, updates = scan(
            lambda x, y: (x + asarrayX(1), y + asarrayX(1)), [], [x0, x1], n_steps=3
        )
        x0 = asarrayX(np.zeros((3,)))
        x0[0] = vx0
        x0 = at.constant(x0)

        to_replace = outputs[0].owner.inputs[0].owner.inputs[1]
        outputs = clone_replace(outputs, replace=[(to_replace, x0)])

        f9 = function([], outputs, updates=updates, mode=self.mode)
        scan_node = [x for x in f9.maker.fgraph.toposort() if isinstance(x.op, Scan)]
        assert 0 not in scan_node[0].op.destroy_map.keys()
        assert 1 in scan_node[0].op.destroy_map.keys()
Beispiel #17
0
class TestScanMerge:
    mode = get_default_mode().including("scan")

    def test_basic(self):
        x = vector()
        y = vector()

        def sum(s):
            return s + 1

        sx, upx = scan(sum, sequences=[x])
        sy, upy = scan(sum, sequences=[y])

        f = function(
            [x, y], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops")
        )
        topo = f.maker.fgraph.toposort()
        scans = [n for n in topo if isinstance(n.op, Scan)]
        assert len(scans) == 2

        sx, upx = scan(sum, sequences=[x], n_steps=2)
        sy, upy = scan(sum, sequences=[y], n_steps=3)

        f = function(
            [x, y], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops")
        )
        topo = f.maker.fgraph.toposort()
        scans = [n for n in topo if isinstance(n.op, Scan)]
        assert len(scans) == 2

        sx, upx = scan(sum, sequences=[x], n_steps=4)
        sy, upy = scan(sum, sequences=[y], n_steps=4)

        f = function(
            [x, y], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops")
        )
        topo = f.maker.fgraph.toposort()
        scans = [n for n in topo if isinstance(n.op, Scan)]
        assert len(scans) == 1

        sx, upx = scan(sum, sequences=[x])
        sy, upy = scan(sum, sequences=[x])

        f = function([x], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops"))
        topo = f.maker.fgraph.toposort()
        scans = [n for n in topo if isinstance(n.op, Scan)]
        assert len(scans) == 1

        sx, upx = scan(sum, sequences=[x])
        sy, upy = scan(sum, sequences=[x], mode="FAST_COMPILE")

        f = function([x], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops"))
        topo = f.maker.fgraph.toposort()
        scans = [n for n in topo if isinstance(n.op, Scan)]
        assert len(scans) == 1

        sx, upx = scan(sum, sequences=[x])
        sy, upy = scan(sum, sequences=[x], truncate_gradient=1)

        f = function([x], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops"))
        topo = f.maker.fgraph.toposort()
        scans = [n for n in topo if isinstance(n.op, Scan)]
        assert len(scans) == 2

    def test_three_scans(self):
        r"""
        This test checks a case where we have three `Scan`\s, two of them
        cannot be merged together, but the third one can be merged with
        either.
        """
        x = vector()
        y = vector()

        def sum(s):
            return s + 1

        sx, upx = scan(sum, sequences=[x], n_steps=4, name="X")
        # We need to use an expression of y rather than y so the toposort
        # comes up with the 'Y' scan last.
        sy, upy = scan(sum, sequences=[2 * y + 2], n_steps=4, name="Y")
        sz, upz = scan(sum, sequences=[sx], n_steps=4, name="Z")

        f = function(
            [x, y], [sy, sz], mode=self.mode.excluding("scan_pushout_seqs_ops")
        )
        topo = f.maker.fgraph.toposort()
        scans = [n for n in topo if isinstance(n.op, Scan)]
        assert len(scans) == 2

        rng = np.random.default_rng(utt.fetch_seed())
        x_val = rng.uniform(size=(4,)).astype(config.floatX)
        y_val = rng.uniform(size=(4,)).astype(config.floatX)
        # Run it so DebugMode can detect optimization problems.
        f(x_val, y_val)

    def test_belongs_to_set(self):
        """
        Test the method belongs_to of this class. Specifically see if it
        detects the two `Scan` nodes as not being similar.
        """
        inps = vector()
        state = scalar()
        y1, _ = scan(lambda x, y: x * y, sequences=inps, outputs_info=state, n_steps=5)

        y2, _ = scan(
            lambda x, y: (x + y, until(x > 0)),
            sequences=inps,
            outputs_info=state,
            n_steps=5,
        )
        scan_node1 = y1.owner.inputs[0].owner
        assert isinstance(scan_node1.op, Scan)
        scan_node2 = y2.owner.inputs[0].owner
        assert isinstance(scan_node2.op, Scan)
        opt_obj = ScanMerge()
        assert not opt_obj.belongs_to_set(scan_node1, [scan_node2])
        assert not opt_obj.belongs_to_set(scan_node2, [scan_node1])
Beispiel #18
0
class TestRemoveConstantsAndUnusedInputsScan:
    mode = get_default_mode().including("scan")

    def test_remove_constants_and_unused_inputs_scan_non_seqs(self):
        """Test the rewrite `remove_constants_and_unused_inputs_scan` for non-sequences."""
        W = matrix(name="W")
        v = ivector(name="v")
        y1, _ = scan(
            lambda i, W: W[i], sequences=v, outputs_info=None, non_sequences=[W]
        )
        y2, _ = scan(
            lambda i, _, W: W[i],
            sequences=v,
            outputs_info=None,
            non_sequences=[W[0], W],
        )
        y3, _ = scan(
            lambda i, W, _: W[i],
            sequences=v,
            outputs_info=None,
            non_sequences=[W, W[0]],
        )
        y4, _ = scan(
            lambda i, _, _2, W: W[i],
            sequences=v,
            outputs_info=None,
            non_sequences=[W[0], W[0], W],
        )
        y5, _ = scan(
            lambda i, _, W, _2: W[i],
            sequences=v,
            outputs_info=None,
            non_sequences=[W[0], W, W[0]],
        )
        y6, _ = scan(
            lambda i, W, _, _2: W[i],
            sequences=v,
            outputs_info=None,
            non_sequences=[W, W[0], W[0]],
        )
        # TODO: y7 have problem during run time. I think it should
        # raise an error during the scan construction.
        # y7, _ = scan(lambda i, W, _, _2: W[i], sequences=v,
        #                    outputs_info=None, non_sequences=[v, W[0], W])

        W_val = np.random.normal(size=(3, 3)).astype(config.floatX)
        exp_val = W_val[np.r_[1, 2]]

        for out in [y1, y2, y3, y4, y5, y6]:
            f = function([W, v], out, mode=self.mode)

            res = f(W_val, [1, 2])
            assert np.array_equal(res, exp_val)

            scan_nodes = scan_nodes_from_fct(f)
            assert len(scan_nodes) == 1

            scan_node = scan_nodes[0]
            assert len(scan_node.inputs[1:]) == len(set(scan_node.inputs[1:]))
            inp = scan_node.op.inner_non_seqs(scan_node.op.inner_inputs)
            assert len(inp) == 1
            assert len(inp) == len(set(inp))

            inp = scan_node.op.outer_non_seqs(scan_node.inputs)
            assert len(inp) == 1
            assert len(inp) == len(set(inp))

    def test_remove_constants_and_unused_inputs_scan_seqs(self):
        """Test the opt remove_constants_and_unused_inputs_scan for sequences."""
        W = matrix(name="W")
        v = ivector(name="v")
        vv = matrix(name="vv")
        y1, _ = scan(
            lambda i, W: W[i], sequences=v, outputs_info=None, non_sequences=[W]
        )
        y2, _ = scan(
            lambda i, _, W: W[i], sequences=[v, v], outputs_info=None, non_sequences=W
        )
        y3, _ = scan(
            lambda i, _, W: W[i],
            sequences=[v, vv[0]],
            outputs_info=None,
            non_sequences=W,
        )
        y4, _ = scan(
            lambda _, i, W: W[i],
            sequences=[vv[0], v],
            outputs_info=None,
            non_sequences=W,
        )
        y5, _ = scan(
            lambda _, i, _2, W: W[i],
            sequences=[vv, v, vv[0]],
            outputs_info=None,
            non_sequences=W,
        )
        y6, _ = scan(
            lambda _, _2, i, W: W[i],
            sequences=[vv[0], vv, v],
            outputs_info=None,
            non_sequences=W,
        )
        y7, _ = scan(
            lambda i, _, _2, W: W[i],
            sequences=[v, vv[0], vv[0]],
            outputs_info=None,
            non_sequences=W,
        )
        y8, _ = scan(
            lambda _, i, W, _2, _3: W[i],
            sequences=[vv[0], v],
            outputs_info=None,
            non_sequences=[W, W[0], W[0]],
        )

        W_val = np.random.normal(size=(3, 3)).astype(config.floatX)
        exp_val = W_val[np.r_[1, 2]]

        for out in [y1, y2, y3, y4, y5, y6, y7, y8]:
            f = function(
                [W, v, vv],
                out,
                on_unused_input="ignore",
                mode=self.mode,
            )

            res = f(W_val, [1, 2], W_val)
            assert np.array_equal(res, exp_val)

            scan_nodes = scan_nodes_from_fct(f)
            assert len(scan_nodes) == 1
            scan_node = scan_nodes[0]

            assert len(scan_node.inputs[1:]) == len(set(scan_node.inputs[1:]))
            inp = scan_node.op.inner_seqs(scan_node.op.inner_inputs)
            assert len(inp) == 1
            inp = scan_node.op.outer_seqs(scan_node.inputs)
            assert len(inp) == 1
            inp = scan_node.op.inner_non_seqs(scan_node.op.inner_inputs)
            assert len(inp) == 1
            inp = scan_node.op.outer_non_seqs(scan_node.inputs)
            assert len(inp) == 1