コード例 #1
0
    def test_enumerate(self):
        used = [impl.MAP, impl.FILTER, impl.COUNT
                ] + [impl.GT0, impl.LT0, impl.EVEN, impl.ODD]
        weights = np.ones(len(used))
        ctx = Context(dict(zip(used, weights)))

        input_type_combinations = [[LIST]]
        T = 1
        programs = enumerate_programs(input_type_combinations, T, ctx, 1000)
        self.assertEqual(len(programs), 8)
コード例 #2
0
    def test_dfs1(self):
        ctx = Context(dict(zip(impl.FUNCTIONS, np.ones(len(impl.FUNCTIONS)))))

        inputs_list = [
            [ListValue([1, -2, 3, -4, 5, -6, 7])],
        ]
        output_list = [ListValue([1, 3, 15, 105])]
        examples = list(zip(inputs_list, output_list))

        T = 3
        solution, nb_steps = dfs(examples, T, ctx)
        for inputs, output in examples:
            self.assertEqual(solution(*inputs), output)
        self.assertTrue(nb_steps > 10)
コード例 #3
0
    def test_impossible(self):
        """Return the first n primes which is impossible in this language."""
        ctx = Context(dict(zip(impl.FUNCTIONS, np.ones(len(impl.FUNCTIONS)))))

        inputs_list = [
            [ListValue(list(range(6)))],
        ]
        output_list = [ListValue([2,3,5,7,11,13])]

        examples = list(zip(inputs_list, output_list))

        T = 2
        solution, nb_steps = dfs(examples, T, ctx)
        self.assertFalse(solution)
        self.assertTrue(nb_steps > 2000)
コード例 #4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--nb_inputs', type=int)
    parser.add_argument('--nb_train', type=int)
    parser.add_argument('--nb_test', type=int)
    parser.add_argument('--prog_len', type=int)
    parser.add_argument('--outfile', type=str)
    args = parser.parse_args()

    ctx = Context(dict(zip(impl.FUNCTIONS, np.ones(len(impl.FUNCTIONS)))))

    input_type_combinations = get_input_type_combinations(args.nb_inputs)

    programs = enumerate_programs(input_type_combinations, args.prog_len, ctx,
                                  args.nb_train + args.nb_test)

    # train / test split
    random.shuffle(programs)

    train_programs = programs[:args.nb_train]
    test_programs = programs[args.nb_train:args.nb_train + args.nb_test]

    # TODO: rethink how to enforce semantic disjoint.
    # Hard to get exactly nb_train/nb_test programs where nb_test is
    # semantic disjoint from train.

    #train_programs = set(programs)
    #test_programs = []
    #pbar = tqdm.tqdm(total=args.nb_test)
    #for program in programs:
    #    # enforce semantically disjoint test set
    #    if (len(test_programs) < args.nb_test and
    #        is_disjoint(program, train_programs - {program})):
    #       test_programs.append(program)
    #       train_programs.discard(program)
    #       pbar.update(1)

    train_outfile = args.outfile.replace('.txt', '') + '_train.txt'
    test_outfile = args.outfile.replace('.txt', '') + '_test.txt'

    for programs, outfile in zip([train_programs, test_programs],
                                 [train_outfile, test_outfile]):
        print('writing ', outfile)
        with open(outfile, 'w') as fh:
            for program in sorted(list(programs)):
                fh.write(program.prefix + '\n')
コード例 #5
0
    def test_dfs(self):
        ctx = Context(dict(zip(impl.FUNCTIONS, np.ones(len(impl.FUNCTIONS)))))

        inputs_list = [
            [ListValue([1, 2, 3, 4, 5])],
        ]
        output_list = [
            ListValue([2, 4, 6, 8, 10]),
        ]
        examples = list(zip(inputs_list, output_list))

        T = 2
        solution, nb_steps = dfs(examples, T, ctx)
        for inputs, output in examples:
            self.assertEqual(solution(*inputs), output)
        self.assertTrue(
            nb_steps >= 10)  # Was > before, not >= but this always failed
コード例 #6
0
    def test_sort_and_add(self):
        score_map = {
            impl.FILTER: .8,
            impl.LTIMES: .9,
            impl.SCAN1L: .5,
            impl.GT0: 1.
        }
        ctx = Context(score_map)

        # dfs favors more recent inputs
        actual = 'LIST|FILTER,>0,0|FILTER,>0,1|SCAN1L,*,2'
        inputs_list = [
            [ListValue([1, -2, 3, -4, 5, -6, 7])],
        ]
        output_list = [ListValue([1, 3, 15, 105])]
        examples = list(zip(inputs_list, output_list))
        T = 3
        solution, nb_steps_list = sort_and_add(examples, T, ctx, gas=np.inf)
        self.assertEqual(len(nb_steps_list), 2)
        self.assertEqual(solution.prefix, actual)
コード例 #7
0
    def test_context(self):

        scores_map = {
            impl.MAP: 1.,
            impl.FILTER: .5,
            impl.COUNT: .5,
            impl.TIMES2: 1.,
            impl.MINUS1: 0.
        }

        ctx = Context(scores_map)

        self.assertEqual(set(ctx.functions),
                         {impl.MAP, impl.FILTER, impl.COUNT})
        func_scores = [scores_map[x] for x in ctx.functions]
        self.assertEqual(func_scores, list(reversed(sorted(func_scores))))

        for _, funcs in ctx.typemap.items():
            func_scores = [scores_map[x] for x in funcs]
            self.assertEqual(func_scores, list(reversed(sorted(func_scores))))

        self.assertEqual(len(ctx.typemap), 4)
コード例 #8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--nb_inputs', type=int)
    parser.add_argument('--nb_train', type=int)
    parser.add_argument('--nb_test', type=int)
    parser.add_argument('--prog_len', type=int)
    parser.add_argument('--train_out', type=str)
    parser.add_argument('--test_out', type=str)
    parser.add_argument('--enforce_disjoint', action='store_true')
    args = parser.parse_args()

    ctx = Context(dict(zip(impl.FUNCTIONS, np.ones(len(impl.FUNCTIONS)))))

    input_type_combinations = get_input_type_combinations(args.nb_inputs)

    programs = enumerate_programs(input_type_combinations, args.prog_len, ctx,
                                  args.nb_train + args.nb_test)

    # train / test split
    random.shuffle(programs)

    args.nb_test = min(len(programs) / 2, args.nb_test)
    args.nb_train = min(len(programs) - args.nb_test, args.nb_train)

    if args.enforce_disjoint:
        train_programs = set(programs)
        test_programs = []
        for program in tqdm.tqdm(programs, total=args.nb_test):
            input_output_examples = None
            for i in range(5):
                try:
                    input_output_examples = constraint.get_input_output_examples(
                        program, M=2)
                except NullInputError:
                    continue

            if not input_output_examples:
                train_programs.discard(program)
                continue

            same_programs = set()
            for train_program in train_programs:
                if constraint.is_same(program, train_program,
                                      input_output_examples):
                    same_programs.add(train_program)

            test_programs.append(program)
            train_programs.difference_update(same_programs)

            if len(test_programs) == args.nb_test:
                break
    else:
        train_programs = programs[args.nb_train]
        test_programs = programs[args.nb_train:args.nb_train + args.nb_test]

    train_outfile = args.train_out
    test_outfile = args.test_out

    for programs, outfile in zip([train_programs, test_programs],
                                 [train_outfile, test_outfile]):
        print('writing ', outfile, '({} programs)'.format(len(programs)))
        with open(outfile, 'w') as fh:
            for program in sorted(list(programs)):
                fh.write(program.prefix + '\n')