def insert_regex(in_str, choice_len, regex_constraint): """ insert sampling string for regex if needed """ insert_pos = set(range(choice_len)) for r in regex_constraint: desired_counts = regex_constraint[r] actual_counts = len(re.findall(ALL_REGEX[r], ''.join(in_str))) num_of_inserts = max(0, desired_counts - actual_counts) if len(insert_pos) < num_of_inserts: raise IndexError indices_of_inserts = set(random.sample(insert_pos, k=num_of_inserts)) for i in indices_of_inserts: in_str[i] = pre.create(ALL_REGEX[r]).sample() # print('Expected', desired_counts, r) # print('Actual number is', actual_counts) # print('After insert', ''.join(in_str)) insert_pos -= indices_of_inserts str = ''.join(in_str) return str[:choice_len] if len(str) > choice_len else str
def getValidNetworkOutputs(net, current_trace, examples, maxNetworkEvals=None): if maxNetworkEvals is None: maxNetworkEvals = 10 lookup = { concept: RegexWrapper(concept) for concept in current_trace.baseConcepts } examples = tuple(sorted(examples)) isCached = examples in networkCache if isCached: o_generator = networkCache[examples]['valid'] else: def get_more_outputs(): networkCache[examples] = {'valid': [], 'all': set()} inputs = [[(example, ) for example in examples]] * 500 outputs_count = Counter(net.sample(inputs)) for o in sorted(outputs_count, key=outputs_count.get): yield (o, outputs_count[o]) o_generator = (o for i in range(maxNetworkEvals) for o in get_more_outputs()) group_idx = 0 for o, count in o_generator: if not isCached: if o in networkCache[examples]['all']: continue else: networkCache[examples]['all'].add(o) try: if not isCached: networkCache[examples]['all'].add(o) pre.create(o, lookup=lookup) #throw error if o is not a valid regex if not isCached: networkCache[examples]['valid'].append((o, count)) yield (o, count, group_idx) group_idx += 1 except pre.ParseException: pass
def __init__(self): self.regexes = [ pre.create(".+"), pre.create("\d+"), pre.create("\w+"), pre.create("\s+"), pre.create("\\u+"), pre.create("\l+")]
def getRelatedRegexConcepts( o ): #Proposals that will be good only if getRegexConcept(o) is good def extend(t, c): #add CRPs at the end if any(x in modes for x in ("regex-crp", "regex-crp-crp")): t, c = t.addPY(c) if "regex-crp" in modes: yield (t, c) if "regex-crp-crp" in modes: t, c = t.addPY(c) yield (t, c) for (t, c) in extend(*getRegexConcept(o)): yield (t, c) for i in range(len(o)): if o[i] in current_trace.baseConcepts: if fuzzConcepts: #Try replacing one concept in regex with parent or child for o_alt in similarConcepts.get(o[i], []): r = pre.create(o[:i] + (o_alt, ) + o[i + 1:], lookup=lookup) t, c = current_trace.addregex(r) yield (t, c) for (t, c) in extend(t, c): yield (t, c) if "crp-regex" in modes: #Try replacing one concept with a new PYConcept t, c = current_trace.addPY(o[i]) r = pre.create(o[:i] + (c, ) + o[i + 1:], lookup={ **lookup, c: RegexWrapper(c) }) t, c = t.addregex(r) yield (t, c) for (t, c) in extend(t, c): yield (t, c)
def evaluate_datum(i, datum, model, dcModel, nRepeats, mdl, max_to_check): t = time.time() samples = {(PregHole, )} # make more general # TODO, i don't think n_checked, n_hit = 0, 0 if model: if args.beam: samples, _scores = model.beam_decode([datum.IO[:nExamples]], beam_size=nRepeats) else: samples, _scores, _ = model.sampleAndScore([datum.IO[:nExamples]], nRepeats=nRepeats) # only loop over unique samples: samples = {tuple(sample) for sample in samples} # only if (not holeSpecificDcModel) or (not dcModel): g = basegrammar if not dcModel else dcModel.infer_grammar( datum.IO[:nExamples]) # TODO pp g = untorch(g) sketchtups = [] for sample in samples: try: sk = pre_to_prog(pre.create(sample, lookup=lookup_d)) if holeSpecificDcModel: g = untorch( dcModel.infer_grammar( (datum.IO[:nExamples], sample))) #TODO: make sure this line is correct .. sketchtups.append(SketchTup(sk, g)) except ParseException: n_checked += 1 yield (RegexResult(sample, None, float('-inf'), n_checked, time.time() - t), float('-inf')) continue # only loop over unique sketches: sketchtups = {sk for sk in sketchtups} #fine #alternate which sketch to enumerate from each time results, n_checked, n_hit = pypy_enumerate(datum.tp, datum.IO, mdl, sketchtups, n_checked, n_hit, t, max_to_check) yield from (result for result in results) ######TODO: want search time and total time to hit task ###### print(f"task {i}:") print(f"evaluation for task {i} took {time.time()-t} seconds") print(f"For task {i}, tried {n_checked} sketches, found {n_hit} hits")
def get_gt_ll(name, examples): #gets groundtruth from dict import pregex as pre r_str = gt_dict[name] preg = pre.create(r_str) if type(examples[0]) == list: examples = ["".join(example) for example in examples] s = sum(preg.match(example) for example in examples) if s == float("-inf"): print("bad for ", name) print('preg:', preg) print('preg sample:', [preg.sample() for i in range(3)]) print("exs", examples) #assert False return s
def regex_bound(X): c = Counter(X) regexes = [ pre.create(".+"), pre.create("\d+"), pre.create("\w+"), pre.create("\s"), pre.create("\\u+"), pre.create("\l+") ] regex_scores = [] for r in regexes: regex_scores.append( sum(c[x] * r.match(x) for x in X) / sum(c.values())) return max(regex_scores)
def generate(): l = random.choice(list(range(1, MAX_CONST_LEN))) c = pre.create("."*l).sample() return Constant(c)
import random import pregex as pre import math assert (pre.create("f(a|o)*") == pre.Concat([ pre.String("f"), pre.KleeneStar(pre.Alt([pre.String("a"), pre.String("o")])) ])) assert (pre.create("fa|o*") == pre.Concat([ pre.String("f"), pre.Alt([pre.String("a"), pre.KleeneStar(pre.String("o"))]) ])) assert (pre.create("(f.*)+") == pre.Plus( pre.Concat([pre.String("f"), pre.KleeneStar(pre.dot)]))) test_cases = [("foo", "fo", False), ("foo", "foo", True), ("foo", "fooo", False), ("foo", "fo*", True), ("foo", "fo+", True), ("foo", "f(oo)*", True), ("foo", "f(a|b)*", False), ("foo", "f(a|o)*", True), ("foo", "fa|o*", True), ("foo", "fo|a*", False), ("foo", "f|ao|ao|a", True), ("f" + "o" * 50, "f" + "o*" * 10, True), ("foo", "fo?+", True), ("foo", "fo**", True), ("(foo)", "\\(foo\\)", True), ("foo foo. foo foo foo.", "foo(\\.? foo)*\\.", True), ("123abcABC ", ".+", True), ("123abcABC ", '\\w+', False), ("123abcABC ", "\\w+\\s", True), ("123abcABC ", "\\d+\\l+\\u+\\s", True)] for (string, regex, matches) in test_cases: print("Parsing", regex) r = pre.create(regex) print("Matching", string, r)
while nextInput: s = input() if s == "": nextInput = False else: examples.append(s) print("calculating... ") samples, scores = model.sampleAndScore([examples], nRepeats=2) print(samples) print(scores) print(len(scores), len(samples)) index = scores.index(max(scores)) #print(samples[index]) try: sample = pre.create(list(samples[index])) except: sample = samples[index] #sample = samples[index] print("best example by nn score:", sample, ", nn score:", max(scores)) pregexes = [] pscores = [] for samp in samples: try: reg = pre.create(list(samp)) pregexes.append(reg) pscores.append(sum(reg.match(ex) for ex in examples)) except: pregexes.append(samp) pscores.append(float('-inf'))
if __name__ == '__main__': import time import pregex as pre g = basegrammar d = sample_datum(g=g, N=4, compute_sketches=True, top_k_sketches=100, inv_temp=1.0, reward_fn=None, sample_fn=None, dc_model=None) print(d.p) print(d.p.evaluate([])) print(d.sketch) #print(d.sketch.evaluate([])(pre.String(""))) print(d.sketch.evaluate([])) print(d.sketchseq) for o in d.IO: print("example") print(o) from util.regex_util import PregHole, pre_to_prog preg = pre.create(d.sketchseq, lookup={PregHole: PregHole()}) print(preg) print(pre_to_prog(preg))
print(task.name) totalTasks += 1 print("\tTRAIN\t", ["".join(example[1]) for example in task.examples]) testingExamples = regexHeldOutExamples(task) print("\tTEST\t", [example[1] for example in testingExamples]) gt_preg = gt_dict[int(task.name.split(" ")[-1])] print("\tHuman written regex:", gt_preg) eprint(verbatimTable(["".join(example[1]) for example in task.examples] + [None] + \ [gt_preg,None] + \ [example[1] for example in testingExamples])) eprint("&") gt_preg = pre.create(gt_preg) def examineProgram(entry): global preg global diff_lookup program = entry.program ll = entry.logLikelihood program = program.visit(ConstantVisitor(task.str_const)) print(program) preg = program.evaluate([])(pre.String("")) Pstring = prettyRegex(preg) if autodiff: params, diff_lookup = create_params() #TODO opt = optim.Adam(params, lr=0.1)
def getRegexConcept(o): r = pre.create(o, lookup=lookup) t, c = current_trace.addregex(r) return (t, c)
if type(state.regex) is pre.Alt: for x in state.regex.values: if type(x) is RegexWrapper: addParent(c, x.concept) #return {c: descendants[c] + ancestors[c] for c in self.baseConcepts} return {c: parents[c] + children[c] for c in self.baseConcepts} # ------------ Unit tests ------------ if __name__=="__main__": import pickle import os trace = Trace() trace, firstName = trace.addPYregex(pre.create("\\w+")) trace, lastName = trace.addPYregex(pre.create("\\w+")) regex = pre.create("f l", {"f":RegexWrapper(firstName), "l":RegexWrapper(lastName)}) trace, fullName = trace.addPYregex(regex) trace, observation1 = trace.observe(fullName, "Luke Hewitt") trace, observation2 = trace.observe(fullName, "Kevin Ellis") trace, observation3 = trace.observe(fullName, "Max Nye") trace, observation4 = trace.observe(fullName, "Max Siegel") trace, observation5 = trace.observe(fullName, "Max KW") with open('trace_Test.p', 'wb') as file: pickle.dump(trace, file) with open('trace_Test.p', 'rb') as file:
def loadData(file, n_examples, n_tasks, max_length): if file[-9:] == "csv_900.p": print("Loading csv_900.p, ignoring data params.") with open(file, 'rb') as f: return pickle.load(f) rand = np.random.RandomState() rand.seed(0) all_tasks = [] for x in pickle.load(open(file, 'rb')): elems_filtered = [elem for elem in x['data'] if len(elem) < max_length] if len(elems_filtered) == 0: continue task = rand.choice(elems_filtered, size=min(len(elems_filtered), n_examples), replace=False).tolist() all_tasks.append(task) data = [] def lenEntropy(examples): return (max(len(x) for x in examples), -util.entropy(examples)) all_tasks = sorted(all_tasks, key=lenEntropy) tasks_unique = [] for task in all_tasks: unique = set(task) if not any( len(unique ^ x) / len(unique) < 0.7 for x in tasks_unique ): #No two tasks should have mostly the same unique elements data.append(task) tasks_unique.append(unique) data = [X for X in data if not all(x == X[0] for x in X)] grouped_data = [[ examples for examples in data if max(len(x) for x in examples[:100]) == i ] for i in range(max_length)] grouped_data = [X for X in grouped_data if len(X) > 0] #pos_int_regex = pre.create("0|((1|2|3|4|5|6|7|8|9)\d*)") #float_regex = pre.Concat([pos_int_regex, pre.create("\.\d+")]) num_regex = pre.create("-?0|((1|2|3|4|5|6|7|8|9)\d*)(\.\d+)?") test_data = [] for i in range(len(grouped_data)): #rand.shuffle(grouped_data[i]) for fil in [num_regex]: fil_idxs = [ j for j, xs in enumerate(grouped_data[i]) if all( fil.match(x) > float("-inf") for x in Counter(xs)) ] #Indexes that match filter grouped_data[i] = [ grouped_data[i][j] for j in range(len(grouped_data[i])) if j not in fil_idxs[math.ceil(0.25 * len(grouped_data[i])):] ] #Keep at most 20% grouped_data[i].sort(key=len, reverse=True) test_data.extend( [X for X in grouped_data[i][n_tasks:] if len(set(X)) >= 5]) grouped_data[i] = grouped_data[i][:n_tasks] grouped_data[i].sort(key=lenEntropy) data = [x for examples in grouped_data for x in examples] #group_idxs = list(np.cumsum([len(X) for X in grouped_data])) # rand.shuffle(data) # if args.n_tasks is not None: # data = data[args.skip_tasks:args.n_tasks + args.skip_tasks] # data = sorted(data, key=lambda examples: (max(len(x) for x in examples), -len(set(examples)))) test_data = test_data[:-(len(test_data) % 10)] data = data[:-((len(data) + len(test_data)) % 100)] data.sort(key=lenEntropy) test_data.sort(key=lenEntropy) unique_lengths = sorted(list(set([max(len(x) for x in X) for X in data]))) group_idxs = np.cumsum([ len([X for X in data if max(len(x) for x in X) == l]) for l in unique_lengths ]) return data, group_idxs, test_data