def setup_split_search (rep, head, restrs, hyps, i_opts, j_opts, unfold_limit = None, tags = None): p = rep.p if not tags: tags = p.pairing.tags if unfold_limit == None: unfold_limit = max ([start + (2 * step) + 1 for (start, step) in i_opts + j_opts]) trace ('Split search at %d, unfold limit %d.' % (head, unfold_limit)) l_tag, r_tag = tags loop_elts = [(n, start, step) for n in p.splittable_points (head) for (start, step) in i_opts] init_to_split = init_loops_to_split (p, restrs) r_to_split = [n for n in init_to_split if p.node_tags[n][0] == r_tag] cand_r_loop_elts = [(n2, start, step) for n in r_to_split for n2 in p.splittable_points (n) for (start, step) in j_opts] err_restrs = restr_others (p, tuple ([(sp, vc_upto (unfold_limit)) for sp in r_to_split]) + restrs, 1) nrerr_pc = mk_not (rep.get_pc (('Err', err_restrs), tag = r_tag)) def get_pc (n, k): head = p.loop_id (n) assert head in init_to_split if n != head: k += 1 restrs2 = restrs + ((head, vc_num (k)), ) return rep.get_pc ((n, restrs2)) for n in r_to_split: get_pc (n, unfold_limit) get_pc (head, unfold_limit) premise = foldr1 (mk_and, [nrerr_pc] + map (rep.interpret_hyp, hyps)) premise = logic.weaken_assert (premise) knowledge = SearchKnowledge (rep, 'search at %d (unfold limit %d)' % (head, unfold_limit), restrs, hyps, tags, (loop_elts, cand_r_loop_elts)) knowledge.premise = premise last_knowledge[0] = knowledge # make sure the representation is in sync rep.test_hyp_whyps (true_term, hyps) # make sure all mem eqs are being tracked mem_vs = [v for v in knowledge.v_ids if v[0].typ == builtinTs['Mem']] for (i, v) in enumerate (mem_vs): for v2 in mem_vs[:i]: for pred in expand_var_eqs (knowledge, (v, v2)): smt_expr (pred, {}, rep.solv) for v in knowledge.v_ids: for pred in expand_var_eqs (knowledge, (v, 'Const')): smt_expr (pred, {}, rep.solv) return knowledge
def find_split_limit (p, n, restrs, hyps, kind, bound = 51, must_find = True, hints = [], use_rep = None): tag = p.node_tags[n][0] trace ('Finding split limit: %d (%s) %s' % (n, tag, restrs)) trace (' (restrs = %s)' % (restrs, )) trace (' (hyps = %s)' % (hyps, ), push = 1) if use_rep == None: rep = mk_graph_slice (p, fast = True) else: rep = use_rep check_order = hints + [i for i in split_sample_set (bound) if i not in hints] for i in check_order: restrs2 = restrs + ((n, VisitCount (kind, i)), ) pc = rep.get_pc ((n, restrs2)) restrs3 = restr_others (p, restrs2, 2) epc = rep.get_pc (('Err', restrs3), tag = tag) hyp = mk_implies (mk_not (epc), mk_not (pc)) if rep.test_hyp_whyps (hyp, hyps): trace ('split limit found: %d' % i, push = -1) return i trace ('No split limit found for %d (%s).' % (n, tag), push = -1) if must_find: assert not 'split limit found' return None
def hyps_add_model (self, hyps, assert_progress = True): if hyps: test_expr = foldr1 (mk_and, hyps) else: # we want to learn something, either a new model, or # that all hyps are true. if there are no hyps, # learning they're all true is learning nothing. # instead force a model test_expr = false_term test_expr = mk_implies (self.premise, test_expr) m = {} (r, _) = self.rep.solv.parallel_check_hyps ([(1, test_expr)], {}, model = m) if r == 'unsat': if not hyps: trace ('WARNING: SearchKnowledge: premise unsat.') trace (" ... learning procedure isn't going to work.") return if assert_progress: assert not (set (hyps) <= self.facts), hyps for hyp in hyps: self.facts.add (hyp) else: assert r == 'sat', r self.add_model (m) if assert_progress: assert self.model_trace[-2:-1] != [m]
def save_bound (glob, split_bin_addr, call_ctxt, prob_hash, prev_bounds, bound, time = None): f_names = [trace_refute.get_body_addrs_fun (x) for x in call_ctxt + [split_bin_addr]] loop_name = '<%s>' % ' -> '.join (f_names) comment = '# bound for loop in %s:' % loop_name ss = ['LoopBound'] + serialise_bound (split_bin_addr, bound) if glob: ss[0] = 'GlobalLoopBound' ss += [str (len (call_ctxt))] + map (hex, call_ctxt) ss += [str (prob_hash)] if glob: assert prev_bounds == None else: ss += [str (len (prev_bounds))] for (split, bound) in prev_bounds: ss += serialise_bound (split, bound) s = ' '.join (ss) f = open ('%s/LoopBounds.txt' % target_objects.target_dir, 'a') f.write (comment + '\n') f.write (s + '\n') if time != None: ctxt2 = call_ctxt + [split_bin_addr] ctxt2 = ' '.join ([str (len (ctxt2))] + map (hex, ctxt2)) f.write ('LoopBoundTiming %s %s\n' % (ctxt2, time)) f.close () trace ('Found bound %s for 0x%x in %s.' % (bound, split_bin_addr, loop_name))
def save_bound(glob, split_bin_addr, call_ctxt, prob_hash, prev_bounds, bound, time=None): f_names = [trace_refute.get_body_addrs_fun(x) for x in call_ctxt + [split_bin_addr]] loop_name = "<%s>" % " -> ".join(f_names) comment = "# bound for loop in %s:" % loop_name ss = ["LoopBound"] + serialise_bound(split_bin_addr, bound) if glob: ss[0] = "GlobalLoopBound" ss += [str(len(call_ctxt))] + map(hex, call_ctxt) ss += [str(prob_hash)] if glob: assert prev_bounds == None else: ss += [str(len(prev_bounds))] for (split, bound) in prev_bounds: ss += serialise_bound(split, bound) s = " ".join(ss) f = open("%s/LoopBounds.txt" % target_objects.target_dir, "a") f.write(comment + "\n") f.write(s + "\n") if time != None: ctxt2 = call_ctxt + [split_bin_addr] ctxt2 = " ".join([str(len(ctxt2))] + map(hex, ctxt2)) f.write("LoopBoundTiming %s %s\n" % (ctxt2, time)) f.close() trace("Found bound %s for 0x%x in %s." % (bound, split_bin_addr, loop_name))
def do_loop_analysis(self): entries = [e for (e, tag, nm, args) in self.entries] self.loop_data = {} graph = self.mk_node_graph() comps = logic.tarjan(graph, entries) self.tarjan_order = [] for (head, tail) in comps: self.tarjan_order.append(head) self.tarjan_order.extend(tail) if not tail and head not in graph[head]: continue trace("Loop (%d, %s)" % (head, tail)) loop_set = set(tail) loop_set.add(head) r = self.force_single_loop_return(head, loop_set) if r != None: tail.append(r) loop_set.add(r) self.tarjan_order.append(r) self.compute_preds() self.loop_data[head] = ("Head", loop_set) for t in tail: self.loop_data[t] = ("Mem", head) # put this in first-to-last order. self.tarjan_order.reverse()
def search_bin_bound (p, restrs, hyps, split): trace ('Searching for bound for 0x%x in %s.', (split, p.name)) bound = search_bound (p, restrs, hyps, split) if bound: return bound # try to use a bound inferred from C if avoid_C_information[0]: # OK told not to return None if get_prior_loop_heads (p, split): # too difficult for now return None asm_tag = p.node_tags[split][0] (_, fname, _) = p.get_entry_details (asm_tag) funs = [f for pair in target_objects.pairings[fname] for f in pair.funs.values ()] c_tags = [tag for tag in p.tags () if p.get_entry_details (tag)[1] in funs and tag != asm_tag] if len (c_tags) != 1: print 'Surprised to see multiple matching tags %s' % c_tags return None [c_tag] = c_tags return getBinaryBoundFromC (p, c_tag, split, restrs, hyps)
def save_compiled_funcs (fname): out = open (fname, 'w') for (f, func) in functions.iteritems (): trace ('Saving %s' % f) for s in func.serialise (): out.write (s + '\n') out.close ()
def v_eqs_to_split (p, pair, v_eqs, restrs, hyps, tags = None): trace ('v_eqs_to_split: (%s, %s)' % pair) ((l_n, l_init, l_step), (r_n, r_init, r_step)) = pair l_details = (l_n, (l_init, l_step), mk_seq_eqs (p, l_n, l_step, True) + [v_i[0] for (v_i, v_j) in v_eqs if v_j == 'Const']) r_details = (r_n, (r_init, r_step), mk_seq_eqs (p, r_n, r_step, False) + c_memory_loop_invariant (p, r_n, l_n)) eqs = [(v_i[0], mk_cast (v_j[0], v_i[0].typ)) for (v_i, v_j) in v_eqs if v_j != 'Const'] n = 2 split = (l_details, r_details, eqs, n, (n * r_step) - 1) trace ('Split: %s' % (split, )) if tags == None: tags = p.pairing.tags hyps = hyps + check.split_loop_hyps (tags, split, restrs, exit = True) r_max = find_split_limit (p, r_n, restrs, hyps, 'Offset', bound = (n + 2) * r_step, must_find = False, hints = [n * r_step, n * r_step + 1]) if r_max == None: trace ('v_eqs_to_split: no RHS limit') return None if r_max > n * r_step: trace ('v_eqs_to_split: RHS limit not %d' % (n * r_step)) return None trace ('v_eqs_to_split: split %s' % (split,)) return split
def load_proofs_from_file (fname): f = open (fname) proofs = {} lines = None for line in f: line = line.strip () if line.startswith ('ProblemProof'): assert line.endswith ('{'), line name_bit = line[len ('ProblemProof') : -1].strip () assert name_bit.startswith ('('), name_bit assert name_bit.endswith (')'), name_bit name = name_bit[1:-1] lines = [] elif line == '}': assert lines[0] == 'Problem' assert lines[-2] == 'EndProblem' import problem trace ('loading proof from %d lines' % len (lines)) p = problem.deserialise (name, lines[:-1]) proof = deserialise (lines[-1]) proofs.setdefault (name, []) proofs[name].append ((p, proof)) trace ('loaded proof %s' % name) lines = None elif line.startswith ('#'): pass elif line: lines.append (line) assert not lines return proofs
def test_hyp_whyps (self, hyp, hyps, cache = None, fast = False, model = None): self.avail_hyps = set (hyps) if not self.used_hyps <= self.avail_hyps: self.rebuild () last_test[0] = (hyp, hyps, list (self.pc_env_requests)) expr = self.interpret_hyp_imps (hyps, hyp) trace ('Testing hyp whyps', push = 1) trace ('requests = %s' % self.pc_env_requests) expr_s = smt_expr (expr, {}, self.solv) if cache and expr_s in cache: trace ('Cached: %s' % cache[expr_s]) return cache[expr_s] if fast: trace ('(not in cache)') return None self.solv.add_pvalid_dom_assertions () if model == None: (result, _) = self.solv.parallel_test_hyps ( [(None, expr)], {}) else: result = self.solv.test_hyp (expr, {}, model = model) trace ('Result: %s' % result, push = -1) if cache != None: cache[expr_s] = result if not result: last_failed_test[0] = last_test[0] return result
def expr_eval_before (n2, expr): if expr.kind == 'Op': if expr.vals == []: return (set(), expr) vals = [expr_eval_before (n2, v) for v in expr.vals] if None in vals: return None s = set.union (* [s for (s, v) in vals]) if len(s) > 1: if diag: trace ('too many vars for %s @ %d: %s' % (expr, n2, s)) return None return (s, Expr ('Op', expr.typ, name = expr.name, vals = [v for (s, v) in vals])) elif expr.kind == 'Num': return (set(), expr) elif expr.kind == 'Var': return var_eval_before (n2, (expr.name, expr.typ)) else: if diag: trace ('Unwalkable expr %s' % expr) return None
def build_proof_rec (searcher, p, restrs, hyps): trace ('doing build proof rec with restrs = %r, hyps = %r' % (restrs, hyps)) (kind, details) = searcher (p, restrs, hyps) last_searcher_results.append ((p, restrs, hyps, kind, details)) del last_searcher_results[:-10] trace ('proof searcher found %s, %s' % (kind, details)) if kind == 'Restr': (restr_kind, restr_points) = details return build_proof_rec_with_restrs (restr_points, restr_kind, searcher, p, restrs, hyps) elif kind == 'Leaf': return ProofNode ('Leaf', None, ()) assert kind in ['CaseSplit', 'Split'] split = details [(_, hyps1, _), (_, hyps2, _)] = check.proof_subproblems (p, kind, split, restrs, hyps, '') if kind == 'CaseSplit': return ProofNode ('CaseSplit', split, [build_proof_rec (searcher, p, restrs, hyps1), build_proof_rec (searcher, p, restrs, hyps2)]) split_points = check.split_heads (split) no_loop_proof = build_proof_rec_with_restrs (split_points, 'Number', searcher, p, restrs, hyps1) loop_proof = build_proof_rec_with_restrs (split_points, 'Offset', searcher, p, restrs, hyps2) return ProofNode ('Split', split, [no_loop_proof, loop_proof])
def find_split_loop (p, head, restrs, hyps, unfold_limit = 9): assert p.loop_data[head][0] == 'Head' assert p.node_tags[head][0] == p.pairing.tags[0] # the idea is to loop through testable hyps, starting with ones that # need smaller models (the most unfolded models will time out for # large problems like finaliseSlot) rep = mk_graph_slice (p, fast = True) nec = get_necessary_split_opts (p, head, restrs, hyps) if nec and nec[0] == 'CaseSplit': return nec elif nec: i_j_opts = nec else: i_j_opts = default_i_j_opts (unfold_limit) ind_fails = [] for (i_opts, j_opts) in i_j_opts: result = find_split (rep, head, restrs, hyps, i_opts, j_opts) if result[0] != None: return result ind_fails.extend (result[1]) if ind_fails: trace ('Warning: inductive failures: %s' % ind_fails) raise NoSplit ()
def build_rodata (rodata_stream): rodata = {} for line in rodata_stream: if not is_rodata_line.match (line): continue bits = line.split () rodata[int (bits[0][:-1], 16)] = int (bits[1], 16) rodata_min = min (rodata.keys ()) rodata_max = max (rodata.keys ()) + 4 assert rodata_min % 4 == 0 rodata_range = range (rodata_min, rodata_max, 4) for x in rodata_range: if x not in rodata: trace ('.rodata section gap at address %x' % x) struct_name = fresh_name ('rodata', structs) struct = Struct (struct_name, rodata_max - rodata_min, 1) structs[struct_name] = struct (start, end) = sections['.rodata'] assert start <= rodata_min assert end + 1 >= rodata_max return (rodata, mk_word32 (rodata_min), struct.typ)
def parse_all(lines): '''Toplevel parser for input information. Accepts an iterator over lines. See syntax.quick_reference for an explanation.''' structs = {} functions = {} const_globals = {} for line in lines: bits = line.split() # empty lines and #-comments ignored if not bits or bits[0][0] == '#': continue if bits[0] == 'Struct': # Struct <name> <size> <alignment> # followed by block of StructField lines assert bits[1] not in structs current_struct = Struct (bits[1], parse_int (bits[2]), parse_int (bits[3])) structs[bits[1]] = current_struct elif bits[0] == 'StructField': # StructField <name> <type (encoded)> <offset> (_, typ) = parse_typ(bits, 2, symbolic_types = True) current_struct.add_field (bits[1], typ, parse_int (bits[-1])) elif bits[0] == 'ConstGlobalDef': # ConstGlobalDef <name> <value> name = bits[1] (_, val) = parse_expr (bits, 2) const_globals[name] = val elif bits[0] == 'Function': # Function <name> <inputs> <outputs> # followed by optional block of node lines # concluded by EntryPoint line (n, inputs) = parse_list (parse_arg, bits, 2) (_, outputs) = parse_list (parse_arg, bits, n) trace ('Function %s' % bits[1]) current_function = Function (bits[1], inputs, outputs) functions[bits[1]] = current_function elif bits[0] == 'EntryPoint': # EntryPoint <entry point> entry = node_name(bits[1]) # instead of setting function.entry to this value, # create a dummy node. this ensures there is always # at least one node (EntryPoint Ret is valid) and # also that the entry point is not in a loop name = fresh_node (current_function.nodes) current_function.nodes[name] = Node ('Basic', entry, []) current_function.entry = name # ensure that the function graph is closed check_cfg (current_function) current_function = None else: # <node name> <node (encoded)> name = node_name(bits[0]) assert name not in current_function.nodes, (name, bits) current_function.nodes[name] = parse_node (bits, 1) return (structs, functions, const_globals)
def build_proof (p): init_hyps = check.init_point_hyps (p) proof = build_proof_rec (default_searcher, p, (), list (init_hyps)) trace ('Built proof for %s' % p.name) printout (repr (proof)) last_proof[0] = proof return proof
def toplevel_check (pair, check_loops = True, report = False, count = None, only_build_problem = False): if not only_build_problem: printout ('Testing Function pair %s' % pair) if count and not only_build_problem: (i, n) = count printout (' (function pairing %d of %d)' % (i + 1, n)) for (tag, fname) in pair.funs.iteritems (): if not functions[fname].entry: printout ('Skipping %s, underspecified %s' % (pair, tag)) return 'None' prev_tracer = tracer[0] if report: tracer[0] = lambda s, n: () exception = None trace (time.asctime ()) start_time = time.time() sys.stdout.flush () try: p = check.build_problem (pair) if only_build_problem: tracer[0] = prev_tracer return 'True' if report: printout (' .. built problem, finding proof') if not check_loops and p.loop_data: printout ('Problem has loop!') tracer[0] = prev_tracer return 'Loop' if check_loops == 'only' and not p.loop_data: printout ('No loop in problem.') tracer[0] = prev_tracer return 'NoLoop' proof = search.build_proof (p) if report: printout (' .. proof found.') try: if report: result = check.check_proof_report (p, proof) else: result = check.check_proof (p, proof) if result: printout ('Refinement proven.') else: printout ('Refinement NOT proven.') except solver.SolverFailure, e: printout ('Solver timeout/failure in proof check.') result = 'CheckSolverFailure' except Exception, e: trace ('EXCEPTION in checking %s:' % p.name) exception = sys.exc_info () result = 'CheckEXCEPT'
def get_bound_ctxt(split, call_ctxt, use_cache=True): trace('Getting bound for 0x%x in context %s.' % (split, call_ctxt)) (p, hyps, addr_map) = get_call_ctxt_problem(split, call_ctxt) orig_split = split split = p.loop_id(addr_map[split]) assert split, (orig_split, call_ctxt) split_bin_addr = min( [addr for addr in addr_map if p.loop_id(addr_map[addr]) == split]) prior = get_prior_loop_heads(p, split) restrs = () prev_bounds = [] for split2 in prior: # recursion! split2 = p.loop_id(split2) assert split2 addr = min( [addr for addr in addr_map if p.loop_id(addr_map[addr]) == split2]) bound = get_bound_ctxt(addr, call_ctxt) prev_bounds.append((addr, bound)) k = (p.name, split2, bound, restrs, tuple(hyps)) if k in known_bound_restr_hyps: (restrs, hyps) = known_bound_restr_hyps[k] else: (restrs, hyps) = add_loop_bound_restrs_hyps(p, restrs, hyps, split2, bound, call_ctxt + [orig_split]) known_bound_restr_hyps[k] = (restrs, hyps) # start timing now. we miss some setup time, but it avoids double counting # the recursive searches. start = time.time() p_h = problem_hash(p) prev_bounds = sorted(prev_bounds) if not known_bounds: load_bounds() known = known_bounds.get(split_bin_addr, []) for (call_ctxt2, h, prev_bounds2, bound) in known: match = (not call_ctxt2 or call_ctxt[-len(call_ctxt2):] == call_ctxt2) if match and use_cache and h == p_h and prev_bounds2 == prev_bounds: return bound bound = search_bin_bound(p, restrs, hyps, split) known = known_bounds.setdefault(split_bin_addr, []) known.append((call_ctxt, p_h, prev_bounds, bound)) end = time.time() save_bound(False, split_bin_addr, call_ctxt, p_h, prev_bounds, bound, time=end - start) return bound
def test_hyp_group (rep, group): imps = [(hyps, hyp) for (hyps, hyp, _) in group] names = set ([name for (_, _, name) in group]) trace ('Testing group of hyps: %s' % list (names), push = 1) (res, i) = rep.test_hyp_imps (imps) trace ('Group result: %r' % res, push = -1) if res: return (res, None) else: return (res, group[i])
def get_ptr_offsets(p, n_ptrs, bases, hyps=[], cache=None, fail_early=False): """detect which ptrs are guaranteed to be at constant offsets from some set of basis ptrs""" rep = rep_graph.mk_graph_slice(p, fast=True) if cache == None: cache = {} last_get_ptr_offsets[0] = (p, n_ptrs, bases, hyps) smt_bases = [] for (n, ptr, k) in bases: n_vc = default_n_vc(p, n) (_, env) = rep.get_node_pc_env(n_vc) smt = solver.smt_expr(ptr, env, rep.solv) smt_bases.append((smt, k)) ptr_typ = ptr.typ smt_ptrs = [] for (n, ptr) in n_ptrs: n_vc = default_n_vc(p, n) pc_env = rep.get_node_pc_env(n_vc) if not pc_env: continue smt = solver.smt_expr(ptr, pc_env[1], rep.solv) hyp = rep_graph.pc_true_hyp((n_vc, p.node_tags[n][0])) smt_ptrs.append(((n, ptr), smt, hyp)) hyps = hyps + mk_not_callable_hyps(p) for tag in set([p.node_tags[n][0] for (n, _) in n_ptrs]): hyps = hyps + init_correctness_hyps(p, tag) tags = set([p.node_tags[n][0] for (n, ptr) in n_ptrs]) ex_defs = {} for t in tags: ex_defs.update(get_extra_sp_defs(rep, t)) offs = [] for (v, ptr, hyp) in smt_ptrs: off = None for (ptr2, k) in smt_bases: off = offs_expr_const(ptr, ptr2, rep, [hyp] + hyps, cache=cache, extra_defs=ex_defs, typ=ptr_typ) if off != None: offs.append((v, off, k)) break if off == None: trace('get_ptr_offs fallthrough at %d: %s' % v) trace(str([hyp] + hyps)) assert not fail_early, (v, ptr) return offs
def compute_recursion_idents (group, extra_unfolds): idents = {} group = set (group) recursion_trace.append ('Computing for group %s' % group) trace ('Computing recursion idents for group %s' % group) prevs = set ([f for f in functions if [f2 for f2 in functions[f].function_calls () if f2 in group]]) for f in prevs - group: recursion_trace.append (' checking for %s' % f) trace ('Checking idents for %s' % f) while add_recursion_ident (f, group, idents, extra_unfolds): pass return idents
def var_eval_after (n2, v): node = nodes[n2] if node.kind == 'Call' and v in node.rets: if diag: trace ('fetched %s from call at %d' % (v, n2)) return None elif node.kind == 'Basic': for (lv, val) in node.upds: if lv == v: return expr_eval_before (n2, val) return var_eval_before (n2, v) else: return var_eval_before (n2, v)
def test_hyp_group(rep, group, detail=None): imps = [(hyps, hyp) for (hyps, hyp, _) in group] names = set([name for (_, _, name) in group]) trace('Testing group of hyps: %s' % list(names), push=1) (res, i, res_kind) = rep.test_hyp_imps(imps) trace('Group result: %r' % res, push=-1) if res: return (res, None) else: if detail: detail[0] = res_kind return (res, group[i])
def inline_no_pre_pairing(p): # FIXME: handle code sharing with check.inline_completely_unmatched while True: ns = [ n for n in p.nodes if p.nodes[n].kind == 'Call' if p.nodes[n].fname not in pre_pairings if not is_instruction(p.nodes[n].fname) ] for n in ns: trace('Inlining %s at %d.' % (p.nodes[n].fname, n)) problem.inline_at_point(p, n) if not ns: return
def eval_model (m, s, toplevel = None): if s in m: return m[s] if toplevel == None: toplevel = s if type (s) == str: try: result = solver.smt_to_val (s) except Exception, e: trace ('Error with eval_model') trace (toplevel) raise e return result
def toplevel_check (pair, check_loops = True, report = False, count = None): printout ('Testing Function pair %s' % pair) if count: (i, n) = count printout (' (function pairing %d of %d)' % (i + 1, n)) for (tag, fname) in pair.funs.iteritems (): if not functions[fname].entry: printout ('Skipping %s, underspecified %s' % (pair, tag)) return 'None' prev_tracer = tracer[0] if report: tracer[0] = lambda s, n: () exception = None trace (time.asctime ()) start_time = time.time() sys.stdout.flush () try: p = check.build_problem (pair) if report: printout (' .. built problem, finding proof') if not check_loops and p.loop_data: printout ('Problem has loop!') tracer[0] = prev_tracer return 'Loop' if check_loops == 'only' and not p.loop_data: printout ('No loop in problem.') tracer[0] = prev_tracer return 'NoLoop' proof = search.build_proof (p) if report: printout (' .. proof found.') try: if report: result = check.check_proof_report (p, proof) else: result = check.check_proof (p, proof) if result: printout ('Refinement proven.') else: printout ('Refinement NOT proven.') except solver.SolverFailure, e: printout ('Solver timeout/failure in proof check.') result = 'CheckSolverFailure' except Exception, e: trace ('EXCEPTION in checking %s:' % p.name) exception = sys.exc_info () result = 'CheckEXCEPT'
def mk_inp_env (n, args, rep): trace ('rep_graph setting up input env at %d' % n, push = 1) inp_env = {} for (v_nm, typ) in args: inp_env[(v_nm, typ)] = rep.solv.add_var_restr (v_nm + '_init', typ, mem_name = 'Init') for (v_nm, typ) in args: z = rep.var_rep_request ((v_nm, typ), 'Init', (n, ()), inp_env) if z: inp_env[(v_nm, typ)] = z trace ('done setting up input env at %d' % n, push = -1) return inp_env
def get_bound_ctxt (split, call_ctxt): trace ('Getting bound for 0x%x in context %s.' % (split, call_ctxt)) (p, hyps, addr_map) = get_call_ctxt_problem (split, call_ctxt) orig_split = split split = p.loop_id (addr_map[split]) assert split, (orig_split, call_ctxt) split_bin_addr = min ([addr for addr in addr_map if p.loop_id (addr_map[addr]) == split]) prior = get_prior_loop_heads (p, split) restrs = () prev_bounds = [] for split2 in prior: # recursion! split2 = p.loop_id (split2) assert split2 addr = min ([addr for addr in addr_map if p.loop_id (addr_map[addr]) == split2]) bound = get_bound_ctxt (addr, call_ctxt) prev_bounds.append ((addr, bound)) k = (p.name, split2, bound, restrs, tuple (hyps)) if k in known_bound_restr_hyps: (restrs, hyps) = known_bound_restr_hyps[k] else: (restrs, hyps) = add_loop_bound_restrs_hyps (p, restrs, hyps, split2, bound, call_ctxt + [orig_split]) known_bound_restr_hyps[k] = (restrs, hyps) # start timing now. we miss some setup time, but it avoids double counting # the recursive searches. start = time.time () p_h = problem_hash (p) prev_bounds = sorted (prev_bounds) if not known_bounds: load_bounds () known = known_bounds.get (split_bin_addr, []) for (call_ctxt2, h, prev_bounds2, bound) in known: match = (not call_ctxt2 or call_ctxt[- len (call_ctxt2):] == call_ctxt2) if match and h == p_h and prev_bounds2 == prev_bounds: return bound bound = search_bin_bound (p, restrs, hyps, split) known = known_bounds.setdefault (split_bin_addr, []) known.append ((call_ctxt, p_h, prev_bounds, bound)) end = time.time () save_bound (False, split_bin_addr, call_ctxt, p_h, prev_bounds, bound, time = end - start) return bound
def consider_inline_c1 (p, n, c_funs, inline_tag, force_inline, skip_underspec): node = p.nodes[n] assert node.kind == 'Call' if p.node_tags[n][0] != inline_tag: return False f_nm = node.fname if skip_underspec and not functions[f_nm].entry: trace ('Skipping inlining underspecified %s' % f_nm) return False if f_nm not in c_funs or (force_inline and force_inline (f_nm)): return lambda: inline_at_point (p, n) else: return False
def consider_inline1(p, n, matched_funs, inline_tag, force_inline, skip_underspec): node = p.nodes[n] assert node.kind == 'Call' if p.node_tags[n][0] != inline_tag: return False f_nm = node.fname if skip_underspec and not functions[f_nm].entry: trace('Skipping inlining underspecified %s' % f_nm) return False if f_nm not in matched_funs or (force_inline and force_inline(f_nm)): return lambda: inline_at_point(p, n) else: return False
def try_inline (self, n, pc, env): if not self.inliner: return False inline = self.inliner ((self.p, n)) if not inline: return False # make sure this node is reachable before inlining if self.solv.test_hyp (mk_not (pc), env): trace ('Skipped inlining at %d.' % n) return False trace ('Inlining at %d.' % n) inline () raise InlineEvent ()
def compute_recursion_idents(group, extra_unfolds): idents = {} group = set(group) recursion_trace.append('Computing for group %s' % group) printout('Doing recursion analysis for function group:') printout(' %s' % list(group)) prevs = set([ f for f in functions if [f2 for f2 in functions[f].function_calls() if f2 in group] ]) for f in prevs - group: recursion_trace.append(' checking for %s' % f) trace('Checking idents for %s' % f) while add_recursion_ident(f, group, idents, extra_unfolds): pass return idents
def getBinaryBoundFromC(p, c_tag, asm_split, restrs, hyps): c_heads = [h for h in search.init_loops_to_split(p, restrs) if p.node_tags[h][0] == c_tag] c_bounds = [(p.loop_id(split), search_bound(p, (), hyps, split)) for split in c_heads] if not [b for (n, b) in c_bounds if b]: trace("no C bounds found (%s)." % c_bounds) return None asm_tag = p.node_tags[asm_split][0] rep = rep_graph.mk_graph_slice(p) i_seq_opts = [(0, 1), (1, 1), (2, 1)] j_seq_opts = [(0, 1), (0, 2), (1, 1)] tags = [p.node_tags[asm_split][0], c_tag] try: split = search.find_split(rep, asm_split, restrs, hyps, i_seq_opts, j_seq_opts, 5, tags=[asm_tag, c_tag]) except solver.SolverFailure, e: return None
def trace_search_fail (knowledge): trace (('Exhausted split candidates for %s' % knowledge.name)) fails = [it for it in knowledge.pairs.items () if it[1][0] == 'Failed'] last_failed_pairings.append (fails) del last_failed_pairings[:-10] fails10 = fails[:10] trace (' %d of %d failed pairings:' % (len (fails10), len (fails))) for f in fails10: trace (' %s' % (f,)) ind_fails = [it for it in fails if str (it[1][1]) == 'InductFailed'] if ind_fails: trace ( 'Inductive failures!') for f in ind_fails: trace (' %s' % (f,)) return ind_fails
def find_actual_call_node(p, n): """we're getting call addresses from the binary trace, and using the node naming convention to find a relevant graph node, but it might not be the actual call node. a short breadth-first-search should hopefully find it.""" stack = [(n, 3)] init_n = n while stack: (n, limit) = stack.pop(0) if limit < 0: continue if p.nodes[n].kind == "Call": return n else: for c in p.nodes[n].get_conts(): stack.append((c, limit - 1)) trace("failed to find Call node near %s" % init_n) return None
def find_actual_call_node(p, n): """we're getting call addresses from the binary trace, and using the node naming convention to find a relevant graph node, but it might not be the actual call node. a short breadth-first-search should hopefully find it.""" stack = [(n, 3)] init_n = n while stack: (n, limit) = stack.pop(0) if limit < 0: continue if p.nodes[n].kind == 'Call': return n else: for c in p.nodes[n].get_conts(): stack.append((c, limit - 1)) trace('failed to find Call node near %s' % init_n) return None
def loop_var_analysis(p, head, tail): # getting the set of variables that go round the loop nodes = set(tail) nodes.add(head) used_vs = set([]) created_vs_at = {} visit = [] def process_node(n, created): if p.nodes[n].is_noop(): lvals = set([]) else: vs = syntax.get_node_rvals(p.nodes[n]) for rv in vs.iteritems(): if rv not in created: used_vs.add(rv) lvals = set(p.nodes[n].get_lvals()) created = set.union(created, lvals) created_vs_at[n] = created visit.extend(p.nodes[n].get_conts()) process_node(head, set([])) while visit: n = visit.pop() if (n not in nodes) or (n in created_vs_at): continue if not all([pr in created_vs_at for pr in p.preds[n]]): continue pre_created = [created_vs_at[pr] for pr in p.preds[n]] process_node(n, set.union(*pre_created)) final_pre_created = [ created_vs_at[pr] for pr in p.preds[head] if pr in nodes ] created = set.union(*final_pre_created) loop_vs = set.intersection(created, used_vs) trace('Loop vars at head: %s' % loop_vs) return loop_vs
def loop_no_match_unroll (rep, restrs, hyps, split, other_tag, unroll): p = rep.p assert p.node_tags[split][0] != other_tag restr = ((split, vc_num (unroll)), ) restrs2 = restr_others (p, restr + restrs, 2) loop_cond = rep.get_pc ((split, restr + restrs)) ret_cond = rep.get_pc (('Ret', restrs2), tag = other_tag) # loop should be reachable if rep.test_hyp_whyps (mk_not (loop_cond), hyps): trace ('Loop weak at %d (unroll count %d).' % (split, unroll)) return True # reaching the loop should imply reaching a loop on the other side hyp = mk_not (mk_and (loop_cond, ret_cond)) if not rep.test_hyp_whyps (hyp, hyps): trace ('Loop independent at %d (unroll count %d).' % (split, unroll)) return True return False
def inline_completely_unmatched(p, ref_tags=None, skip_underspec=False): if ref_tags == None: ref_tags = p.pairing.tags while True: ns = [(n, skip_underspec and not functions[p.nodes[n].fname].entry) for n in p.nodes if p.nodes[n].kind == 'Call' if not [ pair for pair in pairings.get(p.nodes[n].fname, []) if pair.tags == ref_tags ]] [ trace('Skipped inlining underspecified %s.' % p.nodes[n].fname) for (n, skip) in ns if skip ] ns = [n for (n, skip) in ns if not skip] for n in ns: trace('Function %s at %d - %s - completely unmatched.' % (p.nodes[n].fname, n, p.node_tags[n][0])) inline_at_point(p, n, do_analysis=False) if not ns: p.do_analysis() return
def inline_at_point(p, n, do_analysis=True): node = p.nodes[n] if node.kind != 'Call': return f_nm = node.fname fun = functions[f_nm] (tag, detail) = p.node_tags[n] idx = p.node_tag_revs[(tag, detail)].index(n) p.inline_scripts[tag].append((detail, idx, f_nm)) trace('Inlining %s into %s' % (f_nm, p.name)) if n in p.loop_data: trace(' inlining into loop %d!' % p.loop_id(n)) ex = p.alloc_node(tag, (f_nm, 'RetToCaller')) (ns, vs) = p.add_function(fun, tag, {'Ret': ex}) en = ns[fun.entry] inp_lvs = [(vs[v], typ) for (v, typ) in fun.inputs] p.nodes[n] = Node('Basic', en, azip(inp_lvs, node.args)) out_vs = [mk_var(vs[v], typ) for (v, typ) in fun.outputs] p.nodes[ex] = Node('Basic', node.cont, azip(node.rets, out_vs)) p.cached_analysis.clear() if do_analysis: p.do_analysis() trace('Problem size now %d' % len(p.nodes)) sys.stdin.flush() return ns.values()
def compute_recursive_stack_bounds(immed): assert not immediate_stack_bounds_loop(immed) bounds = {} todo = immed.keys() report = 1000 while todo: if len(todo) >= report: trace('todo length %d' % len(todo)) trace('tail: %s' % todo[-20:]) report += 1000 (fname, ident) = todo.pop() if (fname, ident) in bounds: continue (static, calls) = immed[(fname, ident)] if [1 for k in calls if k not in bounds]: todo.append((fname, ident)) todo.extend(calls.keys()) continue else: bounds[(fname, ident)] = max([static] + [bounds[k] + calls[k] for k in calls]) return bounds
def compile_struct_use(function): trace('Compiling in %s.' % function.name) vs = get_vars(function) max_node = max(function.nodes.keys() + [2]) visit_vs = vs.keys() replaces = {} while visit_vs: v = visit_vs.pop() typ = vs[v] if typ.kind == 'Struct': fields = structs[typ.name].field_list elif typ.kind == 'Array': fields = [(i, typ.el_typ) for i in range(typ.num)] else: continue new_vs = [(nm, fresh_name('%s.%s' % (v, nm), vs, f_typ), f_typ) for (nm, f_typ) in fields] replaces[v] = new_vs visit_vs.extend([v_nm for (_, v_nm, _) in new_vs]) for n in function.nodes: node = function.nodes[n] if node.kind == 'Basic': node.upds = compile_upds(replaces, node.upds) elif node.kind == 'Basic': assert not node.lval[1].kind in ['Struct', 'Array'] node.val = compile_accs(replaces, node.val) elif node.kind == 'Call': node.args = expand_arg_fields(replaces, node.args) node.rets = expand_lval_list(replaces, node.rets) elif node.kind == 'Cond': node.cond = compile_accs(replaces, node.cond) else: assert not 'node kind understood' function.inputs = expand_lval_list(replaces, function.inputs) function.outputs = expand_lval_list(replaces, function.outputs) return len(replaces) == 0
def get_asm_calling_convention(fname): if fname in asm_cc_cache: return asm_cc_cache[fname] if fname not in pre_pairings: bits = fname.split("'") if not is_instruction(fname): trace("Warning: unusual unmatched function (%s, %s)." % (fname, bits)) return None pair = pre_pairings[fname] assert pair['ASM'] == fname c_fun = functions[pair['C']] from logic import split_scalar_pairs (var_c_args, c_imem, glob_c_args) = split_scalar_pairs(c_fun.inputs) (var_c_rets, c_omem, glob_c_rets) = split_scalar_pairs(c_fun.outputs) num_args = len(var_c_args) num_rets = len(var_c_rets) const_mem = not (c_omem) cc = get_asm_calling_convention_inner(num_args, num_rets, const_mem) asm_cc_cache[fname] = cc return cc
def check_proof(p, proof, use_rep=None): checks = proof_checks(p, proof) groups = proof_check_groups(checks) for group in groups: if use_rep == None: rep = rep_graph.mk_graph_slice(p) else: rep = use_rep detail = [0] (verdict, elt) = test_hyp_group(rep, group, detail) if verdict: continue (hyps, hyp, name) = elt last_failed_check[0] = elt trace('%s: proof failed!' % name) trace(' (failure kind: %r)' % detail[0]) return False if save_checked_proofs[0]: save = save_checked_proofs[0] save(p, proof) return True
def build_problem(pairing, force_inline=None, avoid_abort=False): p = Problem(pairing) for (tag, fname) in pairing.funs.items(): p.add_entry_function(functions[fname], tag) p.do_analysis() # FIXME: the inlining is heuristic, and arguably belongs in 'search' inline_completely_unmatched(p, skip_underspec=avoid_abort) # now do any C inlining inline_reachable_unmatched_C(p, force_inline, skip_underspec=avoid_abort) trace('Done inlining.') p.pad_merge_points() p.do_analysis() if not avoid_abort: p.check_no_inner_loops() return p
def offs_expr_const(addr_expr, sp_expr, rep, hyps, extra_defs={}, cache=None, typ=syntax.word32T): """if the offset between a stack addr and the initial stack pointer is a constant offset, try to compute it.""" addr_x = solver.parse_s_expression(addr_expr) sp_x = solver.parse_s_expression(sp_expr) vs = [(addr_x, 1), (sp_x, -1)] const = 0 while True: start_vs = list(vs) new_vs = {} for (x, mult) in vs: (var, c) = split_sum_s_expr(x, rep.solv, extra_defs, typ=typ) for v in var: new_vs.setdefault(v, 0) new_vs[v] += var[v] * mult const += c * mult vs = [(x, n) for (x, n) in new_vs.iteritems() if n % (2**typ.num) != 0] if not vs: return const vs = [(simplify_expr_whyps(x, rep, hyps, cache=cache, extra_defs=extra_defs), n) for (x, n) in vs] if sorted(vs) == sorted(start_vs): pass # vs = split_merge_ite_sum_sexpr (vs) if sorted(vs) == sorted(start_vs): trace('offs_expr_const: not const') trace('%s - %s' % (addr_expr, sp_expr)) trace(str(vs)) trace(str(hyps)) last_10_non_const.append((addr_expr, sp_expr, vs, hyps)) del last_10_non_const[:-10] return None
i_seq_opts = [(0, 1), (1, 1), (2, 1)] j_seq_opts = [(0, 1), (0, 2), (1, 1)] tags = [p.node_tags[asm_split][0], c_tag] try: split = search.find_split(rep, asm_split, restrs, hyps, i_seq_opts, j_seq_opts, 5, tags=[asm_tag, c_tag]) except solver.SolverFailure, e: return None if not split or split[0] != 'Split': trace('no split found (%s).' % repr(split)) return None (_, split) = split rep = rep_graph.mk_graph_slice(p) checks = check.split_checks(p, (), hyps, split, tags=[asm_tag, c_tag]) groups = check.proof_check_groups(checks) try: for group in groups: (res, el) = check.test_hyp_group(rep, group) if not res: trace('split check failed!') trace('failed at %s' % el) return None except solver.SolverFailure, e: return None (as_details, c_details, _, n, _) = split
def get_bound_super_ctxt_inner(split, call_ctxt, no_splitting=(False, None)): first_f = trace_refute.identify_function([], (call_ctxt + [split])[:1]) call_sites = all_call_sites(first_f) if function_limit(first_f) == 0: return (0, 'FunctionLimit') safe_call_sites = [ cs for cs in call_sites if ctxt_within_function_limits([cs] + call_ctxt) ] if call_sites and not safe_call_sites: return (0, 'FunctionLimit') if len(call_ctxt) < 3 and len(safe_call_sites) == 1: call_ctxt2 = list(safe_call_sites) + call_ctxt if call_ctxt_computable(split, call_ctxt2): trace('using unique calling context %s' % str((split, call_ctxt2))) return get_bound_super_ctxt(split, call_ctxt2) fname = trace_refute.identify_function(call_ctxt, [split]) bound = function_limit_bound(fname, split) if bound: return bound bound = get_bound_ctxt(split, call_ctxt) if bound: return bound trace('no bound found immediately.') if no_splitting[0]: assert no_splitting[1], no_splitting no_splitting[1][0] = True trace('cannot split by context (recursion).') return None # try to split over potential call sites if len(call_ctxt) >= 3: trace('cannot split by context (context depth).') return None if len(call_sites) == 0: # either entry point or nonsense trace('cannot split by context (reached top level).') return None problem_sites = [ call_site for call_site in safe_call_sites if not call_ctxt_computable(split, [call_site] + call_ctxt) ] if problem_sites: trace('cannot split by context (issues in %s).' % problem_sites) return None anc_bounds = [ get_bound_super_ctxt(split, [call_site] + call_ctxt, no_splitting=True) for call_site in safe_call_sites ] if None in anc_bounds: return None (bound, kind) = max(anc_bounds) return (bound, 'MergedBound')
result = 'CheckSolverFailure' except Exception, e: trace ('EXCEPTION in checking %s:' % p.name) exception = sys.exc_info () result = 'CheckEXCEPT' except problem.Abort: result = 'ProblemAbort' except search.NoSplit: result = 'ProofNoSplit' except solver.SolverFailure, e: printout ('Solver timeout/failure in proof search.') result = 'ProofSolverFailure' except Exception, e: trace ('EXCEPTION in handling %s:' % pair) exception = sys.exc_info () result = 'ProofEXCEPT' end_time = time.time () tracer[0] = prev_tracer if exception: (etype, evalue, tb) = exception traceback.print_exception (etype, evalue, tb, file = sys.stdout) if not only_build_problem: printout ('Result %s for pair %s, time taken: %.2fs' % (result, pair, end_time - start_time)) sys.stdout.flush ()