def retrieveJSONTasks(filename, features=False): """ For JSON of the form: {"name": str, "type": {"input" : bool|int|list-of-bool|list-of-int, "output": bool|int|list-of-bool|list-of-int}, "examples": [{"i": data, "o": data}]} """ with open(filename, "r") as f: loaded = json.load(f) TP = { "bool": tbool, "int": tint, "list-of-bool": tlist(tbool), "list-of-int": tlist(tint), } return [ Task( item["name"], arrow(TP[item["type"]["input"]], TP[item["type"]["output"]]), [((ex["i"], ), ex["o"]) for ex in item["examples"]], features=(None if not features else list_features( [((ex["i"], ), ex["o"]) for ex in item["examples"]])), cache=False, ) for item in loaded ]
def genericType(t): if t.name == "real": return treal elif t.name == "positive": return treal elif t.name == "vector": return tlist(treal) elif t.name == "list": return tlist(genericType(t.arguments[0])) elif t.isArrow(): return arrow(genericType(t.arguments[0]), genericType(t.arguments[1])) else: assert False, "could not make type generic: %s" % t
def basePrimitives(): return [Primitive(str(j), tint, j) for j in range(6)] + [ Primitive("*", arrow(tint, tint, tint), _multiplication), Primitive("gt?", arrow(tint, tint, tbool), _gt), Primitive("is-prime", arrow(tint, tbool), _isPrime), Primitive("is-square", arrow(tint, tbool), _isSquare), # McCarthy Primitive("empty", tlist(t0), []), Primitive("cons", arrow(t0, tlist(t0), tlist(t0)), _cons), Primitive("car", arrow(tlist(t0), t0), _car), Primitive("cdr", arrow(tlist(t0), tlist(t0)), _cdr), Primitive("empty?", arrow(tlist(t0), tbool), _isEmpty), Primitive("if", arrow(tbool, t0, t0, t0), _if), Primitive("eq?", arrow(tint, tint, tbool), _eq), Primitive("+", arrow(tint, tint, tint), _addition), Primitive("-", arrow(tint, tint, tint), _subtraction) ]
def McCarthyPrimitives(): "These are < primitives provided by 1959 lisp as introduced by McCarthy" return [ Primitive("empty", tlist(t0), []), Primitive("cons", arrow(t0, tlist(t0), tlist(t0)), _cons), Primitive("car", arrow(tlist(t0), t0), _car), Primitive("cdr", arrow(tlist(t0), tlist(t0)), _cdr), Primitive("empty?", arrow(tlist(t0), tbool), _isEmpty), #Primitive("unfold", arrow(t0, arrow(t0,t1), arrow(t0,t0), arrow(t0,tbool), tlist(t1)), _isEmpty), #Primitive("1+", arrow(tint,tint),None), # Primitive("range", arrow(tint, tlist(tint)), range), # Primitive("map", arrow(arrow(t0, t1), tlist(t0), tlist(t1)), _map), # Primitive("index", arrow(tint,tlist(t0),t0),None), # Primitive("length", arrow(tlist(t0),tint),None), primitiveRecursion1, #primitiveRecursion2, Primitive("gt?", arrow(tint, tint, tbool), _gt), Primitive("if", arrow(tbool, t0, t0, t0), _if), Primitive("eq?", arrow(tint, tint, tbool), _eq), Primitive("+", arrow(tint, tint, tint), _addition), Primitive("-", arrow(tint, tint, tint), _subtraction), ] + [Primitive(str(j), tint, j) for j in range(2)]
def algolispPrimitives(): return [ Primitive("fn_call", arrow(tfunction, tlist(tsymbol), tsymbol), _fn_call), Primitive("lambda1_call", arrow(tfunction, tlist(tsymbol), tsymbol), lambda f: lambda sx: ["lambda1", [f] + sx] if type(sx)==list else ["lambda1", [f] + [sx]] ), Primitive("lambda2_call", arrow(tfunction, tlist(tsymbol), tsymbol), lambda f: lambda sx: ["lambda2", [f] + sx] if type(sx)==list else ["lambda2", [f] + [sx]] ), #symbol converters: # SYMBOL = constant | argument | function_call | function | lambda Primitive("symbol_constant", arrow(tconstant, tsymbol), lambda x: x), Primitive("symbol_function", arrow(tfunction, tsymbol), lambda x: x), #list converters Primitive('list_init_symbol', arrow(tsymbol, tlist(tsymbol)), lambda symbol: [symbol] ), Primitive('list_add_symbol', arrow(tsymbol, tlist(tsymbol), tlist(tsymbol)), lambda symbol: lambda symbols: symbols + [symbol] if type(symbols) == list else [symbols] + [symbol]) ] + [ #functions: Primitive(ec_name, tfunction, algo_name) for algo_name, ec_name in fn_lookup.items() ] + [ #Constants Primitive(ec_name, tconstant, algo_name) for algo_name, ec_name in const_lookup.items() ]
def isListFunction(tp): try: Context().unify(tp, arrow(tlist(tint), t0)) return True except UnificationFailure: return False
def bootstrapTarget(): """These are the primitives that we hope to learn from the bootstrapping procedure""" return [ # learned primitives Primitive("map", arrow(arrow(t0, t1), tlist(t0), tlist(t1)), _map), Primitive("unfold", arrow(t0, arrow(t0,tbool), arrow(t0,t1), arrow(t0,t0), tlist(t1)), _unfold), Primitive("range", arrow(tint, tlist(tint)), _range), Primitive("index", arrow(tint, tlist(t0), t0), _index), Primitive("fold", arrow(tlist(t0), t1, arrow(t0, t1, t1), t1), _fold), Primitive("length", arrow(tlist(t0), tint), len), # built-ins Primitive("if", arrow(tbool, t0, t0, t0), _if), Primitive("+", arrow(tint, tint, tint), _addition), Primitive("-", arrow(tint, tint, tint), _subtraction), Primitive("empty", tlist(t0), []), Primitive("cons", arrow(t0, tlist(t0), tlist(t0)), _cons), Primitive("car", arrow(tlist(t0), t0), _car), Primitive("cdr", arrow(tlist(t0), tlist(t0)), _cdr), Primitive("empty?", arrow(tlist(t0), tbool), _isEmpty), ] + [Primitive(str(j), tint, j) for j in range(2)]
def primitives(): return [Primitive(str(j), tint, j) for j in range(6)] + [ Primitive("empty", tlist(t0), []), Primitive("singleton", arrow(t0, tlist(t0)), _single), Primitive("range", arrow(tint, tlist(tint)), _range), Primitive("++", arrow(tlist(t0), tlist(t0), tlist(t0)), _append), # Primitive("map", arrow(arrow(t0, t1), tlist(t0), tlist(t1)), _map), Primitive( "mapi", arrow( arrow( tint, t0, t1), tlist(t0), tlist(t1)), _mapi), # Primitive("reduce", arrow(arrow(t1, t0, t1), t1, tlist(t0), t1), _reduce), Primitive( "reducei", arrow( arrow( tint, t1, t0, t1), t1, tlist(t0), t1), _reducei), Primitive("true", tbool, True), Primitive("not", arrow(tbool, tbool), _not), Primitive("and", arrow(tbool, tbool, tbool), _and), Primitive("or", arrow(tbool, tbool, tbool), _or), # Primitive("if", arrow(tbool, t0, t0, t0), _if), Primitive("sort", arrow(tlist(tint), tlist(tint)), sorted), Primitive("+", arrow(tint, tint, tint), _addition), Primitive("*", arrow(tint, tint, tint), _multiplication), Primitive("negate", arrow(tint, tint), _negate), Primitive("mod", arrow(tint, tint, tint), _mod), Primitive("eq?", arrow(tint, tint, tbool), _eq), Primitive("gt?", arrow(tint, tint, tbool), _gt), Primitive("is-prime", arrow(tint, tbool), _isPrime), Primitive("is-square", arrow(tint, tbool), _isSquare), # these are achievable with above primitives, but unlikely #Primitive("flatten", arrow(tlist(tlist(t0)), tlist(t0)), _flatten), # (lambda (reduce (lambda (lambda (++ $1 $0))) empty $0)) Primitive("sum", arrow(tlist(tint), tint), sum), # (lambda (lambda (reduce (lambda (lambda (+ $0 $1))) 0 $0))) Primitive("reverse", arrow(tlist(t0), tlist(t0)), _reverse), # (lambda (reduce (lambda (lambda (++ (singleton $0) $1))) empty $0)) Primitive("all", arrow(arrow(t0, tbool), tlist(t0), tbool), _all), # (lambda (lambda (reduce (lambda (lambda (and $0 $1))) true (map $1 $0)))) Primitive("any", arrow(arrow(t0, tbool), tlist(t0), tbool), _any), # (lambda (lambda (reduce (lambda (lambda (or $0 $1))) true (map $1 $0)))) Primitive("index", arrow(tint, tlist(t0), t0), _index), # (lambda (lambda (reducei (lambda (lambda (lambda (if (eq? $1 $4) $0 0)))) 0 $0))) Primitive("filter", arrow(arrow(t0, tbool), tlist(t0), tlist(t0)), _filter), # (lambda (lambda (reduce (lambda (lambda (++ $1 (if ($3 $0) (singleton $0) empty)))) empty $0))) #Primitive("replace", arrow(arrow(tint, t0, tbool), tlist(t0), tlist(t0), tlist(t0)), _replace), # (FLATTEN (lambda (lambda (lambda (mapi (lambda (lambda (if ($4 $1 $0) $3 (singleton $1)))) $0))))) Primitive("slice", arrow(tint, tint, tlist(t0), tlist(t0)), _slice), # (lambda (lambda (lambda (reducei (lambda (lambda (lambda (++ $2 (if (and (or (gt? $1 $5) (eq? $1 $5)) (not (or (gt? $4 $1) (eq? $1 $4)))) (singleton $0) empty))))) empty $0)))) ]
Primitive("gt?", arrow(tint, tint, tbool), _gt), Primitive("is-prime", arrow(tint, tbool), _isPrime), Primitive("is-square", arrow(tint, tbool), _isSquare), # McCarthy Primitive("empty", tlist(t0), []), Primitive("cons", arrow(t0, tlist(t0), tlist(t0)), _cons), Primitive("car", arrow(tlist(t0), t0), _car), Primitive("cdr", arrow(tlist(t0), tlist(t0)), _cdr), Primitive("empty?", arrow(tlist(t0), tbool), _isEmpty), Primitive("if", arrow(tbool, t0, t0, t0), _if), Primitive("eq?", arrow(tint, tint, tbool), _eq), Primitive("+", arrow(tint, tint, tint), _addition), Primitive("-", arrow(tint, tint, tint), _subtraction) ] zip_primitive = Primitive("zip", arrow(tlist(t0), tlist(t1), arrow(t0, t1, t2), tlist(t2)), _zip) def bootstrapTarget(): """These are the primitives that we hope to learn from the bootstrapping procedure""" return [ # learned primitives Primitive("map", arrow(arrow(t0, t1), tlist(t0), tlist(t1)), _map), Primitive("unfold", arrow(t0, arrow(t0,tbool), arrow(t0,t1), arrow(t0,t0), tlist(t1)), _unfold), Primitive("range", arrow(tint, tlist(tint)), _range), Primitive("index", arrow(tint, tlist(t0), t0), _index), Primitive("fold", arrow(tlist(t0), t1, arrow(t0, t1, t1), t1), _fold), Primitive("length", arrow(tlist(t0), tint), len), # built-ins Primitive("if", arrow(tbool, t0, t0, t0), _if), Primitive("+", arrow(tint, tint, tint), _addition),
def deepcoderPrimitives(): return [ Primitive("HEAD", arrow(tlist(tint), tint), _head), Primitive("LAST", arrow(tlist(tint), tint), _tail), Primitive("TAKE", arrow(tint, tlist(tint), tlist(tint)), _take), Primitive("DROP", arrow(tint, tlist(tint), tlist(tint)), _drop), Primitive("ACCESS", arrow(tint, tlist(tint), tint), _access), Primitive("MINIMUM", arrow(tlist(tint), tint), _minimum), Primitive("MAXIMUM", arrow(tlist(tint), tint), _maximum), Primitive("REVERSE", arrow(tlist(tint), tlist(tint)), _reverse), Primitive("SORT", arrow(tlist(tint), tlist(tint)), _sort), Primitive("SUM", arrow(tlist(tint), tint), _sum) ] + [ Primitive("MAP", arrow(int_to_int, tlist(tint), tlist(tint)), _map), #is this okay??? Primitive("FILTER", arrow(int_to_bool, tlist(tint), tlist(tint)), _filter), #is this okay??? Primitive("COUNT", arrow(int_to_bool, tlist(tint), tint), _count), #is this okay??? Primitive( "ZIPWITH", arrow(int_to_int_to_int, tlist(tint), tlist(tint), tlist(tint)), _zipwith), #is this okay??? Primitive("SCANL1", arrow(int_to_int_to_int, tlist(tint), tlist(tint)), _scanl1), #is this okay??? ] + [ Primitive("INC", int_to_int, _succ), Primitive("DEC", int_to_int, _pred), Primitive("SHL", int_to_int, _double), Primitive("SHR", int_to_int, _half), Primitive("doNEG", int_to_int, _negate), Primitive("SQR", int_to_int, _square), Primitive("MUL3", int_to_int, _triple), Primitive("DIV3", int_to_int, _third), Primitive("MUL4", int_to_int, _quad), Primitive("DIV4", int_to_int, _quarter), ] + [ Primitive("isPOS", int_to_bool, _pos), Primitive("isNEG", int_to_bool, _neg), Primitive("isEVEN", int_to_bool, _even), Primitive("isODD", int_to_bool, _odd), ] + [ Primitive("+", int_to_int_to_int, _add), Primitive("-", int_to_int_to_int, _sub), Primitive("*", int_to_int_to_int, _mult), Primitive("MIN", int_to_int_to_int, _min), Primitive("MAX", int_to_int_to_int, _max) ]
def napsPrimitives(): return [ Primitive("program", arrow(tlist(RECORD), tlist(FUNC), PROGRAM), _program), # TODO # RECORD Primitive( "func", arrow(TYPE, name, tlist(VAR), tlist(VAR), tlist(VAR), tlist(STMT)), _func('func')), # TODO Primitive( "ctor", arrow(TYPE, name, tlist(VAR), tlist(VAR), tlist(VAR), tlist(STMT)), _func('ctor')), Primitive("var", arrow(TYPE, name, VAR), _var) ] + [ # STMT ::= EXPR | IF | FOREACH | WHILE | BREAK | CONTINUE | RETURN | NOOP Primitive("stmt_expr", arrow(EXPR, STMT), lambda x: x), Primitive("stmt_if", arrow(IF, STMT), lambda x: x), Primitive("stmt_foreach", arrow(FOREACH, STMT), lambda x: x), Primitive("stmt_while", arrow(WHILE, STMT), lambda x: x), Primitive("stmt_break", arrow(BREAK, STMT), lambda x: x), Primitive("stmt_continue", arrow(CONTINUE, STMT), lambda x: x), Primitive("stmt_return", arrow(RETURN, STMT), lambda x: x), Primitive("stmt_noop", arrow(NOOP, STMT), lambda x: x) ] + [ # EXPR ::= ASSIGN | VAR | FIELD | CONSTANT | INVOKE | TERNARY | CAST Primitive("expr_assign", arrow(ASSIGN, EXPR), lambda x: x), Primitive("expr_var", arrow(VAR, EXPR), lambda x: x), Primitive("expr_field", arrow(FIELD, EXPR), lambda x: x), Primitive("expr_constant", arrow(CONSTANT, EXPR), lambda x: x), Primitive("expr_invoke", arrow(INVOKE, EXPR), lambda x: x), Primitive("expr_ternary", arrow(TERNARY, EXPR), lambda x: x), Primitive("expr_cast", arrow(CAST, EXPR), lambda x: x) ] + [Primitive("assign", arrow(TYPE, LHS, EXPR, ASSIGN), _assign)] + [ # LHS ::= VAR | FIELD | INVOKE Primitive("lhs_var", arrow(VAR, LHS), lambda x: x), Primitive("lhs_field", arrow(FIELD, LHS), lambda x: x), Primitive("lhs_invoke", arrow(INVOKE, LHS), lambda x: x) ] + [ Primitive("if", arrow(TYPE, EXPR, tlist(STMT), tlist(STMT), IF), _if), Primitive("foreach", arrow(TYPE, VAR, EXPR, tlist(STMT), FOREACH), _foreach), Primitive("while", arrow(TYPE, EXPR, tlist(STMT), tlist(STMT), WHILE), _while), Primitive("break", arrow(TYPE, BREAK), lambda tp: ['break', tp]), Primitive("continue", arrow(TYPE, CONTINUE), lambda tp: ['continue', tp]), Primitive("return", arrow(TYPE, EXPR, RETURN), _return), Primitive("noop", NOOP, ['noop']), Primitive("field", arrow(TYPE, EXPR, field_name, FIELD), _field), # TODO Primitive("constant", arrow(TYPE, value, CONSTANT), _constant), Primitive("invoke", arrow(TYPE, function_name, tlist(EXPR), INVOKE), _invoke), # TODO Primitive("ternary", arrow(TYPE, EXPR, EXPR, EXPR, TERNARY), _ternary), Primitive("cast", arrow(TYPE, EXPR, CAST), _cast) ] + [ # below are TYPE: Primitive("bool", TYPE, 'bool'), Primitive("char", TYPE, 'char'), Primitive("char*", TYPE, 'char*'), Primitive("int", TYPE, 'int'), Primitive("real", TYPE, 'real'), Primitive("array", arrow(TYPE, TYPE), lambda tp: tp + '*'), Primitive("set", arrow(TYPE, TYPE), lambda tp: tp + '%'), Primitive("map", arrow(TYPE, TYPE, TYPE), lambda tp1: lambda tp2: '<' + tp1 + '|' + tp2 + '>'), Primitive("record_name", TYPE, 'record_name#') # TODO ] + [ #stuff about lists: # STMTs, EXPRs, VARs, maybe Funcs and records Primitive('list_init_stmt', arrow(STMT, tlist(STMT)), lambda stmt: [stmt]), Primitive('list_add_stmt', arrow(STMT, tlist(STMT), tlist(STMT)), lambda stmt: lambda stmts: stmts + [stmt]), Primitive('list_init_expr', arrow(EXPR, tlist(EXPR)), lambda expr: [expr]), Primitive('list_add_expr', arrow(EXPR, tlist(EXPR), tlist(EXPR)), lambda expr: lambda exprs: exprs + [expr]), Primitive('list_init_var', arrow(VAR, tlist(VAR)), lambda var: [var]), Primitive('list_add_var', arrow(VAR, tlist(VAR), tlist(VAR)), lambda var: lambda _vars: _vars + [var]) ] + [ # value Primitive('0', value, 0), Primitive("1", value, "1"), Primitive("-1", value, "-1") # ... ] + [ # function_name: Primitive('+', function_name, '+'), Primitive('&&', function_name, "&&"), Primitive("!", function_name, "!"), Primitive("!=", function_name, "!="), Primitive("string_find", function_name, "string_find") # ... ] + [ # field_name: Primitive('', field_name, '') # ... ] + [ # Primitive(f'var{str(i)}', name, f'var{str(i)}') for i in range(12) ]
def deepcoderProductions(): return [(0.0, prim) for prim in deepcoderPrimitives()] # def flatten_program(p): # string = p.show(False) # num_inputs = string.count('lambda') # string = string.replace('lambda', '') # string = string.replace('(', '') # string = string.replace(')', '') # #remove '_fn' (optional) # for i in range(num_inputs): # string = string.replace('$' + str(num_inputs-i-1),'input_' + str(i)) # string = string.split(' ') # string = list(filter(lambda x: x is not '', string)) # return string if __name__ == "__main__": #g = Grammar.uniform(deepcoderPrimitives()) g = Grammar.fromProductions(deepcoderProductions(), logVariable=.9) request = arrow(tlist(tint), tint, tint) p = g.sample(request) print("request:", request) print("program:") print(prettyProgram(p)) print("flattened_program:") flat = flatten_program(p) print(flat)
p = program.visit(RandomParameterization.single) return super(LearnedFeatureExtractor, self).featuresOfProgram(p, tp) if __name__ == "__main__": pi = 3.14 # I think this is close enough to pi # Data taken from: # https://secure-media.collegeboard.org/digitalServices/pdf/ap/ap-physics-1-equations-table.pdf # https://secure-media.collegeboard.org/digitalServices/pdf/ap/physics-c-tables-and-equations-list.pdf # http://mcat.prep101.com/wp-content/uploads/ES_MCATPhysics.pdf # some linear algebra taken from "parallel distributed processing" tasks = [ # parallel distributed processing makeTask("vector addition (2)", arrow(tvector, tvector, tvector), vectorAddition), makeTask("vector addition (many)", arrow(tlist(tvector), tvector), lambda vs: reduce(vectorAddition, vs)), makeTask("vector norm", arrow(tvector, treal), lambda v: innerProduct(v, v)**0.5), # mcat makeTask("freefall velocity = (2gh)**.5", arrow(tpositive, treal), lambda h: (2 * 9.8 * h)**0.5), makeTask("v^2 = v_0^2 + 2a(x-x0)", arrow(treal, treal, treal, treal, treal), lambda v0, a, x, x0: v0**2 + 2 * a * (x - x0)), makeTask("v = (vx**2 + vy**2)**0.5", arrow(treal, treal, treal), lambda vx, vy: (vx**2 + vy**2)**0.5), makeTask("a_r = v**2/R", arrow(treal, tpositive, treal), lambda v, r: v * v / r), makeTask("e = mc^2", arrow(tpositive, tpositive, treal), lambda m, c: m * c * c),
def OldDeepcoderPrimitives(): return [ Primitive("head", arrow(tlist(tint), tint), _head), Primitive("tail", arrow(tlist(tint), tint), _tail), Primitive("take", arrow(tint, tlist(tint), tlist(tint)), _take), Primitive("drop", arrow(tint, tlist(tint), tlist(tint)), _drop), Primitive("access", arrow(tint, tlist(tint), tint), _access), Primitive("minimum", arrow(tlist(tint), tint), _minimum), Primitive("maximum", arrow(tlist(tint), tint), _maximum), Primitive("reverse", arrow(tlist(tint), tlist(tint)), _reverse), Primitive("sort", arrow(tlist(tint), tlist(tint)), _sort), Primitive("sum", arrow(tlist(tint), tint), _sum) ] + [ Primitive("map", arrow(int_to_int, tlist(tint), tlist(tint)), _map), #is this okay??? Primitive("filter_int", arrow(int_to_bool, tlist(tint), tlist(tint)), _filter), #is this okay??? Primitive("count", arrow(int_to_bool, tlist(tint), tint), _count), #is this okay??? Primitive( "zipwith", arrow(int_to_int_to_int, tlist(tint), tlist(tint), tlist(tint)), _zipwith), #is this okay??? Primitive("scanl1", arrow(int_to_int_to_int, tlist(tint), tlist(tint)), _scanl1), #is this okay??? # ] + [ # Primitive("succ", arrow(tint, tint), _succ), # Primitive("pred", arrow(tint, tint), _pred), # Primitive("double", arrow(tint, tint), _double), # Primitive("half", arrow(tint, tint), _half), # Primitive("neg", arrow(tint, tint), _neg), # Primitive("square", arrow(tint, tint), _square), # Primitive("triple", arrow(tint, tint), _triple), # Primitive("third", arrow(tint, tint), _third), # Primitive("quad", arrow(tint, tint), _quad), # Primitive("quarter", arrow(tint, tint), _quarter), # ] + [ # Primitive("pos", arrow(tint, tbool), _pos), # Primitive("neg", arrow(tint, tbool), _neg), # Primitive("even", arrow(tint, tbool), _even), # Primitive("odd", arrow(tint, tbool), _odd), # ] + [ # Primitive("add", arrow(tint, tint, tint), _add), # Primitive("sub", arrow(tint, tint, tint), _sub), # Primitive("mult", arrow(tint, tint, tint), _mult), # Primitive("min", arrow(tint, tint, tint), _min), # Primitive("max", arrow(tint, tint, tint), _max) ] + [ Primitive("succ_fn", int_to_int, _succ), Primitive("pred_fn", int_to_int, _pred), Primitive("double_fn", int_to_int, _double), Primitive("half_fn", int_to_int, _half), Primitive("negate_fn", int_to_int, _negate), Primitive("square_fn", int_to_int, _square), Primitive("triple_fn", int_to_int, _triple), Primitive("third_fn", int_to_int, _third), Primitive("quad_fn", int_to_int, _quad), Primitive("quarter_fn", int_to_int, _quarter), ] + [ Primitive("pos_fn", int_to_bool, _pos), Primitive("neg_fn", int_to_bool, _neg), Primitive("even_fn", int_to_bool, _even), Primitive("odd_fn", int_to_bool, _odd), ] + [ Primitive("add_fn", int_to_int_to_int, _add), Primitive("sub_fn", int_to_int_to_int, _sub), Primitive("mult_fn", int_to_int_to_int, _mult), Primitive("min_fn", int_to_int_to_int, _min), Primitive("max_fn", int_to_int_to_int, _max) ]
t.features = drawFunction(200, 10., t.f) delattr(t, 'f') test, _ = testTrainSplit(tasks, 100) random.seed(42) random.shuffle(test) test = test[:100] g = Grammar.uniform( [real, real_division, real_addition, real_multiplication]) fe = FeatureExtractor([]) BATCHSIZE = 64 elif arguments.domain == "list": BATCHSIZE = 16 tasks = retrieveJSONTasks("data/list_tasks.json") + sortBootstrap() tasks.extend([ Task("remove empty lists", arrow(tlist(tlist(tbool)), tlist(tlist(tbool))), [((ls, ), list(filter(lambda l: len(l) > 0, ls))) for _ in range(15) for ls in [[[ random.random() < 0.5 for _ in range(random.randint(0, 3)) ] for _ in range(4)]]]), Task("keep squares", arrow(tlist(tint), tlist(tint)), [ ((xs, ), list(filter(lambda x: int(math.sqrt(x))**2 == x, xs))) for _ in range(15) for xs in [[ random.choice([0, 1, 4, 9, 16, 25]) if random.random() < 0.5 else random.randint(0, 9) for _ in range(7) ]] ]), Task("keep primes", arrow(tlist(tint), tlist(tint)), [ ((xs, ),
def main(args): """ Takes the return value of the `commandlineArguments()` function as input and trains/tests the model on manipulating sequences of numbers. """ random.seed(args.pop("random_seed")) dataset = args.pop("dataset") tasks = { "Lucas-old": lambda: retrieveJSONTasks("data/list_tasks.json") + sortBootstrap(), "bootstrap": make_list_bootstrap_tasks, "sorting": sortBootstrap, "Lucas-depth1": lambda: retrieveJSONTasks("data/list_tasks2.json")[:105], "Lucas-depth2": lambda: retrieveJSONTasks("data/list_tasks2.json")[:4928], "Lucas-depth3": lambda: retrieveJSONTasks("data/list_tasks2.json"), }[dataset]() maxTasks = args.pop("maxTasks") if maxTasks and len(tasks) > maxTasks: necessaryTasks = [] # maxTasks will not consider these if dataset.startswith("Lucas2.0") and dataset != "Lucas2.0-depth1": necessaryTasks = tasks[:105] eprint("Unwilling to handle {} tasks, truncating..".format(len(tasks))) random.shuffle(tasks) del tasks[maxTasks:] tasks = necessaryTasks + tasks if dataset.startswith("Lucas"): # extra tasks for filter tasks.extend([ Task("remove empty lists", arrow(tlist(tlist(tbool)), tlist(tlist(tbool))), [((ls, ), list(filter(lambda l: len(l) > 0, ls))) for _ in range(15) for ls in [[[ random.random() < 0.5 for _ in range(random.randint(0, 3)) ] for _ in range(4)]]]), Task("keep squares", arrow(tlist(tint), tlist(tint)), [ ((xs, ), list(filter(lambda x: int(math.sqrt(x))**2 == x, xs))) for _ in range(15) for xs in [[ random.choice([0, 1, 4, 9, 16, 25]) if random.random() < 0.5 else random.randint(0, 9) for _ in range(7) ]] ]), Task("keep primes", arrow(tlist(tint), tlist(tint)), [ ((xs, ), list( filter( lambda x: x in {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}, xs))) for _ in range(15) for xs in [[ random.choice([2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]) if random.random() < 0.5 else random.randint(0, 9) for _ in range(7) ]] ]), ]) for i in range(4): tasks.extend([ Task("keep eq %s" % i, arrow(tlist(tint), tlist(tint)), [((xs, ), list(filter(lambda x: x == i, xs))) for _ in range(15) for xs in [[random.randint(0, 6) for _ in range(5)]]]), Task("remove eq %s" % i, arrow(tlist(tint), tlist(tint)), [((xs, ), list(filter(lambda x: x != i, xs))) for _ in range(15) for xs in [[random.randint(0, 6) for _ in range(5)]]]), Task("keep gt %s" % i, arrow(tlist(tint), tlist(tint)), [((xs, ), list(filter(lambda x: x > i, xs))) for _ in range(15) for xs in [[random.randint(0, 6) for _ in range(5)]]]), Task("remove gt %s" % i, arrow(tlist(tint), tlist(tint)), [((xs, ), list(filter(lambda x: not x > i, xs))) for _ in range(15) for xs in [[random.randint(0, 6) for _ in range(5)]]]) ]) def isIdentityTask(t): return all(len(xs) == 1 and xs[0] == y for xs, y in t.examples) eprint("Removed", sum(isIdentityTask(t) for t in tasks), "tasks that were just the identity function") tasks = [t for t in tasks if not isIdentityTask(t)] prims = { "base": basePrimitives, "McCarthy": McCarthyPrimitives, "common": bootstrapTarget_extra, "noLength": no_length, "rich": primitives }[args.pop("primitives")]() haveLength = not args.pop("noLength") haveMap = not args.pop("noMap") haveUnfold = not args.pop("noUnfold") eprint(f"Including map as a primitive? {haveMap}") eprint(f"Including length as a primitive? {haveLength}") eprint(f"Including unfold as a primitive? {haveUnfold}") baseGrammar = Grammar.uniform([p for p in prims if (p.name != "map" or haveMap) and \ (p.name != "unfold" or haveUnfold) and \ (p.name != "length" or haveLength)]) extractor = { "learned": LearnedFeatureExtractor, }[args.pop("extractor")] extractor.H = args.pop("hidden") timestamp = datetime.datetime.now().isoformat() outputDirectory = "experimentOutputs/list/%s" % timestamp os.system("mkdir -p %s" % outputDirectory) args.update({ "featureExtractor": extractor, "outputPrefix": "%s/list" % outputDirectory, "evaluationTimeout": 0.0005, }) eprint("Got {} list tasks".format(len(tasks))) split = args.pop("split") if split: train_some = defaultdict(list) for t in tasks: necessary = train_necessary(t) if not necessary: continue if necessary == "some": train_some[t.name.split()[0]].append(t) else: t.mustTrain = True for k in sorted(train_some): ts = train_some[k] random.shuffle(ts) ts.pop().mustTrain = True test, train = testTrainSplit(tasks, split) if True: test = [t for t in test if t.name not in EASYLISTTASKS] eprint("Alotted {} tasks for training and {} for testing".format( len(train), len(test))) else: train = tasks test = [] explorationCompression(baseGrammar, train, testingTasks=test, **args)
Primitive("is-square", arrow(tint, tbool), _isSquare), # McCarthy Primitive("empty", tlist(t0), []), Primitive("cons", arrow(t0, tlist(t0), tlist(t0)), _cons), Primitive("car", arrow(tlist(t0), t0), _car), Primitive("cdr", arrow(tlist(t0), tlist(t0)), _cdr), Primitive("empty?", arrow(tlist(t0), tbool), _isEmpty), Primitive("if", arrow(tbool, t0, t0, t0), _if), Primitive("eq?", arrow(tint, tint, tbool), _eq), Primitive("+", arrow(tint, tint, tint), _addition), Primitive("-", arrow(tint, tint, tint), _subtraction) ] zip_primitive = Primitive( "zip", arrow(tlist(t0), tlist(t1), arrow(t0, t1, t2), tlist(t2)), _zip) def bootstrapTarget(): """These are the primitives that we hope to learn from the bootstrapping procedure""" return [ # learned primitives Primitive("map", arrow(arrow(t0, t1), tlist(t0), tlist(t1)), _map), Primitive( "unfold", arrow(t0, arrow(t0, tbool), arrow(t0, t1), arrow(t0, t0), tlist(t1)), _unfold), Primitive("range", arrow(tint, tlist(tint)), _range), Primitive("index", arrow(tint, tlist(t0), t0), _index), Primitive("fold", arrow(tlist(t0), t1, arrow(t0, t1, t1), t1), _fold), Primitive("length", arrow(tlist(t0), tint), len),