def primitive(self, e): if e.name == "r_const": #return Primitive("STRING", e.tp, random.choice(self.words)) s = random.choice(self.regexes).sample() #random string const s = pre.String(s) e.value = PRC(s,arity=0) return e
def taskOfProgram(self, p, t): #raise NotImplementedError num_examples = random.choice(self.num_examples_list) p = p.visit(ConstantInstantiateVisitor.SINGLE) preg = p.evaluate([])(pre.String("")) t = Task("Helm", t, [((), list(preg.sample())) for _ in range(num_examples) ]) return t
def examineProgram(entry): global preg global diff_lookup program = entry.program ll = entry.logLikelihood program = program.visit(ConstantVisitor(task.str_const)) print(program) preg = program.evaluate([])(pre.String("")) Pstring = prettyRegex(preg) if autodiff: params, diff_lookup = create_params() #TODO opt = optim.Adam(params, lr=0.1) for i in range(AUTODIFF_ITER): opt.zero_grad() #normalize_diff_lookup(diff_lookup) #todo, using softmax and such preg = preg.map(map_fun(diff_lookup)) #preg.map(fun) score = sum( preg.match( "".join(example[1])) for example in task.examples) (-score).backward(retain_graph=True) opt.step() # if i%10==0: # print(i, params) #post-optimization score #normalize_diff_lookup(diff_lookup) #todo, using softmax and such preg = preg.map(map_fun(diff_lookup)) #print("parameters:") ll = sum( preg.match( "".join(example[1])) for example in task.examples ) testing_likelihood = sum(preg.match(testingString) for _,testingString in testingExamples) ground_truth_testing = sum(gt_preg.match(testingString) for _,testingString in testingExamples) eprint("&") eprint(verbatimTable([Pstring] + [preg.sample() for i in range(5)])) print("\t", Pstring) print("\t", "samples:") print("\t", [preg.sample() for i in range(5)]) entry.trainHit = ll >= task.gt if ll >= task.gt: print(f"\t HIT (train), Ground truth: {task.gt}, found ll: {ll}") else: print(f"\t MISS (train), Ground truth: {task.gt}, found ll: {ll}") entry.testHit = testing_likelihood >= ground_truth_testing if testing_likelihood >= ground_truth_testing: print(f"\t HIT (test), Ground truth: {ground_truth_testing}, found ll: {testing_likelihood}") else: print(f"\t MISS (test), Ground truth: {ground_truth_testing}, found ll: {testing_likelihood}")
def testingRegexLikelihood(task, program): global REGEXCACHINGTABLE from dreamcoder.domains.regex.makeRegexTasks import regexHeldOutExamples import pregex as pre testing = regexHeldOutExamples(task) program = program.visit(ConstantVisitor(task.str_const)) r = program.evaluate([])(pre.String("")) ll = 0. for _, s in testing: if (r, s) not in REGEXCACHINGTABLE: REGEXCACHINGTABLE[(r, s)] = r.match(s) ll += REGEXCACHINGTABLE[(r, s)] return ll
def addStupidRegex(frontier, g): global stupidProgram global stupidRegex import pregex as pre if stupidProgram is None: from dreamcoder.domains.regex.regexPrimitives import reducedConcatPrimitives reducedConcatPrimitives() stupidProgram = Program.parse("(lambda (r_kleene (lambda (r_dot $0)) $0))") stupidRegex = stupidProgram.evaluate([])(pre.String("")) if any( e.program == stupidProgram for e in frontier ): return frontier lp = g.logLikelihood(frontier.task.request, stupidProgram) ll = sum(stupidRegex.match("".join(example)) for _,example in frontier.task.examples) fe = FrontierEntry(logPrior=lp, logLikelihood=ll, program=stupidProgram) return Frontier(frontier.entries + [fe], task=frontier.task).normalize()
import random import pregex as pre import math assert (pre.create("f(a|o)*") == pre.Concat([ pre.String("f"), pre.KleeneStar(pre.Alt([pre.String("a"), pre.String("o")])) ])) assert (pre.create("fa|o*") == pre.Concat([ pre.String("f"), pre.Alt([pre.String("a"), pre.KleeneStar(pre.String("o"))]) ])) assert (pre.create("(f.*)+") == pre.Plus( pre.Concat([pre.String("f"), pre.KleeneStar(pre.dot)]))) test_cases = [("foo", "fo", False), ("foo", "foo", True), ("foo", "fooo", False), ("foo", "fo*", True), ("foo", "fo+", True), ("foo", "f(oo)*", True), ("foo", "f(a|b)*", False), ("foo", "f(a|o)*", True), ("foo", "fa|o*", True), ("foo", "fo|a*", False), ("foo", "f|ao|ao|a", True), ("f" + "o" * 50, "f" + "o*" * 10, True), ("foo", "fo?+", True), ("foo", "fo**", True), ("(foo)", "\\(foo\\)", True), ("foo foo. foo foo foo.", "foo(\\.? foo)*\\.", True), ("123abcABC ", ".+", True), ("123abcABC ", '\\w+', False), ("123abcABC ", "\\w+\\s", True), ("123abcABC ", "\\d+\\l+\\u+\\s", True)] for (string, regex, matches) in test_cases: print("Parsing", regex) r = pre.create(regex) print("Matching", string, r)
def test_task(m, task, timeout): start = time.time() failed_cands = set() print(task.examples) frontier = [] sampleFrequency = {} while time.time() - start < timeout: query = makeExamples(task) #import pdb; pdb.set_trace() candidates = m.sample([query] * BATCHSIZE) #i think this works #print('len failed', len(failed_cands)) for cand in candidates: try: p = Program.parse(" ".join(cand)) except ParseFailure: continue except IndexError: continue except AssertionError: continue if p not in failed_cands: if "STRING" in str(p): assert arguments.domain == 'text' if len(task.stringConstants) == 0: ll = float('-inf') else: ci = Text.ConstantInstantiateVisitor( [[cc for cc in sc] for sc in task.stringConstants], sample=False) ll = min( task.logLikelihood(pp, timeout=0.1 if arguments. domain != 'rational' else None) for pp in p.visit(ci)) if arguments.domain == 'regex': # regex is handled specially # we just collect all of the candidates and then marginalize over them # but we have to make sure that each candidate is well typed and well formed ll = float('-inf') if not p.canHaveType(task.request): p = None else: from examineFrontier import ConstantVisitor p = p.visit(ConstantVisitor(task.str_const)) try: regex = p.evaluate([])(pre.String("")) if arguments.sampleLikelihood: sampleFrequency[ regex] = 1 + sampleFrequency.get(regex) p = None else: dataLikelihood = sum( regex.match("".join(y)) for _, y in task.examples) logPrior = g.logLikelihood(task.request, p) frontier.append( FrontierEntry( p, logPrior=logPrior, logLikelihood=dataLikelihood)) #print("sampled program",p, # "which translates into regex",regex, # "and which assigns the following likelihood to the test data", # dataLikelihood, # "and which has prior probability",logPrior) except: p = None elif arguments.domain != 'logo': ll = task.logLikelihood( p, timeout=0.1 if arguments.domain != 'rational' else None) else: try: yh = drawLogo(p, timeout=1., resolution=28) if isinstance(yh, list) and list(map( int, yh)) == task.examples[0][1]: ll = 0. else: ll = float('-inf') #print("no warning, we are cool.jpeg") except JSONDecodeError: eprint( "WARNING: Could not decode json. If this occurs occasionally it might be because the neural network is producing invalid code. Otherwise, if this occurs frequently, then this is a bug." ) ll = float('-inf') #print(ll) if ll > float('-inf'): #print(p) #print(task.name) return True elif p is not None: failed_cands.add(p) if arguments.domain != 'regex': return False from examineFrontier import testingRegexLikelihood if arguments.sampleLikelihood: return lse([ math.log(frequency) + testingRegexLikelihood(task, regex) for regex, frequency in sampleFrequency.items() ]) # calculate that thing that we have to for regex frontier = Frontier(frontier, task) from graphs import addStupidRegex frontier = addStupidRegex(frontier, g) print("for this task I think that the following is the map estimate:\n", frontier.topK(1)) if arguments.taskLikelihood: return lse([e.logPrior + e.logLikelihood for e in frontier]) return lse([ e.logPosterior + testingRegexLikelihood(task, e.program) for e in frontier ])
def primitive(self, e): if e.name == "r_const": e.value = PRC(pre.String(self.const)) return e
examineProgram(entry) posteriorHits += int(entry.trainHit) posteriorHits_test += int(entry.testHit) print("\t", "best Likelihood:") entry = max(frontier.entries, key=lambda e: e.logLikelihood) examineProgram(entry) likelihoodHits += int(entry.trainHit) likelihoodHits_test += int(entry.testHit) print() print("\t", "Posterior predictive samples...") programSamples = [frontier.sample().program for _ in range(5)] programSamples = [ p.visit(ConstantVisitor(task.str_const)).evaluate([])( pre.String("")) for p in programSamples ] stringSamples = [p.sample() for p in programSamples] programSamples = [prettyRegex(p) for p in programSamples] eprint("&") eprint( verbatimTable(list(zip(stringSamples, programSamples)), columns=2)) eprint("\\\\") eprint() posterior = [(e.logPosterior, e.program.visit(ConstantVisitor( task.str_const)).evaluate([])(pre.String(""))) for e in frontier.normalize()]
def getProposals( net, current_trace, target_examples, net_examples=None, depth=0, modes=("crp", "regex-crp", "crp-regex"), nProposals=10, likelihoodWeighting=1, subsampleSize=None, altWith=None, maxNetworkEvals=None, doPrint=True, fuzzConcepts=True ): #Includes proposals from network, and proposals on existing concepts assert (all(x in ["crp", "regex-crp", "regex-crp-crp", "crp-regex"] for x in modes)) examples = net_examples if net_examples is not None else target_examples if subsampleSize is not None: counter = Counter(examples) min_examples, max_examples = subsampleSize nSubsamples = 10 proposal_strings_sofar = [ ] #TODO: this better. Want to avoid duplicate proposals. For now, just using string representation to check... for i in range(nSubsamples): num_examples = random.randint(min_examples, max_examples) sampled_examples = list( np.random.choice(list(counter.keys()), size=min(num_examples, len(counter)), p=np.array(list(counter.values())) / sum(counter.values()), replace=False)) for proposal in getProposals(net, current_trace, target_examples, sampled_examples, depth, modes, int(nProposals / nSubsamples), likelihoodWeighting, subsampleSize=None): proposal_string = proposal.concept.str(proposal.trace, depth=-1) if proposal_string not in proposal_strings_sofar: proposal_strings_sofar.append(proposal_string) yield proposal else: examples = tuple(sorted(examples)) isCached = examples in networkCache cur_proposals = [] net_proposals = [] def getProposalID(proposal): #To avoid duplicate proposals return proposal.concept.str(proposal.trace, depth=-1) proposalIDs_so_far = [] def addProposal(trace, concept, add_to, related=()): def f(t, c, final): return Proposal( depth, tuple(sorted(examples)), tuple(target_examples) if final else tuple(examples), current_trace, t, c, (), altWith, None, None, None) p = evalProposal(f(trace, concept, final=False), likelihoodWeighting=likelihoodWeighting * len(target_examples) / len(examples)) if p.valid and getProposalID(p) not in proposalIDs_so_far: relatedProposals = tuple( f(t, c, final=True) for (t, c) in related) p = p._replace(related=relatedProposals, target_examples=tuple(target_examples)) proposalIDs_so_far.append(getProposalID(p)) add_to.append(p) return p if p.valid else None addProposal(*current_trace.addregex( pre.String(examples[0]) if len(set(examples)) == 1 else pre.Alt([pre.String(x) for x in set(examples)])), cur_proposals) #Exactly the examples for c in current_trace.baseConcepts: addProposal(current_trace.fork(), c, cur_proposals) if "crp" in modes: t, c = current_trace.addPY(c) addProposal(t, c, cur_proposals) n_cur = math.ceil(nProposals / 2) n_net = math.floor(nProposals / 2) m_net = n_net * 5 if net is not None: similarConcepts = current_trace.getSimilarConcepts() lookup = { concept: RegexWrapper(concept) for concept in current_trace.baseConcepts } def getRegexConcept(o): r = pre.create(o, lookup=lookup) t, c = current_trace.addregex(r) return (t, c) def getRelatedRegexConcepts( o ): #Proposals that will be good only if getRegexConcept(o) is good def extend(t, c): #add CRPs at the end if any(x in modes for x in ("regex-crp", "regex-crp-crp")): t, c = t.addPY(c) if "regex-crp" in modes: yield (t, c) if "regex-crp-crp" in modes: t, c = t.addPY(c) yield (t, c) for (t, c) in extend(*getRegexConcept(o)): yield (t, c) for i in range(len(o)): if o[i] in current_trace.baseConcepts: if fuzzConcepts: #Try replacing one concept in regex with parent or child for o_alt in similarConcepts.get(o[i], []): r = pre.create(o[:i] + (o_alt, ) + o[i + 1:], lookup=lookup) t, c = current_trace.addregex(r) yield (t, c) for (t, c) in extend(t, c): yield (t, c) if "crp-regex" in modes: #Try replacing one concept with a new PYConcept t, c = current_trace.addPY(o[i]) r = pre.create(o[:i] + (c, ) + o[i + 1:], lookup={ **lookup, c: RegexWrapper(c) }) t, c = t.addregex(r) yield (t, c) for (t, c) in extend(t, c): yield (t, c) # if len(o)==0: # yield () # else: # for s2 in getRelatedRegexStrings(o[1:]): # for s1 in [o[0]] + similarConcepts.get(o[0], []): # yield (s1,) + s2 for (o, count, group_idx) in getValidNetworkOutputs(net, current_trace, examples): t, c = getRegexConcept(o) p = addProposal(t, c, net_proposals, related=getRelatedRegexConcepts(o)) if group_idx >= m_net: break cur_proposals.sort(key=lambda proposal: proposal.final_trace.score, reverse=True) net_proposals.sort(key=lambda proposal: proposal.final_trace.score, reverse=True) # scores = {proposals[i]:evals[i].trace.score for i in range(len(proposals)) if evals[i].trace is not None} # proposals = sorted(scores.keys(), key=lambda proposal:-scores[proposal]) proposals = cur_proposals[:n_cur] + net_proposals[:n_net] proposals.sort(key=lambda proposal: proposal.final_trace.score, reverse=True) if not isCached and doPrint: print("Proposals (ll*%2.2f): " % likelihoodWeighting, ", ".join(examples), "--->", ", ".join(("N:" if proposal in net_proposals else "") + proposal.concept.str(proposal.trace) for proposal in proposals), flush=True) for p in proposals: if tuple(sorted(examples)) == tuple(sorted(target_examples)): yield p else: yield p.strip()
print("\t", "best Posterior:") entry = max(frontier.entries, key=lambda e: e.logLikelihood + e.logPrior) examineProgram(entry) posteriorHits += int(entry.trainHit) posteriorHits_test += int(entry.testHit) print("\t", "best Likelihood:") entry = max(frontier.entries, key=lambda e: e.logLikelihood) examineProgram(entry) likelihoodHits += int(entry.trainHit) likelihoodHits_test += int(entry.testHit) print() print("\t","Posterior predictive samples...") programSamples = [frontier.sample().program for _ in range(5)] programSamples = [p.visit(ConstantVisitor(task.str_const)).evaluate([])(pre.String("")) for p in programSamples ] stringSamples = [p.sample() for p in programSamples ] programSamples = [prettyRegex(p) for p in programSamples ] eprint("&") eprint(verbatimTable(list(zip(stringSamples, programSamples)),columns=2)) eprint("\\\\") eprint() posterior = [(e.logPosterior, e.program.visit(ConstantVisitor(task.str_const)).evaluate([])(pre.String(""))) for e in frontier.normalize() ] testingExamples = [te for _,te in testingExamples]