def case_char(name: str, map_: Dict[str, str], ctx: z3.Context) \ -> z3.FuncDeclRef: char = z3.FreshConst(z3.StringSort(ctx)) const = char for x, y in map_: const = z3.If(char == z3.StringVal(x, ctx), \ z3.StringVal(y, ctx), const, ctx) f = z3.RecFunction(f'{name}_char', z3.StringSort(ctx), z3.StringSort(ctx)) z3.RecAddDefinition(f, char, const) return f
def checkForTypeEqualToString(var, other_var, op): res = None if var.decl().kind() == z3.Z3_OP_UNINTERPRETED and str(var.decl()).startswith('type:') and \ op['val'] == 'string' and op['op'] in ['==', '===']: res = z3.Or(var == other_var, var == z3.StringVal('JSON')) elif other_var.decl().kind() == z3.Z3_OP_UNINTERPRETED and str(other_var.decl()).startswith('type:') and \ op['val'] == 'string' and op['op'] in ['==', '===']: res = z3.Or(var == other_var, other_var == z3.StringVal('JSON')) elif var.decl().kind() == z3.Z3_OP_UNINTERPRETED and str(var.decl()).startswith('type:') and \ op['val'] == 'string' and op['op'] in ['!=', '!==']: res = z3.And(var == other_var, var == z3.StringVal('JSON')) elif other_var.decl().kind() == z3.Z3_OP_UNINTERPRETED and str(other_var.decl()).startswith('type:') and \ op['val'] == 'string' and op['op'] in ['!=', '!==']: res = z3.And(var == other_var, other_var == z3.StringVal('JSON')) return res
def values_to_smt(prefix: str, state_or_command: Union[State, 'Command'], declarations: Dict[str, Any]) -> List[z3.ExprRef]: """ Creates a Z3 equality expression for all variables in values dictionary to assert their values and returns a list of Z3 expressions. Parameters: prefix: the prefix that is added to this set of variables in the declarations. values: a state or command which contains the variables or parameters we want to add to query. declarations: a dictionary of all variables (e.g. command parameters and state variables) name as key and Z3 variable as value. Returns: A list of Z3 expressions that assert equality of variables and their values. """ smt = [] logger.debug("converting values to SMT: %s", state_or_command) for param_or_variable in state_or_command: logger.debug("creating SMT assertion for var: %s", param_or_variable.name) name = param_or_variable.name val = state_or_command[name] d = declarations['{}{}'.format(prefix, name)] if isinstance(val, str): smt.append(d == z3.StringVal(val, ctx=d.ctx)) else: smt.append(d == val) logger.debug("converted values to SMT: %s", smt) return smt
def z3_val(val, ctx): if isinstance(val, int): return z3.IntVal(val, ctx) elif isinstance(val, str): return z3.StringVal(val, ctx) else: raise ValueError(f'z3_val: unsupported typ ({type(val)})')
def eval_regex(re_string, flags, test_string, offset, endpos=None): py_patt = re.compile(re_string, flags) with standalone_statespace as space: with NoTracing(): s = SymbolicStr("symstr" + space.uniq()) space.add(s.var == z3.StringVal(test_string)) return deep_realize( _match_pattern(py_patt, re_string, s, offset, endpos))
def adapt(self, condition): value = condition['value'] if isinstance(value, basestring): return z3.StringVal(value) elif isinstance(value, int): return value else: raise Exception('got unknown rawvalue')
def get_z3_val(valtype, value, name, datatype_name=None, ctx=None): val = None if isinstance(valtype, z3.z3.DatatypeSortRef): # discrete values datatype val = getattr(valtype, value) elif valtype is Types.INT: try: val = z3.BitVecVal(value, 32, ctx=ctx) except Exception as exc: raise ValueError( f"Error during INT conversion. Cannot convert value: {value}, type: {type(value)}, name: {name}" ) elif valtype is Types.INTEGER: try: val = z3.IntVal(value, ctx=ctx) except Exception as exc: raise ValueError( f"Error during INTEGER conversion. Cannot convert value: {value}, type: {type(value)}, name: {name}" ) elif valtype is Types.FLOAT: try: val = z3.FPVal(value, z3.Float32(), ctx=ctx) except Exception as exc: raise ValueError( f"Error during FLOAT conversion. Cannot convert value: {value}, type: {type(value)}, name: {name}" ) elif valtype is Types.REAL: try: val = z3.RealVal(value, ctx=ctx) except Exception as exc: raise ValueError( f"Error during REAL conversion. Cannot convert value: {value}, type: {type(value)}, name: {name}" ) elif valtype is Types.BOOL: try: val = z3.BoolVal(value, ctx=ctx) except Exception as exc: raise ValueError( f"Error during BOOL conversion of value to INT. value: {value}, type: {type(value)}, name: {name}" ) elif valtype is Types.STRING: try: val = z3.StringVal(value, ctx=ctx) except Exception as exc: raise ValueError( f"Error during STRING conversion of value to INT. value: {value}, type: {type(value)}, name: {name}" ) elif isinstance(valtype, list): datatype = _get_datatype_from_list(valtype, datatype_name) val = getattr(datatype, value) valtype = datatype else: raise ValueError( f"I do not know how to create a z3-value for type {valtype}") assert val is not None, f"Value wasn't converted: valtype: {valtype}, value: {value}, name: {name}" val.type = valtype return val
def __k_bool__(self): if z3.is_int(self.value) or z3.is_real(self.value) or z3.is_bv( self.value): yield inference.InferenceResult(Z3Proxy.init_expr( self.value != 0, self.defaults), status=True) elif z3.is_string(self.value): yield inference.InferenceResult(Z3Proxy.init_expr( self.value != z3.StringVal(""), self.defaults), status=True) else: yield inference.InferenceResult(self, status=True)
def numeral_to_z3(num): # TODO: allow other numeric types z3sort = lookup_native(num.sort, sorts, "sort") if z3sort == None: return z3.Const(num.name, num.sort.to_z3()) # uninterpreted sort try: name = num.name[1:-1] if num.name.startswith('"') else num.name if isinstance(z3sort, z3.SeqSortRef) and z3sort.is_string(): return z3.StringVal(name) return z3sort.cast(str(int(name, 0))) # allow 0x,0b, etc except: raise iu.IvyError( None, 'Cannot cast "{}" to native sort {}'.format(num, z3sort))
def string_split(x, args): st = x split_val = z3.StringVal(args[0].encode()) x = transformNonBooleanLazyEvaluations(x) arr = z3.Array('__ignore_{}.split({})'.format(str(x), str(args[0])), z3.IntSort(), z3.StringSort()) for i in range(3): index = z3.IndexOf(st, split_val, 0) s = z3.SubString(st, 0, index) st = z3.SubString(st, index + z3.Length(split_val), z3.Length(st)) GLOBAL_CONSTRAINTS.append(z3.Select(arr, i) == s) GLOBAL_CONSTRAINTS.append(s != z3.StringVal('')) GLOBAL_ARRAY_HANDLER[arr].append(s) GLOBAL_CONSTRAINTS.append(z3.Select(arr, 3) == st) GLOBAL_CONSTRAINTS.append(st != z3.StringVal('')) GLOBAL_ARRAY_HANDLER[arr].append(st) # We just guess the length here and hope that this works for the program ARRAY_LENGTHS[str(arr.decl())] = 4 GLOBAL_CONSTRAINTS.append(z3.IndexOf(GLOBAL_ARRAY_HANDLER[arr][-1], split_val, 0) == -1) # GLOBAL_CONSTRAINTS.append(z3.PrefixOf(GLOBAL_ARRAY_HANDLER[arr][0], x)) return arr
def to_case(name: str, name_char: Callable[[z3.Context], z3.FuncDeclRef], \ text: z3.SeqRef) -> z3.SeqRef: ctx = text.ctx empty = z3.StringVal("", ctx) f = z3.RecFunction(name, z3.StringSort(ctx), z3.StringSort(ctx)) z3.RecAddDefinition(f, text, \ z3.If(text == empty, \ empty, z3.Concat(name_char(ctx)(head(text)), \ f(tail(text))))) return f(text)
def binary_instanceof(left, right): if z3.is_string_value(right) and right.as_string() == '': right = None if left.decl().kind() == z3.Z3_OP_UNINTERPRETED: typ_val = z3.String('type:' + str(left.decl())) if right is None: return typ_val == z3.StringVal('undefined') else: raise Exception( 'we have not yet seen', right, 'as type') else: raise Exception( 'We probably need to introduce intermediary variables here and assert that their type is something specific')
def _to_smt_constant_expression(expr_object, smt_context_object): val_obj = expr_object.value_object constant_type = val_obj.value_type if (constant_type.type_code == exprtypes.TypeCodes.boolean_type): return z3.BoolVal(val_obj.value_object, smt_context_object.ctx()) elif (constant_type.type_code == exprtypes.TypeCodes.integer_type): return z3.IntVal(val_obj.value_object, smt_context_object.ctx()) elif (constant_type.type_code == exprtypes.TypeCodes.bit_vector_type): int_value = val_obj.value_object.value return z3.BitVecVal(int_value, constant_type.size, smt_context_object.ctx()) elif (constant_type.type_code == exprtypes.TypeCodes.string_type): return z3.StringVal(val_obj.value_object, smt_context_object.ctx()) else: raise basetypes.UnhandledCaseError('Odd type code: %s' % constant_type.type_code)
def glean_unknown_symbol(sym): sym = str(sym) # sys.stderr.write("trying to glean unknown symbol: \"%s\"\n" % (sym)) if int_pattern.match(sym): num = int(sym) if num == 0: return z3.BoolVal(False) else: return z3.BoolVal(True) elif hex_pattern.match(sym): return z3.Bool("\"%s\"" % (sym)) # elif identifier_pattern.match(sym): # return z3.Bool("CONFIG_%s" % (sym)) else: return z3.StringVal(sym)
def createZ3ForBool(var): if z3.is_int(var): return var != z3.IntVal(0) elif z3.is_string(var): return var != z3.StringVal('') elif z3.is_array(var): return z3.BoolVal(True) elif z3.is_bool(var): return var elif var is None: # this should be the case when we have a JSON value that is just inside a conditional etc return None elif z3.is_seq(var): # not string but still something ref-like we only found cases where this was string comparisons using <, >, etc. return var else: raise Exception('unhandled type in uninterpreted if')
def getZ3ValFromJSVal(val): if type(val) == str: return z3.StringVal(val) if type(val) == bool: return z3.BoolVal(val) if type(val) == int: return z3.IntVal(val) if type(val) == int: return z3.IntVal(val) if type(val) == list: arr = z3.Array('ignore_helper_constant_array_' + randomString(), z3.IntSort(), z3.StringSort()) for i, arg in enumerate(val): GLOBAL_CONSTRAINTS.append(z3.Select(arr, i) == createZ3ExpressionFromConstraint(arg, {})) ARRAY_LENGTHS[str(arr.decl())] = len(val) return arr if type(val) == dict: raise NotSupportedException('Complex Objects as base for operations with proxy strings are not yet supported!') raise Exception('Could not transform Js val to Z3 Val' + repr(val))
def SReplace(text: SStr, old: SStr, new: SStr) -> z3.SeqRef: ctx = text.ctx text, old, new = text.expr, old.expr, new.expr empty = z3.StringVal("", ctx) replace_all = z3.RecFunction('replace_all', \ z3.StringSort(ctx), z3.StringSort(ctx), z3.StringSort(ctx), # from z3.StringSort(ctx)) # to z3.RecAddDefinition(replace_all, [text, old, new], \ z3.If(z3.Or(old == empty, \ z3.Not(z3.Contains(text, old)), ctx),\ text, replace_all(z3.Replace(text, old, new), \ old, new))) return replace_all(text, old, new)
def coerceTypesIfPossible(var, other_var): if z3.is_or(other_var) and not z3.is_bool(var): other_var = transformNonBooleanLazyEvaluations(other_var) if z3.is_or(var) and not z3.is_bool(other_var): var = transformNonBooleanLazyEvaluations(var) if z3.is_and(other_var) and not z3.is_bool(var): other_var = transformNonBooleanLazyEvaluations(other_var) if z3.is_and(var) and not z3.is_bool(other_var): var = transformNonBooleanLazyEvaluations(var) if var.decl().kind() == z3.Z3_OP_UNINTERPRETED: if z3.is_bool(other_var) and not z3.is_bool(var): infered_types[str(var)] = 'boolean' return z3.Bool(str(var)), other_var if z3.is_string(other_var) and not z3.is_string(var): if other_var.as_string() == '': # we probably dont want to coerce in this specific case as this is merely a non empty check if z3.is_bool(var): return var, z3.BoolVal(False) if z3.is_int(var): return var, z3.IntVal(0) else: infered_types[str(var)] = 'string' return z3.String(str(var)), other_var if z3.is_int(other_var) and not z3.is_int(var): infered_types[str(var)] = 'number' return z3.Int(str(var)), other_var elif var.decl().kind() == z3.Z3_OP_UNINTERPRETED: if z3.is_bool(var): infered_types[str(var)] = 'boolean' if z3.is_string(var): infered_types[str(var)] = 'string' if z3.is_int(var): infered_types[str(var)] = 'number' else: # this means that it is non-interpreted and we need to coerce other var to the type of var if z3.is_string(var) and z3.is_int_value(other_var): other_var = z3.StringVal(str(other_var)) if z3.is_arith(var) and z3.is_string(other_var): other_var = z3.IntVal(int(other_var.as_string())) return var, other_var
def s_add_transition_to(self, w): """ @pre len(w)>0 """ i = self.gennum() pre = Sequence("pre%d" % i) a = Sequence("a%d" % i) self.s.add(z3.Length(a) == 1) self.s.add(z3.Concat(pre, a) == self.stackvars[w[:-1]]) x = self.d(a, z3.StringVal(w[-1:])) self.s.add(z3.Length(x) <= self.limitL) for i in range(self.limitL): self.s.add( z3.Implies( z3.Length(x) > i, z3.And(x[i] < z3.IntVal(self.limitS), x[i] >= 0))) self.s.add(z3.Concat(pre, x) == self.stackvars[w])
def reverse(s: z3.SeqRef) -> z3.SeqRef: ctx = s.ctx empty = z3.StringVal("", ctx) acc = z3.FreshConst(z3.StringSort(ctx), 'acc') tail_rev = z3.RecFunction('reverse', \ z3.StringSort(ctx), z3.StringSort(ctx), \ z3.StringSort(ctx)) z3.RecAddDefinition(tail_rev, [s, acc], \ z3.If(s == empty, \ acc, tail_rev(tail(s), z3.Concat(head(s), acc)))) rev = z3.RecFunction('reverse', z3.StringSort(ctx), \ z3.StringSort(ctx)) z3.RecAddDefinition(rev, s, tail_rev(s, empty)) return rev(s)
def distinguish2(self, r1, r2): global use_derivatives solver = z3.Solver() solver.set('random_seed', 7) solver.set('sat.random_seed', 7) if use_derivatives: try: solver.set('smt.seq.use_derivatives', True) solver.check() except: pass z3_r1 = self._toz3.eval(r1[0]) z3_r2 = self._toz3.eval(r2[0]) dist = z3.String("distinguishing") ro_1 = z3.Bool(f"ro_1") solver.add(ro_1 == z3.InRe(dist, z3_r1)) ro_2 = z3.Bool(f"ro_2") solver.add(ro_2 == z3.InRe(dist, z3_r2)) solver.add(ro_1 != ro_2) if solver.check() == z3.sat: if len(r1[2][0]) == 0 and len(r2[2][0]) == 0: dist_input = solver.model()[dist].as_string() if solver.model()[ro_1]: return dist_input, [r1], [r2], [] else: return dist_input, [r2], [r1], [] # Find dist_input that respects conditions r1_str = self._printer.eval(r1[0], captures=r1[2][1]) r1_conditions = list(map(lambda c: " ".join(map(str, c)), r1[2][0])) r2_str = self._printer.eval(r2[0], captures=r2[2][1]) r2_conditions = list(map(lambda c: " ".join(map(str, c)), r2[2][0])) while True: dist_input = solver.model()[dist].as_string() match = re.fullmatch(r1_str, dist_input) if match is not None and check_conditions( r1_conditions, match): break match = re.fullmatch(r2_str, dist_input) if match is not None and check_conditions( r2_conditions, match): break solver.add(dist != z3.StringVal(dist_input)) if not solver.check() == z3.sat: return None, None, None, None if solver.model()[ro_1]: return dist_input, [r1], [r2], [] else: return dist_input, [r2], [r1], [] else: return None, None, None, None
def s_isFinalConfiguration(self, c): z3val = z3.StringVal(c) return z3.And(z3.Length(z3val) <= 1, self.Qf(z3val))
def construct_from_z3_model(self, m, d, Qf, alphabet): to_check = [0] checked = set(to_check) print("Extracting tables") self.D = dict() self.QF = set() self.productive = None print("m[d] = %s" % m[d]) print("m[qf] = %s" % m[Qf]) symbols = set([0]) while len(to_check): current = to_check.pop() conf = z3.Unit(z3.IntVal(current)) for a in alphabet: # range(0, 256): y = m.evaluate( d( # z3.SubSeq(conf, z3.Length(conf)-1, 1), conf, z3.StringVal(bytes([a]))), model_completion=True) def extract_seq_as_list(y): result = List() for c in y.children(): if isinstance(c, z3.SeqRef): result += extract_seq_as_list(c) else: result += List([c.as_long()]) return result rhs = extract_seq_as_list(y) for symbol in rhs: symbols.add(symbol) Dq = self.D.setdefault(current, dict()) Dq[a] = rhs for i in rhs: if not i in checked: checked.add(i) to_check += [i] if m.evaluate(Qf(z3.Empty(z3.SeqSort(z3.IntSort())))): self.QF.add(List([])) print("(stack/q) symbols encountered: %s" % symbols) for symbol in symbols: conf = z3.Unit(z3.IntVal(symbol)) f = m.evaluate(Qf(conf)) if f: self.QF.add(List([symbol])) self.symbols = symbols
else: return (result, None) is_https = z3.Bool('is_https') is_http = z3.Bool('is_http') proto = z3.String('proto') proto_delimiter = z3.String('proto_delimiter') domain = z3.String('domain') fqdn = z3.String('fqdn') root_domain = z3.String('root_domain') base_url = proto + '://' + fqdn unknown_string = z3.String('solution') z3_byte_to_ord = {z3.StringVal(bytes([ii])): ii for ii in range(256)} def z3_str_to_bytes(zs: z3.StringVal): ll = z3.simplify(z3.Length(zs)).as_long() return bytes([z3_byte_to_ord[z3.simplify(zs[ii])] for ii in range(ll)]) def z3_str_to_str(zs: z3.StringVal): return str(z3_str_to_bytes(zs), encoding='utf8') def public_vars(model): return {k:model[k] for k in model.decls() if not k.name().startswith('_')} WILDCARD_MARKER = '<' OUTPUT_WILDCARD_MARKER = '☠' # skull emoji
def wildcard_trace(solver, symbols: Dict[int, List[str]], use_priming=True) -> Tuple[z3.CheckSatResult, Dict]: """Return the result of the attack (sat means attack was successful) and associated data. If the attack was successful, associated data includes what attack was executed, witness information, and the solution found. The dictionary keys are "strategy", "solution", "witness". Otherwise, associated data includes the attack that was executed, plus debug info under the key "debug_info". """ if use_priming: # prime the solver with an easy question. # In `test_optional_dollar` it seems to give several folds speedup, if you used the same # solver for `test_dot` previously. prime_result, prime_model = check_formulae(solver, z3.Not(RegexStringExpr.ignore_wildcards)) logger.info('check %s', prime_result) if prime_result == z3.sat: logger.debug(public_vars(prime_model)) else: solver.reset() return prime_result, None base = proto + proto_delimiter + fqdn solver.add( z3.Or(proto_delimiter == z3.StringVal('//'), proto_delimiter == z3.StringVal(WILDCARD_MARKER * 2)), z3.Xor(z3.PrefixOf(base + '/', unknown_string), base == unknown_string), z3.Not(z3.Contains(proto, '/')), z3.Not(z3.Contains(fqdn, '/')), z3.Length(proto) > 0, z3.Length(fqdn) > 0, ) if DEBUG: #debug_result, debug_model = check_formulae(solver, unknown_string == debug_model[unknown_string]) #logger.debug(debug_model) debug_result = solver.check() logger.debug(debug_result) if debug_result == z3.sat: debug_model = solver.model() logger.debug(public_vars(debug_model)) ans = z3.simplify(debug_model[proto] + debug_model[proto_delimiter] + debug_model[fqdn]) else: return debug_result, None #debug_result, debug_model = check_formulae(solver, (base_url != ans)) result, model = check_formulae(solver, z3.Not(RegexStringExpr.ignore_wildcards), z3.Contains(proto + fqdn, z3.StringVal(WILDCARD_MARKER))) if result == z3.sat: _conc1 = lambda zs: tz.first(concretizations(z3_str_to_bytes(zs), symbols)) logger.info(public_vars(model)) ans = z3.simplify(model[proto] + model[proto_delimiter] + model[fqdn]) return result, { 'solution': _conc1(model[unknown_string]).replace(WILDCARD_MARKER, OUTPUT_WILDCARD_MARKER), 'strategy': 'wildcard_trace', 'witness': _conc1(ans).replace(WILDCARD_MARKER, OUTPUT_WILDCARD_MARKER)} else: return result, {'strategy': 'wildcard_trace', 'debug_info': None}
def SNewLine(ctx: z3.Context) -> z3.SeqRef: return z3.StringVal('\n', ctx=ctx)
def is_whitespace(c: z3.SeqRef) -> z3.BoolRef: ctx = c.ctx space = z3.StringVal(" ", ctx) newline = z3.StringVal("\n", ctx) tab = z3.StringVal("\t", ctx) return z3.Or(c == space, c == newline, c == tab, ctx)
def generate_constraints(data): I = len(data[0]) # number of inflections count = 0 cost_constraint = 0 column_cost = [0]*I length_c = 0 constraints = [] for example in data: suffixes = [z3.String('suf' + chr(ord('A') + i)) for i in range(I) ] prefixes = [z3.String('pre' + chr(ord('A') + i)) for i in range(I) ] # preA = z3.String('preA') # preB = z3.String('preB') # sufA = z3.String('sufA') # sufB = z3.String('sufB') stem = z3.String('stem' + str(count)) # 1 is associated with the prefix unch1 = [z3.String('unch1' + str(count) + chr(ord('A') + i)) for i in range(I) ] # 2 is associated with the suffix unch2 = [z3.String('unch2' + str(count) + chr(ord('A') + i)) for i in range(I) ] # unchA1 = z3.String('unch' + str(count) + 'A') # unchA2 = z3.String('unch' + str(count) + 'B') # unchB1 = z3.String('unch' + str(count) + 'C') # unchB2 = z3.String('unch' + str(count) + 'D') ch = [z3.String('ch' + str(count) + chr(ord('A') + i)) for i in range(I) ] var = [z3.String('var' + str(count) + chr(ord('A') + i)) for i in range(I) ] # varA = z3.String('var' + str(count) + 'A') # varB = z3.String('var' + str(count) + 'B') # scA = z3.Int('sc'+str(count)+'A') # scB = z3.Int('sc'+str(count)+'B') sc = [z3.Int('sc' + str(count) + chr(ord('A') + i)) for i in range(I) ] lc = z3.Int('l'+str(count)) for v in var: constraints.append(z3.Length(v) <= 1) # constraints.append(z3.Length(varA) <= 1) # constraints.append(z3.Length(varB) <= 1) for i in range(I): constraints.append(z3.Concat(prefixes[i],stem,suffixes[i]) == z3.Concat(unch1[i],ch[i],unch2[i])) # constraints.append(z3.Concat(preA,stem,sufA) == z3.Concat(unchA1,chA,unchA2)) # constraints.append(z3.Concat(preB,stem,sufB) == z3.Concat(unchB1,chB,unchB2)) for i in range(I): if len(example[i]) == 0: continue constraints.append(z3.StringVal(convert_ipa(example[i],m)) == z3.Concat(unch1[i],var[i],unch2[i])) # constraints.append(z3.StringVal(convert_ipa(example[0],m)) == z3.Concat(unchA1,varA,unchA2)) # constraints.append(z3.StringVal(convert_ipa(example[1],m)) == z3.Concat(unchB1,varB,unchB2)) constraints.append(z3.Length(stem) == lc) for i in range(I): constraints.append(z3.If(ch[i] == var[i],0,1) == sc[i]) #constraints.append(z3.If(chB == varB,0,1) == scB) length_c = length_c + lc for i in range(I): constraints.append(sc[i] <= 1) cost_constraint = cost_constraint + sum(sc) for i in range(I): column_cost[i] = column_cost[i] + sc[i] count += 1 return constraints, cost_constraint, column_cost
def eval_regex(re_string, flags, test_string, offset, endpos=None): py_patt = re.compile(re_string, flags) space = context_statespace() s = SymbolicStr("symstr" + space.uniq()) space.add(s.var == z3.StringVal(test_string)) return _match_pattern(py_patt, re_string, s, offset, endpos)
def _internal_match_patterns(space: StateSpace, top_patterns: Any, flags: int, smtstr: z3.ExprRef, offset: int) -> Optional[_Match]: """ >>> from crosshair.statespace import SimpleStateSpace >>> import sre_parse >>> smtstr = z3.String('smtstr') >>> space = SimpleStateSpace() >>> space.add(smtstr == z3.StringVal('aabb')) >>> _internal_match_patterns(space, sre_parse.parse('a+'), 0, smtstr, 0).span() (0, 2) >>> _internal_match_patterns(space, sre_parse.parse('ab'), 0, smtstr, 1).span() (1, 3) """ matchstr = z3.SubString(smtstr, offset, z3.Length(smtstr)) if offset > 0 else smtstr if len(top_patterns) == 0: return _Match([(None, offset, offset)]) pattern = top_patterns[0] def continue_matching(prefix): suffix = _internal_match_patterns(space, top_patterns[1:], flags, smtstr, prefix.end()) if suffix is None: return None return prefix._add_match(suffix) # TODO: using a typed internal function triggers __hash__es inside the typing module. # Seems like this casues nondeterminism due to a global LRU cache used by the typing module. def fork_on(expr, sz): if space.smt_fork(expr): return continue_matching(_Match([(None, offset, offset + sz)])) else: return None # Handle simple single-character expressions using z3's built-in capabilities. z3_re = single_char_regex(pattern, flags) if z3_re is not None: ch = z3.SubString(matchstr, 0, 1) return fork_on(z3.InRe(ch, z3_re), 1) (op, arg) = pattern if op is MAX_REPEAT: (min_repeat, max_repeat, subpattern) = arg if max_repeat < min_repeat: return None reps = 0 overall_match = _Match([(None, offset, offset)]) while reps < min_repeat: submatch = _internal_match_patterns(space, subpattern, flags, smtstr, overall_match.end()) if submatch is None: return None overall_match = overall_match._add_match(submatch) reps += 1 if max_repeat != MAXREPEAT and reps >= max_repeat: return continue_matching(overall_match) submatch = _internal_match_patterns(space, subpattern, flags, smtstr, overall_match.end()) if submatch is None: return continue_matching(overall_match) # we matched; try to be greedy first, and fall back to `submatch` as the last consumed match greedy_remainder = _patt_replace( top_patterns, arg, ( 1, max_repeat if max_repeat == MAXREPEAT else max_repeat - (min_repeat + 1), subpattern, ), ) greedy_match = _internal_match_patterns(space, greedy_remainder, flags, smtstr, submatch.end()) if greedy_match is not None: return overall_match._add_match(submatch)._add_match(greedy_match) else: match_with_optional = continue_matching( overall_match._add_match(submatch)) if match_with_optional is not None: return match_with_optional else: return continue_matching(overall_match) elif op is BRANCH and arg[0] is None: # NOTE: order matters - earlier branches are more greedily matched than later branches. branches = arg[1] first_path = list(branches[0]) + list(top_patterns)[1:] submatch = _internal_match_patterns(space, first_path, flags, smtstr, offset) # _patt_replace(top_patterns, pattern, branches[0]) if submatch is not None: return submatch if len(branches) <= 1: return None else: return _internal_match_patterns( space, _patt_replace(top_patterns, branches, branches[1:]), flags, smtstr, offset, ) elif op is AT: if arg in (AT_END, AT_END_STRING): if arg is AT_END and re.MULTILINE & flags: raise ReUnhandled("Multiline match with AT_END_STRING") return fork_on(matchstr == z3.StringVal(""), 0) elif op is SUBPATTERN: (groupnum, _a, _b, subpatterns) = arg if (_a, _b) != (0, 0): raise ReUnhandled("unsupported subpattern args") new_top = (list(subpatterns) + [(_END_GROUP_MARKER, (groupnum, offset))] + list(top_patterns)[1:]) return _internal_match_patterns(space, new_top, flags, smtstr, offset) elif op is _END_GROUP_MARKER: (group_num, begin) = arg match = continue_matching(_Match([(None, offset, offset)])) if match is None: return None while len(match._groups) <= group_num: match._groups.append(None) match._groups[group_num] = (None, begin, offset) return match raise ReUnhandled(op)