def makeTestdata(synth=True, challenge=False): raise NotImplementedError tasks = [] if synth: tasks = makeTasks() if challenge: challenge_tasks, _ = loadPBETasks() tasks = tasks + challenge_tasks tasklist = [] for task in tasks: if task.stringConstants == [] and task.request == arrow( tlist(tcharacter), tlist(tcharacter)): IO = tuple((''.join(x[0]), ''.join(y)) for x, y in task.examples) program = None pseq = None sketch, sketchseq, reward, sketchprob = None, None, None, None tp = tprogram tasklist.append( Datum(tp, program, pseq, IO, sketch, sketchseq, reward, sketchprob)) return tasklist
def recurse_list(l, target_tp): #base case x = recurse(l[0]) x = convert_to_tp(x, target_tp) e = convert_to_tp(x, tlist(target_tp)) for exp in l[1:]: x = recurse(exp) if x.infer() != target_tp: x = convert_to_tp(x, target_tp) # maybe always convert? request = arrow(target_tp, tlist(target_tp), tlist(target_tp)) list_converter = [x for x in Primitive.GLOBALS.values() if x.tp==request][0] # TODO e = Application(Application(list_converter, x), e) return e
def convert_source_to_datum(source, N=5, V=512, L=10, compute_sketches=False, top_k_sketches=20, inv_temp=1.0, reward_fn=None, sample_fn=None, dc_model=None, use_timeout=False): source = source.replace(' | ', '\n') dc_program = compile(source, V=V, L=L) if dc_program is None: return None # find IO IO = tuple(generate_IO_examples(dc_program, N=N, L=L, V=V)) # find tp ins = [tint if inp == int else tlist(tint) for inp in dc_program.ins] if dc_program.out == int: out = tint else: assert dc_program.out==[int] out = tlist(tint) tp = arrow( *(ins+[out]) ) # find program p pseq = tuple(convert_dc_program_to_ec(dc_program, tp)) # find pseq p = parseprogram(pseq, tp) # TODO: use correct grammar, and if compute_sketches: # find sketch grammar = basegrammar if not dc_model else dc_model.infer_grammar(IO) #This line needs to change sketch, reward, sketchprob = make_holey_deepcoder(p, top_k_sketches, grammar, tp, inv_temp=inv_temp, reward_fn=reward_fn, sample_fn=sample_fn, use_timeout=use_timeout) #TODO # find sketchseq sketchseq = tuple(flatten_program(sketch)) else: sketch, sketchseq, reward, sketchprob = None, None, None, None return Datum(tp, p, pseq, IO, sketch, sketchseq, reward, sketchprob)
def basePrimitives(): return [Primitive(str(j), tint, j) for j in xrange(6)] + [ Primitive("*", arrow(tint, tint, tint), _multiplication), Primitive("gt?", arrow(tint, tint, tbool), _gt), Primitive("is-prime", arrow(tint, tbool), _isPrime), Primitive("is-square", arrow(tint, tbool), _isSquare), # McCarthy Primitive("empty", tlist(t0), []), Primitive("cons", arrow(t0, tlist(t0), tlist(t0)), _cons), Primitive("car", arrow(tlist(t0), t0), _car), Primitive("cdr", arrow(tlist(t0), tlist(t0)), _cdr), Primitive("empty?", arrow(tlist(t0), tbool), _isEmpty), Primitive("if", arrow(tbool, t0, t0, t0), _if), Primitive("eq?", arrow(tint, tint, tbool), _eq), Primitive("+", arrow(tint, tint, tint), _addition), Primitive("-", arrow(tint, tint, tint), _subtraction) ]
def McCarthyPrimitives(): "These are ~ primitives provided by 1959 lisp as introduced by McCarthy" return [ Primitive("empty", tlist(t0), []), Primitive("cons", arrow(t0, tlist(t0), tlist(t0)), _cons), Primitive("car", arrow(tlist(t0), t0), _car), Primitive("cdr", arrow(tlist(t0), tlist(t0)), _cdr), Primitive("empty?", arrow(tlist(t0), tbool), _isEmpty), #Primitive("unfold", arrow(t0, arrow(t0,t1), arrow(t0,t0), arrow(t0,tbool), tlist(t1)), _isEmpty), #Primitive("1+", arrow(tint,tint),None), # Primitive("range", arrow(tint, tlist(tint)), range), # Primitive("map", arrow(arrow(t0, t1), tlist(t0), tlist(t1)), _map), # Primitive("index", arrow(tint,tlist(t0),t0),None), # Primitive("length", arrow(tlist(t0),tint),None), primitiveRecursion1, primitiveRecursion2, Primitive("if", arrow(tbool, t0, t0, t0), _if), Primitive("eq?", arrow(tint, tint, tbool), _eq), Primitive("+", arrow(tint, tint, tint), _addition), Primitive("-", arrow(tint, tint, tint), _subtraction), ] + [Primitive(str(j), tint, j) for j in xrange(2)]
def convert_examples_to_datum(self): # Get the type first: funtype = self.examples[0].typeof() ioexamples = [] for example in self.examples: inputs = example.inputs outputs = example.outputs[0] non_int = False for inp in inputs: if type(inp) != type(1): non_int = True # This is all integers, so transform it into a list. if not non_int: print("Warning: converting type for SketchAdapt to avoid int->int") inputs = (list(example.inputs),) print(inputs) funtype = arrow(tlist(tint), self.examples[0].outtype) ioexamples.append((inputs, outputs)) ioexamples = tuple(ioexamples) datum = Datum(funtype, None, None, ioexamples, None, None, None, None) return [datum]
def algolispPrimitives(): return [ Primitive("fn_call", arrow(tfunction, tlist(tsymbol), tsymbol), _fn_call), Primitive("lambda1_call", arrow(tfunction, tlist(tsymbol), tsymbol), lambda f: lambda sx: ["lambda1", [f] + sx] if type(sx)==list else ["lambda1", [f] + [sx]] ), Primitive("lambda2_call", arrow(tfunction, tlist(tsymbol), tsymbol), lambda f: lambda sx: ["lambda2", [f] + sx] if type(sx)==list else ["lambda2", [f] + [sx]] ), #symbol converters: # SYMBOL = constant | argument | function_call | function | lambda Primitive("symbol_constant", arrow(tconstant, tsymbol), lambda x: x), Primitive("symbol_function", arrow(tfunction, tsymbol), lambda x: x), #list converters Primitive('list_init_symbol', arrow(tsymbol, tlist(tsymbol)), lambda symbol: [symbol] ), Primitive('list_add_symbol', arrow(tsymbol, tlist(tsymbol), tlist(tsymbol)), lambda symbol: lambda symbols: symbols + [symbol] if type(symbols) == list else [symbols] + [symbol]) ] + [ #functions: Primitive(ec_name, tfunction, algo_name) for algo_name, ec_name in fn_lookup.items() ] + [ #Constants Primitive(ec_name, tconstant, algo_name) for algo_name, ec_name in const_lookup.items() ]
def primitives(): return [Primitive(str(j), tint, j) for j in xrange(6)] + [ Primitive("empty", tlist(t0), []), Primitive("singleton", arrow(t0, tlist(t0)), _single), Primitive("range", arrow(tint, tlist(tint)), range), Primitive("++", arrow(tlist(t0), tlist(t0), tlist(t0)), _append), # Primitive("map", arrow(arrow(t0, t1), tlist(t0), tlist(t1)), _map), Primitive("mapi", arrow(arrow(tint, t0, t1), tlist(t0), tlist(t1)), _mapi), # Primitive("reduce", arrow(arrow(t1, t0, t1), t1, tlist(t0), t1), _reduce), Primitive("reducei", arrow(arrow(tint, t1, t0, t1), t1, tlist(t0), t1), _reducei), Primitive("true", tbool, True), Primitive("not", arrow(tbool, tbool), _not), Primitive("and", arrow(tbool, tbool, tbool), _and), Primitive("or", arrow(tbool, tbool, tbool), _or), # Primitive("if", arrow(tbool, t0, t0, t0), _if), Primitive("sort", arrow(tlist(tint), tlist(tint)), sorted), Primitive("+", arrow(tint, tint, tint), _addition), Primitive("*", arrow(tint, tint, tint), _multiplication), Primitive("negate", arrow(tint, tint), _negate), Primitive("mod", arrow(tint, tint, tint), _mod), Primitive("eq?", arrow(tint, tint, tbool), _eq), Primitive("gt?", arrow(tint, tint, tbool), _gt), Primitive("is-prime", arrow(tint, tbool), _isPrime), Primitive("is-square", arrow(tint, tbool), _isSquare), # these are achievable with above primitives, but unlikely #Primitive("flatten", arrow(tlist(tlist(t0)), tlist(t0)), _flatten), ## (lambda (reduce (lambda (lambda (++ $1 $0))) empty $0)) Primitive("sum", arrow(tlist(tint), tint), sum), # (lambda (lambda (reduce (lambda (lambda (+ $0 $1))) 0 $0))) Primitive("reverse", arrow(tlist(t0), tlist(t0)), _reverse), # (lambda (reduce (lambda (lambda (++ (singleton $0) $1))) empty $0)) Primitive("all", arrow(arrow(t0, tbool), tlist(t0), tbool), _all), # (lambda (lambda (reduce (lambda (lambda (and $0 $1))) true (map $1 $0)))) Primitive("any", arrow(arrow(t0, tbool), tlist(t0), tbool), _any), # (lambda (lambda (reduce (lambda (lambda (or $0 $1))) true (map $1 $0)))) Primitive("index", arrow(tint, tlist(t0), t0), _index), # (lambda (lambda (reducei (lambda (lambda (lambda (if (eq? $1 $4) $0 0)))) 0 $0))) Primitive("filter", arrow(arrow(t0, tbool), tlist(t0), tlist(t0)), _filter), # (lambda (lambda (reduce (lambda (lambda (++ $1 (if ($3 $0) (singleton $0) empty)))) empty $0))) #Primitive("replace", arrow(arrow(tint, t0, tbool), tlist(t0), tlist(t0), tlist(t0)), _replace), ## (FLATTEN (lambda (lambda (lambda (mapi (lambda (lambda (if ($4 $1 $0) $3 (singleton $1)))) $0))))) Primitive("slice", arrow(tint, tint, tlist(t0), tlist(t0)), _slice), # (lambda (lambda (lambda (reducei (lambda (lambda (lambda (++ $2 (if (and (or (gt? $1 $5) (eq? $1 $5)) (not (or (gt? $4 $1) (eq? $1 $4)))) (singleton $0) empty))))) empty $0)))) ]
] + [Primitive(str(j), tint, j) for j in xrange(2)] if __name__ == "__main__": import pickle g = Grammar.uniform(McCarthyPrimitives()) with open( "/home/ellisk/om/ec/experimentOutputs/list_aic=1.0_arity=3_ET=1800_expandFrontier=2.0_it=4_likelihoodModel=all-or-nothing_MF=5_baseline=False_pc=10.0_L=1.0_K=5_rec=False.pickle", "rb") as handle: b = pickle.load(handle).grammars[-1] print b p = Program.parse( "(lambda (lambda (lambda (if (empty? $0) empty (cons (+ (car $0) (car $1)) ($2 (cdr $0) (cdr $1)))))))" ) t = arrow(tlist(tint), tlist(tint), tlist(tint)) print g.logLikelihood(arrow(t, t), p) print b.logLikelihood(arrow(t, t), p) # p = Program.parse("""(lambda (lambda # (unfold 0 # (lambda (+ (index $0 $2) (index $0 $1))) # (lambda (1+ $0)) # (lambda (eq? $0 (length $1)))))) # """) p = Program.parse("""(lambda (lambda (map (lambda (+ (index $0 $2) (index $0 $1))) (range (length $0)) )))""") # .replace("unfold", "#(lambda (lambda (lambda (lambda (fix1 $0 (lambda (lambda (#(lambda (lambda (lambda (if $0 empty (cons $1 $2))))) ($1 ($3 $0)) ($4 $0) ($5 $0)))))))))").\ # replace("length", "#(lambda (fix1 $0 (lambda (lambda (if (empty? $0) 0 (+ ($1 (cdr $0)) 1))))))").\ # replace("forloop", "(#(lambda (lambda (lambda (lambda (fix1 $0 (lambda (lambda (#(lambda (lambda (lambda (if $0 empty (cons $1 $2))))) ($1 ($3 $0)) ($4 $0) ($5 $0))))))))) (lambda (#(eq? 0) $0)) $0 (lambda (#(lambda (- $0 1)) $0)))").\ # replace("inc","#(lambda (+ $0 1))").\
def array_output(self, arr, nolen=False): self.outputs = [arr] self.outtype = tlist(tint)
def add_array_input(self, arr1, nolen=False): self.intype.append(tlist(tint)) if nolen: self.inputs.append(arr1) else: self.inputs.append(arr1)
def OldDeepcoderPrimitives(): return [ Primitive("head", arrow(tlist(tint), tint), _head), Primitive("tail", arrow(tlist(tint), tint), _tail), Primitive("take", arrow(tint, tlist(tint), tlist(tint)), _take), Primitive("drop", arrow(tint, tlist(tint), tlist(tint)), _drop), Primitive("access", arrow(tint, tlist(tint), tint), _access), Primitive("minimum", arrow(tlist(tint), tint), _minimum), Primitive("maximum", arrow(tlist(tint), tint), _maximum), Primitive("reverse", arrow(tlist(tint), tlist(tint)), _reverse), Primitive("sort", arrow(tlist(tint), tlist(tint)), _sort), Primitive("sum", arrow(tlist(tint), tint), _sum) ] + [ Primitive("map", arrow(int_to_int, tlist(tint), tlist(tint)), _map), #is this okay??? Primitive("filter", arrow(int_to_bool, tlist(tint), tlist(tint)), _filter), #is this okay??? Primitive("count", arrow(int_to_bool, tlist(tint), tint), _count), #is this okay??? Primitive( "zipwith", arrow(int_to_int_to_int, tlist(tint), tlist(tint), tlist(tint)), _zipwith), #is this okay??? Primitive("scanl1", arrow(int_to_int_to_int, tlist(tint), tlist(tint)), _scanl1), #is this okay??? # ] + [ # Primitive("succ", arrow(tint, tint), _succ), # Primitive("pred", arrow(tint, tint), _pred), # Primitive("double", arrow(tint, tint), _double), # Primitive("half", arrow(tint, tint), _half), # Primitive("neg", arrow(tint, tint), _neg), # Primitive("square", arrow(tint, tint), _square), # Primitive("triple", arrow(tint, tint), _triple), # Primitive("third", arrow(tint, tint), _third), # Primitive("quad", arrow(tint, tint), _quad), # Primitive("quarter", arrow(tint, tint), _quarter), # ] + [ # Primitive("pos", arrow(tint, tbool), _pos), # Primitive("neg", arrow(tint, tbool), _neg), # Primitive("even", arrow(tint, tbool), _even), # Primitive("odd", arrow(tint, tbool), _odd), # ] + [ # Primitive("add", arrow(tint, tint, tint), _add), # Primitive("sub", arrow(tint, tint, tint), _sub), # Primitive("mult", arrow(tint, tint, tint), _mult), # Primitive("min", arrow(tint, tint, tint), _min), # Primitive("max", arrow(tint, tint, tint), _max) ] + [ Primitive("succ_fn", int_to_int, _succ), Primitive("pred_fn", int_to_int, _pred), Primitive("double_fn", int_to_int, _double), Primitive("half_fn", int_to_int, _half), Primitive("negate_fn", int_to_int, _negate), Primitive("square_fn", int_to_int, _square), Primitive("triple_fn", int_to_int, _triple), Primitive("third_fn", int_to_int, _third), Primitive("quad_fn", int_to_int, _quad), Primitive("quarter_fn", int_to_int, _quarter), ] + [ Primitive("pos_fn", int_to_bool, _pos), Primitive("neg_fn", int_to_bool, _neg), Primitive("even_fn", int_to_bool, _even), Primitive("odd_fn", int_to_bool, _odd), ] + [ Primitive("add_fn", int_to_int_to_int, _add), Primitive("sub_fn", int_to_int_to_int, _sub), Primitive("mult_fn", int_to_int_to_int, _mult), Primitive("min_fn", int_to_int_to_int, _min), Primitive("max_fn", int_to_int_to_int, _max) ]
def bootstrapTarget(): """These are the primitives that we hope to learn from the bootstrapping procedure""" return [ # learned primitives Primitive("map", arrow(arrow(t0, t1), tlist(t0), tlist(t1)), _map), Primitive( "unfold", arrow(t0, arrow(t0, tbool), arrow(t0, t1), arrow(t0, t0), tlist(t1)), _unfold), Primitive("range", arrow(tint, tlist(tint)), _range), Primitive("index", arrow(tint, tlist(t0), t0), _index), Primitive("fold", arrow(tlist(t0), t1, arrow(t0, t1, t1), t1), _fold), Primitive("length", arrow(tlist(t0), tint), len), # built-ins Primitive("if", arrow(tbool, t0, t0, t0), _if), Primitive("+", arrow(tint, tint, tint), _addition), Primitive("-", arrow(tint, tint, tint), _subtraction), Primitive("empty", tlist(t0), []), Primitive("cons", arrow(t0, tlist(t0), tlist(t0)), _cons), Primitive("car", arrow(tlist(t0), t0), _car), Primitive("cdr", arrow(tlist(t0), tlist(t0)), _cdr), Primitive("empty?", arrow(tlist(t0), tbool), _isEmpty), ] + [Primitive(str(j), tint, j) for j in range(2)]
Primitive("is-square", arrow(tint, tbool), _isSquare), # McCarthy Primitive("empty", tlist(t0), []), Primitive("cons", arrow(t0, tlist(t0), tlist(t0)), _cons), Primitive("car", arrow(tlist(t0), t0), _car), Primitive("cdr", arrow(tlist(t0), tlist(t0)), _cdr), Primitive("empty?", arrow(tlist(t0), tbool), _isEmpty), Primitive("if", arrow(tbool, t0, t0, t0), _if), Primitive("eq?", arrow(tint, tint, tbool), _eq), Primitive("+", arrow(tint, tint, tint), _addition), Primitive("-", arrow(tint, tint, tint), _subtraction) ] zip_primitive = Primitive( "zip", arrow(tlist(t0), tlist(t1), arrow(t0, t1, t2), tlist(t2)), _zip) def bootstrapTarget(): """These are the primitives that we hope to learn from the bootstrapping procedure""" return [ # learned primitives Primitive("map", arrow(arrow(t0, t1), tlist(t0), tlist(t1)), _map), Primitive( "unfold", arrow(t0, arrow(t0, tbool), arrow(t0, t1), arrow(t0, t0), tlist(t1)), _unfold), Primitive("range", arrow(tint, tlist(tint)), _range), Primitive("index", arrow(tint, tlist(t0), t0), _index), Primitive("fold", arrow(tlist(t0), t1, arrow(t0, t1, t1), t1), _fold), Primitive("length", arrow(tlist(t0), tint), len),
from grammar import Grammar from deepcoderPrimitives import deepcoderProductions, flatten_program #from program import Application, Hole #import math from type import Context, arrow, tint, tlist, tbool, UnificationFailure from program import prettyProgram from train.main_supervised_deepcoder import parseprogram, make_holey_deepcoder #g = Grammar.uniform(deepcoderPrimitives()) g = Grammar.fromProductions( deepcoderProductions(), logVariable=.9) #TODO - find correct grammar weights request = arrow(tlist(tint), tint, tint) p = g.sample(request) sketch = make_holey_deepcoder(p, 10, g, request) print("request:", request) print("program:") print(prettyProgram(p)) print("flattened_program:") flat = flatten_program(p) print(flat) prog = parseprogram(flat, request) print("recovered program:") print(prettyProgram(prog))
def isListFunction(tp): try: Context().unify(tp, arrow(tlist(tint), t0)) return True except UnificationFailure: return False
def deepcoderPrimitives(): return [ Primitive("HEAD", arrow(tlist(tint), tint), _head), Primitive("LAST", arrow(tlist(tint), tint), _tail), Primitive("TAKE", arrow(tint, tlist(tint), tlist(tint)), _take), Primitive("DROP", arrow(tint, tlist(tint), tlist(tint)), _drop), Primitive("ACCESS", arrow(tint, tlist(tint), tint), _access), Primitive("MINIMUM", arrow(tlist(tint), tint), _minimum), Primitive("MAXIMUM", arrow(tlist(tint), tint), _maximum), Primitive("REVERSE", arrow(tlist(tint), tlist(tint)), _reverse), Primitive("SORT", arrow(tlist(tint), tlist(tint)), _sort), Primitive("SUM", arrow(tlist(tint), tint), _sum) ] + [ Primitive("MAP", arrow(int_to_int, tlist(tint), tlist(tint)), _map), #is this okay??? Primitive("FILTER", arrow(int_to_bool, tlist(tint), tlist(tint)), _filter), #is this okay??? Primitive("COUNT", arrow(int_to_bool, tlist(tint), tint), _count), #is this okay??? Primitive( "ZIPWITH", arrow(int_to_int_to_int, tlist(tint), tlist(tint), tlist(tint)), _zipwith), #is this okay??? Primitive("SCANL1", arrow(int_to_int_to_int, tlist(tint), tlist(tint)), _scanl1), #is this okay??? ] + [ Primitive("INC", int_to_int, _succ), Primitive("DEC", int_to_int, _pred), Primitive("SHL", int_to_int, _double), Primitive("SHR", int_to_int, _half), Primitive("doNEG", int_to_int, _negate), Primitive("SQR", int_to_int, _square), Primitive("MUL3", int_to_int, _triple), Primitive("DIV3", int_to_int, _third), Primitive("MUL4", int_to_int, _quad), Primitive("DIV4", int_to_int, _quarter), ] + [ Primitive("isPOS", int_to_bool, _pos), Primitive("isNEG", int_to_bool, _neg), Primitive("isEVEN", int_to_bool, _even), Primitive("isODD", int_to_bool, _odd), ] + [ Primitive("+", int_to_int_to_int, _add), Primitive("-", int_to_int_to_int, _sub), Primitive("*", int_to_int_to_int, _mult), Primitive("MIN", int_to_int_to_int, _min), Primitive("MAX", int_to_int_to_int, _max) ]