def test_enumerate(self): used = [impl.MAP, impl.FILTER, impl.COUNT ] + [impl.GT0, impl.LT0, impl.EVEN, impl.ODD] weights = np.ones(len(used)) ctx = Context(dict(zip(used, weights))) input_type_combinations = [[LIST]] T = 1 programs = enumerate_programs(input_type_combinations, T, ctx, 1000) self.assertEqual(len(programs), 8)
def test_dfs1(self): ctx = Context(dict(zip(impl.FUNCTIONS, np.ones(len(impl.FUNCTIONS))))) inputs_list = [ [ListValue([1, -2, 3, -4, 5, -6, 7])], ] output_list = [ListValue([1, 3, 15, 105])] examples = list(zip(inputs_list, output_list)) T = 3 solution, nb_steps = dfs(examples, T, ctx) for inputs, output in examples: self.assertEqual(solution(*inputs), output) self.assertTrue(nb_steps > 10)
def test_impossible(self): """Return the first n primes which is impossible in this language.""" ctx = Context(dict(zip(impl.FUNCTIONS, np.ones(len(impl.FUNCTIONS))))) inputs_list = [ [ListValue(list(range(6)))], ] output_list = [ListValue([2,3,5,7,11,13])] examples = list(zip(inputs_list, output_list)) T = 2 solution, nb_steps = dfs(examples, T, ctx) self.assertFalse(solution) self.assertTrue(nb_steps > 2000)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--nb_inputs', type=int) parser.add_argument('--nb_train', type=int) parser.add_argument('--nb_test', type=int) parser.add_argument('--prog_len', type=int) parser.add_argument('--outfile', type=str) args = parser.parse_args() ctx = Context(dict(zip(impl.FUNCTIONS, np.ones(len(impl.FUNCTIONS))))) input_type_combinations = get_input_type_combinations(args.nb_inputs) programs = enumerate_programs(input_type_combinations, args.prog_len, ctx, args.nb_train + args.nb_test) # train / test split random.shuffle(programs) train_programs = programs[:args.nb_train] test_programs = programs[args.nb_train:args.nb_train + args.nb_test] # TODO: rethink how to enforce semantic disjoint. # Hard to get exactly nb_train/nb_test programs where nb_test is # semantic disjoint from train. #train_programs = set(programs) #test_programs = [] #pbar = tqdm.tqdm(total=args.nb_test) #for program in programs: # # enforce semantically disjoint test set # if (len(test_programs) < args.nb_test and # is_disjoint(program, train_programs - {program})): # test_programs.append(program) # train_programs.discard(program) # pbar.update(1) train_outfile = args.outfile.replace('.txt', '') + '_train.txt' test_outfile = args.outfile.replace('.txt', '') + '_test.txt' for programs, outfile in zip([train_programs, test_programs], [train_outfile, test_outfile]): print('writing ', outfile) with open(outfile, 'w') as fh: for program in sorted(list(programs)): fh.write(program.prefix + '\n')
def test_dfs(self): ctx = Context(dict(zip(impl.FUNCTIONS, np.ones(len(impl.FUNCTIONS))))) inputs_list = [ [ListValue([1, 2, 3, 4, 5])], ] output_list = [ ListValue([2, 4, 6, 8, 10]), ] examples = list(zip(inputs_list, output_list)) T = 2 solution, nb_steps = dfs(examples, T, ctx) for inputs, output in examples: self.assertEqual(solution(*inputs), output) self.assertTrue( nb_steps >= 10) # Was > before, not >= but this always failed
def test_sort_and_add(self): score_map = { impl.FILTER: .8, impl.LTIMES: .9, impl.SCAN1L: .5, impl.GT0: 1. } ctx = Context(score_map) # dfs favors more recent inputs actual = 'LIST|FILTER,>0,0|FILTER,>0,1|SCAN1L,*,2' inputs_list = [ [ListValue([1, -2, 3, -4, 5, -6, 7])], ] output_list = [ListValue([1, 3, 15, 105])] examples = list(zip(inputs_list, output_list)) T = 3 solution, nb_steps_list = sort_and_add(examples, T, ctx, gas=np.inf) self.assertEqual(len(nb_steps_list), 2) self.assertEqual(solution.prefix, actual)
def test_context(self): scores_map = { impl.MAP: 1., impl.FILTER: .5, impl.COUNT: .5, impl.TIMES2: 1., impl.MINUS1: 0. } ctx = Context(scores_map) self.assertEqual(set(ctx.functions), {impl.MAP, impl.FILTER, impl.COUNT}) func_scores = [scores_map[x] for x in ctx.functions] self.assertEqual(func_scores, list(reversed(sorted(func_scores)))) for _, funcs in ctx.typemap.items(): func_scores = [scores_map[x] for x in funcs] self.assertEqual(func_scores, list(reversed(sorted(func_scores)))) self.assertEqual(len(ctx.typemap), 4)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--nb_inputs', type=int) parser.add_argument('--nb_train', type=int) parser.add_argument('--nb_test', type=int) parser.add_argument('--prog_len', type=int) parser.add_argument('--train_out', type=str) parser.add_argument('--test_out', type=str) parser.add_argument('--enforce_disjoint', action='store_true') args = parser.parse_args() ctx = Context(dict(zip(impl.FUNCTIONS, np.ones(len(impl.FUNCTIONS))))) input_type_combinations = get_input_type_combinations(args.nb_inputs) programs = enumerate_programs(input_type_combinations, args.prog_len, ctx, args.nb_train + args.nb_test) # train / test split random.shuffle(programs) args.nb_test = min(len(programs) / 2, args.nb_test) args.nb_train = min(len(programs) - args.nb_test, args.nb_train) if args.enforce_disjoint: train_programs = set(programs) test_programs = [] for program in tqdm.tqdm(programs, total=args.nb_test): input_output_examples = None for i in range(5): try: input_output_examples = constraint.get_input_output_examples( program, M=2) except NullInputError: continue if not input_output_examples: train_programs.discard(program) continue same_programs = set() for train_program in train_programs: if constraint.is_same(program, train_program, input_output_examples): same_programs.add(train_program) test_programs.append(program) train_programs.difference_update(same_programs) if len(test_programs) == args.nb_test: break else: train_programs = programs[args.nb_train] test_programs = programs[args.nb_train:args.nb_train + args.nb_test] train_outfile = args.train_out test_outfile = args.test_out for programs, outfile in zip([train_programs, test_programs], [train_outfile, test_outfile]): print('writing ', outfile, '({} programs)'.format(len(programs))) with open(outfile, 'w') as fh: for program in sorted(list(programs)): fh.write(program.prefix + '\n')