Exemple #1
0
 def testEvalSem(self):
     com = Seq(Assign(zero, Lambda(s, one)), Assign(one, Lambda(s, Nat(2))))
     st = mk_const_fun(NatType, zero)
     st2 = fun_upd_of_seq(0, 1, 1, 2)
     goal = Sem(com, st, st2)
     prf = imp.eval_Sem_macro().get_proof_term(goal, []).export()
     self.assertEqual(theory.check_proof(prf), Thm([], goal))
Exemple #2
0
    def testBetaNorm(self):
        x = Var('x', Ta)
        y = Var('y', Ta)
        test_data = [
            (Lambda(x, x)(x), x),
            (Lambda(x, Lambda(y, y))(x, y), y),
            (Lambda(x, Lambda(y, x))(x, y), x),
        ]

        for t, res in test_data:
            self.assertEqual(t.beta_norm(), res)
Exemple #3
0
    def testComputeWP(self):
        Q = Var("Q", TFun(natFunT, BoolType))

        test_data = [
            (Assign(zero, Lambda(s, one)),
             Lambda(s, Q(mk_fun_upd(s, zero, one)))),
            (Seq(Assign(zero, Lambda(s, one)), Assign(one, Lambda(s, Nat(2)))),
             Lambda(s, Q(mk_fun_upd(s, zero, one, one, Nat(2))))),
        ]

        for c, P in test_data:
            prf = imp.compute_wp(natFunT, c, Q).export()
            self.assertEqual(theory.check_proof(prf), Thm([], Valid(P, c, Q)))
Exemple #4
0
    def testVCG(self):
        P = Var("P", TFun(natFunT, BoolType))
        Q = Var("Q", TFun(natFunT, BoolType))

        test_data = [
            Assign(zero, Lambda(s, one)),
            Seq(Assign(zero, Lambda(s, one)), Assign(one, Lambda(s, Nat(2)))),
        ]

        for c in test_data:
            goal = Valid(P, c, Q)
            prf = imp.vcg(natFunT, goal).export()
            self.assertEqual(theory.check_proof(prf).concl, goal)

            prf = imp.vcg_tactic().get_proof_term(Thm([], goal), None, []).export()
            self.assertEqual(theory.check_proof(prf).prop, goal)
Exemple #5
0
def mk_collect(x, body):
    """Given term x and a term P possibly depending on x, return
    the term {x. P}.

    """
    assert x.is_var(), "mk_collect"
    return collect(x.T)(Lambda(x, body))
Exemple #6
0
    def testPrintUnicode(self):
        test_data = [
            (And(A, B), "A ∧ B"),
            (Or(A, B), "A ∨ B"),
            (Implies(A, B), "A ⟶ B"),
            (Lambda(a, P(a)), "λa. P a"),
            (Forall(a, P(a)), "∀a. P a"),
            (Exists(a, P(a)), "∃a. P a"),
            (Not(A), "¬A"),
            (Lambda(m, m + 2), "λm::nat. m + 2"),
            (Lambda(m, m + n), "λm. m + n"),
        ]

        with global_setting(unicode=True):
            for t, s in test_data:
                self.assertEqual(printer.print_term(t), s)
Exemple #7
0
def mk_exists1(x, body):
    """Given a variable x and a term P possibly depending on x, return
    the term ?!x. P.

    """
    assert x.is_var(), "mk_exists1"
    exists1_t = Const("exists1", TFun(TFun(x.T, BoolType), BoolType))
    return exists1_t(Lambda(x, body))
Exemple #8
0
 def testEvalSem5(self):
     com = While(Lambda(s, Not(Eq(s(zero), Nat(3)))), assn_true, incr_one)
     st = mk_const_fun(NatType, zero)
     st2 = fun_upd_of_seq(0, 3)
     goal = Sem(com, st, st2)
     prf = imp.eval_Sem_macro().get_proof_term(goal, []).export()
     rpt = ProofReport()
     self.assertEqual(theory.check_proof(prf, rpt), Thm([], goal))
Exemple #9
0
def mk_some(x, body):
    """Given a variable x and a term P possibly depending on x, return
    the term SOME x. P.

    """
    assert x.is_var(), "mk_some"
    some_t = Const("Some", TFun(TFun(x.T, BoolType), x.T))
    return some_t(Lambda(x, body))
Exemple #10
0
    def abstraction(x, th):
        """Derivation rule ABSTRACTION:

        A |- t1 = t2
        ------------------------
        A |- (%x. t1) = (%x. t2)  where x does not occur in A.
        """
        if any(hyp.occurs_var(x) for hyp in th.hyps):
            raise InvalidDerivationException("abstraction")
        elif th.is_equals():
            t1, t2 = th.prop.args
            try:
                t1_new, t2_new = Lambda(x, t1), Lambda(x, t2)
            except term.TermException:
                raise InvalidDerivationException("abstraction")
            return Thm(th.hyps, Eq(t1_new, t2_new))
        else:
            raise InvalidDerivationException("abstraction")
Exemple #11
0
def process_file(input, output):
    basic.load_theory('hoare')

    dn = os.path.dirname(os.path.realpath(__file__))
    with open(os.path.join(dn, 'examples/' + input + '.json'),
              encoding='utf-8') as a:
        data = json.load(a)

    output = json_output.JSONTheory(output, ["hoare"],
                                    "Generated from " + input)
    content = data['content']
    eval_count = 0
    vcg_count = 0
    for run in content[:5]:
        if run['ty'] == 'eval':
            com = parse_com(run['com'])
            st1 = mk_const_fun(NatType, nat.zero)
            for k, v in sorted(run['init'].items()):
                st1 = mk_fun_upd(st1, Nat(str_to_nat(k)), Nat(v))
            st2 = mk_const_fun(NatType, nat.zero)
            for k, v in sorted(run['final'].items()):
                st2 = mk_fun_upd(st2, Nat(str_to_nat(k)), Nat(v))
            Sem = imp.Sem(natFunT)
            goal = Sem(com, st1, st2)
            prf = ProofTerm("eval_Sem", goal, []).export()
            rpt = ProofReport()
            th = theory.check_proof(prf, rpt)
            output.add_theorem("eval" + str(eval_count), th, prf)
            eval_count += 1
        elif run['ty'] == 'vcg':
            com = parse_com(run['com'])
            pre = Lambda(st, parse_cond(run['pre']))
            post = Lambda(st, parse_cond(run['post']))
            Valid = imp.Valid(natFunT)
            goal = Valid(pre, com, post)
            prf = imp.vcg_solve(goal).export()
            rpt = ProofReport()
            th = theory.check_proof(prf, rpt)
            output.add_theorem("vcg" + str(vcg_count), th, prf)
            vcg_count += 1
        else:
            raise TypeError

    output.export_json()
Exemple #12
0
 def get_proof_term(self, goal, *, args=None, prevs=None):
     th_name, var = args
     P = Lambda(var, goal.prop)
     th = theory.get_theorem(th_name)
     f, args = th.concl.strip_comb()
     if len(args) != 1:
         raise NotImplementedError
     inst = matcher.first_order_match(args[0], var)
     inst[f.name] = P
     return rule().get_proof_term(goal, args=(th_name, inst))
Exemple #13
0
    def testEvalSem4(self):
        com = Cond(Lambda(s, Not(Eq(s(zero), one))), incr_one, Skip)
        st = mk_const_fun(NatType, zero)
        st2 = fun_upd_of_seq(0, 1)
        goal = Sem(com, st, st2)
        prf = imp.eval_Sem_macro().get_proof_term(goal, []).export()
        self.assertEqual(theory.check_proof(prf), Thm([], goal))

        goal = Sem(com, st2, st2)
        prf = imp.eval_Sem_macro().get_proof_term(goal, []).export()
        self.assertEqual(theory.check_proof(prf), Thm([], goal))
Exemple #14
0
 def testRule4(self):
     n = Var("n", NatType)
     self.run_test('nat',
                   tactic.rule(),
                   vars={"n": "nat"},
                   goal="n + 0 = n",
                   args=("nat_induct", Inst(P=Lambda(n, Eq(n + 0, n)),
                                            x=n)),
                   new_goals=[
                       "(0::nat) + 0 = 0",
                       "!n. n + 0 = n --> Suc n + 0 = Suc n"
                   ])
Exemple #15
0
    def testInferPrintedType(self):
        t = Const("nil", ListType(Ta))
        infer_printed_type(t)
        self.assertTrue(hasattr(t, "print_type"))

        t = cons(Ta)(Var("a", Ta))
        infer_printed_type(t)
        self.assertFalse(hasattr(t.fun, "print_type"))

        t = Eq(Const("nil", ListType(Ta)), Const("nil", ListType(Ta)))
        infer_printed_type(t)
        self.assertFalse(hasattr(t.fun.fun, "print_type"))
        self.assertTrue(hasattr(t.arg1, "print_type"))
        self.assertFalse(hasattr(t.arg, "print_type"))

        t = Eq(mk_append(nil(Ta),nil(Ta)), nil(Ta))
        infer_printed_type(t)
        self.assertTrue(hasattr(t.arg1.arg1, "print_type"))
        self.assertFalse(hasattr(t.arg1.arg, "print_type"))
        self.assertFalse(hasattr(t.arg, "print_type"))

        t = Lambda(Var("x", Ta), Eq(Var("x", Ta), Var("x", Ta)))
        infer_printed_type(t)
Exemple #16
0
def compute_wp(T, c, Q):
    """Compute the weakest precondition for the given command
    and postcondition. Here c is the program and Q is the postcondition.
    The computation is by case analysis on the form of c. The function
    returns a proof term showing [...] |- Valid P c Q, where P is the
    computed precondition, and [...] contains the additional subgoals.

    """
    if c.is_const("Skip"):  # Skip
        return apply_theorem("skip_rule", concl=Valid(T)(Q, c, Q))
    elif c.is_comb("Assign", 2):  # Assign a b
        a, b = c.args
        s = Var("s", T)
        P2 = Lambda(s, Q(function.mk_fun_upd(s, a, b(s).beta_conv())))
        return apply_theorem("assign_rule",
                             inst=Inst(b=b),
                             concl=Valid(T)(P2, c, Q))
    elif c.is_comb("Seq", 2):  # Seq c1 c2
        c1, c2 = c.args
        wp1 = compute_wp(T, c2, Q)  # Valid Q' c2 Q
        wp2 = compute_wp(T, c1, wp1.prop.args[0])  # Valid Q'' c1 Q'
        return apply_theorem("seq_rule", wp2, wp1)
    elif c.is_comb("Cond", 3):  # Cond b c1 c2
        b, c1, c2 = c.args
        wp1 = compute_wp(T, c1, Q)
        wp2 = compute_wp(T, c2, Q)
        res = apply_theorem("if_rule", wp1, wp2, inst=Inst(b=b))
        return res
    elif c.is_comb("While", 3):  # While b I c
        _, I, _ = c.args
        pt = apply_theorem("while_rule", concl=Valid(T)(I, c, Q))
        pt0 = ProofTerm.assume(pt.assums[0])
        pt1 = vcg(T, pt.assums[1])
        return pt.implies_elim(pt0, pt1)
    else:
        raise NotImplementedError
Exemple #17
0
    def run(cls, g):
        gf = g.formula()
        env = g.env()

        # TODO apply induction to a prod, instead of a env
        if isinstance(gf, Prod) and isinstance(gf.arg_type, Ind):
            goal_type = gf.body.type(
                ContextEnvironment(Binding(gf.arg_name, type=gf.arg_type),
                                   env))

            assert isinstance(goal_type, Sort), "goal formula is not a sort."

            thm_prefix = None
            if goal_type.is_type():
                thm_prefix = '_rect'
            elif goal_type.is_prop():
                thm_prefix = '_ind'
            elif goal_type.is_set():
                thm_prefix = '_rec'
            else:
                assert False, "unknown sort : %s" % goal_type

            mutind_name = gf.arg_type.mutind_name
            ind_index = gf.arg_type.ind_index

            # when an inductive is not the only one of a mut-inductive, we can only obtain
            # the full name of the mut-inductive, i.e. Coq.Init.Datatypes.nat
            # however, all the induction theorems are declared with the full name of the
            # corresponding inductive, in other words, we have to generate them ourselves
            ind = env.mutind(mutind_name).inds[ind_index]
            ind_full_segs = mutind_name.split('.')[:-1] + [ind.name]
            ind_full_name = '.'.join(ind_full_segs)
            thm_to_apply = Const(ind_full_name + thm_prefix)

            # type of an induction theorem is usually:
            #
            #     Prop with a hole -> Proof on Constructor 0 -> Proof on Constructor 1 -> ... -> Prop to Prove
            #
            # which is we need to find.
            p_hole = Lambda(gf.arg_name, gf.arg_type, gf.body)

            subgoals = []
            for c in ind.constructors:
                # form of a subgoal is usually:
                #
                #     - P c (e.g. c = O)
                #     - forall x : S1, P (c x) (e.g. c : S1 -> S2)
                #     - ...
                ctyp = c.term().type(env)

                inner = Apply(c.term())
                subgoal = Apply(p_hole, inner)

                # we keep this variable in case there is no argument given to inner
                # then we need to replace the `Apply` term in the original subgoal
                # to its func (inner.func)
                original_subgoal = subgoal

                depth = 0
                while isinstance(ctyp, Prod):
                    if ctyp.arg_type != gf.arg_type:
                        inner.args.append(Rel(depth))
                        subgoal = Prod(ctyp.arg_name, ctyp.arg_type, subgoal)
                        depth += 1
                    else:
                        inner.args.append(Rel(depth + 1))
                        subgoal = Prod(
                            ctyp.arg_name if ctyp.arg_name is not None else
                            'x', ctyp.arg_type,
                            Prod(None, Apply(p_hole, Rel(0)), subgoal))
                        depth += 2

                    ctyp = ctyp.body

                if inner.args == []:
                    original_subgoal.args[0] = inner.func

                subgoals.append(Goal(subgoal, env))

            proof = Proof(
                Apply(thm_to_apply, p_hole,
                      *(map(lambda g: g.formula(), subgoals))), *subgoals)

            print(proof.proof_formula)
            return proof
        else:
            raise cls.TacticFailure
Exemple #18
0
 def assign_cmd(self, v, e):
     Assign = imp.Assign(NatType, NatType)
     return Assign(Nat(str_to_nat(v)), Lambda(st, e))
Exemple #19
0
 def if_cmd(self, b, c1, c2):
     Cond = imp.Cond(natFunT)
     return Cond(Lambda(st, b), c1, c2)
Exemple #20
0
 def while_cmd(self, b, c):
     While = imp.While(natFunT)
     return While(Lambda(st, b), Lambda(st, true), c)
Exemple #21
0
 def while_cmd_inv(self, b, inv, c):
     While = imp.While(natFunT)
     return While(Lambda(st, b), Lambda(st, inv), c)
Exemple #22
0
from syntax import parser

natFunT = TFun(NatType, NatType)
Sem = imp.Sem(natFunT)
Skip = imp.Skip(natFunT)
Assign = imp.Assign(NatType, NatType)
Seq = imp.Seq(natFunT)
Cond = imp.Cond(natFunT)
While = imp.While(natFunT)
Valid = imp.Valid(natFunT)

zero = nat.zero
one = nat.one

s = Var("s", natFunT)
assn_true = Lambda(s, true)
incr_one = Assign(zero, Lambda(s, nat.plus(s(zero), one)))

def fun_upd_of_seq(*ns):
    return mk_fun_upd(mk_const_fun(NatType, zero), *[Nat(n) for n in ns])

class HoareTest(unittest.TestCase):
    def setUp(self):
        basic.load_theory('hoare')

    def testEvalSem(self):
        com = Seq(Assign(zero, Lambda(s, one)), Assign(one, Lambda(s, Nat(2))))
        st = mk_const_fun(NatType, zero)
        st2 = fun_upd_of_seq(0, 1, 1, 2)
        goal = Sem(com, st, st2)
        prf = imp.eval_Sem_macro().get_proof_term(goal, []).export()
Exemple #23
0
    def testPrintLogical(self):
        test_data = [
            # Variables
            (SVar("P", BoolType), "?P"),
            (a, "a"),

            # Equality and implies
            (Eq(a, b), "a = b"),
            (Implies(A, B), "A --> B"),
            (Implies(A, B, C), "A --> B --> C"),
            (Implies(Implies(A, B), C), "(A --> B) --> C"),
            (Implies(A, Eq(a, b)), "A --> a = b"),
            (Eq(Implies(A, B), Implies(B, C)), "(A --> B) <--> (B --> C)"),
            (Eq(A, Eq(B, C)), "A <--> B <--> C"),
            (Eq(Eq(A, B), C), "(A <--> B) <--> C"),

            # Conjunction and disjunction
            (And(A, B), "A & B"),
            (Or(A, B), "A | B"),
            (And(A, And(B, C)), "A & B & C"),
            (And(And(A, B), C), "(A & B) & C"),
            (Or(A, Or(B, C)), "A | B | C"),
            (Or(Or(A, B), C), "(A | B) | C"),
            (Or(And(A, B), C), "A & B | C"),
            (And(Or(A, B), C), "(A | B) & C"),
            (Or(A, And(B, C)), "A | B & C"),
            (And(A, Or(B, C)), "A & (B | C)"),
            (Or(And(A, B), And(B, C)), "A & B | B & C"),
            (And(Or(A, B), Or(B, C)), "(A | B) & (B | C)"),

            # Negation
            (Not(A), "~A"),
            (Not(Not(A)), "~~A"),

            # Constants
            (true, "true"),
            (false, "false"),

            # Mixed
            (Implies(And(A, B), C), "A & B --> C"),
            (Implies(A, Or(B, C)), "A --> B | C"),
            (And(A, Implies(B, C)), "A & (B --> C)"),
            (Or(Implies(A, B), C), "(A --> B) | C"),
            (Not(And(A, B)), "~(A & B)"),
            (Not(Implies(A, B)), "~(A --> B)"),
            (Not(Eq(A, B)), "~(A <--> B)"),
            (Eq(Not(A), B), "~A <--> B"),
            (Eq(Not(A), Not(B)), "~A <--> ~B"),
            (Implies(A, Eq(B, C)), "A --> B <--> C"),
            (Eq(Implies(A, B), C), "(A --> B) <--> C"),

            # Abstraction
            (Lambda(a, And(P(a), Q(a))), "%a. P a & Q a"),

            # Quantifiers
            (Forall(a, P(a)), "!a. P a"),
            (Forall(a, Forall(b, And(P(a), P(b)))), "!a. !b. P a & P b"),
            (Forall(a, And(P(a), Q(a))), "!a. P a & Q a"),
            (And(Forall(a, P(a)), Q(a)), "(!a1. P a1) & Q a"),
            (Forall(a, Implies(P(a), Q(a))), "!a. P a --> Q a"),
            (Implies(Forall(a, P(a)), Q(a)), "(!a1. P a1) --> Q a"),
            (Implies(Forall(a, P(a)),
                     Forall(a, Q(a))), "(!a. P a) --> (!a. Q a)"),
            (Implies(Exists(a, P(a)),
                     Exists(a, Q(a))), "(?a. P a) --> (?a. Q a)"),
            (Eq(A, Forall(a, P(a))), "A <--> (!a. P a)"),
            (Exists(a, P(a)), "?a. P a"),
            (Exists(a, Forall(b, R(a, b))), "?a. !b. R a b"),
            (logic.mk_exists1(a, P(a)), "?!a. P a"),
            (logic.mk_the(a, P(a)), "THE a. P a"),
            (logic.mk_some(a, P(a)), "SOME a. P a"),
            (Forall(a, Exists(b, R(a, b))), "!a. ?b. R a b"),

            # If
            (mk_if(A, a, b), "if A then a else b"),
            (Eq(mk_if(A, a, b), a), "(if A then a else b) = a"),
            (mk_if(A, P, Q), "if A then P else Q"),
        ]

        with global_setting(unicode=False):
            for t, s in test_data:
                self.assertEqual(printer.print_term(t), s)
Exemple #24
0
def get_nat_power_bounds(pt, n):
    """Given theorem of the form t Mem I, obtain a theorem of
    the form t ^ n Mem J.

    """
    a, b = get_mem_bounds(pt)
    if not n.is_number():
        raise NotImplementedError
    if eval_hol_expr(a) >= 0 and is_mem_closed(pt):
        pt = apply_theorem('nat_power_interval_pos_closed',
                           auto.auto_solve(real_nonneg(a)),
                           pt,
                           inst=Inst(n=n))
    elif eval_hol_expr(a) >= 0 and is_mem_open(pt):
        pt = apply_theorem('nat_power_interval_pos_open',
                           auto.auto_solve(real_nonneg(a)),
                           pt,
                           inst=Inst(n=n))
    elif eval_hol_expr(a) >= 0 and is_mem_lopen(pt):
        pt = apply_theorem('nat_power_interval_pos_lopen',
                           auto.auto_solve(real_nonneg(a)),
                           pt,
                           inst=Inst(n=n))
    elif eval_hol_expr(a) >= 0 and is_mem_ropen(pt):
        pt = apply_theorem('nat_power_interval_pos_ropen',
                           auto.auto_solve(real_nonneg(a)),
                           pt,
                           inst=Inst(n=n))
    elif eval_hol_expr(b) <= 0 and is_mem_closed(pt):
        int_n = n.dest_number()
        if int_n % 2 == 0:
            even_pt = nat_as_even(int_n)
            pt = apply_theorem('nat_power_interval_neg_even_closed',
                               auto.auto_solve(real_nonpos(b)), even_pt, pt)
        else:
            odd_pt = nat_as_odd(int_n)
            pt = apply_theorem('nat_power_interval_neg_odd_closed',
                               auto.auto_solve(real_nonpos(b)), odd_pt, pt)
    elif eval_hol_expr(b) <= 0 and is_mem_open(pt):
        int_n = n.dest_number()
        if int_n % 2 == 0:
            even_pt = nat_as_even(int_n)
            pt = apply_theorem('nat_power_interval_neg_even_open',
                               auto.auto_solve(real_nonpos(b)), even_pt, pt)
        else:
            odd_pt = nat_as_odd(int_n)
            pt = apply_theorem('nat_power_interval_neg_odd_open',
                               auto.auto_solve(real_nonpos(b)), odd_pt, pt)
    elif is_mem_closed(pt):
        # Closed interval containing 0
        t = pt.prop.arg1
        assm1 = hol_set.mk_mem(t, real.closed_interval(a, Real(0)))
        assm2 = hol_set.mk_mem(t, real.closed_interval(Real(0), b))
        pt1 = get_nat_power_bounds(ProofTerm.assume(assm1),
                                   n).implies_intr(assm1)
        pt2 = get_nat_power_bounds(ProofTerm.assume(assm2),
                                   n).implies_intr(assm2)
        x = Var('x', RealType)
        pt = apply_theorem('split_interval_closed',
                           auto.auto_solve(real.less_eq(a, Real(0))),
                           auto.auto_solve(real.less_eq(Real(0), b)),
                           pt1,
                           pt2,
                           pt,
                           inst=Inst(x=t, f=Lambda(x, x**n)))
        subset_pt = interval_union_subset(pt.prop.arg)
        pt = apply_theorem('subsetE', subset_pt, pt)
    elif is_mem_open(pt):
        # Open interval containing 0
        t = pt.prop.arg1
        assm1 = hol_set.mk_mem(t, real.open_interval(a, Real(0)))
        assm2 = hol_set.mk_mem(t, real.ropen_interval(Real(0), b))
        pt1 = get_nat_power_bounds(ProofTerm.assume(assm1),
                                   n).implies_intr(assm1)
        pt2 = get_nat_power_bounds(ProofTerm.assume(assm2),
                                   n).implies_intr(assm2)
        x = Var('x', RealType)
        pt = apply_theorem('split_interval_open',
                           auto.auto_solve(real.less_eq(a, Real(0))),
                           auto.auto_solve(real.less_eq(Real(0), b)),
                           pt1,
                           pt2,
                           pt,
                           inst=Inst(x=t, f=Lambda(x, x**n)))
        subset_pt = interval_union_subset(pt.prop.arg)
        pt = apply_theorem('subsetE', subset_pt, pt)
    else:
        raise NotImplementedError
    return norm_mem_interval(pt)
Exemple #25
0
    def match(pat, t):
        trace.append((pat, t))
        if pat.head.is_svar():
            # Case where the head of the function is a variable.
            if pat.head.name not in inst:
                # If the head variable is not instantiated, check that the
                # arguments are distinct, and each argument is either a
                # bound variable or a matched variable. In addition, all bound
                # variables appearing in t also appear as an argument.
                # If all conditions hold, assign appropriately.

                heuristic_match = False
                # Check each argument is either a bound variable or is a
                # schematic variable that is already matched.
                for v in pat.args:
                    if not (v in bd_vars or (v.is_svar() and v.name in inst)):
                        heuristic_match = True

                # Check arguments of pat are distinct.
                if len(set(pat.args)) != len(pat.args):
                    heuristic_match = True

                # Check t does not contain any extra bound variables.
                t_vars = t.get_vars()
                if any(v in t_vars and v not in pat.args for v in bd_vars):
                    heuristic_match = True

                if heuristic_match:
                    # Heuristic matching: just assign pat.fun to t.fun.
                    if pat.is_svar():
                        # t contains bound variables, so match fails
                        raise MatchException(trace)
                    elif t.is_comb():
                        try:
                            pat.head.T.match_incr(t.fun.get_type(),
                                                  inst.tyinst)
                        except TypeMatchException:
                            raise MatchException(trace)
                        inst[pat.head.name] = t.fun
                        match(pat.arg, t.arg)
                    else:
                        raise MatchException(trace)
                else:
                    # First, obtain and match the expected type of pat_T.
                    Tlist = []
                    for v in pat.args:
                        if v in bd_vars:
                            Tlist.append(v.T)
                        else:
                            Tlist.append(inst[v.name].get_type())
                    Tlist.append(t.get_type())
                    try:
                        pat.head.T.match_incr(TFun(*Tlist), inst.tyinst)
                    except TypeMatchException:
                        raise MatchException(trace)

                    # The instantiation of the head variable is computed by starting
                    # with t, then abstract each of the arguments.
                    inst_t = t
                    for v in reversed(pat.args):
                        if v in bd_vars:
                            if inst_t.is_comb(
                            ) and inst_t.arg == v and v not in inst_t.fun.get_vars(
                            ):
                                op_data = operator.get_info_for_fun(
                                    inst_t.head)
                                if inst_t.is_comb("IF", 3):
                                    inst_t = Lambda(v, inst_t)
                                elif op_data is None:
                                    # inst_t is of the form f x, where x is the argument.
                                    # In this case, directly reduce to f.
                                    inst_t = inst_t.fun
                                elif op_data.arity == operator.BINARY and len(
                                        inst_t.args) == 2:
                                    inst_t = Lambda(v, inst_t)
                                else:
                                    inst_t = inst_t.fun
                            else:
                                # Otherwise, perform the abstraction.
                                inst_t = Lambda(v, inst_t)
                        else:
                            assert v.name in inst
                            inst_v = inst[v.name]
                            if inst_t.is_comb(
                            ) and inst_t.arg == inst_v and not find_term(
                                    inst_t.fun, inst_v):
                                inst_t = inst_t.fun
                            elif inst_v.is_var():
                                inst_t = Lambda(inst_v, inst_t)
                            else:
                                raise MatchException(trace)
                    inst[pat.head.name] = inst_t
            else:
                # If the head variable is already instantiated, apply the
                # instantiation onto the arguments, simplify using beta-conversion,
                # and match again.
                pat2 = inst[pat.head.name](*pat.args).beta_norm()
                match(pat2, t.beta_norm())
        elif pat.is_var() or pat.is_const():
            # The case where pat is a free variable, constant, or comes
            # from a bound variable.
            if pat.ty != t.ty or pat.name != t.name:
                raise MatchException(trace)
            else:
                try:
                    pat.T.match_incr(t.T, inst.tyinst)
                except TypeMatchException:
                    raise MatchException(trace)
        elif pat.is_comb():
            # In the combination case (where the head is not a variable),
            # match fun and arg.
            if pat.ty != t.ty:
                raise MatchException(trace)
            if is_pattern(pat.fun,
                          list(inst.keys()),
                          bd_vars=[v.name for v in bd_vars]):
                match(pat.fun, t.fun)
                match(pat.arg, t.arg)
            else:
                match(pat.arg, t.arg)
                match(pat.fun, t.fun)
        elif pat.is_abs():
            # When pat is a lambda term, t must also be a lambda term.
            # Replace bound variable by a variable, then match the body.
            if t.is_abs():
                try:
                    pat.var_T.match_incr(t.var_T, inst.tyinst)
                except TypeMatchException:
                    raise MatchException(trace)
                T = pat.var_T.subst(inst.tyinst)

                var_names = [
                    v.name for v in pat.body.get_vars() + t.body.get_vars()
                ]
                nm = name.get_variant_name(pat.var_name, var_names)
                v = Var(nm, T)
                pat_body = pat.subst_type(inst.tyinst).subst_bound(v)
                t_body = t.subst_bound(v)
                bd_vars.append(v)
                match(pat_body, t_body)
                bd_vars.pop()
            else:
                tT = t.get_type()
                if not tT.is_fun():
                    raise MatchException(trace)
                try:
                    pat.var_T.match_incr(tT.domain_type(), inst.tyinst)
                except TypeMatchException:
                    raise MatchException(trace)
                T = pat.var_T.subst(inst.tyinst)
                match(pat, Abs(pat.var_name, T, t(Bound(0))))
        elif pat.is_bound():
            raise MatchException(trace)
        else:
            raise TypeError

        trace.pop()