def test_monad_laws(): "Test if the basic monadic functions conform to the three Monad Laws." from hornet.expressions import unit, bind, lift x = ast.Name(id='x', ctx=load) y = ast.Name(id='y', ctx=load) z = ast.Name(id='z', ctx=load) mx = unit(x) binop = lambda u, op, v: unit(ast.BinOp(left=u, op=op(), right=v)) and_y = lambda u: binop(u, ast.BitAnd, y) or_z = lambda u: binop(u, ast.BitOr, z) y_and = lambda v: binop(y, ast.BitAnd, v) z_or = lambda v: binop(z, ast.BitOr, v) mfuncs = [unit, lift(identity), and_y, or_z, y_and, z_or] # left identity: for mf in mfuncs: ast_eq(bind(mx, mf), mf(x)) # right identity: ast_eq(bind(mx, unit), mx) # associativity: for mf, mg in itertools.product(mfuncs, repeat=2): ast_eq( bind(bind(mx, mf), mg), bind(mx, lambda v: bind(mf(v), mg)) )
def expand_call(node): if is_name(node): return collect_functor(unit(node)()) elif is_call(node): return collect_functor(unit(copy.deepcopy(node))) else: raise TypeError('Name or Call node expected, not {}'.format(node))
def test_expression_operators(): "Test all Expression factory functions that are called as operators." from hornet.expressions import unit, Num from hornet.symbols import x, y x_name = x.node y_name = y.node name = 'joe' num = 123 items = [Num(1), Num(2), Num(3)] pairs = ( [x[y], ast.Subscript(value=x_name, slice=ast.Index(y_name), ctx=load)], [x(1, 2, 3), ast.Call( func=x_name, args=items, keywords=[], starargs=None, kwargs=None, )], ) for expr, node in pairs: ast_eq(expr, unit(node)) pairs = ( [-x, ast.USub], [+x, ast.UAdd], [~x, ast.Invert], ) for expr, op in pairs: ast_eq(expr, unit(ast.UnaryOp(op(), x_name))) pairs = ( [x ** y, ast.Pow], [x * y, ast.Mult], [x / y, ast.Div], [x // y, ast.FloorDiv], [x % y, ast.Mod], [x + y, ast.Add], [x - y, ast.Sub], [x << y, ast.LShift], [x >> y, ast.RShift], [x & y, ast.BitAnd], [x ^ y, ast.BitXor], [x | y, ast.BitOr], ) for expr, op in pairs: ast_eq(expr, unit(ast.BinOp(x_name, op(), y_name)))
def test_expression_factories(): "Test all Expression factory functions that are called directly." from hornet.expressions import ( unit, Name, Str, Bytes, Num, Tuple, List, Set, Wrapper, AstWrapper ) class Callable: def __call__(self): pass obj = object() name = 'joe' num = 123 keys = [Str('a'), Str('b'), Str('c')] values = [Num(1), Num(2), Num(3)] pairs = ( [Name(name), ast.Name(id=name, ctx=load)], [Str(name), ast.Str(s=name)], [Bytes(name), ast.Bytes(s=name)], [Num(num), ast.Num(n=num)], [Tuple(keys), ast.Tuple(elts=keys, ctx=load)], [List(keys), ast.List(elts=keys, ctx=load)], [Set(keys), ast.Set(elts=keys)], [Wrapper(obj), AstWrapper(wrapped=obj)], ) for expr, node in pairs: ast_eq(expr, unit(node))
def expand_pushbacks(node): if not node.elts: return None elif all(is_terminal(each) for each in node.elts): elts = [collect_pushback(_C_(unit(each))) for each in node.elts] return foldr(conjunction, elts) else: raise TypeError( 'Non-terminal in DCG pushback list found: {}'.format(node))
def expand_terminals(node, cont): if not node.elts: return cont(None) elif all(is_terminal(each) for each in node.elts): *elts, last = ( collect_terminal(_C_(unit(each))) for each in node.elts) return foldr(conjunction, elts, cont(last)) else: raise TypeError( 'Non-terminal in DCG terrminal list found: {}'.format(node))
def expand_body(node, cont): if is_bitand(node): def right_side(rightmost_of_left_side): return conjunction( rightmost_of_left_side, expand_body(node.right, cont)) return expand_body(node.left, right_side) elif is_list(node): return expand_terminals(node, cont) elif is_set(node): assert len(node.elts) == 1 return cont(unit(node.elts[0])) else: return cont(expand_call(node))
def expand(root): if not is_rshift(root): return unit(root) def expand_call(node): if is_name(node): return collect_functor(unit(node)()) elif is_call(node): return collect_functor(unit(copy.deepcopy(node))) else: raise TypeError('Name or Call node expected, not {}'.format(node)) def expand_terminals(node, cont): if not node.elts: return cont(None) elif all(is_terminal(each) for each in node.elts): *elts, last = ( collect_terminal(_C_(unit(each))) for each in node.elts) return foldr(conjunction, elts, cont(last)) else: raise TypeError( 'Non-terminal in DCG terrminal list found: {}'.format(node)) def expand_pushbacks(node): if not node.elts: return None elif all(is_terminal(each) for each in node.elts): elts = [collect_pushback(_C_(unit(each))) for each in node.elts] return foldr(conjunction, elts) else: raise TypeError( 'Non-terminal in DCG pushback list found: {}'.format(node)) def expand_body(node, cont): if is_bitand(node): def right_side(rightmost_of_left_side): return conjunction( rightmost_of_left_side, expand_body(node.right, cont)) return expand_body(node.left, right_side) elif is_list(node): return expand_terminals(node, cont) elif is_set(node): assert len(node.elts) == 1 return cont(unit(node.elts[0])) else: return cont(expand_call(node)) def expand_clause(node): head = node.left body = node.right if is_bitand(head): def pushback(rightmost_of_body): return conjunction( rightmost_of_body, expand_pushbacks(head.right)) return rule( expand_call(head.left), expand_body(body, pushback)) else: return rule( expand_call(head), expand_body(body, identity)) from_left = [] from_right = collections.deque() def collect_functor(call): args = call.node.args from_left.append(args.append) from_left.append(args.append) return call def collect_terminal(call): args = call.node.args from_left.append(functools.partial(args.insert, -1)) from_left.append(args.append) return call def collect_pushback(call): args = call.node.args from_right.appendleft(functools.partial(args.insert, -2)) from_right.appendleft(args.append) return call clause = expand_clause(root) pairs = splitpairs(rotate(itertools.chain(from_left, from_right))) for (set_left, set_right), var in zip(pairs, numbered_vars('_')): set_left(var) set_right(var) return clause