コード例 #1
0
ファイル: sample.py プロジェクト: aaronguo1996/ProgramSearch
def insert_regex(in_str, choice_len, regex_constraint):
    """
    insert sampling string for regex if needed
    """
    insert_pos = set(range(choice_len))
    for r in regex_constraint:
        desired_counts = regex_constraint[r]
        actual_counts = len(re.findall(ALL_REGEX[r], ''.join(in_str)))
        num_of_inserts = max(0, desired_counts - actual_counts)

        if len(insert_pos) < num_of_inserts:
            raise IndexError

        indices_of_inserts = set(random.sample(insert_pos, k=num_of_inserts))

        for i in indices_of_inserts:
            in_str[i] = pre.create(ALL_REGEX[r]).sample()

        # print('Expected', desired_counts, r)
        # print('Actual number is', actual_counts)
        # print('After insert', ''.join(in_str))
        insert_pos -= indices_of_inserts

    str = ''.join(in_str)
    return str[:choice_len] if len(str) > choice_len else str
コード例 #2
0
def getValidNetworkOutputs(net, current_trace, examples, maxNetworkEvals=None):
    if maxNetworkEvals is None: maxNetworkEvals = 10

    lookup = {
        concept: RegexWrapper(concept)
        for concept in current_trace.baseConcepts
    }

    examples = tuple(sorted(examples))
    isCached = examples in networkCache

    if isCached:
        o_generator = networkCache[examples]['valid']
    else:

        def get_more_outputs():
            networkCache[examples] = {'valid': [], 'all': set()}
            inputs = [[(example, ) for example in examples]] * 500
            outputs_count = Counter(net.sample(inputs))
            for o in sorted(outputs_count, key=outputs_count.get):
                yield (o, outputs_count[o])

        o_generator = (o for i in range(maxNetworkEvals)
                       for o in get_more_outputs())

    group_idx = 0
    for o, count in o_generator:
        if not isCached:
            if o in networkCache[examples]['all']:
                continue
            else:
                networkCache[examples]['all'].add(o)
        try:
            if not isCached: networkCache[examples]['all'].add(o)
            pre.create(o,
                       lookup=lookup)  #throw error if o is not a valid regex
            if not isCached: networkCache[examples]['valid'].append((o, count))
            yield (o, count, group_idx)
            group_idx += 1
        except pre.ParseException:
            pass
コード例 #3
0
 def __init__(self):
     self.regexes = [
     pre.create(".+"),
     pre.create("\d+"),
     pre.create("\w+"),
     pre.create("\s+"),
     pre.create("\\u+"),
     pre.create("\l+")]
コード例 #4
0
            def getRelatedRegexConcepts(
                o
            ):  #Proposals that will be good only if getRegexConcept(o) is good
                def extend(t, c):  #add CRPs at the end
                    if any(x in modes for x in ("regex-crp", "regex-crp-crp")):
                        t, c = t.addPY(c)
                        if "regex-crp" in modes: yield (t, c)
                        if "regex-crp-crp" in modes:
                            t, c = t.addPY(c)
                            yield (t, c)

                for (t, c) in extend(*getRegexConcept(o)):
                    yield (t, c)

                for i in range(len(o)):
                    if o[i] in current_trace.baseConcepts:
                        if fuzzConcepts:
                            #Try replacing one concept in regex with parent or child
                            for o_alt in similarConcepts.get(o[i], []):
                                r = pre.create(o[:i] + (o_alt, ) + o[i + 1:],
                                               lookup=lookup)
                                t, c = current_trace.addregex(r)
                                yield (t, c)
                                for (t, c) in extend(t, c):
                                    yield (t, c)

                        if "crp-regex" in modes:
                            #Try replacing one concept with a new PYConcept
                            t, c = current_trace.addPY(o[i])
                            r = pre.create(o[:i] + (c, ) + o[i + 1:],
                                           lookup={
                                               **lookup, c: RegexWrapper(c)
                                           })
                            t, c = t.addregex(r)
                            yield (t, c)
                            for (t, c) in extend(t, c):
                                yield (t, c)
コード例 #5
0
def evaluate_datum(i, datum, model, dcModel, nRepeats, mdl, max_to_check):
    t = time.time()
    samples = {(PregHole, )}  # make more general #  TODO, i don't think
    n_checked, n_hit = 0, 0
    if model:
        if args.beam:
            samples, _scores = model.beam_decode([datum.IO[:nExamples]],
                                                 beam_size=nRepeats)
        else:
            samples, _scores, _ = model.sampleAndScore([datum.IO[:nExamples]],
                                                       nRepeats=nRepeats)
        # only loop over unique samples:
        samples = {tuple(sample) for sample in samples}  # only
    if (not holeSpecificDcModel) or (not dcModel):
        g = basegrammar if not dcModel else dcModel.infer_grammar(
            datum.IO[:nExamples])  # TODO pp
        g = untorch(g)
    sketchtups = []
    for sample in samples:
        try:
            sk = pre_to_prog(pre.create(sample, lookup=lookup_d))

            if holeSpecificDcModel:
                g = untorch(
                    dcModel.infer_grammar(
                        (datum.IO[:nExamples],
                         sample)))  #TODO: make sure this line is correct ..

            sketchtups.append(SketchTup(sk, g))

        except ParseException:
            n_checked += 1
            yield (RegexResult(sample, None, float('-inf'), n_checked,
                               time.time() - t), float('-inf'))
            continue
    # only loop over unique sketches:
    sketchtups = {sk for sk in sketchtups}  #fine
    #alternate which sketch to enumerate from each time

    results, n_checked, n_hit = pypy_enumerate(datum.tp, datum.IO, mdl,
                                               sketchtups, n_checked, n_hit, t,
                                               max_to_check)
    yield from (result for result in results)

    ######TODO: want search time and total time to hit task ######
    print(f"task {i}:")
    print(f"evaluation for task {i} took {time.time()-t} seconds")
    print(f"For task {i}, tried {n_checked} sketches, found {n_hit} hits")
コード例 #6
0
def get_gt_ll(name, examples):
    #gets groundtruth from dict
    import pregex as pre
    r_str = gt_dict[name]
    preg = pre.create(r_str)

    if type(examples[0]) == list:
        examples = ["".join(example) for example in examples]

    s = sum(preg.match(example) for example in examples)
    if s == float("-inf"):
        print("bad for ", name)
        print('preg:', preg)
        print('preg sample:', [preg.sample() for i in range(3)])
        print("exs", examples)
        #assert False
    return s
コード例 #7
0
def regex_bound(X):
    c = Counter(X)
    regexes = [
        pre.create(".+"),
        pre.create("\d+"),
        pre.create("\w+"),
        pre.create("\s"),
        pre.create("\\u+"),
        pre.create("\l+")
    ]
    regex_scores = []
    for r in regexes:
        regex_scores.append(
            sum(c[x] * r.match(x) for x in X) / sum(c.values()))

    return max(regex_scores)
コード例 #8
0
 def generate():
     l = random.choice(list(range(1, MAX_CONST_LEN)))
     c = pre.create("."*l).sample()
     return Constant(c)
コード例 #9
0
ファイル: test_pregex.py プロジェクト: liqing-ustc/dreamcoder
import random
import pregex as pre
import math

assert (pre.create("f(a|o)*") == pre.Concat([
    pre.String("f"),
    pre.KleeneStar(pre.Alt([pre.String("a"), pre.String("o")]))
]))
assert (pre.create("fa|o*") == pre.Concat([
    pre.String("f"),
    pre.Alt([pre.String("a"), pre.KleeneStar(pre.String("o"))])
]))
assert (pre.create("(f.*)+") == pre.Plus(
    pre.Concat([pre.String("f"), pre.KleeneStar(pre.dot)])))

test_cases = [("foo", "fo", False), ("foo", "foo", True),
              ("foo", "fooo", False), ("foo", "fo*", True),
              ("foo", "fo+", True), ("foo", "f(oo)*", True),
              ("foo", "f(a|b)*", False), ("foo", "f(a|o)*", True),
              ("foo", "fa|o*", True), ("foo", "fo|a*", False),
              ("foo", "f|ao|ao|a", True),
              ("f" + "o" * 50, "f" + "o*" * 10, True), ("foo", "fo?+", True),
              ("foo", "fo**", True), ("(foo)", "\\(foo\\)", True),
              ("foo foo. foo foo foo.", "foo(\\.? foo)*\\.", True),
              ("123abcABC ", ".+", True), ("123abcABC ", '\\w+', False),
              ("123abcABC ", "\\w+\\s", True),
              ("123abcABC ", "\\d+\\l+\\u+\\s", True)]
for (string, regex, matches) in test_cases:
    print("Parsing", regex)
    r = pre.create(regex)
    print("Matching", string, r)
コード例 #10
0
        while nextInput:
            s = input()
            if s == "":
                nextInput = False
            else:
                examples.append(s)

    print("calculating... ")
    samples, scores = model.sampleAndScore([examples], nRepeats=2)
    print(samples)
    print(scores)
    print(len(scores), len(samples))
    index = scores.index(max(scores))
    #print(samples[index])
    try:
        sample = pre.create(list(samples[index]))
    except:
        sample = samples[index]
    #sample = samples[index]
    print("best example by nn score:", sample, ", nn score:", max(scores))

    pregexes = []
    pscores = []
    for samp in samples:
        try:
            reg = pre.create(list(samp))
            pregexes.append(reg)
            pscores.append(sum(reg.match(ex) for ex in examples))
        except:
            pregexes.append(samp)
            pscores.append(float('-inf'))
コード例 #11
0
if __name__ == '__main__':

    import time
    import pregex as pre

    g = basegrammar
    d = sample_datum(g=g,
                     N=4,
                     compute_sketches=True,
                     top_k_sketches=100,
                     inv_temp=1.0,
                     reward_fn=None,
                     sample_fn=None,
                     dc_model=None)
    print(d.p)
    print(d.p.evaluate([]))
    print(d.sketch)
    #print(d.sketch.evaluate([])(pre.String("")))
    print(d.sketch.evaluate([]))
    print(d.sketchseq)
    for o in d.IO:
        print("example")
        print(o)

    from util.regex_util import PregHole, pre_to_prog

    preg = pre.create(d.sketchseq, lookup={PregHole: PregHole()})
    print(preg)
    print(pre_to_prog(preg))
コード例 #12
0
ファイル: examineFrontier.py プロジェクト: zlapp/ec
        print(task.name)
        totalTasks += 1
        print("\tTRAIN\t", ["".join(example[1]) for example in task.examples])

        testingExamples = regexHeldOutExamples(task)
        print("\tTEST\t", [example[1] for example in testingExamples])

        gt_preg = gt_dict[int(task.name.split(" ")[-1])]
        print("\tHuman written regex:", gt_preg)

        eprint(verbatimTable(["".join(example[1]) for example in task.examples] + [None] + \
                             [gt_preg,None] + \
                             [example[1] for example in testingExamples]))
        eprint("&")

        gt_preg = pre.create(gt_preg)

        def examineProgram(entry):
            global preg
            global diff_lookup
            program = entry.program
            ll = entry.logLikelihood
            program = program.visit(ConstantVisitor(task.str_const))
            print(program)
            preg = program.evaluate([])(pre.String(""))
            Pstring = prettyRegex(preg)

            if autodiff:
                params, diff_lookup = create_params()  #TODO

                opt = optim.Adam(params, lr=0.1)
コード例 #13
0
 def getRegexConcept(o):
     r = pre.create(o, lookup=lookup)
     t, c = current_trace.addregex(r)
     return (t, c)
コード例 #14
0
				if type(state.regex) is pre.Alt:
					for x in state.regex.values:
						if type(x) is RegexWrapper:
							addParent(c, x.concept)

		#return {c: descendants[c] + ancestors[c] for c in self.baseConcepts}
		return {c: parents[c] + children[c] for c in self.baseConcepts}

# ------------ Unit tests ------------

if __name__=="__main__":
	import pickle
	import os

	trace = Trace()
	trace, firstName = trace.addPYregex(pre.create("\\w+"))
	trace, lastName  = trace.addPYregex(pre.create("\\w+"))

	regex = pre.create("f l", {"f":RegexWrapper(firstName), "l":RegexWrapper(lastName)})
	trace, fullName = trace.addPYregex(regex)


	trace, observation1 = trace.observe(fullName, "Luke Hewitt")
	trace, observation2 = trace.observe(fullName, "Kevin Ellis")
	trace, observation3 = trace.observe(fullName, "Max Nye")
	trace, observation4 = trace.observe(fullName, "Max Siegel")
	trace, observation5 = trace.observe(fullName, "Max KW")

	with open('trace_Test.p', 'wb') as file:
		pickle.dump(trace, file)
	with open('trace_Test.p', 'rb') as file:
コード例 #15
0
def loadData(file, n_examples, n_tasks, max_length):
    if file[-9:] == "csv_900.p":
        print("Loading csv_900.p, ignoring data params.")
        with open(file, 'rb') as f:
            return pickle.load(f)

    rand = np.random.RandomState()
    rand.seed(0)

    all_tasks = []
    for x in pickle.load(open(file, 'rb')):
        elems_filtered = [elem for elem in x['data'] if len(elem) < max_length]
        if len(elems_filtered) == 0: continue

        task = rand.choice(elems_filtered,
                           size=min(len(elems_filtered), n_examples),
                           replace=False).tolist()
        all_tasks.append(task)

    data = []

    def lenEntropy(examples):
        return (max(len(x) for x in examples), -util.entropy(examples))

    all_tasks = sorted(all_tasks, key=lenEntropy)
    tasks_unique = []

    for task in all_tasks:
        unique = set(task)
        if not any(
                len(unique ^ x) / len(unique) < 0.7 for x in tasks_unique
        ):  #No two tasks should have mostly the same unique elements
            data.append(task)
            tasks_unique.append(unique)

    data = [X for X in data if not all(x == X[0] for x in X)]
    grouped_data = [[
        examples for examples in data if max(len(x)
                                             for x in examples[:100]) == i
    ] for i in range(max_length)]
    grouped_data = [X for X in grouped_data if len(X) > 0]

    #pos_int_regex = pre.create("0|((1|2|3|4|5|6|7|8|9)\d*)")
    #float_regex = pre.Concat([pos_int_regex, pre.create("\.\d+")])
    num_regex = pre.create("-?0|((1|2|3|4|5|6|7|8|9)\d*)(\.\d+)?")

    test_data = []
    for i in range(len(grouped_data)):
        #rand.shuffle(grouped_data[i])
        for fil in [num_regex]:
            fil_idxs = [
                j for j, xs in enumerate(grouped_data[i]) if all(
                    fil.match(x) > float("-inf") for x in Counter(xs))
            ]  #Indexes that match filter
            grouped_data[i] = [
                grouped_data[i][j] for j in range(len(grouped_data[i]))
                if j not in fil_idxs[math.ceil(0.25 * len(grouped_data[i])):]
            ]  #Keep at most 20%

        grouped_data[i].sort(key=len, reverse=True)
        test_data.extend(
            [X for X in grouped_data[i][n_tasks:] if len(set(X)) >= 5])
        grouped_data[i] = grouped_data[i][:n_tasks]
        grouped_data[i].sort(key=lenEntropy)

    data = [x for examples in grouped_data for x in examples]
    #group_idxs = list(np.cumsum([len(X) for X in grouped_data]))
    # rand.shuffle(data)
    # if args.n_tasks is not None:
    # 	data = data[args.skip_tasks:args.n_tasks + args.skip_tasks]

    # data = sorted(data, key=lambda examples: (max(len(x) for x in examples), -len(set(examples))))

    test_data = test_data[:-(len(test_data) % 10)]
    data = data[:-((len(data) + len(test_data)) % 100)]

    data.sort(key=lenEntropy)
    test_data.sort(key=lenEntropy)

    unique_lengths = sorted(list(set([max(len(x) for x in X) for X in data])))
    group_idxs = np.cumsum([
        len([X for X in data if max(len(x) for x in X) == l])
        for l in unique_lengths
    ])
    return data, group_idxs, test_data