示例#1
0
def ast_test_all_raise(error, test, *invalid):
    for each in invalid:
        try:
            with assert_raises(error):
                bind(each, test)
        except:
            print(each)
            raise
示例#2
0
def ast_test_all(test, mf_result, mf_expected, *term_pairs):
    for term, expected in term_pairs:
        result = bind(term, mf_result)
        expected = bind(expected, mf_expected)
        try:
            test(result, expected)
        except:
            print(term)
            print(result)
            print(expected)
            print()
            raise
示例#3
0
def test_monad_laws():
    "Test if the basic monadic functions conform to the three Monad Laws."

    from hornet.expressions import unit, bind, lift

    x = ast.Name(id='x', ctx=load)
    y = ast.Name(id='y', ctx=load)
    z = ast.Name(id='z', ctx=load)
    mx = unit(x)

    binop = lambda u, op, v: unit(ast.BinOp(left=u, op=op(), right=v))
    and_y = lambda u: binop(u, ast.BitAnd, y)
    or_z = lambda u: binop(u, ast.BitOr, z)
    y_and = lambda v: binop(y, ast.BitAnd, v)
    z_or = lambda v: binop(z, ast.BitOr, v)
    mfuncs = [unit, lift(identity), and_y, or_z, y_and, z_or]

    # left identity:
    for mf in mfuncs:
        ast_eq(bind(mx, mf), mf(x))

    # right identity:
    ast_eq(bind(mx, unit), mx)

    # associativity:
    for mf, mg in itertools.product(mfuncs, repeat=2):
        ast_eq(
            bind(bind(mx, mf), mg),
            bind(mx, lambda v: bind(mf(v), mg))
        )