コード例 #1
0
ファイル: sketch_synthesizer.py プロジェクト: Marghrid/FOREST
    def fill_smt(self, sketch):
        global m_counter
        m_counter = 0

        domains = self.get_domains(sketch)

        z3_solver = z3.Solver()
        try:
            z3_solver.set('smt.seq.use_derivatives', True)
            z3_solver.check()
        except:
            pass

        start = time.time()

        m_vars = {}
        for values in itertools.product(*domains):
            if self.configuration.die:
                break
            concrete = deepcopy(sketch)
            holes = self.traverse_and_save_holes(concrete)
            assert len(values) == len(holes)
            for i, hole in enumerate(holes):
                hole.data = values[i]

            z3re = self.to_z3.eval(concrete)
            m = z3.Bool(_get_new_m())
            m_vars[m] = concrete

            big_and = []
            for v in self.valid:
                v_s = v[0]
                big_and.append(z3.InRe(v_s, z3re))
            for i in self.invalid:
                i_s = i[0]
                big_and.append(z3.Not(z3.InRe(i_s, z3re)))

            z3_solver.add(m == z3.And(big_and))

        z3_solver.add(z3.Or(*m_vars.keys()))

        time_encoding = (time.time() - start)

        start = time.time()
        res = z3_solver.check()
        if res == z3.sat:
            self.count_sat_calls += 1
            self.time_sat_calls += (time.time() - start)
            self.time_sat_encoding += time_encoding
            ret_val = list(map(lambda k: m_vars[k],
                               filter(lambda m: z3_solver.model()[m], m_vars.keys())))
            assert len(ret_val) > 0
            return ret_val
        else:
            self.count_unsat_calls += 1
            self.time_unsat_calls += (time.time() - start)
            self.time_unsat_encoding += time_encoding
            return []
コード例 #2
0
def string_match(x, args):
    if type(args[0]) == dict:
        # we know that they used some parts of the tainted data as regex
        val = createZ3ExpressionFromConstraint(args[0], {})
        # we know that its a direct flow and thus we can handle this with ease
        return z3.IndexOf(x, val, 0) > -1

    else:
        return z3.InRe(x, regex_to_z3(args[0]))
コード例 #3
0
    def multi_distinguish(self, regexes):
        start = time.time()
        # Problem: cannot distinguish more than 4 regexes at once: it takes forever.
        # Solution: use only 4 randomly selected regexes for the SMT maximization,
        # and then add the others to the solution.
        if len(regexes) <= 4:
            selected_regexes = regexes
            others = []
        else:
            random.seed('regex')
            random.shuffle(regexes)
            selected_regexes = regexes[:4]
            others = regexes[4:]
        solver = z3.Optimize()

        z3_regexes = []
        for regex in selected_regexes:
            z3_regex = self._toz3.eval(regex)
            z3_regexes.append(z3_regex)

        dist = z3.String("distinguishing")
        # solver.add(z3.Length(dist) <= 6)

        ro_z3 = []
        for i, z3_regex in enumerate(z3_regexes):
            ro = z3.Bool(f"ro_{i}")
            ro_z3.append(ro)
            solver.add(ro == z3.InRe(dist, z3_regex))

        # ro_z3[i] == true if dist matches regex[i].

        big_or = []
        for ro_i, ro_j in combinations(ro_z3, 2):
            big_or.append(z3.Xor(ro_i, ro_j))
            solver.add_soft(z3.Xor(ro_i, ro_j))
        solver.add(z3.Or(big_or))  # at least one regex is distinguished

        if solver.check() == z3.sat:
            # print(solver.model())
            print("took", round(time.time() - start, 2), "seconds")
            keep_if_valid = []
            keep_if_invalid = []
            dist_input = str(solver.model()[dist]).strip('"')
            for i, ro in enumerate(ro_z3):
                if solver.model()[ro]:
                    keep_if_valid.append(selected_regexes[i])
                else:
                    keep_if_invalid.append(selected_regexes[i])
                smallest_regex = min(selected_regexes,
                                     key=lambda r: len(self._printer.eval(r)))
            return dist_input, keep_if_valid, keep_if_invalid, others
        else:
            return None, None, None, None
コード例 #4
0
    def has_failed_examples(self, regex: Node):
        """
        Test whether the given program would fail on any of the examples provided.
        """
        if not self.use_smt:
            regex = self._interpreter.eval(regex)
            re_compiled = re.compile(regex)
            return any(
                map(lambda x: self._match(re_compiled, x.input) != x.output,
                    self._examples)
            )
        else:
            regex_z3 = self._to_z3.eval(regex)
            z3_solver = z3.Solver()
            big_and = []
            for x in self._examples:
                if x.output:
                    big_and.append(z3.InRe(x.input[0], regex_z3))
                else:
                    big_and.append(z3.Not(z3.InRe(x.input[0], regex_z3)))
            z3_solver.add(z3.And(big_and))

            return z3_solver.check() == z3.unsat
コード例 #5
0
def string_search(x, args):
    new_val = z3.String('__ignore_search_helper_' + randomString())
    regex = args[0]
    startsWith = False
    endsWith = False

    if regex[0] == '^':
        startsWith = True
        regex = regex[1:]
    if regex[-1] == '$':
        endsWith = True
        regex = regex[:-1]
    GLOBAL_CONSTRAINTS.append(z3.InRe(new_val, regex_to_z3(args[0])))

    if startsWith and endsWith:
        # we need to return the index which should be 0 iff it matches
        GLOBAL_CONSTRAINTS.append(x == new_val)
        return z3.IntVal(0)
    elif startsWith:
        GLOBAL_CONSTRAINTS.append(z3.PrefixOf(new_val, x))
    elif endsWith:
        GLOBAL_CONSTRAINTS.append(z3.SuffixOf(new_val, x))
    return z3.IndexOf(x, new_val, 0)
コード例 #6
0
def _internal_match_patterns(space: StateSpace, top_patterns: Any, flags: int,
                             smtstr: z3.ExprRef,
                             offset: int) -> Optional[_Match]:
    """
    >>> from crosshair.statespace import SimpleStateSpace
    >>> import sre_parse
    >>> smtstr = z3.String('smtstr')
    >>> space = SimpleStateSpace()
    >>> space.add(smtstr == z3.StringVal('aabb'))
    >>> _internal_match_patterns(space, sre_parse.parse('a+'), 0, smtstr, 0).span()
    (0, 2)
    >>> _internal_match_patterns(space, sre_parse.parse('ab'), 0, smtstr, 1).span()
    (1, 3)
    """
    matchstr = z3.SubString(smtstr, offset,
                            z3.Length(smtstr)) if offset > 0 else smtstr
    if len(top_patterns) == 0:
        return _Match([(None, offset, offset)])
    pattern = top_patterns[0]

    def continue_matching(prefix):
        suffix = _internal_match_patterns(space, top_patterns[1:], flags,
                                          smtstr, prefix.end())
        if suffix is None:
            return None
        return prefix._add_match(suffix)

    # TODO: using a typed internal function triggers __hash__es inside the typing module.
    # Seems like this casues nondeterminism due to a global LRU cache used by the typing module.
    def fork_on(expr, sz):
        if space.smt_fork(expr):
            return continue_matching(_Match([(None, offset, offset + sz)]))
        else:
            return None

    # Handle simple single-character expressions using z3's built-in capabilities.
    z3_re = single_char_regex(pattern, flags)
    if z3_re is not None:
        ch = z3.SubString(matchstr, 0, 1)
        return fork_on(z3.InRe(ch, z3_re), 1)

    (op, arg) = pattern
    if op is MAX_REPEAT:
        (min_repeat, max_repeat, subpattern) = arg
        if max_repeat < min_repeat:
            return None
        reps = 0
        overall_match = _Match([(None, offset, offset)])
        while reps < min_repeat:
            submatch = _internal_match_patterns(space, subpattern, flags,
                                                smtstr, overall_match.end())
            if submatch is None:
                return None
            overall_match = overall_match._add_match(submatch)
            reps += 1
        if max_repeat != MAXREPEAT and reps >= max_repeat:
            return continue_matching(overall_match)
        submatch = _internal_match_patterns(space, subpattern, flags, smtstr,
                                            overall_match.end())
        if submatch is None:
            return continue_matching(overall_match)
        # we matched; try to be greedy first, and fall back to `submatch` as the last consumed match
        greedy_remainder = _patt_replace(
            top_patterns,
            arg,
            (
                1,
                max_repeat if max_repeat == MAXREPEAT else max_repeat -
                (min_repeat + 1),
                subpattern,
            ),
        )
        greedy_match = _internal_match_patterns(space, greedy_remainder, flags,
                                                smtstr, submatch.end())
        if greedy_match is not None:
            return overall_match._add_match(submatch)._add_match(greedy_match)
        else:
            match_with_optional = continue_matching(
                overall_match._add_match(submatch))
            if match_with_optional is not None:
                return match_with_optional
            else:
                return continue_matching(overall_match)
    elif op is BRANCH and arg[0] is None:
        # NOTE: order matters - earlier branches are more greedily matched than later branches.
        branches = arg[1]
        first_path = list(branches[0]) + list(top_patterns)[1:]
        submatch = _internal_match_patterns(space, first_path, flags, smtstr,
                                            offset)
        # _patt_replace(top_patterns, pattern, branches[0])
        if submatch is not None:
            return submatch
        if len(branches) <= 1:
            return None
        else:
            return _internal_match_patterns(
                space,
                _patt_replace(top_patterns, branches, branches[1:]),
                flags,
                smtstr,
                offset,
            )
    elif op is AT:
        if arg in (AT_END, AT_END_STRING):
            if arg is AT_END and re.MULTILINE & flags:
                raise ReUnhandled("Multiline match with AT_END_STRING")
            return fork_on(matchstr == z3.StringVal(""), 0)
    elif op is SUBPATTERN:
        (groupnum, _a, _b, subpatterns) = arg
        if (_a, _b) != (0, 0):
            raise ReUnhandled("unsupported subpattern args")
        new_top = (list(subpatterns) + [(_END_GROUP_MARKER,
                                         (groupnum, offset))] +
                   list(top_patterns)[1:])
        return _internal_match_patterns(space, new_top, flags, smtstr, offset)
    elif op is _END_GROUP_MARKER:
        (group_num, begin) = arg
        match = continue_matching(_Match([(None, offset, offset)]))
        if match is None:
            return None
        while len(match._groups) <= group_num:
            match._groups.append(None)
        match._groups[group_num] = (None, begin, offset)
        return match
    raise ReUnhandled(op)
コード例 #7
0
    def distinguish2(self, r1, r2):
        global use_derivatives
        solver = z3.Solver()
        solver.set('random_seed', 7)
        solver.set('sat.random_seed', 7)
        if use_derivatives:
            try:
                solver.set('smt.seq.use_derivatives', True)
                solver.check()
            except:
                pass

        z3_r1 = self._toz3.eval(r1[0])
        z3_r2 = self._toz3.eval(r2[0])

        dist = z3.String("distinguishing")

        ro_1 = z3.Bool(f"ro_1")
        solver.add(ro_1 == z3.InRe(dist, z3_r1))
        ro_2 = z3.Bool(f"ro_2")
        solver.add(ro_2 == z3.InRe(dist, z3_r2))

        solver.add(ro_1 != ro_2)

        if solver.check() == z3.sat:
            if len(r1[2][0]) == 0 and len(r2[2][0]) == 0:
                dist_input = solver.model()[dist].as_string()
                if solver.model()[ro_1]:
                    return dist_input, [r1], [r2], []
                else:
                    return dist_input, [r2], [r1], []

            # Find dist_input that respects conditions
            r1_str = self._printer.eval(r1[0], captures=r1[2][1])
            r1_conditions = list(map(lambda c: " ".join(map(str, c)),
                                     r1[2][0]))
            r2_str = self._printer.eval(r2[0], captures=r2[2][1])
            r2_conditions = list(map(lambda c: " ".join(map(str, c)),
                                     r2[2][0]))

            while True:
                dist_input = solver.model()[dist].as_string()

                match = re.fullmatch(r1_str, dist_input)
                if match is not None and check_conditions(
                        r1_conditions, match):
                    break

                match = re.fullmatch(r2_str, dist_input)
                if match is not None and check_conditions(
                        r2_conditions, match):
                    break

                solver.add(dist != z3.StringVal(dist_input))
                if not solver.check() == z3.sat:
                    return None, None, None, None

            if solver.model()[ro_1]:
                return dist_input, [r1], [r2], []
            else:
                return dist_input, [r2], [r1], []
        else:
            return None, None, None, None
コード例 #8
0
ファイル: sketch_synthesizer.py プロジェクト: Marghrid/FOREST
    def fill_hybrid(self, sketch):
        domains = self.get_domains(sketch)

        z3_solver = z3.Solver()
        try:
            z3_solver.set('smt.seq.use_derivatives', True)
            z3_solver.check()
        except:
            pass

        start = time.time()

        m_vars = {}
        for values in itertools.product(*domains):
            if self.configuration.die:
                break
            concrete = deepcopy(sketch)
            holes = self.traverse_and_save_holes(concrete)
            assert len(values) == len(holes)
            for i, hole in enumerate(holes):
                hole.data = values[i]

            z3re = self.to_z3.eval(concrete)
            m = z3.Bool(_get_new_m())
            m_vars[m] = concrete

            big_and = []
            for v in self.valid:
                v_s = v[0]
                big_and.append(z3.InRe(v_s, z3re))
            for i in self.invalid:
                i_s = i[0]
                big_and.append(z3.Not(z3.InRe(i_s, z3re)))

            z3_solver.add(m == z3.And(big_and))

        z3_solver.add(z3.Or(*m_vars.keys()))

        z3_solver.set("timeout", 100)
        print("checking...")
        res = z3_solver.check()
        print("done.")

        if res == z3.sat:
            self.count_sat_calls += 1
            self.time_sat_calls += (time.time() - start)
            ret_val = list(map(lambda k: m_vars[k],
                               filter(lambda m: z3_solver.model()[m], m_vars.keys())))
            assert len(ret_val) > 0
            return ret_val
        elif res == z3.unsat:
            self.count_unsat_calls += 1
            self.time_unsat_calls += (time.time() - start)
            return []
        elif res == z3.unknown:
            print("SMT unknown.")
            stop_time = time.time()
            correct = self.fill_brute_force(sketch)
            if len(correct) > 0:
                self.time_sat_calls += (stop_time - start)
                self.count_smt_unknown_sat += 1
            else:
                self.time_unsat_calls += (stop_time - start)
                self.count_smt_unknown_unsat += 1

            return correct
        else:
            logger.error("Unknown Z3 response", res)