Example #1
0
def retrieveJSONTasks(filename, features=False):
    """
    For JSON of the form:
        {"name": str,
         "type": {"input" : bool|int|list-of-bool|list-of-int,
                  "output": bool|int|list-of-bool|list-of-int},
         "examples": [{"i": data, "o": data}]}
    """
    with open(filename, "r") as f:
        loaded = json.load(f)
    TP = {
        "bool": tbool,
        "int": tint,
        "list-of-bool": tlist(tbool),
        "list-of-int": tlist(tint),
    }
    return [
        Task(
            item["name"],
            arrow(TP[item["type"]["input"]], TP[item["type"]["output"]]),
            [((ex["i"], ), ex["o"]) for ex in item["examples"]],
            features=(None if not features else list_features(
                [((ex["i"], ), ex["o"]) for ex in item["examples"]])),
            cache=False,
        ) for item in loaded
    ]
Example #2
0
 def genericType(t):
     if t.name == "real":
         return treal
     elif t.name == "positive":
         return treal
     elif t.name == "vector":
         return tlist(treal)
     elif t.name == "list":
         return tlist(genericType(t.arguments[0]))
     elif t.isArrow():
         return arrow(genericType(t.arguments[0]),
                      genericType(t.arguments[1]))
     else:
         assert False, "could not make type generic: %s" % t
Example #3
0
def basePrimitives():
    return [Primitive(str(j), tint, j) for j in range(6)] + [
        Primitive("*", arrow(tint, tint, tint), _multiplication),
        Primitive("gt?", arrow(tint, tint, tbool), _gt),
        Primitive("is-prime", arrow(tint, tbool), _isPrime),
        Primitive("is-square", arrow(tint, tbool), _isSquare),
        # McCarthy
        Primitive("empty", tlist(t0), []),
        Primitive("cons", arrow(t0, tlist(t0), tlist(t0)), _cons),
        Primitive("car", arrow(tlist(t0), t0), _car),
        Primitive("cdr", arrow(tlist(t0), tlist(t0)), _cdr),
        Primitive("empty?", arrow(tlist(t0), tbool), _isEmpty),
        Primitive("if", arrow(tbool, t0, t0, t0), _if),
        Primitive("eq?", arrow(tint, tint, tbool), _eq),
        Primitive("+", arrow(tint, tint, tint), _addition),
        Primitive("-", arrow(tint, tint, tint), _subtraction)
    ]
Example #4
0
def McCarthyPrimitives():
    "These are < primitives provided by 1959 lisp as introduced by McCarthy"
    return [
        Primitive("empty", tlist(t0), []),
        Primitive("cons", arrow(t0, tlist(t0), tlist(t0)), _cons),
        Primitive("car", arrow(tlist(t0), t0), _car),
        Primitive("cdr", arrow(tlist(t0), tlist(t0)), _cdr),
        Primitive("empty?", arrow(tlist(t0), tbool), _isEmpty),
        #Primitive("unfold", arrow(t0, arrow(t0,t1), arrow(t0,t0), arrow(t0,tbool), tlist(t1)), _isEmpty),
        #Primitive("1+", arrow(tint,tint),None),
        # Primitive("range", arrow(tint, tlist(tint)), range),
        # Primitive("map", arrow(arrow(t0, t1), tlist(t0), tlist(t1)), _map),
        # Primitive("index", arrow(tint,tlist(t0),t0),None),
        # Primitive("length", arrow(tlist(t0),tint),None),
        primitiveRecursion1,
        #primitiveRecursion2,
        Primitive("gt?", arrow(tint, tint, tbool), _gt),
        Primitive("if", arrow(tbool, t0, t0, t0), _if),
        Primitive("eq?", arrow(tint, tint, tbool), _eq),
        Primitive("+", arrow(tint, tint, tint), _addition),
        Primitive("-", arrow(tint, tint, tint), _subtraction),
    ] + [Primitive(str(j), tint, j) for j in range(2)]
Example #5
0
def algolispPrimitives():
    return [
    Primitive("fn_call", arrow(tfunction, tlist(tsymbol), tsymbol), _fn_call),

    Primitive("lambda1_call", arrow(tfunction, tlist(tsymbol), tsymbol), lambda f: lambda sx: ["lambda1", [f] + sx] if type(sx)==list else ["lambda1", [f] + [sx]] ),
    Primitive("lambda2_call", arrow(tfunction, tlist(tsymbol), tsymbol), lambda f: lambda sx: ["lambda2", [f] + sx] if type(sx)==list else ["lambda2", [f] + [sx]] ),
    #symbol converters:
    # SYMBOL = constant | argument | function_call | function | lambda
    Primitive("symbol_constant", arrow(tconstant, tsymbol), lambda x: x),
    Primitive("symbol_function", arrow(tfunction, tsymbol), lambda x: x),
    #list converters
    Primitive('list_init_symbol', arrow(tsymbol, tlist(tsymbol)), lambda symbol: [symbol] ),
    Primitive('list_add_symbol', arrow(tsymbol, tlist(tsymbol), tlist(tsymbol)), lambda symbol: lambda symbols: symbols + [symbol] if type(symbols) == list else [symbols] + [symbol])
    ] + [
    #functions:
    Primitive(ec_name, tfunction, algo_name) for algo_name, ec_name in fn_lookup.items()
    ] + [
    #Constants
    Primitive(ec_name, tconstant, algo_name) for algo_name, ec_name in const_lookup.items()
    ]
Example #6
0
def isListFunction(tp):
    try:
        Context().unify(tp, arrow(tlist(tint), t0))
        return True
    except UnificationFailure:
        return False
Example #7
0
def bootstrapTarget():
    """These are the primitives that we hope to learn from the bootstrapping procedure"""
    return [
        # learned primitives
        Primitive("map", arrow(arrow(t0, t1), tlist(t0), tlist(t1)), _map),
        Primitive("unfold", arrow(t0, arrow(t0,tbool), arrow(t0,t1), arrow(t0,t0), tlist(t1)), _unfold),
        Primitive("range", arrow(tint, tlist(tint)), _range),
        Primitive("index", arrow(tint, tlist(t0), t0), _index),
        Primitive("fold", arrow(tlist(t0), t1, arrow(t0, t1, t1), t1), _fold),
        Primitive("length", arrow(tlist(t0), tint), len),

        # built-ins
        Primitive("if", arrow(tbool, t0, t0, t0), _if),
        Primitive("+", arrow(tint, tint, tint), _addition),
        Primitive("-", arrow(tint, tint, tint), _subtraction),
        Primitive("empty", tlist(t0), []),
        Primitive("cons", arrow(t0, tlist(t0), tlist(t0)), _cons),
        Primitive("car", arrow(tlist(t0), t0), _car),
        Primitive("cdr", arrow(tlist(t0), tlist(t0)), _cdr),
        Primitive("empty?", arrow(tlist(t0), tbool), _isEmpty),
    ] + [Primitive(str(j), tint, j) for j in range(2)]
Example #8
0
def primitives():
    return [Primitive(str(j), tint, j) for j in range(6)] + [
        Primitive("empty", tlist(t0), []),
        Primitive("singleton", arrow(t0, tlist(t0)), _single),
        Primitive("range", arrow(tint, tlist(tint)), _range),
        Primitive("++", arrow(tlist(t0), tlist(t0), tlist(t0)), _append),
        # Primitive("map", arrow(arrow(t0, t1), tlist(t0), tlist(t1)), _map),
        Primitive(
            "mapi",
            arrow(
                arrow(
                    tint,
                    t0,
                    t1),
                tlist(t0),
                tlist(t1)),
            _mapi),
        # Primitive("reduce", arrow(arrow(t1, t0, t1), t1, tlist(t0), t1), _reduce),
        Primitive(
            "reducei",
            arrow(
                arrow(
                    tint,
                    t1,
                    t0,
                    t1),
                t1,
                tlist(t0),
                t1),
            _reducei),

        Primitive("true", tbool, True),
        Primitive("not", arrow(tbool, tbool), _not),
        Primitive("and", arrow(tbool, tbool, tbool), _and),
        Primitive("or", arrow(tbool, tbool, tbool), _or),
        # Primitive("if", arrow(tbool, t0, t0, t0), _if),

        Primitive("sort", arrow(tlist(tint), tlist(tint)), sorted),
        Primitive("+", arrow(tint, tint, tint), _addition),
        Primitive("*", arrow(tint, tint, tint), _multiplication),
        Primitive("negate", arrow(tint, tint), _negate),
        Primitive("mod", arrow(tint, tint, tint), _mod),
        Primitive("eq?", arrow(tint, tint, tbool), _eq),
        Primitive("gt?", arrow(tint, tint, tbool), _gt),
        Primitive("is-prime", arrow(tint, tbool), _isPrime),
        Primitive("is-square", arrow(tint, tbool), _isSquare),

        # these are achievable with above primitives, but unlikely
        #Primitive("flatten", arrow(tlist(tlist(t0)), tlist(t0)), _flatten),
        # (lambda (reduce (lambda (lambda (++ $1 $0))) empty $0))
        Primitive("sum", arrow(tlist(tint), tint), sum),
        # (lambda (lambda (reduce (lambda (lambda (+ $0 $1))) 0 $0)))
        Primitive("reverse", arrow(tlist(t0), tlist(t0)), _reverse),
        # (lambda (reduce (lambda (lambda (++ (singleton $0) $1))) empty $0))
        Primitive("all", arrow(arrow(t0, tbool), tlist(t0), tbool), _all),
        # (lambda (lambda (reduce (lambda (lambda (and $0 $1))) true (map $1 $0))))
        Primitive("any", arrow(arrow(t0, tbool), tlist(t0), tbool), _any),
        # (lambda (lambda (reduce (lambda (lambda (or $0 $1))) true (map $1 $0))))
        Primitive("index", arrow(tint, tlist(t0), t0), _index),
        # (lambda (lambda (reducei (lambda (lambda (lambda (if (eq? $1 $4) $0 0)))) 0 $0)))
        Primitive("filter", arrow(arrow(t0, tbool), tlist(t0), tlist(t0)), _filter),
        # (lambda (lambda (reduce (lambda (lambda (++ $1 (if ($3 $0) (singleton $0) empty)))) empty $0)))
        #Primitive("replace", arrow(arrow(tint, t0, tbool), tlist(t0), tlist(t0), tlist(t0)), _replace),
        # (FLATTEN (lambda (lambda (lambda (mapi (lambda (lambda (if ($4 $1 $0) $3 (singleton $1)))) $0)))))
        Primitive("slice", arrow(tint, tint, tlist(t0), tlist(t0)), _slice),
        # (lambda (lambda (lambda (reducei (lambda (lambda (lambda (++ $2 (if (and (or (gt? $1 $5) (eq? $1 $5)) (not (or (gt? $4 $1) (eq? $1 $4)))) (singleton $0) empty))))) empty $0))))
    ]
Example #9
0
        Primitive("gt?", arrow(tint, tint, tbool), _gt),
        Primitive("is-prime", arrow(tint, tbool), _isPrime),
        Primitive("is-square", arrow(tint, tbool), _isSquare),
        # McCarthy
        Primitive("empty", tlist(t0), []),
        Primitive("cons", arrow(t0, tlist(t0), tlist(t0)), _cons),
        Primitive("car", arrow(tlist(t0), t0), _car),
        Primitive("cdr", arrow(tlist(t0), tlist(t0)), _cdr),
        Primitive("empty?", arrow(tlist(t0), tbool), _isEmpty),
        Primitive("if", arrow(tbool, t0, t0, t0), _if),
        Primitive("eq?", arrow(tint, tint, tbool), _eq),
        Primitive("+", arrow(tint, tint, tint), _addition),
        Primitive("-", arrow(tint, tint, tint), _subtraction)
    ]

zip_primitive = Primitive("zip", arrow(tlist(t0), tlist(t1), arrow(t0, t1, t2), tlist(t2)), _zip)

def bootstrapTarget():
    """These are the primitives that we hope to learn from the bootstrapping procedure"""
    return [
        # learned primitives
        Primitive("map", arrow(arrow(t0, t1), tlist(t0), tlist(t1)), _map),
        Primitive("unfold", arrow(t0, arrow(t0,tbool), arrow(t0,t1), arrow(t0,t0), tlist(t1)), _unfold),
        Primitive("range", arrow(tint, tlist(tint)), _range),
        Primitive("index", arrow(tint, tlist(t0), t0), _index),
        Primitive("fold", arrow(tlist(t0), t1, arrow(t0, t1, t1), t1), _fold),
        Primitive("length", arrow(tlist(t0), tint), len),

        # built-ins
        Primitive("if", arrow(tbool, t0, t0, t0), _if),
        Primitive("+", arrow(tint, tint, tint), _addition),
Example #10
0
def deepcoderPrimitives():
    return [
        Primitive("HEAD", arrow(tlist(tint), tint), _head),
        Primitive("LAST", arrow(tlist(tint), tint), _tail),
        Primitive("TAKE", arrow(tint, tlist(tint), tlist(tint)), _take),
        Primitive("DROP", arrow(tint, tlist(tint), tlist(tint)), _drop),
        Primitive("ACCESS", arrow(tint, tlist(tint), tint), _access),
        Primitive("MINIMUM", arrow(tlist(tint), tint), _minimum),
        Primitive("MAXIMUM", arrow(tlist(tint), tint), _maximum),
        Primitive("REVERSE", arrow(tlist(tint), tlist(tint)), _reverse),
        Primitive("SORT", arrow(tlist(tint), tlist(tint)), _sort),
        Primitive("SUM", arrow(tlist(tint), tint), _sum)
    ] + [
        Primitive("MAP", arrow(int_to_int, tlist(tint), tlist(tint)),
                  _map),  #is this okay???
        Primitive("FILTER", arrow(int_to_bool, tlist(tint), tlist(tint)),
                  _filter),  #is this okay???
        Primitive("COUNT", arrow(int_to_bool, tlist(tint), tint),
                  _count),  #is this okay???
        Primitive(
            "ZIPWITH",
            arrow(int_to_int_to_int, tlist(tint), tlist(tint), tlist(tint)),
            _zipwith),  #is this okay???
        Primitive("SCANL1", arrow(int_to_int_to_int, tlist(tint), tlist(tint)),
                  _scanl1),  #is this okay???
    ] + [
        Primitive("INC", int_to_int, _succ),
        Primitive("DEC", int_to_int, _pred),
        Primitive("SHL", int_to_int, _double),
        Primitive("SHR", int_to_int, _half),
        Primitive("doNEG", int_to_int, _negate),
        Primitive("SQR", int_to_int, _square),
        Primitive("MUL3", int_to_int, _triple),
        Primitive("DIV3", int_to_int, _third),
        Primitive("MUL4", int_to_int, _quad),
        Primitive("DIV4", int_to_int, _quarter),
    ] + [
        Primitive("isPOS", int_to_bool, _pos),
        Primitive("isNEG", int_to_bool, _neg),
        Primitive("isEVEN", int_to_bool, _even),
        Primitive("isODD", int_to_bool, _odd),
    ] + [
        Primitive("+", int_to_int_to_int, _add),
        Primitive("-", int_to_int_to_int, _sub),
        Primitive("*", int_to_int_to_int, _mult),
        Primitive("MIN", int_to_int_to_int, _min),
        Primitive("MAX", int_to_int_to_int, _max)
    ]
Example #11
0
def napsPrimitives():
    return [
        Primitive("program", arrow(tlist(RECORD), tlist(FUNC), PROGRAM),
                  _program),  # TODO
        # RECORD
        Primitive(
            "func",
            arrow(TYPE, name, tlist(VAR), tlist(VAR), tlist(VAR), tlist(STMT)),
            _func('func')),  # TODO
        Primitive(
            "ctor",
            arrow(TYPE, name, tlist(VAR), tlist(VAR), tlist(VAR), tlist(STMT)),
            _func('ctor')),
        Primitive("var", arrow(TYPE, name, VAR), _var)
    ] + [
        # STMT ::= EXPR | IF | FOREACH | WHILE | BREAK | CONTINUE | RETURN | NOOP
        Primitive("stmt_expr", arrow(EXPR, STMT), lambda x: x),
        Primitive("stmt_if", arrow(IF, STMT), lambda x: x),
        Primitive("stmt_foreach", arrow(FOREACH, STMT), lambda x: x),
        Primitive("stmt_while", arrow(WHILE, STMT), lambda x: x),
        Primitive("stmt_break", arrow(BREAK, STMT), lambda x: x),
        Primitive("stmt_continue", arrow(CONTINUE, STMT), lambda x: x),
        Primitive("stmt_return", arrow(RETURN, STMT), lambda x: x),
        Primitive("stmt_noop", arrow(NOOP, STMT), lambda x: x)
    ] + [
        # EXPR ::= ASSIGN | VAR | FIELD | CONSTANT | INVOKE | TERNARY | CAST
        Primitive("expr_assign", arrow(ASSIGN, EXPR), lambda x: x),
        Primitive("expr_var", arrow(VAR, EXPR), lambda x: x),
        Primitive("expr_field", arrow(FIELD, EXPR), lambda x: x),
        Primitive("expr_constant", arrow(CONSTANT, EXPR), lambda x: x),
        Primitive("expr_invoke", arrow(INVOKE, EXPR), lambda x: x),
        Primitive("expr_ternary", arrow(TERNARY, EXPR), lambda x: x),
        Primitive("expr_cast", arrow(CAST, EXPR), lambda x: x)
    ] + [Primitive("assign", arrow(TYPE, LHS, EXPR, ASSIGN), _assign)] + [
        # LHS ::= VAR | FIELD | INVOKE
        Primitive("lhs_var", arrow(VAR, LHS), lambda x: x),
        Primitive("lhs_field", arrow(FIELD, LHS), lambda x: x),
        Primitive("lhs_invoke", arrow(INVOKE, LHS), lambda x: x)
    ] + [
        Primitive("if", arrow(TYPE, EXPR, tlist(STMT), tlist(STMT), IF), _if),
        Primitive("foreach", arrow(TYPE, VAR, EXPR, tlist(STMT), FOREACH),
                  _foreach),
        Primitive("while", arrow(TYPE, EXPR, tlist(STMT), tlist(STMT), WHILE),
                  _while),
        Primitive("break", arrow(TYPE, BREAK), lambda tp: ['break', tp]),
        Primitive("continue", arrow(TYPE, CONTINUE),
                  lambda tp: ['continue', tp]),
        Primitive("return", arrow(TYPE, EXPR, RETURN), _return),
        Primitive("noop", NOOP, ['noop']),
        Primitive("field", arrow(TYPE, EXPR, field_name, FIELD),
                  _field),  # TODO
        Primitive("constant", arrow(TYPE, value, CONSTANT), _constant),
        Primitive("invoke", arrow(TYPE, function_name, tlist(EXPR), INVOKE),
                  _invoke),  # TODO
        Primitive("ternary", arrow(TYPE, EXPR, EXPR, EXPR, TERNARY), _ternary),
        Primitive("cast", arrow(TYPE, EXPR, CAST), _cast)
    ] + [
        # below are TYPE:
        Primitive("bool", TYPE, 'bool'),
        Primitive("char", TYPE, 'char'),
        Primitive("char*", TYPE, 'char*'),
        Primitive("int", TYPE, 'int'),
        Primitive("real", TYPE, 'real'),
        Primitive("array", arrow(TYPE, TYPE), lambda tp: tp + '*'),
        Primitive("set", arrow(TYPE, TYPE), lambda tp: tp + '%'),
        Primitive("map", arrow(TYPE, TYPE, TYPE),
                  lambda tp1: lambda tp2: '<' + tp1 + '|' + tp2 + '>'),
        Primitive("record_name", TYPE, 'record_name#')  # TODO
    ] + [
        #stuff about lists:
        # STMTs, EXPRs, VARs, maybe Funcs and records
        Primitive('list_init_stmt', arrow(STMT, tlist(STMT)),
                  lambda stmt: [stmt]),
        Primitive('list_add_stmt', arrow(STMT, tlist(STMT), tlist(STMT)),
                  lambda stmt: lambda stmts: stmts + [stmt]),
        Primitive('list_init_expr', arrow(EXPR, tlist(EXPR)),
                  lambda expr: [expr]),
        Primitive('list_add_expr', arrow(EXPR, tlist(EXPR), tlist(EXPR)),
                  lambda expr: lambda exprs: exprs + [expr]),
        Primitive('list_init_var', arrow(VAR, tlist(VAR)), lambda var: [var]),
        Primitive('list_add_var', arrow(VAR, tlist(VAR), tlist(VAR)),
                  lambda var: lambda _vars: _vars + [var])
    ] + [
        # value
        Primitive('0', value, 0),
        Primitive("1", value, "1"),
        Primitive("-1", value, "-1")
        # ...
    ] + [
        # function_name:
        Primitive('+', function_name, '+'),
        Primitive('&&', function_name, "&&"),
        Primitive("!", function_name, "!"),
        Primitive("!=", function_name, "!="),
        Primitive("string_find", function_name, "string_find")
        # ...
    ] + [
        # field_name:
        Primitive('', field_name, '')
        # ...
    ] + [
        #
        Primitive(f'var{str(i)}', name, f'var{str(i)}') for i in range(12)
    ]
Example #12
0

def deepcoderProductions():
    return [(0.0, prim) for prim in deepcoderPrimitives()]


# def flatten_program(p):
#     string = p.show(False)
#     num_inputs = string.count('lambda')
#     string = string.replace('lambda', '')
#     string = string.replace('(', '')
#     string = string.replace(')', '')
#     #remove '_fn' (optional)
#     for i in range(num_inputs):
#         string = string.replace('$' + str(num_inputs-i-1),'input_' + str(i))
#     string = string.split(' ')
#     string = list(filter(lambda x: x is not '', string))
#     return string

if __name__ == "__main__":
    #g = Grammar.uniform(deepcoderPrimitives())
    g = Grammar.fromProductions(deepcoderProductions(), logVariable=.9)
    request = arrow(tlist(tint), tint, tint)
    p = g.sample(request)
    print("request:", request)
    print("program:")
    print(prettyProgram(p))
    print("flattened_program:")
    flat = flatten_program(p)
    print(flat)
Example #13
0
        p = program.visit(RandomParameterization.single)
        return super(LearnedFeatureExtractor, self).featuresOfProgram(p, tp)


if __name__ == "__main__":
    pi = 3.14  # I think this is close enough to pi
    # Data taken from:
    # https://secure-media.collegeboard.org/digitalServices/pdf/ap/ap-physics-1-equations-table.pdf
    # https://secure-media.collegeboard.org/digitalServices/pdf/ap/physics-c-tables-and-equations-list.pdf
    # http://mcat.prep101.com/wp-content/uploads/ES_MCATPhysics.pdf
    # some linear algebra taken from "parallel distributed processing"
    tasks = [
        # parallel distributed processing
        makeTask("vector addition (2)", arrow(tvector, tvector, tvector),
                 vectorAddition),
        makeTask("vector addition (many)", arrow(tlist(tvector), tvector),
                 lambda vs: reduce(vectorAddition, vs)),
        makeTask("vector norm", arrow(tvector, treal),
                 lambda v: innerProduct(v, v)**0.5),
        # mcat
        makeTask("freefall velocity = (2gh)**.5", arrow(tpositive, treal),
                 lambda h: (2 * 9.8 * h)**0.5),
        makeTask("v^2 = v_0^2 + 2a(x-x0)",
                 arrow(treal, treal, treal, treal, treal),
                 lambda v0, a, x, x0: v0**2 + 2 * a * (x - x0)),
        makeTask("v = (vx**2 + vy**2)**0.5", arrow(treal, treal, treal),
                 lambda vx, vy: (vx**2 + vy**2)**0.5),
        makeTask("a_r = v**2/R", arrow(treal, tpositive, treal),
                 lambda v, r: v * v / r),
        makeTask("e = mc^2", arrow(tpositive, tpositive, treal),
                 lambda m, c: m * c * c),
Example #14
0
def OldDeepcoderPrimitives():
    return [
        Primitive("head", arrow(tlist(tint), tint), _head),
        Primitive("tail", arrow(tlist(tint), tint), _tail),
        Primitive("take", arrow(tint, tlist(tint), tlist(tint)), _take),
        Primitive("drop", arrow(tint, tlist(tint), tlist(tint)), _drop),
        Primitive("access", arrow(tint, tlist(tint), tint), _access),
        Primitive("minimum", arrow(tlist(tint), tint), _minimum),
        Primitive("maximum", arrow(tlist(tint), tint), _maximum),
        Primitive("reverse", arrow(tlist(tint), tlist(tint)), _reverse),
        Primitive("sort", arrow(tlist(tint), tlist(tint)), _sort),
        Primitive("sum", arrow(tlist(tint), tint), _sum)
    ] + [
        Primitive("map", arrow(int_to_int, tlist(tint), tlist(tint)),
                  _map),  #is this okay???
        Primitive("filter_int", arrow(int_to_bool, tlist(tint), tlist(tint)),
                  _filter),  #is this okay???
        Primitive("count", arrow(int_to_bool, tlist(tint), tint),
                  _count),  #is this okay???
        Primitive(
            "zipwith",
            arrow(int_to_int_to_int, tlist(tint), tlist(tint), tlist(tint)),
            _zipwith),  #is this okay???
        Primitive("scanl1", arrow(int_to_int_to_int, tlist(tint), tlist(tint)),
                  _scanl1),  #is this okay???
        # ] + [
        # Primitive("succ", arrow(tint, tint), _succ),
        # Primitive("pred", arrow(tint, tint), _pred),
        # Primitive("double", arrow(tint, tint), _double),
        # Primitive("half", arrow(tint, tint), _half),
        # Primitive("neg", arrow(tint, tint), _neg),
        # Primitive("square", arrow(tint, tint), _square),
        # Primitive("triple", arrow(tint, tint), _triple),
        # Primitive("third", arrow(tint, tint), _third),
        # Primitive("quad", arrow(tint, tint), _quad),
        # Primitive("quarter", arrow(tint, tint), _quarter),
        # ] + [
        # Primitive("pos", arrow(tint, tbool), _pos),
        # Primitive("neg", arrow(tint, tbool), _neg),
        # Primitive("even", arrow(tint, tbool), _even),
        # Primitive("odd", arrow(tint, tbool), _odd),
        # ] + [
        # Primitive("add", arrow(tint, tint, tint), _add),
        # Primitive("sub", arrow(tint, tint, tint), _sub),
        # Primitive("mult", arrow(tint, tint, tint), _mult),
        # Primitive("min", arrow(tint, tint, tint), _min),
        # Primitive("max", arrow(tint, tint, tint), _max)
    ] + [
        Primitive("succ_fn", int_to_int, _succ),
        Primitive("pred_fn", int_to_int, _pred),
        Primitive("double_fn", int_to_int, _double),
        Primitive("half_fn", int_to_int, _half),
        Primitive("negate_fn", int_to_int, _negate),
        Primitive("square_fn", int_to_int, _square),
        Primitive("triple_fn", int_to_int, _triple),
        Primitive("third_fn", int_to_int, _third),
        Primitive("quad_fn", int_to_int, _quad),
        Primitive("quarter_fn", int_to_int, _quarter),
    ] + [
        Primitive("pos_fn", int_to_bool, _pos),
        Primitive("neg_fn", int_to_bool, _neg),
        Primitive("even_fn", int_to_bool, _even),
        Primitive("odd_fn", int_to_bool, _odd),
    ] + [
        Primitive("add_fn", int_to_int_to_int, _add),
        Primitive("sub_fn", int_to_int_to_int, _sub),
        Primitive("mult_fn", int_to_int_to_int, _mult),
        Primitive("min_fn", int_to_int_to_int, _min),
        Primitive("max_fn", int_to_int_to_int, _max)
    ]
Example #15
0
         t.features = drawFunction(200, 10., t.f)
         delattr(t, 'f')
     test, _ = testTrainSplit(tasks, 100)
     random.seed(42)
     random.shuffle(test)
     test = test[:100]
     g = Grammar.uniform(
         [real, real_division, real_addition, real_multiplication])
     fe = FeatureExtractor([])
     BATCHSIZE = 64
 elif arguments.domain == "list":
     BATCHSIZE = 16
     tasks = retrieveJSONTasks("data/list_tasks.json") + sortBootstrap()
     tasks.extend([
         Task("remove empty lists",
              arrow(tlist(tlist(tbool)), tlist(tlist(tbool))),
              [((ls, ), list(filter(lambda l: len(l) > 0, ls)))
               for _ in range(15) for ls in [[[
                   random.random() < 0.5
                   for _ in range(random.randint(0, 3))
               ] for _ in range(4)]]]),
         Task("keep squares", arrow(tlist(tint), tlist(tint)), [
             ((xs, ), list(filter(lambda x: int(math.sqrt(x))**2 == x, xs)))
             for _ in range(15) for xs in [[
                 random.choice([0, 1, 4, 9, 16, 25])
                 if random.random() < 0.5 else random.randint(0, 9)
                 for _ in range(7)
             ]]
         ]),
         Task("keep primes", arrow(tlist(tint), tlist(tint)), [
             ((xs, ),
Example #16
0
def main(args):
    """
    Takes the return value of the `commandlineArguments()` function as input and
    trains/tests the model on manipulating sequences of numbers.
    """
    random.seed(args.pop("random_seed"))

    dataset = args.pop("dataset")
    tasks = {
        "Lucas-old":
        lambda: retrieveJSONTasks("data/list_tasks.json") + sortBootstrap(),
        "bootstrap":
        make_list_bootstrap_tasks,
        "sorting":
        sortBootstrap,
        "Lucas-depth1":
        lambda: retrieveJSONTasks("data/list_tasks2.json")[:105],
        "Lucas-depth2":
        lambda: retrieveJSONTasks("data/list_tasks2.json")[:4928],
        "Lucas-depth3":
        lambda: retrieveJSONTasks("data/list_tasks2.json"),
    }[dataset]()

    maxTasks = args.pop("maxTasks")
    if maxTasks and len(tasks) > maxTasks:
        necessaryTasks = []  # maxTasks will not consider these
        if dataset.startswith("Lucas2.0") and dataset != "Lucas2.0-depth1":
            necessaryTasks = tasks[:105]

        eprint("Unwilling to handle {} tasks, truncating..".format(len(tasks)))
        random.shuffle(tasks)
        del tasks[maxTasks:]
        tasks = necessaryTasks + tasks

    if dataset.startswith("Lucas"):
        # extra tasks for filter
        tasks.extend([
            Task("remove empty lists",
                 arrow(tlist(tlist(tbool)), tlist(tlist(tbool))),
                 [((ls, ), list(filter(lambda l: len(l) > 0, ls)))
                  for _ in range(15) for ls in [[[
                      random.random() < 0.5
                      for _ in range(random.randint(0, 3))
                  ] for _ in range(4)]]]),
            Task("keep squares", arrow(tlist(tint), tlist(tint)), [
                ((xs, ), list(filter(lambda x: int(math.sqrt(x))**2 == x, xs)))
                for _ in range(15) for xs in [[
                    random.choice([0, 1, 4, 9, 16, 25])
                    if random.random() < 0.5 else random.randint(0, 9)
                    for _ in range(7)
                ]]
            ]),
            Task("keep primes", arrow(tlist(tint), tlist(tint)), [
                ((xs, ),
                 list(
                     filter(
                         lambda x: x in
                         {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}, xs)))
                for _ in range(15) for xs in [[
                    random.choice([2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37])
                    if random.random() < 0.5 else random.randint(0, 9)
                    for _ in range(7)
                ]]
            ]),
        ])
        for i in range(4):
            tasks.extend([
                Task("keep eq %s" % i, arrow(tlist(tint), tlist(tint)),
                     [((xs, ), list(filter(lambda x: x == i, xs)))
                      for _ in range(15)
                      for xs in [[random.randint(0, 6) for _ in range(5)]]]),
                Task("remove eq %s" % i, arrow(tlist(tint), tlist(tint)),
                     [((xs, ), list(filter(lambda x: x != i, xs)))
                      for _ in range(15)
                      for xs in [[random.randint(0, 6) for _ in range(5)]]]),
                Task("keep gt %s" % i, arrow(tlist(tint), tlist(tint)),
                     [((xs, ), list(filter(lambda x: x > i, xs)))
                      for _ in range(15)
                      for xs in [[random.randint(0, 6) for _ in range(5)]]]),
                Task("remove gt %s" % i, arrow(tlist(tint), tlist(tint)),
                     [((xs, ), list(filter(lambda x: not x > i, xs)))
                      for _ in range(15)
                      for xs in [[random.randint(0, 6) for _ in range(5)]]])
            ])

    def isIdentityTask(t):
        return all(len(xs) == 1 and xs[0] == y for xs, y in t.examples)

    eprint("Removed", sum(isIdentityTask(t) for t in tasks),
           "tasks that were just the identity function")
    tasks = [t for t in tasks if not isIdentityTask(t)]

    prims = {
        "base": basePrimitives,
        "McCarthy": McCarthyPrimitives,
        "common": bootstrapTarget_extra,
        "noLength": no_length,
        "rich": primitives
    }[args.pop("primitives")]()
    haveLength = not args.pop("noLength")
    haveMap = not args.pop("noMap")
    haveUnfold = not args.pop("noUnfold")
    eprint(f"Including map as a primitive? {haveMap}")
    eprint(f"Including length as a primitive? {haveLength}")
    eprint(f"Including unfold as a primitive? {haveUnfold}")
    baseGrammar = Grammar.uniform([p
                                   for p in prims
                                   if (p.name != "map" or haveMap) and \
                                   (p.name != "unfold" or haveUnfold) and \
                                   (p.name != "length" or haveLength)])

    extractor = {
        "learned": LearnedFeatureExtractor,
    }[args.pop("extractor")]
    extractor.H = args.pop("hidden")

    timestamp = datetime.datetime.now().isoformat()
    outputDirectory = "experimentOutputs/list/%s" % timestamp
    os.system("mkdir -p %s" % outputDirectory)

    args.update({
        "featureExtractor": extractor,
        "outputPrefix": "%s/list" % outputDirectory,
        "evaluationTimeout": 0.0005,
    })

    eprint("Got {} list tasks".format(len(tasks)))
    split = args.pop("split")
    if split:
        train_some = defaultdict(list)
        for t in tasks:
            necessary = train_necessary(t)
            if not necessary:
                continue
            if necessary == "some":
                train_some[t.name.split()[0]].append(t)
            else:
                t.mustTrain = True
        for k in sorted(train_some):
            ts = train_some[k]
            random.shuffle(ts)
            ts.pop().mustTrain = True

        test, train = testTrainSplit(tasks, split)
        if True:
            test = [t for t in test if t.name not in EASYLISTTASKS]

        eprint("Alotted {} tasks for training and {} for testing".format(
            len(train), len(test)))
    else:
        train = tasks
        test = []

    explorationCompression(baseGrammar, train, testingTasks=test, **args)
Example #17
0
        Primitive("is-square", arrow(tint, tbool), _isSquare),
        # McCarthy
        Primitive("empty", tlist(t0), []),
        Primitive("cons", arrow(t0, tlist(t0), tlist(t0)), _cons),
        Primitive("car", arrow(tlist(t0), t0), _car),
        Primitive("cdr", arrow(tlist(t0), tlist(t0)), _cdr),
        Primitive("empty?", arrow(tlist(t0), tbool), _isEmpty),
        Primitive("if", arrow(tbool, t0, t0, t0), _if),
        Primitive("eq?", arrow(tint, tint, tbool), _eq),
        Primitive("+", arrow(tint, tint, tint), _addition),
        Primitive("-", arrow(tint, tint, tint), _subtraction)
    ]


zip_primitive = Primitive(
    "zip", arrow(tlist(t0), tlist(t1), arrow(t0, t1, t2), tlist(t2)), _zip)


def bootstrapTarget():
    """These are the primitives that we hope to learn from the bootstrapping procedure"""
    return [
        # learned primitives
        Primitive("map", arrow(arrow(t0, t1), tlist(t0), tlist(t1)), _map),
        Primitive(
            "unfold",
            arrow(t0, arrow(t0, tbool), arrow(t0, t1), arrow(t0, t0),
                  tlist(t1)), _unfold),
        Primitive("range", arrow(tint, tlist(tint)), _range),
        Primitive("index", arrow(tint, tlist(t0), t0), _index),
        Primitive("fold", arrow(tlist(t0), t1, arrow(t0, t1, t1), t1), _fold),
        Primitive("length", arrow(tlist(t0), tint), len),