def ast_test_all_raise(error, test, *invalid): for each in invalid: try: with assert_raises(error): bind(each, test) except: print(each) raise
def ast_test_all(test, mf_result, mf_expected, *term_pairs): for term, expected in term_pairs: result = bind(term, mf_result) expected = bind(expected, mf_expected) try: test(result, expected) except: print(term) print(result) print(expected) print() raise
def test_monad_laws(): "Test if the basic monadic functions conform to the three Monad Laws." from hornet.expressions import unit, bind, lift x = ast.Name(id='x', ctx=load) y = ast.Name(id='y', ctx=load) z = ast.Name(id='z', ctx=load) mx = unit(x) binop = lambda u, op, v: unit(ast.BinOp(left=u, op=op(), right=v)) and_y = lambda u: binop(u, ast.BitAnd, y) or_z = lambda u: binop(u, ast.BitOr, z) y_and = lambda v: binop(y, ast.BitAnd, v) z_or = lambda v: binop(z, ast.BitOr, v) mfuncs = [unit, lift(identity), and_y, or_z, y_and, z_or] # left identity: for mf in mfuncs: ast_eq(bind(mx, mf), mf(x)) # right identity: ast_eq(bind(mx, unit), mx) # associativity: for mf, mg in itertools.product(mfuncs, repeat=2): ast_eq( bind(bind(mx, mf), mg), bind(mx, lambda v: bind(mf(v), mg)) )