def __init__(self, valid_examples, invalid_examples, captured, condition_invalid, dsl: TyrellSpec, ground_truth: str, configuration: Configuration): self.max_before_distinguishing = 2 # 2 for conversational clarification self.valid = valid_examples self.invalid = invalid_examples self.captured = captured self.condition_invalid = condition_invalid self.dsl = dsl self.configuration = configuration if ground_truth is None: self.ground_truth_conditions = None self.ground_truth_regex = None else: ground_truth = ground_truth.split(',') ground_truth_conditions = list( filter(lambda s: s.lstrip().startswith("$"), ground_truth)) self.ground_truth_conditions = list( map(lambda s: s.lstrip(), ground_truth_conditions)) self.ground_truth_regex = ",".join( filter(lambda s: not s.lstrip().startswith("$"), ground_truth)) if not configuration.pruning: logger.warning('Synthesizing without pruning the search space.') # If auto-interaction is enabled, the ground truth must be a valid regex. if self.configuration.self_interact: assert len(self.ground_truth_regex) > 0 and is_regex( self.ground_truth_regex) # Initialize components self._printer = RegexInterpreter() # Works like to_string self._distinguisher = RegexDistinguisher() self._decider = RegexDecider(interpreter=RegexInterpreter(), valid_examples=self.valid + self.condition_invalid, invalid_examples=self.invalid) # Capturer works like a synthesizer of capturing groups self._capturer = Capturer(self.valid, self.captured, self.condition_invalid, self.ground_truth_regex, self.ground_truth_conditions, self.configuration) self._node_counter = NodeCounter() # Subclass decides which enumerator to use self._enumerator = None # To store synthesized regexes and captures: self.solutions = [] self.first_regex = None # counters and timers: self.indistinguishable = 0 # Number of indistinguishable programs after which the synthesizer returns. self.max_indistinguishable = 3 self.start_time = None self.last_print_time = time.time()
def __init__(self, valid: List[List[str]], captures: List[List[Optional[str]]], condition_invalid: List[List[str]], ground_truth_regex: str, ground_truth_conditions: List[str], configuration): self.valid = valid self.captures = captures self.condition_invalid = condition_invalid self.ground_truth_regex = ground_truth_regex self.ground_truth_conditions = ground_truth_conditions self.configuration = configuration self.interpreter = RegexInterpreter() self.max_before_distinguish = 2 # 2 for conversational clarification
def synthesize(type_validation): global synthesizer assert synthesizer is not None printer = RegexInterpreter() program = synthesizer.synthesize() if program is not None: regex, capturing_groups, capture_conditions = program conditions, condition_captures = capture_conditions solution_str = printer.eval(regex, captures=condition_captures) if len(conditions) > 0: solution_str += ', ' + conditions_to_str(conditions) print(f'\nSolution:\n {solution_str}') if len(capturing_groups) > 0: print( f'Captures:\n {printer.eval(regex, captures=capturing_groups)}' ) else: print('Solution not found!') return program
class MultipleSynthesizer(ABC): """ Interactive synthesizer. Finds more than one program consistent with the examples. """ def __init__(self, valid_examples, invalid_examples, captured, condition_invalid, dsl: TyrellSpec, ground_truth: str, configuration: Configuration): self.max_before_distinguishing = 2 # 2 for conversational clarification self.valid = valid_examples self.invalid = invalid_examples self.captured = captured self.condition_invalid = condition_invalid self.dsl = dsl self.configuration = configuration if ground_truth is None: self.ground_truth_conditions = None self.ground_truth_regex = None else: ground_truth = ground_truth.split(',') ground_truth_conditions = list( filter(lambda s: s.lstrip().startswith("$"), ground_truth)) self.ground_truth_conditions = list( map(lambda s: s.lstrip(), ground_truth_conditions)) self.ground_truth_regex = ",".join( filter(lambda s: not s.lstrip().startswith("$"), ground_truth)) if not configuration.pruning: logger.warning('Synthesizing without pruning the search space.') # If auto-interaction is enabled, the ground truth must be a valid regex. if self.configuration.self_interact: assert len(self.ground_truth_regex) > 0 and is_regex( self.ground_truth_regex) # Initialize components self._printer = RegexInterpreter() # Works like to_string self._distinguisher = RegexDistinguisher() self._decider = RegexDecider(interpreter=RegexInterpreter(), valid_examples=self.valid + self.condition_invalid, invalid_examples=self.invalid) # Capturer works like a synthesizer of capturing groups self._capturer = Capturer(self.valid, self.captured, self.condition_invalid, self.ground_truth_regex, self.ground_truth_conditions, self.configuration) self._node_counter = NodeCounter() # Subclass decides which enumerator to use self._enumerator = None # To store synthesized regexes and captures: self.solutions = [] self.first_regex = None # counters and timers: self.indistinguishable = 0 # Number of indistinguishable programs after which the synthesizer returns. self.max_indistinguishable = 3 self.start_time = None self.last_print_time = time.time() @property def enumerator(self): return self._enumerator @property def decider(self): return self._decider @abstractmethod def synthesize(self): """ Main synthesis procedure. Implemented in subclasses. """ raise NotImplementedError def terminate(self): stats.total_synthesis_time = round(time.time() - self.start_time, 2) logger.info(f'Synthesizer done.') now = datetime.datetime.now() info_str = f'On {socket.gethostname()} on {now.strftime("%Y-%m-%d %H:%M:%S")}.\n' info_str += f'Enumerator: {self._enumerator}' \ f'{" (no pruning)" if not self.configuration.pruning else ""}\n' info_str += f'Terminated: {self.configuration.die}\n' info_str += str(stats) + "\n\n" if len(self.solutions) > 0: if self.configuration.print_first_regex: first_regex_str = self._decider.interpreter.eval( self.first_regex) info_str += f'First regex: {first_regex_str}\n' regex, capturing_groups, capture_conditions = self.solutions[0] conditions, conditions_captures = capture_conditions solution_str = self._decider.interpreter.eval( regex, captures=conditions_captures) if len(conditions) > 0: solution_str += ', ' + conditions_to_str(conditions) info_str += f'Solution: {solution_str}\n' \ f' Nodes: {self._node_counter.eval(self.solutions[0][0])}\n' if len(capturing_groups) > 0: info_str += f' Cap. groups: ' \ f'{self._decider.interpreter.eval(regex, captures=capturing_groups)}\n' \ f' Num. cap. groups: {len(capturing_groups)}' else: info_str += " No capturing groups." else: info_str += f' No solution.' info_str += '\n' if self.ground_truth_regex is not None: info_str += f' Ground truth: {self.ground_truth_regex}' \ f' {", ".join(self.ground_truth_conditions)}' logger.info(info_str) if len(self.configuration.log_path) > 0: f = open(self.configuration.log_path, "w") f.write(info_str) def distinguish(self): """ Generate a distinguishing input between programs (if there is one), and interact with the user to disambiguate. """ distinguish_start = time.time() dist_input, keep_if_valid, keep_if_invalid, unknown = \ self._distinguisher.distinguish(self.solutions) if dist_input is not None: # interaction_start_time = time.time() stats.regex_interactions += 1 logger.info(f'Distinguishing input "{dist_input}" in ' f'{round(time.time() - distinguish_start, 2)} seconds') for regex in unknown: r0 = self._decider.interpreter.eval(regex) if re.fullmatch(r0, dist_input): keep_if_valid.append(regex) else: keep_if_invalid.append(regex) if not self.configuration.self_interact: self.interact(dist_input, keep_if_valid, keep_if_invalid) else: self.auto_distinguish(dist_input, keep_if_valid, keep_if_invalid) # self.start_time += time.time() - interaction_start_time else: # programs are indistinguishable logger.info("Regexes are indistinguishable") self.indistinguishable += 1 smallest_regex = min(self.solutions, key=lambda r: len(self._printer.eval(r))) self.solutions = [smallest_regex] stats.regex_distinguishing_time += time.time() - distinguish_start stats.regex_synthesis_time += time.time() - distinguish_start def enumerate(self): """ Request new program from the enumerator. """ stats.enumerated_regexes += 1 program = self._enumerator.next() if program is None: # enumerator is exhausted return if self._printer is not None: logger.debug( f'Enumerator generated: {self._printer.eval(program)}') else: logger.debug(f'Enumerator generated: {program}') if stats.enumerated_regexes > 0 and time.time( ) - self.last_print_time > 30: logger.info(f'Enumerated {stats.enumerated_regexes} regexes in ' f'{nice_time(time.time() - self.start_time)}.') self.last_print_time = time.time() return program def interact(self, dist_input, keep_if_valid, keep_if_invalid): """ Interact with user to ascertain whether the distinguishing input is valid """ valid_answer = False # Do not count time spent waiting for user input: add waiting time to start_time. while not valid_answer and not self.configuration.die: x = input(f'Is "{dist_input}" valid? (y/n)\n') if x.lower().rstrip() in yes_values: logger.info(f'"{dist_input}" is {colored("valid", "green")}.') valid_answer = True self._decider.add_example([dist_input], True) self.solutions = keep_if_valid # self.indistinguishable = 0 elif x.lower().rstrip() in no_values: logger.info(f'"{dist_input}" is {colored("invalid", "red")}.') valid_answer = True self._decider.add_example([dist_input], False) self.solutions = keep_if_invalid # self.indistinguishable = 0 else: logger.info( f"Invalid answer {x}! Please answer 'yes' or 'no'.") def auto_distinguish(self, dist_input: str, keep_if_valid: List, keep_if_invalid: List): """ Simulate interaction """ match = re.fullmatch(self.ground_truth_regex, dist_input) if match is not None: logger.info( f'Auto: "{dist_input}" is {colored("valid", "green")}.') self._decider.add_example([dist_input], True) self.solutions = keep_if_valid else: logger.info( f'Auto: "{dist_input}" is {colored("invalid", "red")}.') self._decider.add_example([dist_input], False) self.solutions = keep_if_invalid def try_for_depth(self): stats.first_regex_time = -1 while True: regex = self.try_regex() if regex is None or self.configuration.die: # enumerator is exhausted or user interrupted synthesizer break if regex == -1: # enumerated a regex that is not correct continue if self.configuration.synth_captures: capturing_groups = self.try_capturing_groups(regex) if capturing_groups is None: logger.info( "Failed to find capture groups for the given captures." ) continue else: capturing_groups = [] if self.configuration.synth_conditions: capture_conditions = self.try_capture_conditions(regex) if capture_conditions[0] is None: logger.info( "Failed to find capture conditions that invalidate condition_invalid." ) continue else: capture_conditions = [] self.solutions.append( (regex, capturing_groups, capture_conditions)) if len(self.solutions ) > 0 and not self.configuration.disambiguation: break if len(self.solutions) >= self.max_before_distinguishing: # if there are more than max_before_disambiguating solutions, disambiguate. self.distinguish() if self.indistinguishable >= self.max_indistinguishable: break while len(self.solutions) > 1: self.distinguish() assert len(self.solutions) <= 1 # only one regex remains def try_capture_conditions(self, regex): cap_conditions_synthesis_start = time.time() capture_conditions = self._capturer.synthesize_capture_conditions( regex) stats.cap_conditions_synthesis_time += time.time( ) - cap_conditions_synthesis_start return capture_conditions def try_capturing_groups(self, regex): cap_groups_synthesis_start = time.time() # synthesize captures that reflect the desired captured strings. captures = self._capturer.synthesize_capturing_groups(regex) stats.cap_groups_synthesis_time += time.time( ) - cap_groups_synthesis_start return captures def try_regex(self): regex_synthesis_start = time.time() regex = self.enumerate() if regex is None: return None analysis_result = self._decider.analyze(regex) if analysis_result.is_ok(): # program satisfies I/O examples logger.info( f'Regex accepted. {self._node_counter.eval(regex, [0])} nodes. ' f'{stats.enumerated_regexes} attempts ' f'and {round(time.time() - self.start_time, 2)} seconds:') logger.info(self._printer.eval(regex)) self._enumerator.update() stats.regex_synthesis_time += time.time() - regex_synthesis_start if stats.first_regex_time == -1: stats.first_regex_time = time.time() - self.start_time self.first_regex = regex return regex elif self.configuration.pruning: new_predicates = analysis_result.why() if new_predicates is not None: for pred in new_predicates: pred_str = self._printer.eval(pred.args[0]) if len(pred.args) > 1: pred_str = str(pred.args[1]) + " " + pred_str logger.debug(f'New predicate: {pred.name} {pred_str}') else: new_predicates = None self._enumerator.update(new_predicates) stats.regex_synthesis_time += time.time() - regex_synthesis_start return -1
class Capturer: """ 'capturer is one who, or that which, captures' """ def __init__(self, valid: List[List[str]], captures: List[List[Optional[str]]], condition_invalid: List[List[str]], ground_truth_regex: str, ground_truth_conditions: List[str], configuration): self.valid = valid self.captures = captures self.condition_invalid = condition_invalid self.ground_truth_regex = ground_truth_regex self.ground_truth_conditions = ground_truth_conditions self.configuration = configuration self.interpreter = RegexInterpreter() self.max_before_distinguish = 2 # 2 for conversational clarification def synthesize_capturing_groups(self, regex: Node): """ Given regex, find capturing groups which match self.captures """ if len(self.captures) == 0 or len(self.captures[0]) == 0: return [] nodes = regex.get_leaves() # try placing a capture group in each node for sub in all_sublists_n(nodes, len(self.captures[0])): stats.enumerated_cap_groups += 1 regex_str = self.interpreter.eval(regex, captures=sub) compiled_re = re.compile(regex_str) if not all( map(lambda s: compiled_re.fullmatch(s[0]) is not None, self.valid)): continue if all(map(lambda i: elementwise_eq(compiled_re.fullmatch(self.valid[i][0]).groups(), self.captures[i]), range(len(self.captures)))): return sub return None def synthesize_capture_conditions(self, regex: Node): """ Given regex, synthesise capture conditions that validate self.condition_invalid """ if len(self.condition_invalid) == 0: return [], [] nodes = regex.get_leaves() regex_str = self.interpreter.eval(regex) compiled_re = re.compile(regex_str) # Test that regex satisfies if not all(map(lambda ex: compiled_re.fullmatch(ex[0]), self.valid)): raise ValueError("Regex doesn't match all valid examples") if not all(map(lambda s: compiled_re.fullmatch(s[0]), self.condition_invalid)): logger.info("Regex doesn't match all condition invalid examples. Removing.") self.condition_invalid = list(filter(lambda s: compiled_re.fullmatch(s[0]), self.condition_invalid)) if len(self.condition_invalid) == 0: logger.info("No condition invalid examples left. No capture conditions needed.") return [], [] for n in range(1, len(nodes)): for sub in all_sublists_n(nodes, n): stats.enumerated_cap_conditions += 1 regex_str = self.interpreter.eval(regex, captures=sub) compiled_re = re.compile(regex_str) if not all(map(lambda ex: compiled_re.fullmatch(ex[0]) is not None, self.valid)): continue if not all(map(lambda ex: all(map(lambda g: is_int(g), # or is_float(g), compiled_re.fullmatch(ex[0]).groups())), self.valid)): continue condition = self._synthesize_conditions_for_captures(regex, sub) if condition is not None: return condition, sub return None, None def _synthesize_conditions_for_captures(self, regex, capture_groups): """ Given capturing groups, try to find conditions that satisfy examples. """ assert len(self.condition_invalid) > 0 self._cc_enumerator = CaptureConditionsEnumerator(self.interpreter.eval(regex, captures=capture_groups), len(capture_groups), self.valid, self.condition_invalid) condition_distinguisher = ConditionDistinguisher(regex, capture_groups, self.valid[0][0]) conditions = [] while True: new_condition = self._cc_enumerator.next() if new_condition is not None: if not self.configuration.disambiguation: return new_condition self._cc_enumerator.update() conditions.append(new_condition) if len(conditions) >= self.max_before_distinguish: start_distinguish_time = time.time() dist_input, keep_if_valid, keep_if_invalid = \ condition_distinguisher.distinguish(conditions) stats.cap_conditions_distinguishing_time += time.time() - start_distinguish_time stats.cap_conditions_interactions += 1 if not self.configuration.self_interact: conditions = self._interact(dist_input, keep_if_valid, keep_if_invalid) else: conditions = self._auto_distinguish(dist_input, keep_if_valid, keep_if_invalid) pass else: if len(conditions) == 0: return None else: while len(conditions) > 1: start_distinguish_time = time.time() dist_input, keep_if_valid, keep_if_invalid = \ condition_distinguisher.distinguish(conditions) stats.cap_conditions_distinguishing_time += time.time() - start_distinguish_time stats.cap_conditions_interactions += 1 if not self.configuration.self_interact: conditions = self._interact(dist_input, keep_if_valid, keep_if_invalid) else: conditions = self._auto_distinguish(dist_input, keep_if_valid, keep_if_invalid) assert len(conditions) == 1 return conditions[0] def _interact(self, dist_input, keep_if_valid, keep_if_invalid): """ Interact with user to ascertain whether the distinguishing input is valid """ while not self.configuration.die: x = input(f'Is "{dist_input}" valid? (y/n)\n') if x.lower().rstrip() in yes_values: logger.info(f'"{dist_input}" is {colored("valid", "green")}.') self.valid.append([dist_input]) self._cc_enumerator.add_valid(dist_input) return keep_if_valid elif x.lower().rstrip() in no_values: logger.info(f'"{dist_input}" is {colored("conditional invalid", "red")}.') self.condition_invalid.append([dist_input]) self._cc_enumerator.add_conditional_invalid(dist_input) return keep_if_invalid else: logger.info(f"Invalid answer {x}! Please answer 'yes' or 'no'.") def _auto_distinguish(self, dist_input: str, keep_if_valid: List, keep_if_invalid: List): """ Given distinguishing input, simulate user interaction based on ground truth """ match = re.fullmatch(self.ground_truth_regex, dist_input) if match is not None and check_conditions(self.ground_truth_conditions, match): logger.info(f'Auto: "{dist_input}" is {colored("valid", "green")}.') self.valid.append([dist_input]) self._cc_enumerator.add_valid(dist_input) return keep_if_valid logger.info(f'Auto: "{dist_input}" is {colored("conditional invalid", "red")}.') self.condition_invalid.append([dist_input]) self._cc_enumerator.add_conditional_invalid(dist_input) return keep_if_invalid
def __init__(self): self._toz3 = ToZ3() self._printer = RegexInterpreter() self.force_multi_distinguish = False self.force_distinguish2 = False
class RegexDistinguisher: def __init__(self): self._toz3 = ToZ3() self._printer = RegexInterpreter() self.force_multi_distinguish = False self.force_distinguish2 = False def distinguish(self, programs): logger.debug(f"Distinguishing {len(programs)}: " f"{','.join(map(self._printer.eval, programs))}") assert len(programs) >= 2 if not self.force_multi_distinguish and len(programs) == 2: return self.distinguish2(programs[0], programs[1]) if self.force_distinguish2: dist_input, keep_if_valid, keep_if_invalid, _ = \ self.distinguish2(programs[0], programs[1]) return dist_input, keep_if_valid, keep_if_invalid, programs[2:] else: return self.multi_distinguish(programs) def distinguish2(self, r1, r2): global use_derivatives solver = z3.Solver() solver.set('random_seed', 7) solver.set('sat.random_seed', 7) if use_derivatives: try: solver.set('smt.seq.use_derivatives', True) solver.check() except: pass z3_r1 = self._toz3.eval(r1[0]) z3_r2 = self._toz3.eval(r2[0]) dist = z3.String("distinguishing") ro_1 = z3.Bool(f"ro_1") solver.add(ro_1 == z3.InRe(dist, z3_r1)) ro_2 = z3.Bool(f"ro_2") solver.add(ro_2 == z3.InRe(dist, z3_r2)) solver.add(ro_1 != ro_2) if solver.check() == z3.sat: if len(r1[2][0]) == 0 and len(r2[2][0]) == 0: dist_input = solver.model()[dist].as_string() if solver.model()[ro_1]: return dist_input, [r1], [r2], [] else: return dist_input, [r2], [r1], [] # Find dist_input that respects conditions r1_str = self._printer.eval(r1[0], captures=r1[2][1]) r1_conditions = list(map(lambda c: " ".join(map(str, c)), r1[2][0])) r2_str = self._printer.eval(r2[0], captures=r2[2][1]) r2_conditions = list(map(lambda c: " ".join(map(str, c)), r2[2][0])) while True: dist_input = solver.model()[dist].as_string() match = re.fullmatch(r1_str, dist_input) if match is not None and check_conditions( r1_conditions, match): break match = re.fullmatch(r2_str, dist_input) if match is not None and check_conditions( r2_conditions, match): break solver.add(dist != z3.StringVal(dist_input)) if not solver.check() == z3.sat: return None, None, None, None if solver.model()[ro_1]: return dist_input, [r1], [r2], [] else: return dist_input, [r2], [r1], [] else: return None, None, None, None def multi_distinguish(self, regexes): start = time.time() # Problem: cannot distinguish more than 4 regexes at once: it takes forever. # Solution: use only 4 randomly selected regexes for the SMT maximization, # and then add the others to the solution. if len(regexes) <= 4: selected_regexes = regexes others = [] else: random.seed('regex') random.shuffle(regexes) selected_regexes = regexes[:4] others = regexes[4:] solver = z3.Optimize() z3_regexes = [] for regex in selected_regexes: z3_regex = self._toz3.eval(regex) z3_regexes.append(z3_regex) dist = z3.String("distinguishing") # solver.add(z3.Length(dist) <= 6) ro_z3 = [] for i, z3_regex in enumerate(z3_regexes): ro = z3.Bool(f"ro_{i}") ro_z3.append(ro) solver.add(ro == z3.InRe(dist, z3_regex)) # ro_z3[i] == true if dist matches regex[i]. big_or = [] for ro_i, ro_j in combinations(ro_z3, 2): big_or.append(z3.Xor(ro_i, ro_j)) solver.add_soft(z3.Xor(ro_i, ro_j)) solver.add(z3.Or(big_or)) # at least one regex is distinguished if solver.check() == z3.sat: # print(solver.model()) print("took", round(time.time() - start, 2), "seconds") keep_if_valid = [] keep_if_invalid = [] dist_input = str(solver.model()[dist]).strip('"') for i, ro in enumerate(ro_z3): if solver.model()[ro]: keep_if_valid.append(selected_regexes[i]) else: keep_if_invalid.append(selected_regexes[i]) smallest_regex = min(selected_regexes, key=lambda r: len(self._printer.eval(r))) return dist_input, keep_if_valid, keep_if_invalid, others else: return None, None, None, None
def synthesize(self): self.start_time = time.time() try: valid, invalid = self.split_examples() except: valid = None invalid = None if valid is not None and len( valid[0]) > 1 and not self.configuration.force_dynamic: self._decider = RegexDecider(interpreter=RegexInterpreter(), valid_examples=self.valid, invalid_examples=self.invalid, split_valid=valid) self.valid = valid self.invalid = invalid assert all(map(lambda l: len(l) == len(self.valid[0]), self.valid)) assert all( map(lambda l: len(l) == len(self.invalid[0]), self.invalid)) type_validations = ['regex'] * len(self.valid[0]) builder = DSLBuilder(type_validations, self.valid, self.invalid) dsls = builder.build() for depth in range(3, 10): self._enumerator = StaticMultiTreeEnumerator( self.main_dsl, dsls, depth) depth_start = time.time() self.try_for_depth() stats.per_depth_times[depth] = time.time() - depth_start if len(self.solutions) > 0: self.terminate() return self.solutions[0] elif self.configuration.die: self.terminate() return else: self._decider = RegexDecider(RegexInterpreter(), self.valid, self.invalid) sizes = list(itertools.product(range(3, 10), range(1, 10))) sizes.sort(key=lambda t: (2**t[0] - 1) * t[1]) for dep, length in sizes: self._enumerator = DynamicMultiTreeEnumerator(self.main_dsl, depth=dep, length=length) depth_start = time.time() self.try_for_depth() stats.per_depth_times[(dep, length)] = time.time() - depth_start if len(self.solutions) > 0: self.terminate() return self.solutions[0] elif self.configuration.die: self.terminate() return return None
def __init__(self, regex: Node, capture_groups: List, valid_example: str): self._regex = regex self._capture_groups = capture_groups self._valid_example = valid_example self._interpreter = RegexInterpreter() self.condition_operators = utils.condition_operators
class ConditionDistinguisher: def __init__(self, regex: Node, capture_groups: List, valid_example: str): self._regex = regex self._capture_groups = capture_groups self._valid_example = valid_example self._interpreter = RegexInterpreter() self.condition_operators = utils.condition_operators def distinguish(self, models: List[List[Tuple]]): """ Find distinguishing input for sets of conditions in models """ solver = z3.Optimize() distinct_models = keep_distinct(models) cs = [] logger.info(f"Distinguishing {distinct_models}") for cap_idx in range(len(self._capture_groups)): cs.append(z3.Int(f"c{cap_idx}")) sat_ms = [] for m_idx, model in enumerate(distinct_models): sat_ms.append(z3.Bool(f"sat_m{m_idx}")) solver.add(self._get_sat_m_constraint(cs, model, sat_ms[m_idx])) # solver.add(z3.Xor(sat_m1, sat_m2)) # For conversational clarification big_or = [] for m_i, m_j in combinations( sat_ms, 2): # maximisation objective from multi-distinguish big_or.append(z3.Xor(m_i, m_j)) solver.add_soft(z3.Xor(m_i, m_j)) solver.add( z3.Or(big_or) ) # at least one set of conditions is distinguished from the rest if solver.check() == z3.sat: valid_ex = self._valid_example for c_idx, c_var in enumerate(cs): if solver.model()[c_var] is None: continue c_val = str(solver.model()[c_var]) regex_str = self._interpreter.eval( self._regex, captures=[self._capture_groups[c_idx]]) compiled_re = re.compile(regex_str) cap_substr = compiled_re.fullmatch(valid_ex).groups()[0] c_val = c_val.rjust(len(cap_substr), '0') valid_ex = valid_ex.replace(cap_substr, c_val, 1) keep_if_valid = [] keep_if_invalid = [] for m_idx in range(len(distinct_models)): if solver.model()[sat_ms[m_idx]]: keep_if_valid.append(models[m_idx]) else: keep_if_invalid.append(models[m_idx]) logger.info(f"Dist. input: {valid_ex}") return valid_ex, keep_if_valid, keep_if_invalid else: logger.info(f"Indistinguishable") return None, None, None def _get_sat_m_constraint(self, cs, model, sat_m): big_and = [] for cond_ctr in model: c_idx, cond, bound_val = cond_ctr c_i = cs[c_idx] big_and.append(self._get_reverse_cond(cond, bound_val, c_i)) return sat_m == z3.And(big_and) def _get_reverse_cond(self, cond: str, bound: z3.IntNumRef, cap_var: z3.IntNumRef): op = self.condition_operators[cond] return op(cap_var, bound)
def synthesize(self): self.start_time = time.time() try: valid, invalid = self.split_examples() except: valid = None invalid = None if valid is not None and len(valid[0]) > 1 and not self.configuration.force_dynamic: # self.valid = valid # self.invalid = invalid self._decider = RegexDecider(RegexInterpreter(), valid, invalid, split_valid=valid) assert all(map(lambda l: len(l) == len(valid[0]), valid)) assert all(map(lambda l: len(l) == len(invalid[0]), invalid)) type_validations = ['regex'] * len(valid[0]) builder = DSLBuilder(type_validations, valid, invalid, sketches=True) dsls = builder.build() for depth in range(2, 10): self._enumerator = StaticMultiTreeEnumerator(self.main_dsl, dsls, depth) depth_start = time.time() self.try_for_depth() stats.per_depth_times[depth] = time.time() - depth_start print("level sketches", self.count_level_sketches) self.count_total_sketches += self.count_level_sketches self.count_level_sketches = 0 print("good sketches", self.count_good_sketches) print("\ntotal sketches", self.count_total_sketches) if self.count_good_sketches > 0: self.terminate() return self.solutions[0] self.count_good_sketches = 0 if len(self.solutions) > 0: self.terminate() return self.solutions[0] elif self.configuration.die: self.terminate() return else: self._decider = RegexDecider(RegexInterpreter(), valid, invalid) sizes = list(itertools.product(range(3, 10), range(1, 10))) sizes.sort(key=lambda t: (2 ** t[0] - 1) * t[1]) for dep, length in sizes: logger.info(f'Sketching programs of depth {dep} and length {length}...') self._enumerator = DynamicMultiTreeEnumerator(self.main_dsl, depth=dep, length=length) depth_start = time.time() self.try_for_depth() stats.per_depth_times[(dep, length)] = time.time() - depth_start print("level sketches", self.count_level_sketches) self.count_total_sketches += self.count_level_sketches self.count_level_sketches = 0 print("good sketches", self.count_good_sketches) print("\ntotal sketches", self.count_total_sketches) if self.count_good_sketches > 0: self.terminate() return self.solutions[0] self.count_good_sketches = 0 if len(self.solutions) > 0: self.terminate() return self.solutions[0] elif self.configuration.die: self.terminate() return