Пример #1
0
def test_badoptimization_opt_err():
    # This variant of test_badoptimization() replace the working code
    # with a new apply node that will raise an error.
    @local_optimizer([add])
    def insert_bigger_b_add(fgraph, node):
        if node.op == add:
            inputs = list(node.inputs)
            if inputs[-1].owner is None:
                inputs[-1] = aet.concatenate((inputs[-1], inputs[-1]))
                return [node.op(*inputs)]
        return False

    @local_optimizer([add])
    def insert_bad_dtype(fgraph, node):
        if node.op == add:
            inputs = list(node.inputs)
            if inputs[-1].owner is None:

                return [node.outputs[0].astype("float32")]
        return False

    edb = EquilibriumDB()
    edb.register("insert_bigger_b_add", insert_bigger_b_add, "all")
    opt = edb.query("+all")
    edb2 = EquilibriumDB()
    edb2.register("insert_bad_dtype", insert_bad_dtype, "all")
    opt2 = edb2.query("+all")

    a = dvector()
    b = dvector()

    f = aesara.function([a, b], a + b, mode=DebugMode(optimizer=opt))
    with pytest.raises(ValueError, match=r"insert_bigger_b_add"):
        f(
            [1.0, 2.0, 3.0],
            [2, 3, 4],
        )

    # Test that opt that do an illegal change still get the error from graph.
    with pytest.raises(TypeError) as einfo:
        with config.change_flags(on_opt_error="raise"):
            f2 = aesara.function(
                [a, b],
                a + b,
                mode=DebugMode(optimizer=opt2, stability_patience=1),
            )
        f2(
            [1.0, 2.0, 3.0],
            [2, 3, 4],
        )

    # Test that we can reraise the error with an extended message
    with pytest.raises(TypeError):
        e = einfo.value
        new_e = e.__class__("TTT" + str(e))
        exc_type, exc_value, exc_trace = sys.exc_info()
        exc_value = new_e
        raise exc_value.with_traceback(exc_trace)
Пример #2
0
def test_stochasticoptimization():

    # this optimization alternates between triggering and not triggering.

    last_time_replaced = [False]

    @local_optimizer([add])
    def insert_broken_add_sometimes(fgraph, node):
        if node.op == add:
            last_time_replaced[0] = not last_time_replaced[0]
            if last_time_replaced[0]:
                return [off_by_half(*node.inputs)]
        return False

    edb = EquilibriumDB()
    edb.register("insert_broken_add_sometimes", insert_broken_add_sometimes,
                 "all")
    opt = edb.query("+all")

    a = dvector()
    b = dvector()

    with pytest.raises(StochasticOrder):
        aesara.function(
            [a, b],
            add(a, b),
            mode=DebugMode(
                optimizer=opt,
                check_c_code=True,
                stability_patience=max(2, config.DebugMode__patience),
            ),
        )
Пример #3
0
def test_badoptimization():
    @local_optimizer([add])
    def insert_broken_add(fgraph, node):
        if node.op == add:
            return [off_by_half(*node.inputs)]
        return False

    edb = EquilibriumDB()
    edb.register("insert_broken_add", insert_broken_add, "all")
    opt = edb.query("+all")

    a = dvector()
    b = dvector()

    f = aesara.function([a, b], a + b, mode=DebugMode(optimizer=opt))

    with pytest.raises(BadOptimization) as einfo:
        f(
            [1.0, 2.0, 3.0],
            [2, 3, 4],
        )
    assert str(einfo.value.reason) == "insert_broken_add"