def get_episode_generator(episode_type): # Returns function that generates episodes, # and language class for the input and output language # # Input # episode_type : # # Output # generate_episode: function handle for generating episodes # input_lang: Language object for input sequence # output_lang: Language object for output sequence if episode_type in [ 'scan_random', 'scan_length_original', 'scan_simple_original', 'scan_around_right_original', 'scan_jump_original', 'wordToNumber', 'rules_gen', 'rules_gen_xl' ] or 'lang_' in episode_type: #todo: check that it uses generate_episode_train, generate_episode_test, input_lang, output_lang, _ = util.get_episode_generator( episode_type) else: raise Exception("episode_type is not valid") return generate_episode_train, generate_episode_test, input_lang, output_lang
def generate_val_episodes(self): generate_episode_train, generate_episode_test, _, _, _ = get_episode_generator( self.episode_type) self.tabu_episodes = set([]) self.samples_val = [] for i in range(num_episodes_val): sample = generate_episode_test(self.tabu_episodes) self.samples_val.append(sample) self.tabu_episodes = tabu_update(self.tabu_episodes, sample['identifier'])
outRules = [] #print(nPrims) for i in range(nPrims): #print(len(rules)) outRules.append(dummySample(i)) dist_obj = make_uniform_under_dist(self.distance) pyprob.observe(dist_obj, name="dist") #last = pyprob.observe(pyprob.distributions.Categorical(torch.tensor([0.5,0.5])), name="dist") #outRules.append(last) return tuple(outRules) if __name__== '__main__': from util import get_episode_generator generate_episode_train, generate_episode_test, input_lang, output_lang, prog_lang = get_episode_generator("scan_random") sample = generate_episode_train(set()) grammar = sample['grammar'] print(grammar) #TRUE # trueRules = [1., 0., 1., 0., 1.] # model = SimpleModel() IO = (sample['xs'], sample['ys']) model = FullModel(IO, input_lang.symbols, output_lang.symbols) for i in range(10): g, distance = model.forward(output_distance=True) print("iiii", i)
elif args.type == 'WordToNumber': model = WordToNumber.load(path) else: assert False, "not implemented yet" else: print("new model ...") if args.type == 'miniscanRBbase': model = MiniscanRBBaseline.new(args) elif args.type == 'WordToNumber': model = WordToNumber.new(args) else: assert False, "not implemented yet" if args.num_pretrain_episodes > model.num_pretrain_episodes: model.num_pretrain_episodes = args.num_pretrain_episodes generate_episode_train, _, _, _, _ = get_episode_generator( model.episode_type) samples_val = model.samples_val # if args.type in ['WordToNumber', 'NumberToWord']: # model.samples_val = [] val_states = [] for s in samples_val: states, rules = model.sample_to_statelist(s) #for state, rule in zip(states, rules): for i in range(len(rules)): val_states.append(model.state_rule_to_sample(states[i], rules[i])) if args.parallel: dataqueue = GenData(lambda: gen_samples(generate_episode_train, model), batchsize=args.batchsize,
# samples_val = model.samples_val #load model if args.type == 'miniscanRBbase': model = MiniscanRBBaseline.load(path) elif args.type == 'WordToNumber': model = WordToNumber.load(path) else: assert False, "not implemented yet" # if args.new_test_ep: print("generating new test examples") generate_episode_train, generate_episode_test, input_lang, output_lang, prog_lang = get_episode_generator( args.new_test_ep, model_in_lang=model.input_lang, model_out_lang=model.output_lang, model_prog_lang=model.prog_lang) #model.tabu_episodes = set([]) model.samples_val = [] for i in range(N_TEST_NEW): sample = generate_episode_test(model.tabu_episodes) if args.hack_gt_g: sample['grammar'] = Grammar(exact_perm_doubled_rules(), model.input_lang.symbols) model.samples_val.append(sample) if not args.duplicate_test: model.tabu_episodes = tabu_update(model.tabu_episodes, sample['identifier']) model.input_lang = input_lang
if cur not in sample['xs'] } rules = model.detokenize_action(candidate) test_state = State(query_examples, rules) try: testout = model.REPL(test_state, None) except (ParseError, UnfinishedError, REPLError): print("YOU ERRORED ON BEST GUESS") return 0.0 return (len(test_state.examples) - len(testout.examples)) / len( test_state.examples) path = os.path.join(args.dir_model, args.fn_out_model) generate_episode_train, generate_episode_test, input_lang, output_lang, prog_lang = get_episode_generator( args.episode_type) m = torch.load(args.save_path) m.cuda() m.max_length = 50 with open(args.load_data, 'rb') as h: test_samples = dill.load(h) # if args.new_test_ep: # print("generating new test examples") # generate_episode_train, generate_episode_test, input_lang, output_lang, prog_lang = get_episode_generator( # args.new_test_ep, model_in_lang=model.input_lang, # model_out_lang=model.output_lang, # model_prog_lang=model.prog_lang) # #model.tabu_episodes = set([])
def __init__(self, use_cuda, episode_type, emb_size, nlayers, dropout_p, adam_learning_rate, positional, use_prog_lang_for_input=False): self.USE_CUDA = use_cuda self.episode_type = episode_type self.emb_size = emb_size self.nlayers = nlayers self.dropout_p = dropout_p self.adam_learning_rate = adam_learning_rate self.positional = positional generate_episode_train, generate_episode_test, self.input_lang, self.output_lang, self.prog_lang = get_episode_generator( episode_type) if use_prog_lang_for_input: self.input_size = self.prog_lang.n_symbols else: self.input_size = self.input_lang.n_symbols self.output_size = self.output_lang.n_symbols self.prog_size = self.prog_lang.n_symbols self.encoder = BatchedRuleSynthEncoderRNN(emb_size, self.input_size, self.output_size, self.prog_size, nlayers, dropout_p, tie_encoders=False, rule_positions=positional) self.decoder = BatchedDoubleAttnDecoderRNN(emb_size, self.prog_size, nlayers, dropout_p, fancy_attn=False) if self.USE_CUDA: self.encoder = self.encoder.cuda() self.decoder = self.decoder.cuda() print(' Set learning rate to ' + str(adam_learning_rate)) self.encoder_optimizer = optim.Adam(self.encoder.parameters(), lr=adam_learning_rate) self.decoder_optimizer = optim.Adam(self.decoder.parameters(), lr=adam_learning_rate) print("") print("Architecture options...") print(" Using Synthesis network") print("") describe_model(self.encoder) describe_model(self.decoder) self.pretrain_episode = 0 self.rl_episode = 0
def run_search(model, args): if args.new_test_ep: print("generating new test examples") generate_episode_train, generate_episode_test, input_lang, output_lang, prog_lang = get_episode_generator( args.new_test_ep, model_in_lang=model.input_lang, model_out_lang=model.output_lang, model_prog_lang=model.prog_lang) #model.tabu_episodes = set([]) model.samples_val = [] for i in range(args.n_test): sample = generate_episode_test(model.tabu_episodes) if args.hack_gt_g: sample['grammar'] = Grammar( exact_perm_doubled_rules() , model.input_lang.symbols) model.samples_val.append(sample) if not args.duplicate_test: model.tabu_episodes = tabu_update(model.tabu_episodes, sample['identifier']) model.input_lang = input_lang model.output_lang = output_lang if not args.val_ll_only: model.prog_lang = prog_lang if args.val_ll: val_ll = compute_val_ll(model) print("val ll:", val_ll) if args.val_ll_only: assert False print(f"testing using {args.mode}") #print("batchsize:", args.batchsize) count = 0 results = [] frac_exs_hits = [] tot_time = 0. tot_nodes = 0 all_examples = set() for j, sample in enumerate(model.samples_val): print() print(f"Task {j+1} out of {len(model.samples_val)}") print("ground truth grammar:") print(sample['identifier']) hit, solution, stats = batched_test_with_sampling(sample, model, max_len=1 if 'RB' in args.type or 'Word' in args.type else 15, timeout=args.timeout, verbose=True, min_len=0, batch_size=args.batchsize, nosearch=args.nosearch, partial_credit=args.partial_credit, max_rule_size= 100 if 'RB' in args.type or 'Word' in args.type else 15) tot_time += time.time() - stats['start_time'] tot_nodes += stats['nodes_expanded'] for ex in sample['xs']: all_examples.add(tuple(ex)) if hit: print('done with one of the runs') return tot_time, tot_nodes, len(all_examples) else: assert 0, "didn't hit after 20 examples, that's dumb!"
""" test languages """ import random import os import numpy as np #japanese: from util import get_episode_generator generate_episode_train, generate_episode_test, input_lang, output_lang, prog_lang = get_episode_generator( "wordToNumber") input_symbols = input_lang.symbols MAXLEN = 8 MAXNUM = 99999999 def parseExamples(lang): return { 'ja': parseExamples_JA, 'ko': parseExamples_KO, 'vi': parseExamples_VI, 'it': parseExamples_IT, 'es': parseExamples_ES, 'zh': parseExamples_ZH, 'en': parseExamples_EN, 'fr': parseExamples_FR, 'el': parseExamples_EL, }[lang]()
import os import dill import pyprob import time if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--timeout', type=int, default=30) parser.add_argument('--num_traces', type=int, default=1500) parser.add_argument('--savefile', type=str, default="results/smcREPL.p") parser.add_argument('--mode', type=str, default="MCMC") parser.add_argument('--load_data', type=str, default='') args = parser.parse_args() _, _, input_lang, output_lang, prog_lang = get_episode_generator( "scan_simple_original") if args.load_data: if os.path.isfile(args.load_data): print('loading test data ... ') with open(args.load_data, 'rb') as h: test_samples = dill.load(h) else: assert False frac_exs_hits = [] for i, sample in enumerate(test_samples): IO = (sample['xs'], sample['ys']) model = FullModel(IO, input_lang.symbols, output_lang.symbols)