def fill_smt(self, sketch): global m_counter m_counter = 0 domains = self.get_domains(sketch) z3_solver = z3.Solver() try: z3_solver.set('smt.seq.use_derivatives', True) z3_solver.check() except: pass start = time.time() m_vars = {} for values in itertools.product(*domains): if self.configuration.die: break concrete = deepcopy(sketch) holes = self.traverse_and_save_holes(concrete) assert len(values) == len(holes) for i, hole in enumerate(holes): hole.data = values[i] z3re = self.to_z3.eval(concrete) m = z3.Bool(_get_new_m()) m_vars[m] = concrete big_and = [] for v in self.valid: v_s = v[0] big_and.append(z3.InRe(v_s, z3re)) for i in self.invalid: i_s = i[0] big_and.append(z3.Not(z3.InRe(i_s, z3re))) z3_solver.add(m == z3.And(big_and)) z3_solver.add(z3.Or(*m_vars.keys())) time_encoding = (time.time() - start) start = time.time() res = z3_solver.check() if res == z3.sat: self.count_sat_calls += 1 self.time_sat_calls += (time.time() - start) self.time_sat_encoding += time_encoding ret_val = list(map(lambda k: m_vars[k], filter(lambda m: z3_solver.model()[m], m_vars.keys()))) assert len(ret_val) > 0 return ret_val else: self.count_unsat_calls += 1 self.time_unsat_calls += (time.time() - start) self.time_unsat_encoding += time_encoding return []
def string_match(x, args): if type(args[0]) == dict: # we know that they used some parts of the tainted data as regex val = createZ3ExpressionFromConstraint(args[0], {}) # we know that its a direct flow and thus we can handle this with ease return z3.IndexOf(x, val, 0) > -1 else: return z3.InRe(x, regex_to_z3(args[0]))
def multi_distinguish(self, regexes): start = time.time() # Problem: cannot distinguish more than 4 regexes at once: it takes forever. # Solution: use only 4 randomly selected regexes for the SMT maximization, # and then add the others to the solution. if len(regexes) <= 4: selected_regexes = regexes others = [] else: random.seed('regex') random.shuffle(regexes) selected_regexes = regexes[:4] others = regexes[4:] solver = z3.Optimize() z3_regexes = [] for regex in selected_regexes: z3_regex = self._toz3.eval(regex) z3_regexes.append(z3_regex) dist = z3.String("distinguishing") # solver.add(z3.Length(dist) <= 6) ro_z3 = [] for i, z3_regex in enumerate(z3_regexes): ro = z3.Bool(f"ro_{i}") ro_z3.append(ro) solver.add(ro == z3.InRe(dist, z3_regex)) # ro_z3[i] == true if dist matches regex[i]. big_or = [] for ro_i, ro_j in combinations(ro_z3, 2): big_or.append(z3.Xor(ro_i, ro_j)) solver.add_soft(z3.Xor(ro_i, ro_j)) solver.add(z3.Or(big_or)) # at least one regex is distinguished if solver.check() == z3.sat: # print(solver.model()) print("took", round(time.time() - start, 2), "seconds") keep_if_valid = [] keep_if_invalid = [] dist_input = str(solver.model()[dist]).strip('"') for i, ro in enumerate(ro_z3): if solver.model()[ro]: keep_if_valid.append(selected_regexes[i]) else: keep_if_invalid.append(selected_regexes[i]) smallest_regex = min(selected_regexes, key=lambda r: len(self._printer.eval(r))) return dist_input, keep_if_valid, keep_if_invalid, others else: return None, None, None, None
def has_failed_examples(self, regex: Node): """ Test whether the given program would fail on any of the examples provided. """ if not self.use_smt: regex = self._interpreter.eval(regex) re_compiled = re.compile(regex) return any( map(lambda x: self._match(re_compiled, x.input) != x.output, self._examples) ) else: regex_z3 = self._to_z3.eval(regex) z3_solver = z3.Solver() big_and = [] for x in self._examples: if x.output: big_and.append(z3.InRe(x.input[0], regex_z3)) else: big_and.append(z3.Not(z3.InRe(x.input[0], regex_z3))) z3_solver.add(z3.And(big_and)) return z3_solver.check() == z3.unsat
def string_search(x, args): new_val = z3.String('__ignore_search_helper_' + randomString()) regex = args[0] startsWith = False endsWith = False if regex[0] == '^': startsWith = True regex = regex[1:] if regex[-1] == '$': endsWith = True regex = regex[:-1] GLOBAL_CONSTRAINTS.append(z3.InRe(new_val, regex_to_z3(args[0]))) if startsWith and endsWith: # we need to return the index which should be 0 iff it matches GLOBAL_CONSTRAINTS.append(x == new_val) return z3.IntVal(0) elif startsWith: GLOBAL_CONSTRAINTS.append(z3.PrefixOf(new_val, x)) elif endsWith: GLOBAL_CONSTRAINTS.append(z3.SuffixOf(new_val, x)) return z3.IndexOf(x, new_val, 0)
def _internal_match_patterns(space: StateSpace, top_patterns: Any, flags: int, smtstr: z3.ExprRef, offset: int) -> Optional[_Match]: """ >>> from crosshair.statespace import SimpleStateSpace >>> import sre_parse >>> smtstr = z3.String('smtstr') >>> space = SimpleStateSpace() >>> space.add(smtstr == z3.StringVal('aabb')) >>> _internal_match_patterns(space, sre_parse.parse('a+'), 0, smtstr, 0).span() (0, 2) >>> _internal_match_patterns(space, sre_parse.parse('ab'), 0, smtstr, 1).span() (1, 3) """ matchstr = z3.SubString(smtstr, offset, z3.Length(smtstr)) if offset > 0 else smtstr if len(top_patterns) == 0: return _Match([(None, offset, offset)]) pattern = top_patterns[0] def continue_matching(prefix): suffix = _internal_match_patterns(space, top_patterns[1:], flags, smtstr, prefix.end()) if suffix is None: return None return prefix._add_match(suffix) # TODO: using a typed internal function triggers __hash__es inside the typing module. # Seems like this casues nondeterminism due to a global LRU cache used by the typing module. def fork_on(expr, sz): if space.smt_fork(expr): return continue_matching(_Match([(None, offset, offset + sz)])) else: return None # Handle simple single-character expressions using z3's built-in capabilities. z3_re = single_char_regex(pattern, flags) if z3_re is not None: ch = z3.SubString(matchstr, 0, 1) return fork_on(z3.InRe(ch, z3_re), 1) (op, arg) = pattern if op is MAX_REPEAT: (min_repeat, max_repeat, subpattern) = arg if max_repeat < min_repeat: return None reps = 0 overall_match = _Match([(None, offset, offset)]) while reps < min_repeat: submatch = _internal_match_patterns(space, subpattern, flags, smtstr, overall_match.end()) if submatch is None: return None overall_match = overall_match._add_match(submatch) reps += 1 if max_repeat != MAXREPEAT and reps >= max_repeat: return continue_matching(overall_match) submatch = _internal_match_patterns(space, subpattern, flags, smtstr, overall_match.end()) if submatch is None: return continue_matching(overall_match) # we matched; try to be greedy first, and fall back to `submatch` as the last consumed match greedy_remainder = _patt_replace( top_patterns, arg, ( 1, max_repeat if max_repeat == MAXREPEAT else max_repeat - (min_repeat + 1), subpattern, ), ) greedy_match = _internal_match_patterns(space, greedy_remainder, flags, smtstr, submatch.end()) if greedy_match is not None: return overall_match._add_match(submatch)._add_match(greedy_match) else: match_with_optional = continue_matching( overall_match._add_match(submatch)) if match_with_optional is not None: return match_with_optional else: return continue_matching(overall_match) elif op is BRANCH and arg[0] is None: # NOTE: order matters - earlier branches are more greedily matched than later branches. branches = arg[1] first_path = list(branches[0]) + list(top_patterns)[1:] submatch = _internal_match_patterns(space, first_path, flags, smtstr, offset) # _patt_replace(top_patterns, pattern, branches[0]) if submatch is not None: return submatch if len(branches) <= 1: return None else: return _internal_match_patterns( space, _patt_replace(top_patterns, branches, branches[1:]), flags, smtstr, offset, ) elif op is AT: if arg in (AT_END, AT_END_STRING): if arg is AT_END and re.MULTILINE & flags: raise ReUnhandled("Multiline match with AT_END_STRING") return fork_on(matchstr == z3.StringVal(""), 0) elif op is SUBPATTERN: (groupnum, _a, _b, subpatterns) = arg if (_a, _b) != (0, 0): raise ReUnhandled("unsupported subpattern args") new_top = (list(subpatterns) + [(_END_GROUP_MARKER, (groupnum, offset))] + list(top_patterns)[1:]) return _internal_match_patterns(space, new_top, flags, smtstr, offset) elif op is _END_GROUP_MARKER: (group_num, begin) = arg match = continue_matching(_Match([(None, offset, offset)])) if match is None: return None while len(match._groups) <= group_num: match._groups.append(None) match._groups[group_num] = (None, begin, offset) return match raise ReUnhandled(op)
def distinguish2(self, r1, r2): global use_derivatives solver = z3.Solver() solver.set('random_seed', 7) solver.set('sat.random_seed', 7) if use_derivatives: try: solver.set('smt.seq.use_derivatives', True) solver.check() except: pass z3_r1 = self._toz3.eval(r1[0]) z3_r2 = self._toz3.eval(r2[0]) dist = z3.String("distinguishing") ro_1 = z3.Bool(f"ro_1") solver.add(ro_1 == z3.InRe(dist, z3_r1)) ro_2 = z3.Bool(f"ro_2") solver.add(ro_2 == z3.InRe(dist, z3_r2)) solver.add(ro_1 != ro_2) if solver.check() == z3.sat: if len(r1[2][0]) == 0 and len(r2[2][0]) == 0: dist_input = solver.model()[dist].as_string() if solver.model()[ro_1]: return dist_input, [r1], [r2], [] else: return dist_input, [r2], [r1], [] # Find dist_input that respects conditions r1_str = self._printer.eval(r1[0], captures=r1[2][1]) r1_conditions = list(map(lambda c: " ".join(map(str, c)), r1[2][0])) r2_str = self._printer.eval(r2[0], captures=r2[2][1]) r2_conditions = list(map(lambda c: " ".join(map(str, c)), r2[2][0])) while True: dist_input = solver.model()[dist].as_string() match = re.fullmatch(r1_str, dist_input) if match is not None and check_conditions( r1_conditions, match): break match = re.fullmatch(r2_str, dist_input) if match is not None and check_conditions( r2_conditions, match): break solver.add(dist != z3.StringVal(dist_input)) if not solver.check() == z3.sat: return None, None, None, None if solver.model()[ro_1]: return dist_input, [r1], [r2], [] else: return dist_input, [r2], [r1], [] else: return None, None, None, None
def fill_hybrid(self, sketch): domains = self.get_domains(sketch) z3_solver = z3.Solver() try: z3_solver.set('smt.seq.use_derivatives', True) z3_solver.check() except: pass start = time.time() m_vars = {} for values in itertools.product(*domains): if self.configuration.die: break concrete = deepcopy(sketch) holes = self.traverse_and_save_holes(concrete) assert len(values) == len(holes) for i, hole in enumerate(holes): hole.data = values[i] z3re = self.to_z3.eval(concrete) m = z3.Bool(_get_new_m()) m_vars[m] = concrete big_and = [] for v in self.valid: v_s = v[0] big_and.append(z3.InRe(v_s, z3re)) for i in self.invalid: i_s = i[0] big_and.append(z3.Not(z3.InRe(i_s, z3re))) z3_solver.add(m == z3.And(big_and)) z3_solver.add(z3.Or(*m_vars.keys())) z3_solver.set("timeout", 100) print("checking...") res = z3_solver.check() print("done.") if res == z3.sat: self.count_sat_calls += 1 self.time_sat_calls += (time.time() - start) ret_val = list(map(lambda k: m_vars[k], filter(lambda m: z3_solver.model()[m], m_vars.keys()))) assert len(ret_val) > 0 return ret_val elif res == z3.unsat: self.count_unsat_calls += 1 self.time_unsat_calls += (time.time() - start) return [] elif res == z3.unknown: print("SMT unknown.") stop_time = time.time() correct = self.fill_brute_force(sketch) if len(correct) > 0: self.time_sat_calls += (stop_time - start) self.count_smt_unknown_sat += 1 else: self.time_unsat_calls += (stop_time - start) self.count_smt_unknown_unsat += 1 return correct else: logger.error("Unknown Z3 response", res)