예제 #1
0
    def __init__(self, valid_examples, invalid_examples, captured,
                 condition_invalid, dsl: TyrellSpec, ground_truth: str,
                 configuration: Configuration):

        self.max_before_distinguishing = 2  # 2 for conversational clarification
        self.valid = valid_examples
        self.invalid = invalid_examples
        self.captured = captured
        self.condition_invalid = condition_invalid
        self.dsl = dsl
        self.configuration = configuration
        if ground_truth is None:
            self.ground_truth_conditions = None
            self.ground_truth_regex = None
        else:
            ground_truth = ground_truth.split(',')
            ground_truth_conditions = list(
                filter(lambda s: s.lstrip().startswith("$"), ground_truth))
            self.ground_truth_conditions = list(
                map(lambda s: s.lstrip(), ground_truth_conditions))
            self.ground_truth_regex = ",".join(
                filter(lambda s: not s.lstrip().startswith("$"), ground_truth))

        if not configuration.pruning:
            logger.warning('Synthesizing without pruning the search space.')
        # If auto-interaction is enabled, the ground truth must be a valid regex.
        if self.configuration.self_interact:
            assert len(self.ground_truth_regex) > 0 and is_regex(
                self.ground_truth_regex)

        # Initialize components
        self._printer = RegexInterpreter()  # Works like to_string
        self._distinguisher = RegexDistinguisher()
        self._decider = RegexDecider(interpreter=RegexInterpreter(),
                                     valid_examples=self.valid +
                                     self.condition_invalid,
                                     invalid_examples=self.invalid)

        # Capturer works like a synthesizer of capturing groups
        self._capturer = Capturer(self.valid, self.captured,
                                  self.condition_invalid,
                                  self.ground_truth_regex,
                                  self.ground_truth_conditions,
                                  self.configuration)
        self._node_counter = NodeCounter()

        # Subclass decides which enumerator to use
        self._enumerator = None

        # To store synthesized regexes and captures:
        self.solutions = []
        self.first_regex = None

        # counters and timers:
        self.indistinguishable = 0
        # Number of indistinguishable programs after which the synthesizer returns.
        self.max_indistinguishable = 3
        self.start_time = None
        self.last_print_time = time.time()
예제 #2
0
파일: capturer.py 프로젝트: Marghrid/FOREST
 def __init__(self, valid: List[List[str]], captures: List[List[Optional[str]]],
              condition_invalid: List[List[str]], ground_truth_regex: str,
              ground_truth_conditions: List[str], configuration):
     self.valid = valid
     self.captures = captures
     self.condition_invalid = condition_invalid
     self.ground_truth_regex = ground_truth_regex
     self.ground_truth_conditions = ground_truth_conditions
     self.configuration = configuration
     self.interpreter = RegexInterpreter()
     self.max_before_distinguish = 2  # 2 for conversational clarification
예제 #3
0
파일: forest.py 프로젝트: Marghrid/FOREST
def synthesize(type_validation):
    global synthesizer
    assert synthesizer is not None
    printer = RegexInterpreter()
    program = synthesizer.synthesize()
    if program is not None:
        regex, capturing_groups, capture_conditions = program
        conditions, condition_captures = capture_conditions
        solution_str = printer.eval(regex, captures=condition_captures)
        if len(conditions) > 0:
            solution_str += ', ' + conditions_to_str(conditions)
        print(f'\nSolution:\n  {solution_str}')
        if len(capturing_groups) > 0:
            print(
                f'Captures:\n  {printer.eval(regex, captures=capturing_groups)}'
            )
    else:
        print('Solution not found!')
    return program
예제 #4
0
 def __init__(self):
     self._toz3 = ToZ3()
     self._printer = RegexInterpreter()
     self.force_multi_distinguish = False
     self.force_distinguish2 = False
예제 #5
0
    def synthesize(self):
        self.start_time = time.time()
        try:
            valid, invalid = self.split_examples()
        except:
            valid = None
            invalid = None

        if valid is not None and len(
                valid[0]) > 1 and not self.configuration.force_dynamic:
            self._decider = RegexDecider(interpreter=RegexInterpreter(),
                                         valid_examples=self.valid,
                                         invalid_examples=self.invalid,
                                         split_valid=valid)

            self.valid = valid
            self.invalid = invalid

            assert all(map(lambda l: len(l) == len(self.valid[0]), self.valid))
            assert all(
                map(lambda l: len(l) == len(self.invalid[0]), self.invalid))

            type_validations = ['regex'] * len(self.valid[0])
            builder = DSLBuilder(type_validations, self.valid, self.invalid)
            dsls = builder.build()

            for depth in range(3, 10):
                self._enumerator = StaticMultiTreeEnumerator(
                    self.main_dsl, dsls, depth)
                depth_start = time.time()
                self.try_for_depth()
                stats.per_depth_times[depth] = time.time() - depth_start
                if len(self.solutions) > 0:
                    self.terminate()
                    return self.solutions[0]
                elif self.configuration.die:
                    self.terminate()
                    return

        else:
            self._decider = RegexDecider(RegexInterpreter(), self.valid,
                                         self.invalid)
            sizes = list(itertools.product(range(3, 10), range(1, 10)))
            sizes.sort(key=lambda t: (2**t[0] - 1) * t[1])
            for dep, length in sizes:
                self._enumerator = DynamicMultiTreeEnumerator(self.main_dsl,
                                                              depth=dep,
                                                              length=length)
                depth_start = time.time()
                self.try_for_depth()
                stats.per_depth_times[(dep,
                                       length)] = time.time() - depth_start

                if len(self.solutions) > 0:
                    self.terminate()
                    return self.solutions[0]
                elif self.configuration.die:
                    self.terminate()
                    return

        return None
예제 #6
0
 def __init__(self, regex: Node, capture_groups: List, valid_example: str):
     self._regex = regex
     self._capture_groups = capture_groups
     self._valid_example = valid_example
     self._interpreter = RegexInterpreter()
     self.condition_operators = utils.condition_operators
예제 #7
0
    def synthesize(self):
        self.start_time = time.time()
        try:
            valid, invalid = self.split_examples()
        except:
            valid = None
            invalid = None

        if valid is not None and len(valid[0]) > 1 and not self.configuration.force_dynamic:
            # self.valid = valid
            # self.invalid = invalid
            self._decider = RegexDecider(RegexInterpreter(), valid, invalid, split_valid=valid)

            assert all(map(lambda l: len(l) == len(valid[0]), valid))
            assert all(map(lambda l: len(l) == len(invalid[0]), invalid))

            type_validations = ['regex'] * len(valid[0])
            builder = DSLBuilder(type_validations, valid, invalid, sketches=True)
            dsls = builder.build()

            for depth in range(2, 10):
                self._enumerator = StaticMultiTreeEnumerator(self.main_dsl, dsls, depth)

                depth_start = time.time()
                self.try_for_depth()
                stats.per_depth_times[depth] = time.time() - depth_start

                print("level sketches", self.count_level_sketches)
                self.count_total_sketches += self.count_level_sketches
                self.count_level_sketches = 0
                print("good sketches", self.count_good_sketches)
                print("\ntotal sketches", self.count_total_sketches)

                if self.count_good_sketches > 0:
                    self.terminate()
                    return self.solutions[0]
                self.count_good_sketches = 0

                if len(self.solutions) > 0:
                    self.terminate()
                    return self.solutions[0]
                elif self.configuration.die:
                    self.terminate()
                    return

        else:
            self._decider = RegexDecider(RegexInterpreter(), valid, invalid)
            sizes = list(itertools.product(range(3, 10), range(1, 10)))
            sizes.sort(key=lambda t: (2 ** t[0] - 1) * t[1])
            for dep, length in sizes:
                logger.info(f'Sketching programs of depth {dep} and length {length}...')
                self._enumerator = DynamicMultiTreeEnumerator(self.main_dsl, depth=dep, length=length)

                depth_start = time.time()
                self.try_for_depth()
                stats.per_depth_times[(dep, length)] = time.time() - depth_start

                print("level sketches", self.count_level_sketches)
                self.count_total_sketches += self.count_level_sketches
                self.count_level_sketches = 0
                print("good sketches", self.count_good_sketches)
                print("\ntotal sketches", self.count_total_sketches)

                if self.count_good_sketches > 0:
                    self.terminate()
                    return self.solutions[0]
                self.count_good_sketches = 0

                if len(self.solutions) > 0:
                    self.terminate()
                    return self.solutions[0]
                elif self.configuration.die:
                    self.terminate()
                    return