示例#1
0
    def test_infer_shape(self):

        adscal = at.dscalar()
        bdscal = at.dscalar()
        adscal_val = np.random.random()
        bdscal_val = np.random.random() + 1
        out = assert_op(adscal, bdscal)
        self._compile_and_check([adscal, bdscal], [out],
                                [adscal_val, bdscal_val], Assert)

        admat = at.dmatrix()
        admat_val = np.random.random((3, 4))
        adscal_val += 1
        out = assert_op(admat, adscal, bdscal)
        self._compile_and_check([admat, adscal, bdscal], [out],
                                [admat_val, adscal_val, bdscal_val], Assert)
示例#2
0
def test_CheckAndRaise_validation():

    with pytest.raises(ValueError):
        CheckAndRaise(str)

    g1 = assert_op(np.array(1.0))
    assert isinstance(g1.owner.inputs[0], Constant)
示例#3
0
    def test_one_assert_merge(self):
        # Merge two nodes, one has assert, the other not.
        x1 = matrix("x1")
        x2 = matrix("x2")
        e = dot(x1, x2) + dot(assert_op(x1, (x1 > x2).all()), x2)
        g = FunctionGraph([x1, x2], [e], clone=False)
        MergeOptimizer().optimize(g)

        assert g.outputs[0].owner.op == add
        add_inputs = g.outputs[0].owner.inputs
        assert isinstance(add_inputs[0].owner.op, Dot)
        # Confirm that the `Assert`s are correct
        assert_var = add_inputs[0].owner.inputs[0]
        assert_ref = assert_op(x1, (x1 > x2).all())
        assert equal_computations([assert_var], [assert_ref])
        # Confirm the merge
        assert add_inputs[0] is add_inputs[1]
示例#4
0
    def test_both_assert_merge_identical(self):
        """Merge two nodes, both have `Assert`s on the same node with the same conditions."""
        x1 = matrix("x1")
        x2 = matrix("x2")
        e = dot(assert_op(x1, (x1 > x2).all()), x2) + dot(
            assert_op(x1, (x1 > x2).all()), x2)
        g = FunctionGraph([x1, x2], [e], clone=False)
        MergeOptimizer().optimize(g)

        assert g.outputs[0].owner.op == add
        add_inputs = g.outputs[0].owner.inputs
        assert isinstance(add_inputs[0].owner.op, Dot)
        # Confirm that the `Assert`s are correct
        assert_var = add_inputs[0].owner.inputs[0]
        assert_ref = assert_op(x1, (x1 > x2).all())
        assert equal_computations([assert_var], [assert_ref])
        # Confirm the merge
        assert add_inputs[0] is add_inputs[1]
示例#5
0
文件: test_jax.py 项目: mgorny/aesara
def test_jax_checkandraise():
    p = scalar()
    p.tag.test_value = 0

    res = assert_op(p, p < 1.0)
    res_fg = FunctionGraph([p], [res])

    with pytest.raises(NotImplementedError):
        compare_jax_and_py(res_fg, [1.0])
示例#6
0
    def test_both_assert_merge_2_reverse(self):
        # Test case "test_both_assert_merge_2" but in reverse order
        x1 = matrix("x1")
        x2 = matrix("x2")
        x3 = matrix("x3")
        e = dot(x1, assert_op(x2, (x2 > x3).all())) + dot(
            assert_op(x1, (x1 > x3).all()), x2)
        g = FunctionGraph([x1, x2, x3], [e], clone=False)
        MergeOptimizer().optimize(g)

        assert g.outputs[0].owner.op == add
        add_inputs = g.outputs[0].owner.inputs
        assert isinstance(add_inputs[0].owner.op, Dot)
        # Confirm that the `Assert`s are correct
        assert_var_1, assert_var_2 = add_inputs[0].owner.inputs
        assert_ref_1 = assert_op(x2, (x2 > x3).all())
        assert equal_computations([assert_var_1], [assert_ref_1])
        assert_ref_2 = assert_op(x1, (x1 > x3).all())
        assert equal_computations([assert_var_2], [assert_ref_2])
        # Confirm the merge
        assert add_inputs[0] is add_inputs[1]
示例#7
0
def test_CheckAndRaise_equal():
    x, y = at.vectors("xy")
    g1 = assert_op(x, (x > y).all())
    g2 = assert_op(x, (x > y).all())

    assert equal_computations([g1], [g2])