def prog_io_pair(prog, max_len, counter=0): try: ilen = np.random.randint(max_len - 3) + 1 bound = max(15 - (counter / 20), 1) inp = [random.choice(range(-bound, bound)) for _ in range(ilen)] inp_toks = [ program_utils.prog_rev_vocab[t] for t in program_utils.tokenize(str(inp)) if t != "," ] out = program_utils.evaluate(prog, {"a": inp}) out_toks = [ program_utils.prog_rev_vocab[t] for t in program_utils.tokenize(str(out)) if t != "," ] if counter > 400: out_toks = [] if (out_toks and out_toks[0] == program_utils.prog_rev_vocab["["] and len(out_toks) != len([o for o in out if o == ","]) + 3): raise ValueError("generated list with too long ints") if (out_toks and out_toks[0] != program_utils.prog_rev_vocab["["] and len(out_toks) > 1): raise ValueError("generated one int but tokenized it to many") if len(out_toks) > max_len: raise ValueError("output too long") return (inp_toks, out_toks) except ValueError: return prog_io_pair(prog, max_len, counter + 1)
def prog_io_pair(prog, max_len, counter=0): try: ilen = np.random.randint(max_len - 3) + 1 bound = max(15 - (counter / 20), 1) inp = [random.choice(range(-bound, bound)) for _ in range(ilen)] inp_toks = [program_utils.prog_rev_vocab[t] for t in program_utils.tokenize(str(inp)) if t != ","] out = program_utils.evaluate(prog, {"a": inp}) out_toks = [program_utils.prog_rev_vocab[t] for t in program_utils.tokenize(str(out)) if t != ","] if counter > 400: out_toks = [] if (out_toks and out_toks[0] == program_utils.prog_rev_vocab["["] and len(out_toks) != len([o for o in out if o == ","]) + 3): raise ValueError("generated list with too long ints") if (out_toks and out_toks[0] != program_utils.prog_rev_vocab["["] and len(out_toks) > 1): raise ValueError("generated one int but tokenized it to many") if len(out_toks) > max_len: raise ValueError("output too long") return (inp_toks, out_toks) except ValueError: return prog_io_pair(prog, max_len, counter+1)
def init_data(task, length, nbr_cases, nclass): """Data initialization.""" def rand_pair(l, task): """Random data pair for a task. Total length should be <= l.""" k = int((l-1)/2) base = 10 if task[0] == "b": base = 2 if task[0] == "q": base = 4 d1 = [np.random.randint(base) for _ in xrange(k)] d2 = [np.random.randint(base) for _ in xrange(k)] if task in ["add", "badd", "qadd"]: res = add(d1, d2, base) elif task in ["mul", "bmul"]: d1n = sum([d * (base ** i) for i, d in enumerate(d1)]) d2n = sum([d * (base ** i) for i, d in enumerate(d2)]) if task == "bmul": res = [int(x) for x in list(reversed(str(bin(d1n * d2n))))[:-2]] else: res = [int(x) for x in list(reversed(str(d1n * d2n)))] else: sys.exit() sep = [12] if task in ["add", "badd", "qadd"]: sep = [11] inp = [d + 1 for d in d1] + sep + [d + 1 for d in d2] return inp, [r + 1 for r in res] def rand_dup_pair(l): """Random data pair for duplication task. Total length should be <= l.""" k = int(l/2) x = [np.random.randint(nclass - 1) + 1 for _ in xrange(k)] inp = x + [0 for _ in xrange(l - k)] res = x + x + [0 for _ in xrange(l - 2*k)] return inp, res def rand_rev2_pair(l): """Random data pair for reverse2 task. Total length should be <= l.""" inp = [(np.random.randint(nclass - 1) + 1, np.random.randint(nclass - 1) + 1) for _ in xrange(l/2)] res = [i for i in reversed(inp)] return [x for p in inp for x in p], [x for p in res for x in p] def rand_search_pair(l): """Random data pair for search task. Total length should be <= l.""" inp = [(np.random.randint(nclass - 1) + 1, np.random.randint(nclass - 1) + 1) for _ in xrange(l-1/2)] q = np.random.randint(nclass - 1) + 1 res = 0 for (k, v) in reversed(inp): if k == q: res = v return [x for p in inp for x in p] + [q], [res] def rand_kvsort_pair(l): """Random data pair for key-value sort. Total length should be <= l.""" keys = [(np.random.randint(nclass - 1) + 1, i) for i in xrange(l/2)] vals = [np.random.randint(nclass - 1) + 1 for _ in xrange(l/2)] kv = [(k, vals[i]) for (k, i) in keys] sorted_kv = [(k, vals[i]) for (k, i) in sorted(keys)] return [x for p in kv for x in p], [x for p in sorted_kv for x in p] def prog_io_pair(prog, max_len, counter=0): try: ilen = np.random.randint(max_len - 3) + 1 bound = max(15 - (counter / 20), 1) inp = [random.choice(range(-bound, bound)) for _ in range(ilen)] inp_toks = [program_utils.prog_rev_vocab[t] for t in program_utils.tokenize(str(inp)) if t != ","] out = program_utils.evaluate(prog, {"a": inp}) out_toks = [program_utils.prog_rev_vocab[t] for t in program_utils.tokenize(str(out)) if t != ","] if counter > 400: out_toks = [] if (out_toks and out_toks[0] == program_utils.prog_rev_vocab["["] and len(out_toks) != len([o for o in out if o == ","]) + 3): raise ValueError("generated list with too long ints") if (out_toks and out_toks[0] != program_utils.prog_rev_vocab["["] and len(out_toks) > 1): raise ValueError("generated one int but tokenized it to many") if len(out_toks) > max_len: raise ValueError("output too long") return (inp_toks, out_toks) except ValueError: return prog_io_pair(prog, max_len, counter+1) def spec(inp): """Return the target given the input for some tasks.""" if task == "sort": return sorted(inp) elif task == "id": return inp elif task == "rev": return [i for i in reversed(inp)] elif task == "incr": carry = 1 res = [] for i in xrange(len(inp)): if inp[i] + carry < nclass: res.append(inp[i] + carry) carry = 0 else: res.append(1) carry = 1 return res elif task == "left": return [inp[0]] elif task == "right": return [inp[-1]] elif task == "left-shift": return [inp[l-1] for l in xrange(len(inp))] elif task == "right-shift": return [inp[l+1] for l in xrange(len(inp))] else: print_out("Unknown spec for task " + str(task)) sys.exit() l = length cur_time = time.time() total_time = 0.0 is_prog = task in ["progeval", "progsynth"] if is_prog: inputs_per_prog = 5 program_utils.make_vocab() progs = read_tmp_file("programs_len%d" % (l / 10)) if not progs: progs = program_utils.gen(l / 10, 1.2 * nbr_cases / inputs_per_prog) write_tmp_file("programs_len%d" % (l / 10), progs) prog_ios = read_tmp_file("programs_len%d_io" % (l / 10)) nbr_cases = min(nbr_cases, len(progs) * inputs_per_prog) / 1.2 if not prog_ios: # Generate program io data. prog_ios = [] for pidx, prog in enumerate(progs): if pidx % 500 == 0: print_out("== generating io pairs for program %d" % pidx) if pidx * inputs_per_prog > nbr_cases * 1.2: break ptoks = [program_utils.prog_rev_vocab[t] for t in program_utils.tokenize(prog)] ptoks.append(program_utils.prog_rev_vocab["_EOS"]) plen = len(ptoks) for _ in xrange(inputs_per_prog): if task == "progeval": inp, out = prog_io_pair(prog, plen) prog_ios.append(str(inp) + "\t" + str(out) + "\t" + prog) elif task == "progsynth": plen = max(len(ptoks), 8) for _ in xrange(3): inp, out = prog_io_pair(prog, plen / 2) prog_ios.append(str(inp) + "\t" + str(out) + "\t" + prog) write_tmp_file("programs_len%d_io" % (l / 10), prog_ios) prog_ios_dict = {} for s in prog_ios: i, o, p = s.split("\t") i_clean = "".join([c for c in i if c.isdigit() or c == " "]) o_clean = "".join([c for c in o if c.isdigit() or c == " "]) inp = [int(x) for x in i_clean.split()] out = [int(x) for x in o_clean.split()] if inp and out: if p in prog_ios_dict: prog_ios_dict[p].append([inp, out]) else: prog_ios_dict[p] = [[inp, out]] # Use prog_ios_dict to create data. progs = [] for prog in prog_ios_dict: if len([c for c in prog if c == ";"]) <= (l / 10): progs.append(prog) nbr_cases = min(nbr_cases, len(progs) * inputs_per_prog) / 1.2 print_out("== %d training cases on %d progs" % (nbr_cases, len(progs))) for pidx, prog in enumerate(progs): if pidx * inputs_per_prog > nbr_cases * 1.2: break ptoks = [program_utils.prog_rev_vocab[t] for t in program_utils.tokenize(prog)] ptoks.append(program_utils.prog_rev_vocab["_EOS"]) plen = len(ptoks) dset = train_set if pidx < nbr_cases / inputs_per_prog else test_set for _ in xrange(inputs_per_prog): if task == "progeval": inp, out = prog_ios_dict[prog].pop() dset[task][bin_for(plen)].append([[ptoks, inp, [], []], [out]]) elif task == "progsynth": plen, ilist = max(len(ptoks), 8), [[]] for _ in xrange(3): inp, out = prog_ios_dict[prog].pop() ilist.append(inp + out) dset[task][bin_for(plen)].append([ilist, [ptoks]]) for case in xrange(0 if is_prog else nbr_cases): total_time += time.time() - cur_time cur_time = time.time() if l > 10000 and case % 100 == 1: print_out(" avg gen time %.4f s" % (total_time / float(case))) if task in ["add", "badd", "qadd", "bmul", "mul"]: i, t = rand_pair(l, task) train_set[task][bin_for(len(i))].append([[[], i, [], []], [t]]) i, t = rand_pair(l, task) test_set[task][bin_for(len(i))].append([[[], i, [], []], [t]]) elif task == "dup": i, t = rand_dup_pair(l) train_set[task][bin_for(len(i))].append([[i], [t]]) i, t = rand_dup_pair(l) test_set[task][bin_for(len(i))].append([[i], [t]]) elif task == "rev2": i, t = rand_rev2_pair(l) train_set[task][bin_for(len(i))].append([[i], [t]]) i, t = rand_rev2_pair(l) test_set[task][bin_for(len(i))].append([[i], [t]]) elif task == "search": i, t = rand_search_pair(l) train_set[task][bin_for(len(i))].append([[i], [t]]) i, t = rand_search_pair(l) test_set[task][bin_for(len(i))].append([[i], [t]]) elif task == "kvsort": i, t = rand_kvsort_pair(l) train_set[task][bin_for(len(i))].append([[i], [t]]) i, t = rand_kvsort_pair(l) test_set[task][bin_for(len(i))].append([[i], [t]]) elif task not in ["progeval", "progsynth"]: inp = [np.random.randint(nclass - 1) + 1 for i in xrange(l)] target = spec(inp) train_set[task][bin_for(l)].append([[inp], [target]]) inp = [np.random.randint(nclass - 1) + 1 for i in xrange(l)] target = spec(inp) test_set[task][bin_for(l)].append([[inp], [target]])