Exemplo n.º 1
0
def main(args):
    """
    Takes the return value of the `commandlineArguments()` function as input and
    trains/tests the model on manipulating sequences of numbers.
    """
    random.seed(args.pop("random_seed"))

    dataset = args.pop("dataset")
    tasks = {
        "Lucas-old":
        lambda: retrieveJSONTasks("data/list_tasks.json") + sortBootstrap(),
        "bootstrap":
        make_list_bootstrap_tasks,
        "sorting":
        sortBootstrap,
        "Lucas-depth1":
        lambda: retrieveJSONTasks("data/list_tasks2.json")[:105],
        "Lucas-depth2":
        lambda: retrieveJSONTasks("data/list_tasks2.json")[:4928],
        "Lucas-depth3":
        lambda: retrieveJSONTasks("data/list_tasks2.json"),
    }[dataset]()

    maxTasks = args.pop("maxTasks")
    if maxTasks and len(tasks) > maxTasks:
        necessaryTasks = []  # maxTasks will not consider these
        if dataset.startswith("Lucas2.0") and dataset != "Lucas2.0-depth1":
            necessaryTasks = tasks[:105]

        eprint("Unwilling to handle {} tasks, truncating..".format(len(tasks)))
        random.shuffle(tasks)
        del tasks[maxTasks:]
        tasks = necessaryTasks + tasks

    if dataset.startswith("Lucas"):
        # extra tasks for filter
        tasks.extend([
            Task("remove empty lists",
                 arrow(tlist(tlist(tbool)), tlist(tlist(tbool))),
                 [((ls, ), list(filter(lambda l: len(l) > 0, ls)))
                  for _ in range(15) for ls in [[[
                      random.random() < 0.5
                      for _ in range(random.randint(0, 3))
                  ] for _ in range(4)]]]),
            Task("keep squares", arrow(tlist(tint), tlist(tint)), [
                ((xs, ), list(filter(lambda x: int(math.sqrt(x))**2 == x, xs)))
                for _ in range(15) for xs in [[
                    random.choice([0, 1, 4, 9, 16, 25])
                    if random.random() < 0.5 else random.randint(0, 9)
                    for _ in range(7)
                ]]
            ]),
            Task("keep primes", arrow(tlist(tint), tlist(tint)), [
                ((xs, ),
                 list(
                     filter(
                         lambda x: x in
                         {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}, xs)))
                for _ in range(15) for xs in [[
                    random.choice([2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37])
                    if random.random() < 0.5 else random.randint(0, 9)
                    for _ in range(7)
                ]]
            ]),
        ])
        for i in range(4):
            tasks.extend([
                Task("keep eq %s" % i, arrow(tlist(tint), tlist(tint)),
                     [((xs, ), list(filter(lambda x: x == i, xs)))
                      for _ in range(15)
                      for xs in [[random.randint(0, 6) for _ in range(5)]]]),
                Task("remove eq %s" % i, arrow(tlist(tint), tlist(tint)),
                     [((xs, ), list(filter(lambda x: x != i, xs)))
                      for _ in range(15)
                      for xs in [[random.randint(0, 6) for _ in range(5)]]]),
                Task("keep gt %s" % i, arrow(tlist(tint), tlist(tint)),
                     [((xs, ), list(filter(lambda x: x > i, xs)))
                      for _ in range(15)
                      for xs in [[random.randint(0, 6) for _ in range(5)]]]),
                Task("remove gt %s" % i, arrow(tlist(tint), tlist(tint)),
                     [((xs, ), list(filter(lambda x: not x > i, xs)))
                      for _ in range(15)
                      for xs in [[random.randint(0, 6) for _ in range(5)]]])
            ])

    def isIdentityTask(t):
        return all(len(xs) == 1 and xs[0] == y for xs, y in t.examples)

    eprint("Removed", sum(isIdentityTask(t) for t in tasks),
           "tasks that were just the identity function")
    tasks = [t for t in tasks if not isIdentityTask(t)]

    prims = {
        "base": basePrimitives,
        "McCarthy": McCarthyPrimitives,
        "common": bootstrapTarget_extra,
        "noLength": no_length,
        "rich": primitives
    }[args.pop("primitives")]()
    haveLength = not args.pop("noLength")
    haveMap = not args.pop("noMap")
    haveUnfold = not args.pop("noUnfold")
    eprint(f"Including map as a primitive? {haveMap}")
    eprint(f"Including length as a primitive? {haveLength}")
    eprint(f"Including unfold as a primitive? {haveUnfold}")
    baseGrammar = Grammar.uniform([p
                                   for p in prims
                                   if (p.name != "map" or haveMap) and \
                                   (p.name != "unfold" or haveUnfold) and \
                                   (p.name != "length" or haveLength)])

    extractor = {
        "learned": LearnedFeatureExtractor,
    }[args.pop("extractor")]
    extractor.H = args.pop("hidden")

    timestamp = datetime.datetime.now().isoformat()
    outputDirectory = "experimentOutputs/list/%s" % timestamp
    os.system("mkdir -p %s" % outputDirectory)

    args.update({
        "featureExtractor": extractor,
        "outputPrefix": "%s/list" % outputDirectory,
        "evaluationTimeout": 0.0005,
    })

    eprint("Got {} list tasks".format(len(tasks)))
    split = args.pop("split")
    if split:
        train_some = defaultdict(list)
        for t in tasks:
            necessary = train_necessary(t)
            if not necessary:
                continue
            if necessary == "some":
                train_some[t.name.split()[0]].append(t)
            else:
                t.mustTrain = True
        for k in sorted(train_some):
            ts = train_some[k]
            random.shuffle(ts)
            ts.pop().mustTrain = True

        test, train = testTrainSplit(tasks, split)
        if True:
            test = [t for t in test if t.name not in EASYLISTTASKS]

        eprint("Alotted {} tasks for training and {} for testing".format(
            len(train), len(test)))
    else:
        train = tasks
        test = []

    explorationCompression(baseGrammar, train, testingTasks=test, **args)
Exemplo n.º 2
0
 elif arguments.domain == "rational":
     tasks = rational.makeTasks()
     for t in tasks:
         t.features = drawFunction(200, 10., t.f)
         delattr(t, 'f')
     test, _ = testTrainSplit(tasks, 100)
     random.seed(42)
     random.shuffle(test)
     test = test[:100]
     g = Grammar.uniform(
         [real, real_division, real_addition, real_multiplication])
     fe = FeatureExtractor([])
     BATCHSIZE = 64
 elif arguments.domain == "list":
     BATCHSIZE = 16
     tasks = retrieveJSONTasks("data/list_tasks.json") + sortBootstrap()
     tasks.extend([
         Task("remove empty lists",
              arrow(tlist(tlist(tbool)), tlist(tlist(tbool))),
              [((ls, ), list(filter(lambda l: len(l) > 0, ls)))
               for _ in range(15) for ls in [[[
                   random.random() < 0.5
                   for _ in range(random.randint(0, 3))
               ] for _ in range(4)]]]),
         Task("keep squares", arrow(tlist(tint), tlist(tint)), [
             ((xs, ), list(filter(lambda x: int(math.sqrt(x))**2 == x, xs)))
             for _ in range(15) for xs in [[
                 random.choice([0, 1, 4, 9, 16, 25])
                 if random.random() < 0.5 else random.randint(0, 9)
                 for _ in range(7)
             ]]