def test_infer_shape(self): adscal = at.dscalar() bdscal = at.dscalar() adscal_val = np.random.random() bdscal_val = np.random.random() + 1 out = assert_op(adscal, bdscal) self._compile_and_check([adscal, bdscal], [out], [adscal_val, bdscal_val], Assert) admat = at.dmatrix() admat_val = np.random.random((3, 4)) adscal_val += 1 out = assert_op(admat, adscal, bdscal) self._compile_and_check([admat, adscal, bdscal], [out], [admat_val, adscal_val, bdscal_val], Assert)
def test_CheckAndRaise_validation(): with pytest.raises(ValueError): CheckAndRaise(str) g1 = assert_op(np.array(1.0)) assert isinstance(g1.owner.inputs[0], Constant)
def test_one_assert_merge(self): # Merge two nodes, one has assert, the other not. x1 = matrix("x1") x2 = matrix("x2") e = dot(x1, x2) + dot(assert_op(x1, (x1 > x2).all()), x2) g = FunctionGraph([x1, x2], [e], clone=False) MergeOptimizer().optimize(g) assert g.outputs[0].owner.op == add add_inputs = g.outputs[0].owner.inputs assert isinstance(add_inputs[0].owner.op, Dot) # Confirm that the `Assert`s are correct assert_var = add_inputs[0].owner.inputs[0] assert_ref = assert_op(x1, (x1 > x2).all()) assert equal_computations([assert_var], [assert_ref]) # Confirm the merge assert add_inputs[0] is add_inputs[1]
def test_both_assert_merge_identical(self): """Merge two nodes, both have `Assert`s on the same node with the same conditions.""" x1 = matrix("x1") x2 = matrix("x2") e = dot(assert_op(x1, (x1 > x2).all()), x2) + dot( assert_op(x1, (x1 > x2).all()), x2) g = FunctionGraph([x1, x2], [e], clone=False) MergeOptimizer().optimize(g) assert g.outputs[0].owner.op == add add_inputs = g.outputs[0].owner.inputs assert isinstance(add_inputs[0].owner.op, Dot) # Confirm that the `Assert`s are correct assert_var = add_inputs[0].owner.inputs[0] assert_ref = assert_op(x1, (x1 > x2).all()) assert equal_computations([assert_var], [assert_ref]) # Confirm the merge assert add_inputs[0] is add_inputs[1]
def test_jax_checkandraise(): p = scalar() p.tag.test_value = 0 res = assert_op(p, p < 1.0) res_fg = FunctionGraph([p], [res]) with pytest.raises(NotImplementedError): compare_jax_and_py(res_fg, [1.0])
def test_both_assert_merge_2_reverse(self): # Test case "test_both_assert_merge_2" but in reverse order x1 = matrix("x1") x2 = matrix("x2") x3 = matrix("x3") e = dot(x1, assert_op(x2, (x2 > x3).all())) + dot( assert_op(x1, (x1 > x3).all()), x2) g = FunctionGraph([x1, x2, x3], [e], clone=False) MergeOptimizer().optimize(g) assert g.outputs[0].owner.op == add add_inputs = g.outputs[0].owner.inputs assert isinstance(add_inputs[0].owner.op, Dot) # Confirm that the `Assert`s are correct assert_var_1, assert_var_2 = add_inputs[0].owner.inputs assert_ref_1 = assert_op(x2, (x2 > x3).all()) assert equal_computations([assert_var_1], [assert_ref_1]) assert_ref_2 = assert_op(x1, (x1 > x3).all()) assert equal_computations([assert_var_2], [assert_ref_2]) # Confirm the merge assert add_inputs[0] is add_inputs[1]
def test_CheckAndRaise_equal(): x, y = at.vectors("xy") g1 = assert_op(x, (x > y).all()) g2 = assert_op(x, (x > y).all()) assert equal_computations([g1], [g2])