def define_nat_double(prelude): """Defines a function that doubles a nat. Adds a field called 'double' to the prelude, giving the GlobalVar pointing to the function. """ prelude.double = GlobalVar("double") x = Var("x", prelude.nat()) y = Var("y") z_case = Clause(PatternConstructor(prelude.z), prelude.z()) s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]), prelude.s(prelude.s(prelude.double(y)))) prelude.mod[prelude.double] = Function([x], Match(x, [z_case, s_case]))
def define_nat_add(prelude): """Defines a function that adds two nats and adds a field to the prelude 'add' giving the GlobalVar pointing to that function. """ prelude.add = GlobalVar("add") x = Var("x", prelude.nat()) y = Var("y", prelude.nat()) a = Var("a") z_case = Clause(PatternConstructor(prelude.z), y) s_case = Clause(PatternConstructor(prelude.s, [PatternVar(a)]), prelude.s(prelude.add(a, y))) prelude.mod[prelude.add] = Function([x, y], Match(x, [z_case, s_case]))
def define_nat_nth(prelude): """Defines a function to get the nth eleemnt of a list using a nat to index into the list. nat_nth(l, n): fun<a>(list[a], nat) -> a """ prelude.nat_nth = GlobalVar("nat_nth") a = TypeVar("a") x = Var("x", prelude.l(a)) n = Var("n", prelude.nat()) y = Var("y") z_case = Clause(PatternConstructor(prelude.z), prelude.hd(x)) s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]), prelude.nat_nth(prelude.tl(x), y)) prelude.mod[prelude.nat_nth] = Function([x, n], Match(n, [z_case, s_case]), a, [a])
def define_nat_iterate(prelude): """Defines a function that takes a number n and a function f; returns a closure that takes an argument and applies f n times to its argument. Signature: fn<a>(fn(a) -> a, nat) -> fn(a) -> a """ prelude.nat_iterate = GlobalVar("nat_iterate") a = TypeVar("a") f = Var("f", FuncType([a], a)) x = Var("x", prelude.nat()) y = Var("y", prelude.nat()) z_case = Clause(PatternConstructor(prelude.z), prelude.id) s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]), prelude.compose(f, prelude.nat_iterate(f, y))) prelude.mod[prelude.nat_iterate] = Function([f, x], Match(x, [z_case, s_case]), FuncType([a], a), [a])
def define_nat_update(prelude): """Defines a function to update the nth element of a list and return the updated list. nat_update(l, i, v) : fun<a>(list[a], nat, a) -> list[a] """ prelude.nat_update = GlobalVar("nat_update") a = TypeVar("a") # pylint: disable=invalid-name l = Var("l", prelude.l(a)) n = Var("n", prelude.nat()) v = Var("v", a) y = Var("y") z_case = Clause(PatternConstructor(prelude.z), prelude.cons(v, prelude.tl(l))) s_case = Clause( PatternConstructor(prelude.s, [PatternVar(y)]), prelude.cons(prelude.hd(l), prelude.nat_update(prelude.tl(l), y, v))) prelude.mod[prelude.nat_update] = Function([l, n, v], Match(n, [z_case, s_case]), prelude.l(a), [a])