예제 #1
0
    def __init__(self, valid_examples, invalid_examples, captured,
                 condition_invalid, dsl: TyrellSpec, ground_truth: str,
                 configuration: Configuration):

        self.max_before_distinguishing = 2  # 2 for conversational clarification
        self.valid = valid_examples
        self.invalid = invalid_examples
        self.captured = captured
        self.condition_invalid = condition_invalid
        self.dsl = dsl
        self.configuration = configuration
        if ground_truth is None:
            self.ground_truth_conditions = None
            self.ground_truth_regex = None
        else:
            ground_truth = ground_truth.split(',')
            ground_truth_conditions = list(
                filter(lambda s: s.lstrip().startswith("$"), ground_truth))
            self.ground_truth_conditions = list(
                map(lambda s: s.lstrip(), ground_truth_conditions))
            self.ground_truth_regex = ",".join(
                filter(lambda s: not s.lstrip().startswith("$"), ground_truth))

        if not configuration.pruning:
            logger.warning('Synthesizing without pruning the search space.')
        # If auto-interaction is enabled, the ground truth must be a valid regex.
        if self.configuration.self_interact:
            assert len(self.ground_truth_regex) > 0 and is_regex(
                self.ground_truth_regex)

        # Initialize components
        self._printer = RegexInterpreter()  # Works like to_string
        self._distinguisher = RegexDistinguisher()
        self._decider = RegexDecider(interpreter=RegexInterpreter(),
                                     valid_examples=self.valid +
                                     self.condition_invalid,
                                     invalid_examples=self.invalid)

        # Capturer works like a synthesizer of capturing groups
        self._capturer = Capturer(self.valid, self.captured,
                                  self.condition_invalid,
                                  self.ground_truth_regex,
                                  self.ground_truth_conditions,
                                  self.configuration)
        self._node_counter = NodeCounter()

        # Subclass decides which enumerator to use
        self._enumerator = None

        # To store synthesized regexes and captures:
        self.solutions = []
        self.first_regex = None

        # counters and timers:
        self.indistinguishable = 0
        # Number of indistinguishable programs after which the synthesizer returns.
        self.max_indistinguishable = 3
        self.start_time = None
        self.last_print_time = time.time()
예제 #2
0
파일: capturer.py 프로젝트: Marghrid/FOREST
 def __init__(self, valid: List[List[str]], captures: List[List[Optional[str]]],
              condition_invalid: List[List[str]], ground_truth_regex: str,
              ground_truth_conditions: List[str], configuration):
     self.valid = valid
     self.captures = captures
     self.condition_invalid = condition_invalid
     self.ground_truth_regex = ground_truth_regex
     self.ground_truth_conditions = ground_truth_conditions
     self.configuration = configuration
     self.interpreter = RegexInterpreter()
     self.max_before_distinguish = 2  # 2 for conversational clarification
예제 #3
0
파일: forest.py 프로젝트: Marghrid/FOREST
def synthesize(type_validation):
    global synthesizer
    assert synthesizer is not None
    printer = RegexInterpreter()
    program = synthesizer.synthesize()
    if program is not None:
        regex, capturing_groups, capture_conditions = program
        conditions, condition_captures = capture_conditions
        solution_str = printer.eval(regex, captures=condition_captures)
        if len(conditions) > 0:
            solution_str += ', ' + conditions_to_str(conditions)
        print(f'\nSolution:\n  {solution_str}')
        if len(capturing_groups) > 0:
            print(
                f'Captures:\n  {printer.eval(regex, captures=capturing_groups)}'
            )
    else:
        print('Solution not found!')
    return program
예제 #4
0
class MultipleSynthesizer(ABC):
    """ Interactive synthesizer. Finds more than one program consistent with the
    examples. """
    def __init__(self, valid_examples, invalid_examples, captured,
                 condition_invalid, dsl: TyrellSpec, ground_truth: str,
                 configuration: Configuration):

        self.max_before_distinguishing = 2  # 2 for conversational clarification
        self.valid = valid_examples
        self.invalid = invalid_examples
        self.captured = captured
        self.condition_invalid = condition_invalid
        self.dsl = dsl
        self.configuration = configuration
        if ground_truth is None:
            self.ground_truth_conditions = None
            self.ground_truth_regex = None
        else:
            ground_truth = ground_truth.split(',')
            ground_truth_conditions = list(
                filter(lambda s: s.lstrip().startswith("$"), ground_truth))
            self.ground_truth_conditions = list(
                map(lambda s: s.lstrip(), ground_truth_conditions))
            self.ground_truth_regex = ",".join(
                filter(lambda s: not s.lstrip().startswith("$"), ground_truth))

        if not configuration.pruning:
            logger.warning('Synthesizing without pruning the search space.')
        # If auto-interaction is enabled, the ground truth must be a valid regex.
        if self.configuration.self_interact:
            assert len(self.ground_truth_regex) > 0 and is_regex(
                self.ground_truth_regex)

        # Initialize components
        self._printer = RegexInterpreter()  # Works like to_string
        self._distinguisher = RegexDistinguisher()
        self._decider = RegexDecider(interpreter=RegexInterpreter(),
                                     valid_examples=self.valid +
                                     self.condition_invalid,
                                     invalid_examples=self.invalid)

        # Capturer works like a synthesizer of capturing groups
        self._capturer = Capturer(self.valid, self.captured,
                                  self.condition_invalid,
                                  self.ground_truth_regex,
                                  self.ground_truth_conditions,
                                  self.configuration)
        self._node_counter = NodeCounter()

        # Subclass decides which enumerator to use
        self._enumerator = None

        # To store synthesized regexes and captures:
        self.solutions = []
        self.first_regex = None

        # counters and timers:
        self.indistinguishable = 0
        # Number of indistinguishable programs after which the synthesizer returns.
        self.max_indistinguishable = 3
        self.start_time = None
        self.last_print_time = time.time()

    @property
    def enumerator(self):
        return self._enumerator

    @property
    def decider(self):
        return self._decider

    @abstractmethod
    def synthesize(self):
        """ Main synthesis procedure. Implemented in subclasses. """
        raise NotImplementedError

    def terminate(self):
        stats.total_synthesis_time = round(time.time() - self.start_time, 2)
        logger.info(f'Synthesizer done.')

        now = datetime.datetime.now()
        info_str = f'On {socket.gethostname()} on {now.strftime("%Y-%m-%d %H:%M:%S")}.\n'
        info_str += f'Enumerator: {self._enumerator}' \
                    f'{" (no pruning)" if not self.configuration.pruning else ""}\n'
        info_str += f'Terminated: {self.configuration.die}\n'
        info_str += str(stats) + "\n\n"

        if len(self.solutions) > 0:
            if self.configuration.print_first_regex:
                first_regex_str = self._decider.interpreter.eval(
                    self.first_regex)
                info_str += f'First regex: {first_regex_str}\n'
            regex, capturing_groups, capture_conditions = self.solutions[0]
            conditions, conditions_captures = capture_conditions
            solution_str = self._decider.interpreter.eval(
                regex, captures=conditions_captures)
            if len(conditions) > 0:
                solution_str += ', ' + conditions_to_str(conditions)
            info_str += f'Solution: {solution_str}\n' \
                        f'  Nodes: {self._node_counter.eval(self.solutions[0][0])}\n'
            if len(capturing_groups) > 0:
                info_str += f'  Cap. groups: ' \
                            f'{self._decider.interpreter.eval(regex, captures=capturing_groups)}\n' \
                            f'  Num. cap. groups: {len(capturing_groups)}'
            else:
                info_str += "  No capturing groups."
        else:
            info_str += f'  No solution.'

        info_str += '\n'

        if self.ground_truth_regex is not None:
            info_str += f'  Ground truth: {self.ground_truth_regex}' \
                        f' {", ".join(self.ground_truth_conditions)}'
        logger.info(info_str)

        if len(self.configuration.log_path) > 0:
            f = open(self.configuration.log_path, "w")
            f.write(info_str)

    def distinguish(self):
        """ Generate a distinguishing input between programs (if there is one),
        and interact with the user to disambiguate. """
        distinguish_start = time.time()
        dist_input, keep_if_valid, keep_if_invalid, unknown = \
            self._distinguisher.distinguish(self.solutions)
        if dist_input is not None:
            # interaction_start_time = time.time()
            stats.regex_interactions += 1
            logger.info(f'Distinguishing input "{dist_input}" in '
                        f'{round(time.time() - distinguish_start, 2)} seconds')

            for regex in unknown:
                r0 = self._decider.interpreter.eval(regex)
                if re.fullmatch(r0, dist_input):
                    keep_if_valid.append(regex)
                else:
                    keep_if_invalid.append(regex)

            if not self.configuration.self_interact:
                self.interact(dist_input, keep_if_valid, keep_if_invalid)
            else:
                self.auto_distinguish(dist_input, keep_if_valid,
                                      keep_if_invalid)
            # self.start_time += time.time() - interaction_start_time

        else:  # programs are indistinguishable
            logger.info("Regexes are indistinguishable")
            self.indistinguishable += 1
            smallest_regex = min(self.solutions,
                                 key=lambda r: len(self._printer.eval(r)))
            self.solutions = [smallest_regex]
        stats.regex_distinguishing_time += time.time() - distinguish_start
        stats.regex_synthesis_time += time.time() - distinguish_start

    def enumerate(self):
        """ Request new program from the enumerator. """
        stats.enumerated_regexes += 1
        program = self._enumerator.next()
        if program is None:  # enumerator is exhausted
            return
        if self._printer is not None:
            logger.debug(
                f'Enumerator generated: {self._printer.eval(program)}')
        else:
            logger.debug(f'Enumerator generated: {program}')

        if stats.enumerated_regexes > 0 and time.time(
        ) - self.last_print_time > 30:
            logger.info(f'Enumerated {stats.enumerated_regexes} regexes in '
                        f'{nice_time(time.time() - self.start_time)}.')
            self.last_print_time = time.time()

        return program

    def interact(self, dist_input, keep_if_valid, keep_if_invalid):
        """ Interact with user to ascertain whether the distinguishing input is valid """
        valid_answer = False
        # Do not count time spent waiting for user input: add waiting time to start_time.
        while not valid_answer and not self.configuration.die:
            x = input(f'Is "{dist_input}" valid? (y/n)\n')
            if x.lower().rstrip() in yes_values:
                logger.info(f'"{dist_input}" is {colored("valid", "green")}.')
                valid_answer = True
                self._decider.add_example([dist_input], True)
                self.solutions = keep_if_valid
                # self.indistinguishable = 0
            elif x.lower().rstrip() in no_values:
                logger.info(f'"{dist_input}" is {colored("invalid", "red")}.')
                valid_answer = True
                self._decider.add_example([dist_input], False)
                self.solutions = keep_if_invalid
                # self.indistinguishable = 0
            else:
                logger.info(
                    f"Invalid answer {x}! Please answer 'yes' or 'no'.")

    def auto_distinguish(self, dist_input: str, keep_if_valid: List,
                         keep_if_invalid: List):
        """ Simulate interaction """
        match = re.fullmatch(self.ground_truth_regex, dist_input)
        if match is not None:
            logger.info(
                f'Auto: "{dist_input}" is {colored("valid", "green")}.')
            self._decider.add_example([dist_input], True)
            self.solutions = keep_if_valid
        else:
            logger.info(
                f'Auto: "{dist_input}" is {colored("invalid", "red")}.')
            self._decider.add_example([dist_input], False)
            self.solutions = keep_if_invalid

    def try_for_depth(self):
        stats.first_regex_time = -1
        while True:
            regex = self.try_regex()

            if regex is None or self.configuration.die:  # enumerator is exhausted or user interrupted synthesizer
                break

            if regex == -1:  # enumerated a regex that is not correct
                continue

            if self.configuration.synth_captures:
                capturing_groups = self.try_capturing_groups(regex)

                if capturing_groups is None:
                    logger.info(
                        "Failed to find capture groups for the given captures."
                    )
                    continue
            else:
                capturing_groups = []

            if self.configuration.synth_conditions:
                capture_conditions = self.try_capture_conditions(regex)

                if capture_conditions[0] is None:
                    logger.info(
                        "Failed to find capture conditions that invalidate condition_invalid."
                    )
                    continue
            else:
                capture_conditions = []

            self.solutions.append(
                (regex, capturing_groups, capture_conditions))

            if len(self.solutions
                   ) > 0 and not self.configuration.disambiguation:
                break

            if len(self.solutions) >= self.max_before_distinguishing:
                # if there are more than max_before_disambiguating solutions, disambiguate.
                self.distinguish()

            if self.indistinguishable >= self.max_indistinguishable:
                break
        while len(self.solutions) > 1:
            self.distinguish()
        assert len(self.solutions) <= 1  # only one regex remains

    def try_capture_conditions(self, regex):
        cap_conditions_synthesis_start = time.time()
        capture_conditions = self._capturer.synthesize_capture_conditions(
            regex)
        stats.cap_conditions_synthesis_time += time.time(
        ) - cap_conditions_synthesis_start
        return capture_conditions

    def try_capturing_groups(self, regex):
        cap_groups_synthesis_start = time.time()
        # synthesize captures that reflect the desired captured strings.
        captures = self._capturer.synthesize_capturing_groups(regex)
        stats.cap_groups_synthesis_time += time.time(
        ) - cap_groups_synthesis_start
        return captures

    def try_regex(self):
        regex_synthesis_start = time.time()

        regex = self.enumerate()
        if regex is None:
            return None

        analysis_result = self._decider.analyze(regex)

        if analysis_result.is_ok():  # program satisfies I/O examples
            logger.info(
                f'Regex accepted. {self._node_counter.eval(regex, [0])} nodes. '
                f'{stats.enumerated_regexes} attempts '
                f'and {round(time.time() - self.start_time, 2)} seconds:')
            logger.info(self._printer.eval(regex))
            self._enumerator.update()
            stats.regex_synthesis_time += time.time() - regex_synthesis_start
            if stats.first_regex_time == -1:
                stats.first_regex_time = time.time() - self.start_time
                self.first_regex = regex
            return regex

        elif self.configuration.pruning:
            new_predicates = analysis_result.why()
            if new_predicates is not None:
                for pred in new_predicates:
                    pred_str = self._printer.eval(pred.args[0])
                    if len(pred.args) > 1:
                        pred_str = str(pred.args[1]) + " " + pred_str
                    logger.debug(f'New predicate: {pred.name} {pred_str}')
        else:
            new_predicates = None
        self._enumerator.update(new_predicates)
        stats.regex_synthesis_time += time.time() - regex_synthesis_start
        return -1
예제 #5
0
파일: capturer.py 프로젝트: Marghrid/FOREST
class Capturer:
    """ 'capturer is one who, or that which, captures' """

    def __init__(self, valid: List[List[str]], captures: List[List[Optional[str]]],
                 condition_invalid: List[List[str]], ground_truth_regex: str,
                 ground_truth_conditions: List[str], configuration):
        self.valid = valid
        self.captures = captures
        self.condition_invalid = condition_invalid
        self.ground_truth_regex = ground_truth_regex
        self.ground_truth_conditions = ground_truth_conditions
        self.configuration = configuration
        self.interpreter = RegexInterpreter()
        self.max_before_distinguish = 2  # 2 for conversational clarification

    def synthesize_capturing_groups(self, regex: Node):
        """ Given regex, find capturing groups which match self.captures """
        if len(self.captures) == 0 or len(self.captures[0]) == 0:
            return []
        nodes = regex.get_leaves()
        # try placing a capture group in each node
        for sub in all_sublists_n(nodes, len(self.captures[0])):
            stats.enumerated_cap_groups += 1
            regex_str = self.interpreter.eval(regex, captures=sub)
            compiled_re = re.compile(regex_str)
            if not all(
                    map(lambda s: compiled_re.fullmatch(s[0]) is not None, self.valid)):
                continue
            if all(map(lambda i:
                       elementwise_eq(compiled_re.fullmatch(self.valid[i][0]).groups(),
                                      self.captures[i]), range(len(self.captures)))):
                return sub
        return None

    def synthesize_capture_conditions(self, regex: Node):
        """ Given regex, synthesise capture conditions that validate self.condition_invalid """
        if len(self.condition_invalid) == 0:
            return [], []
        nodes = regex.get_leaves()
        regex_str = self.interpreter.eval(regex)
        compiled_re = re.compile(regex_str)
        # Test that regex satisfies
        if not all(map(lambda ex: compiled_re.fullmatch(ex[0]), self.valid)):
            raise ValueError("Regex doesn't match all valid examples")
        if not all(map(lambda s: compiled_re.fullmatch(s[0]), self.condition_invalid)):
            logger.info("Regex doesn't match all condition invalid examples. Removing.")
            self.condition_invalid = list(filter(lambda s: compiled_re.fullmatch(s[0]),
                                                 self.condition_invalid))
            if len(self.condition_invalid) == 0:
                logger.info("No condition invalid examples left. No capture conditions needed.")
                return [], []

        for n in range(1, len(nodes)):
            for sub in all_sublists_n(nodes, n):
                stats.enumerated_cap_conditions += 1
                regex_str = self.interpreter.eval(regex, captures=sub)
                compiled_re = re.compile(regex_str)
                if not all(map(lambda ex: compiled_re.fullmatch(ex[0]) is not None,
                               self.valid)):
                    continue
                if not all(map(lambda ex: all(map(lambda g: is_int(g),  # or is_float(g),
                                                  compiled_re.fullmatch(ex[0]).groups())), self.valid)):
                    continue
                condition = self._synthesize_conditions_for_captures(regex, sub)
                if condition is not None:
                    return condition, sub
        return None, None

    def _synthesize_conditions_for_captures(self, regex, capture_groups):
        """ Given capturing groups, try to find conditions that satisfy examples. """
        assert len(self.condition_invalid) > 0
        self._cc_enumerator = CaptureConditionsEnumerator(self.interpreter.eval(regex, captures=capture_groups),
                                                          len(capture_groups), self.valid, self.condition_invalid)
        condition_distinguisher = ConditionDistinguisher(regex, capture_groups, self.valid[0][0])

        conditions = []
        while True:
            new_condition = self._cc_enumerator.next()
            if new_condition is not None:
                if not self.configuration.disambiguation:
                    return new_condition
                self._cc_enumerator.update()
                conditions.append(new_condition)
                if len(conditions) >= self.max_before_distinguish:
                    start_distinguish_time = time.time()
                    dist_input, keep_if_valid, keep_if_invalid = \
                        condition_distinguisher.distinguish(conditions)
                    stats.cap_conditions_distinguishing_time += time.time() - start_distinguish_time
                    stats.cap_conditions_interactions += 1
                    if not self.configuration.self_interact:
                        conditions = self._interact(dist_input, keep_if_valid, keep_if_invalid)
                    else:
                        conditions = self._auto_distinguish(dist_input, keep_if_valid, keep_if_invalid)
                    pass
            else:
                if len(conditions) == 0:
                    return None
                else:
                    while len(conditions) > 1:
                        start_distinguish_time = time.time()
                        dist_input, keep_if_valid, keep_if_invalid = \
                            condition_distinguisher.distinguish(conditions)
                        stats.cap_conditions_distinguishing_time += time.time() - start_distinguish_time
                        stats.cap_conditions_interactions += 1
                        if not self.configuration.self_interact:
                            conditions = self._interact(dist_input, keep_if_valid, keep_if_invalid)
                        else:
                            conditions = self._auto_distinguish(dist_input, keep_if_valid,
                                                                keep_if_invalid)
                    assert len(conditions) == 1
                    return conditions[0]

    def _interact(self, dist_input, keep_if_valid, keep_if_invalid):
        """ Interact with user to ascertain whether the distinguishing input is valid """
        while not self.configuration.die:
            x = input(f'Is "{dist_input}" valid? (y/n)\n')
            if x.lower().rstrip() in yes_values:
                logger.info(f'"{dist_input}" is {colored("valid", "green")}.')
                self.valid.append([dist_input])
                self._cc_enumerator.add_valid(dist_input)
                return keep_if_valid
            elif x.lower().rstrip() in no_values:
                logger.info(f'"{dist_input}" is {colored("conditional invalid", "red")}.')
                self.condition_invalid.append([dist_input])
                self._cc_enumerator.add_conditional_invalid(dist_input)
                return keep_if_invalid
            else:
                logger.info(f"Invalid answer {x}! Please answer 'yes' or 'no'.")

    def _auto_distinguish(self, dist_input: str, keep_if_valid: List, keep_if_invalid: List):
        """ Given distinguishing input, simulate user interaction based on ground truth """
        match = re.fullmatch(self.ground_truth_regex, dist_input)
        if match is not None and check_conditions(self.ground_truth_conditions, match):
            logger.info(f'Auto: "{dist_input}" is {colored("valid", "green")}.')
            self.valid.append([dist_input])
            self._cc_enumerator.add_valid(dist_input)
            return keep_if_valid

        logger.info(f'Auto: "{dist_input}" is {colored("conditional invalid", "red")}.')
        self.condition_invalid.append([dist_input])
        self._cc_enumerator.add_conditional_invalid(dist_input)
        return keep_if_invalid
예제 #6
0
 def __init__(self):
     self._toz3 = ToZ3()
     self._printer = RegexInterpreter()
     self.force_multi_distinguish = False
     self.force_distinguish2 = False
예제 #7
0
class RegexDistinguisher:
    def __init__(self):
        self._toz3 = ToZ3()
        self._printer = RegexInterpreter()
        self.force_multi_distinguish = False
        self.force_distinguish2 = False

    def distinguish(self, programs):
        logger.debug(f"Distinguishing {len(programs)}: "
                     f"{','.join(map(self._printer.eval, programs))}")
        assert len(programs) >= 2
        if not self.force_multi_distinguish and len(programs) == 2:
            return self.distinguish2(programs[0], programs[1])
        if self.force_distinguish2:
            dist_input, keep_if_valid, keep_if_invalid, _ = \
                self.distinguish2(programs[0], programs[1])
            return dist_input, keep_if_valid, keep_if_invalid, programs[2:]
        else:
            return self.multi_distinguish(programs)

    def distinguish2(self, r1, r2):
        global use_derivatives
        solver = z3.Solver()
        solver.set('random_seed', 7)
        solver.set('sat.random_seed', 7)
        if use_derivatives:
            try:
                solver.set('smt.seq.use_derivatives', True)
                solver.check()
            except:
                pass

        z3_r1 = self._toz3.eval(r1[0])
        z3_r2 = self._toz3.eval(r2[0])

        dist = z3.String("distinguishing")

        ro_1 = z3.Bool(f"ro_1")
        solver.add(ro_1 == z3.InRe(dist, z3_r1))
        ro_2 = z3.Bool(f"ro_2")
        solver.add(ro_2 == z3.InRe(dist, z3_r2))

        solver.add(ro_1 != ro_2)

        if solver.check() == z3.sat:
            if len(r1[2][0]) == 0 and len(r2[2][0]) == 0:
                dist_input = solver.model()[dist].as_string()
                if solver.model()[ro_1]:
                    return dist_input, [r1], [r2], []
                else:
                    return dist_input, [r2], [r1], []

            # Find dist_input that respects conditions
            r1_str = self._printer.eval(r1[0], captures=r1[2][1])
            r1_conditions = list(map(lambda c: " ".join(map(str, c)),
                                     r1[2][0]))
            r2_str = self._printer.eval(r2[0], captures=r2[2][1])
            r2_conditions = list(map(lambda c: " ".join(map(str, c)),
                                     r2[2][0]))

            while True:
                dist_input = solver.model()[dist].as_string()

                match = re.fullmatch(r1_str, dist_input)
                if match is not None and check_conditions(
                        r1_conditions, match):
                    break

                match = re.fullmatch(r2_str, dist_input)
                if match is not None and check_conditions(
                        r2_conditions, match):
                    break

                solver.add(dist != z3.StringVal(dist_input))
                if not solver.check() == z3.sat:
                    return None, None, None, None

            if solver.model()[ro_1]:
                return dist_input, [r1], [r2], []
            else:
                return dist_input, [r2], [r1], []
        else:
            return None, None, None, None

    def multi_distinguish(self, regexes):
        start = time.time()
        # Problem: cannot distinguish more than 4 regexes at once: it takes forever.
        # Solution: use only 4 randomly selected regexes for the SMT maximization,
        # and then add the others to the solution.
        if len(regexes) <= 4:
            selected_regexes = regexes
            others = []
        else:
            random.seed('regex')
            random.shuffle(regexes)
            selected_regexes = regexes[:4]
            others = regexes[4:]
        solver = z3.Optimize()

        z3_regexes = []
        for regex in selected_regexes:
            z3_regex = self._toz3.eval(regex)
            z3_regexes.append(z3_regex)

        dist = z3.String("distinguishing")
        # solver.add(z3.Length(dist) <= 6)

        ro_z3 = []
        for i, z3_regex in enumerate(z3_regexes):
            ro = z3.Bool(f"ro_{i}")
            ro_z3.append(ro)
            solver.add(ro == z3.InRe(dist, z3_regex))

        # ro_z3[i] == true if dist matches regex[i].

        big_or = []
        for ro_i, ro_j in combinations(ro_z3, 2):
            big_or.append(z3.Xor(ro_i, ro_j))
            solver.add_soft(z3.Xor(ro_i, ro_j))
        solver.add(z3.Or(big_or))  # at least one regex is distinguished

        if solver.check() == z3.sat:
            # print(solver.model())
            print("took", round(time.time() - start, 2), "seconds")
            keep_if_valid = []
            keep_if_invalid = []
            dist_input = str(solver.model()[dist]).strip('"')
            for i, ro in enumerate(ro_z3):
                if solver.model()[ro]:
                    keep_if_valid.append(selected_regexes[i])
                else:
                    keep_if_invalid.append(selected_regexes[i])
                smallest_regex = min(selected_regexes,
                                     key=lambda r: len(self._printer.eval(r)))
            return dist_input, keep_if_valid, keep_if_invalid, others
        else:
            return None, None, None, None
예제 #8
0
    def synthesize(self):
        self.start_time = time.time()
        try:
            valid, invalid = self.split_examples()
        except:
            valid = None
            invalid = None

        if valid is not None and len(
                valid[0]) > 1 and not self.configuration.force_dynamic:
            self._decider = RegexDecider(interpreter=RegexInterpreter(),
                                         valid_examples=self.valid,
                                         invalid_examples=self.invalid,
                                         split_valid=valid)

            self.valid = valid
            self.invalid = invalid

            assert all(map(lambda l: len(l) == len(self.valid[0]), self.valid))
            assert all(
                map(lambda l: len(l) == len(self.invalid[0]), self.invalid))

            type_validations = ['regex'] * len(self.valid[0])
            builder = DSLBuilder(type_validations, self.valid, self.invalid)
            dsls = builder.build()

            for depth in range(3, 10):
                self._enumerator = StaticMultiTreeEnumerator(
                    self.main_dsl, dsls, depth)
                depth_start = time.time()
                self.try_for_depth()
                stats.per_depth_times[depth] = time.time() - depth_start
                if len(self.solutions) > 0:
                    self.terminate()
                    return self.solutions[0]
                elif self.configuration.die:
                    self.terminate()
                    return

        else:
            self._decider = RegexDecider(RegexInterpreter(), self.valid,
                                         self.invalid)
            sizes = list(itertools.product(range(3, 10), range(1, 10)))
            sizes.sort(key=lambda t: (2**t[0] - 1) * t[1])
            for dep, length in sizes:
                self._enumerator = DynamicMultiTreeEnumerator(self.main_dsl,
                                                              depth=dep,
                                                              length=length)
                depth_start = time.time()
                self.try_for_depth()
                stats.per_depth_times[(dep,
                                       length)] = time.time() - depth_start

                if len(self.solutions) > 0:
                    self.terminate()
                    return self.solutions[0]
                elif self.configuration.die:
                    self.terminate()
                    return

        return None
예제 #9
0
 def __init__(self, regex: Node, capture_groups: List, valid_example: str):
     self._regex = regex
     self._capture_groups = capture_groups
     self._valid_example = valid_example
     self._interpreter = RegexInterpreter()
     self.condition_operators = utils.condition_operators
예제 #10
0
class ConditionDistinguisher:
    def __init__(self, regex: Node, capture_groups: List, valid_example: str):
        self._regex = regex
        self._capture_groups = capture_groups
        self._valid_example = valid_example
        self._interpreter = RegexInterpreter()
        self.condition_operators = utils.condition_operators

    def distinguish(self, models: List[List[Tuple]]):
        """ Find distinguishing input for sets of conditions in models """
        solver = z3.Optimize()
        distinct_models = keep_distinct(models)
        cs = []
        logger.info(f"Distinguishing {distinct_models}")
        for cap_idx in range(len(self._capture_groups)):
            cs.append(z3.Int(f"c{cap_idx}"))

        sat_ms = []
        for m_idx, model in enumerate(distinct_models):
            sat_ms.append(z3.Bool(f"sat_m{m_idx}"))
            solver.add(self._get_sat_m_constraint(cs, model, sat_ms[m_idx]))

        # solver.add(z3.Xor(sat_m1, sat_m2)) # For conversational clarification
        big_or = []
        for m_i, m_j in combinations(
                sat_ms, 2):  # maximisation objective from multi-distinguish
            big_or.append(z3.Xor(m_i, m_j))
            solver.add_soft(z3.Xor(m_i, m_j))
        solver.add(
            z3.Or(big_or)
        )  # at least one set of conditions is distinguished from the rest

        if solver.check() == z3.sat:
            valid_ex = self._valid_example
            for c_idx, c_var in enumerate(cs):
                if solver.model()[c_var] is None:
                    continue
                c_val = str(solver.model()[c_var])
                regex_str = self._interpreter.eval(
                    self._regex, captures=[self._capture_groups[c_idx]])
                compiled_re = re.compile(regex_str)
                cap_substr = compiled_re.fullmatch(valid_ex).groups()[0]
                c_val = c_val.rjust(len(cap_substr), '0')
                valid_ex = valid_ex.replace(cap_substr, c_val, 1)
            keep_if_valid = []
            keep_if_invalid = []
            for m_idx in range(len(distinct_models)):
                if solver.model()[sat_ms[m_idx]]:
                    keep_if_valid.append(models[m_idx])
                else:
                    keep_if_invalid.append(models[m_idx])
            logger.info(f"Dist. input: {valid_ex}")
            return valid_ex, keep_if_valid, keep_if_invalid

        else:
            logger.info(f"Indistinguishable")
            return None, None, None

    def _get_sat_m_constraint(self, cs, model, sat_m):
        big_and = []
        for cond_ctr in model:
            c_idx, cond, bound_val = cond_ctr
            c_i = cs[c_idx]
            big_and.append(self._get_reverse_cond(cond, bound_val, c_i))
        return sat_m == z3.And(big_and)

    def _get_reverse_cond(self, cond: str, bound: z3.IntNumRef,
                          cap_var: z3.IntNumRef):
        op = self.condition_operators[cond]
        return op(cap_var, bound)
예제 #11
0
    def synthesize(self):
        self.start_time = time.time()
        try:
            valid, invalid = self.split_examples()
        except:
            valid = None
            invalid = None

        if valid is not None and len(valid[0]) > 1 and not self.configuration.force_dynamic:
            # self.valid = valid
            # self.invalid = invalid
            self._decider = RegexDecider(RegexInterpreter(), valid, invalid, split_valid=valid)

            assert all(map(lambda l: len(l) == len(valid[0]), valid))
            assert all(map(lambda l: len(l) == len(invalid[0]), invalid))

            type_validations = ['regex'] * len(valid[0])
            builder = DSLBuilder(type_validations, valid, invalid, sketches=True)
            dsls = builder.build()

            for depth in range(2, 10):
                self._enumerator = StaticMultiTreeEnumerator(self.main_dsl, dsls, depth)

                depth_start = time.time()
                self.try_for_depth()
                stats.per_depth_times[depth] = time.time() - depth_start

                print("level sketches", self.count_level_sketches)
                self.count_total_sketches += self.count_level_sketches
                self.count_level_sketches = 0
                print("good sketches", self.count_good_sketches)
                print("\ntotal sketches", self.count_total_sketches)

                if self.count_good_sketches > 0:
                    self.terminate()
                    return self.solutions[0]
                self.count_good_sketches = 0

                if len(self.solutions) > 0:
                    self.terminate()
                    return self.solutions[0]
                elif self.configuration.die:
                    self.terminate()
                    return

        else:
            self._decider = RegexDecider(RegexInterpreter(), valid, invalid)
            sizes = list(itertools.product(range(3, 10), range(1, 10)))
            sizes.sort(key=lambda t: (2 ** t[0] - 1) * t[1])
            for dep, length in sizes:
                logger.info(f'Sketching programs of depth {dep} and length {length}...')
                self._enumerator = DynamicMultiTreeEnumerator(self.main_dsl, depth=dep, length=length)

                depth_start = time.time()
                self.try_for_depth()
                stats.per_depth_times[(dep, length)] = time.time() - depth_start

                print("level sketches", self.count_level_sketches)
                self.count_total_sketches += self.count_level_sketches
                self.count_level_sketches = 0
                print("good sketches", self.count_good_sketches)
                print("\ntotal sketches", self.count_total_sketches)

                if self.count_good_sketches > 0:
                    self.terminate()
                    return self.solutions[0]
                self.count_good_sketches = 0

                if len(self.solutions) > 0:
                    self.terminate()
                    return self.solutions[0]
                elif self.configuration.die:
                    self.terminate()
                    return