def scan_evaluation_dir_only(mytype, split, input_lang, output_lang): # Load an entire SCAN pattern file as the query set # Just use the isolated directions as the support set # # Input # mytype : type of SCAN experiment # split : 'train' or 'test' # ... other inputs are language objects D_query = ge.load_scan_file(mytype, split) D_support = [('turn left', 'I_TURN_LEFT'), ('turn right', 'I_TURN_RIGHT')] random.shuffle(D_support) x_support = [d[0].split(' ') for d in D_support] y_support = [d[1].split(' ') for d in D_support] x_query = [d[0].split(' ') for d in D_query] y_query = [d[1].split(' ') for d in D_query] return build_sample(x_support, y_support, x_query, y_query, input_lang, output_lang, '')
def scan_evaluation_val_support(mytype, split, input_lang, output_lang, samples_val): # Use the pre-generated in the validation episodes as the support set. # Replace the validation episodes' query sets as the rest of the SCAN split (e.g., the entire length test set) # # Input # mytype : type of SCAN experiment # split : 'train' or 'test' # ... other inputs are language objects # samples_val : list of pre-generated validation episodes D_query = ge.load_scan_file( mytype, split) # e.g., we can load in the entire "length" test set x_query = [d[0].split(' ') for d in D_query] y_query = [d[1].split(' ') for d in D_query] for idx in range(len(samples_val)): samples = samples_val[idx] samples_val[idx] = build_sample(samples['xs'], samples['ys'], deepcopy(x_query), deepcopy(y_query), input_lang, output_lang, '') return samples_val
def scan_evaluation_prim_only(mytype, split, input_lang, output_lang): # Load an entire SCAN split as the query set. # Use the isolated primitives as the support set # # Input # mytype : type of SCAN experiment # split : 'train' or 'test' # ... other inputs are language objects D_query = ge.load_scan_file(mytype, split) _, _, D_primitive = ge.sample_augment_scan(0, 0, [], shuffle=False, inc_support_in_query=False) D_support = D_primitive # support set only includes the primitive mappings... random.shuffle(D_support) x_support = [d[0].split(' ') for d in D_support] y_support = [d[1].split(' ') for d in D_support] x_query = [d[0].split(' ') for d in D_query] y_query = [d[1].split(' ') for d in D_query] return build_sample(x_support, y_support, x_query, y_query, input_lang, output_lang, '')
def get_episode_generator(episode_type): # Returns function that generates episodes, # and language class for the input and output language # # Input # episode_type : string specifying type of episode # # Output # generate_episode: function handle for generating episodes # input_lang: Language object for input sequence # output_lang: Language object for output sequence input_symbols_list_default = [ 'dax', 'lug', 'wif', 'zup', 'fep', 'blicket', 'kiki', 'tufa', 'gazzer' ] output_symbols_list_default = [ 'RED', 'YELLOW', 'GREEN', 'BLUE', 'PURPLE', 'PINK' ] input_lang = Lang(input_symbols_list_default) output_lang = Lang(output_symbols_list_default) if episode_type == 'ME': # NeurIPS Exp 1 : Mutual exclusivity input_lang = Lang(input_symbols_list_default[:4]) output_lang = Lang(output_symbols_list_default[:4]) generate_episode_train = lambda tabu_episodes: generate_ME( nquery=20, nprims=len(input_lang.symbols), input_lang=input_lang, output_lang=output_lang, tabu_list=tabu_episodes) generate_episode_test = generate_episode_train elif episode_type == 'scan_prim_permutation': # NeurIPS Exp 2 : Adding a new primitive through permutation meta-training scan_all = ge.load_scan_file('all', 'train') scan_all_var = ge.load_scan_var('all', 'train') input_symbols_scan = get_unique_words([c[0] for c in scan_all]) output_symbols_scan = get_unique_words([c[1] for c in scan_all]) input_lang = Lang(input_symbols_scan) output_lang = Lang(output_symbols_scan) generate_episode_train = lambda tabu_episodes: generate_prim_permutation( shuffle=True, nsupport=20, nquery=20, input_lang=input_lang, output_lang=output_lang, scan_var_tuples=scan_all_var, nextra=0, tabu_list=tabu_episodes) generate_episode_test = lambda tabu_episodes: generate_prim_permutation( shuffle=False, nsupport=20, nquery=20, input_lang=input_lang, output_lang=output_lang, scan_var_tuples=scan_all_var, nextra=0, tabu_list=tabu_episodes) elif episode_type == 'scan_prim_augmentation': # NeurIPS Exp 3 : Adding a new primitive through augmentation meta-training nextra_prims = 20 scan_all = ge.load_scan_file('all', 'train') scan_all_var = ge.load_scan_var('all', 'train') input_symbols_scan = get_unique_words( [c[0] for c in scan_all] + [str(i) for i in range(1, nextra_prims + 1)]) output_symbols_scan = get_unique_words( [c[1] for c in scan_all] + ['I_' + str(i) for i in range(1, nextra_prims + 1)]) input_lang = Lang(input_symbols_scan) output_lang = Lang(output_symbols_scan) generate_episode_train = lambda tabu_episodes: generate_prim_augmentation( shuffle=True, nextra=nextra_prims, nsupport=20, nquery=20, input_lang=input_lang, output_lang=output_lang, scan_var_tuples=scan_all_var, tabu_list=tabu_episodes) generate_episode_test = lambda tabu_episodes: generate_prim_augmentation( shuffle=False, nextra=0, nsupport=20, nquery=20, input_lang=input_lang, output_lang=output_lang, scan_var_tuples=scan_all_var, tabu_list=tabu_episodes) elif episode_type == 'scan_around_right': # NeurIPS Exp 4 : Combining familiar concepts through meta-training nextra_prims = 2 scan_all = ge.load_scan_file('all', 'train') scan_all_var = ge.load_scan_dir_var('all', 'train') input_symbols_scan = get_unique_words( [c[0] for c in scan_all] + [str(i) for i in range(1, nextra_prims + 1)]) output_symbols_scan = get_unique_words( [c[1] for c in scan_all] + ['I_' + str(i) for i in range(1, nextra_prims + 1)]) input_lang = Lang(input_symbols_scan) output_lang = Lang(output_symbols_scan) generate_episode_train = lambda tabu_episodes: generate_right_augmentation( shuffle=True, nextra=nextra_prims, nsupport=20, nquery=20, input_lang=input_lang, output_lang=output_lang, scan_var_tuples=scan_all_var, tabu_list=tabu_episodes) generate_episode_test = lambda tabu_episodes: generate_right_augmentation( shuffle=False, nextra=0, nsupport=20, nquery=20, input_lang=input_lang, output_lang=output_lang, scan_var_tuples=scan_all_var, tabu_list=tabu_episodes) elif episode_type == 'scan_length': # NeurIPS Exp 5 : Generalizing to longer instructions through meta-training nextra_prims = 20 # number of additional primitives to augment the episodes with support_threshold = 12 # items with action length less than this belong in the support, # and greater than or equal to this length belong in the query scan_length_train = ge.load_scan_file('length', 'train') scan_length_test = ge.load_scan_file('length', 'test') scan_all = scan_length_train + scan_length_test scan_length_train_var = ge.load_scan_var('length', 'train') scan_length_test_var = ge.load_scan_var('length', 'test') input_symbols_scan = get_unique_words( [c[0] for c in scan_all] + [str(i) for i in range(1, nextra_prims + 1)]) output_symbols_scan = get_unique_words( [c[1] for c in scan_all] + ['I_' + str(i) for i in range(1, nextra_prims + 1)]) input_lang = Lang(input_symbols_scan) output_lang = Lang(output_symbols_scan) scan_length_support_var = [ pair for pair in scan_length_train_var if len(pair[1].split(' ')) < support_threshold ] # partition based on number of output actions scan_length_query_var = [ pair for pair in scan_length_train_var if len(pair[1].split(' ')) >= support_threshold ] # long sequences generate_episode_train = lambda tabu_episodes: generate_length( shuffle=True, nextra=nextra_prims, nsupport=100, nquery=20, input_lang=input_lang, output_lang=output_lang, scan_tuples_support_variable=scan_length_support_var, scan_tuples_query_variable=scan_length_query_var, tabu_list=tabu_episodes) generate_episode_test = lambda tabu_episodes: generate_length( shuffle=False, nextra=0, nsupport=100, nquery=20, input_lang=input_lang, output_lang=output_lang, scan_tuples_support_variable=scan_length_train_var, scan_tuples_query_variable=scan_length_test_var, tabu_list=tabu_episodes) else: raise Exception("episode_type is not valid") return generate_episode_train, generate_episode_test, input_lang, output_lang
def get_episode_generator(episode_type, model_in_lang=None, model_out_lang=None, model_prog_lang=None): # Returns function that generates episodes, # and language class for the input and output language # # Input # episode_type : # # Output # generate_episode: function handle for generating episodes # input_lang: Language object for input sequence # output_lang: Language object for output sequence input_symbols_list_default = ['dax', 'lug', 'fep', 'blicket', 'kiki', 'tufa','gazzer', 'zup', 'wif'] #changed order for sorting output_symbols_list_default = ['RED', 'YELLOW', 'GREEN', 'BLUE', 'PURPLE', 'PINK', 'BLACK', 'WHITE'] input_lang = Lang(input_symbols_list_default) output_lang = Lang(output_symbols_list_default) prog_symbols_list = input_symbols_list_default + output_symbols_list_default[:6] + ['->', '\n', 'x1', 'u1', '[x1]', '[u1]', 'x2', '[x2]', 'u2', '[u2]'] #TODO prog_lang = Lang(prog_symbols_list) if episode_type == 'rules_gen': input_lang = Lang(input_symbols_list_default + ['mup', 'dox', 'kleek'] ) #default has 9 symbols #output_lang defaults to 8 symbols, that works prog_lang = Lang (input_lang.symbols + output_lang.symbols + ['->', '\n', 'x1', 'u1', '[x1]', '[u1]', 'x2', '[x2]', 'u2', '[u2]']) #what does it do to have unused query items? def generate_episode_train(tabu_episodes): nprims = random.choice((3,4)) nsupp = random.choice(range(10,21)) nrules = random.choice((2,3,4)) return generate_rules_episode(nsupport=nsupp,nquery=10,nprims=nprims,nrules=nrules,input_lang=input_lang,output_lang=output_lang, prog_lang=prog_lang, tabu_list=tabu_episodes) generate_episode_test = generate_episode_train elif 'rules_sup_' in episode_type: nSupp = int(episode_type.split('_')[-1]) input_lang = Lang(input_symbols_list_default + ['mup', 'dox', 'kleek'] ) #default has 9 symbols #output_lang defaults to 8 symbols, that works prog_lang = Lang (input_lang.symbols + output_lang.symbols + ['->', '\n', 'x1', 'u1', '[x1]', '[u1]', 'x2', '[x2]', 'u2', '[u2]']) #what does it do to have unused query items? def generate_episode_train(tabu_episodes): nprims = random.choice((3,4)) nsupp = nSupp nrules = random.choice((2,3,4)) return generate_rules_episode(nsupport=nsupp,nquery=10,nprims=nprims,nrules=nrules,input_lang=input_lang,output_lang=output_lang, prog_lang=prog_lang, tabu_list=tabu_episodes) generate_episode_test = generate_episode_train elif 'rules_horules_' in episode_type: nHO = int(episode_type.split('_')[-1]) input_lang = Lang(input_symbols_list_default + ['mup', 'dox', 'kleek'] ) #default has 9 symbols #output_lang defaults to 8 symbols, that works prog_lang = Lang (input_lang.symbols + output_lang.symbols + ['->', '\n', 'x1', 'u1', '[x1]', '[u1]', 'x2', '[x2]', 'u2', '[u2]']) #what does it do to have unused query items? def generate_episode_train(tabu_episodes): nprims = random.choice((3,4)) nsupp = 30 #random.choice(range(10,21)) nrules = nHO return generate_rules_episode(nsupport=nsupp,nquery=10,nprims=nprims,nrules=nrules,input_lang=input_lang,output_lang=output_lang, prog_lang=prog_lang, tabu_list=tabu_episodes) generate_episode_test = generate_episode_train elif 'rules_prims_' in episode_type: nPrims = int(episode_type.split('_')[-1]) input_lang = Lang(input_symbols_list_default + ['mup', 'dox', 'kleek'] ) #default has 9 symbols #output_lang defaults to 8 symbols, that works prog_lang = Lang (input_lang.symbols + output_lang.symbols + ['->', '\n', 'x1', 'u1', '[x1]', '[u1]', 'x2', '[x2]', 'u2', '[u2]']) #what does it do to have unused query items? def generate_episode_train(tabu_episodes): nprims = nPrims nsupp = 30 #random.choice(range(10,21)) nrules = random.choice((2,3,4)) return generate_rules_episode(nsupport=nsupp,nquery=10,nprims=nprims,nrules=nrules,input_lang=input_lang,output_lang=output_lang, prog_lang=prog_lang, tabu_list=tabu_episodes) generate_episode_test = generate_episode_train elif 'lang_' in episode_type: lang = episode_type.split('_')[-1] from number_generate_model import generate_lang_test_episode from number_word_interpret_grammar import RHS_DICT tokens = ['token'+format(i, '02d') for i in range(1, 52)] input_lang = Lang(tokens) output_lang = Lang([str(i) for i in range(10)] ) prog_symbols = ['1000000*','10000*', '1000*', '100*', '10*', '[x1]*10', '[x1]*100', '[x1]*1000', '[x1]*10000', '[x1]*1000000', '[x1]', '[u1]','[y1]', 'x1', 'u1', 'y1', '->', '\n'] + [str(i) for i in range(10)] prog_lang = Lang(prog_symbols+input_lang.symbols) nsupp = 25 nquery = 100 def generate_episode_train(tabu_examples): return generate_lang_test_episode(nsupp, nquery, input_lang, output_lang, prog_lang, tabu_examples, lang=lang) generate_episode_test = generate_episode_train elif episode_type == 'wordToNumber': from number_generate_model import generate_wordToNumber_episode from number_word_interpret_grammar import RHS_DICT tokens = ['token'+format(i, '02d') for i in range(1, 52)] input_lang = Lang(tokens) output_lang = Lang([str(i) for i in range(10)] ) prog_symbols = ['1000000*','10000*', '1000*', '100*', '10*', '[x1]*10', '[x1]*100', '[x1]*1000', '[x1]*10000', '[x1]*1000000', '[x1]', '[u1]','[y1]', 'x1', 'u1', 'y1', '->', '\n'] + [str(i) for i in range(10)] prog_lang = Lang(prog_symbols+input_lang.symbols) def generate_episode_train(tabu_examples): nsupp = random.choice(range(60,101)) #should vary this ... nquery = 10 return generate_wordToNumber_episode(nsupp, nquery, input_lang, output_lang, prog_lang, tabu_examples) generate_episode_test = generate_episode_train elif episode_type == 'scan_random': words = ['walk','look','run','jump','turn','left','right','opposite','around','twice','thrice','and','after'] + ['dax', 'blicket', 'lug', 'kiki'] cmds = ['WALK','LOOK','RUN','JUMP','LTURN','RTURN'] + ['RED', 'BLUE', 'GREEN'] input_lang = Lang( words ) output_lang = Lang( cmds ) prog_lang = Lang( words+cmds+ ['->', '\n', 'x1', 'u1', '[x1]', '[u1]', 'x2', '[x2]', 'u2', '[u2]', ""]) #, '[', ']'] ) #"" tp = 'random' def generate_episode_train(tabu_episodes): nprims = random.choice(range(4,9)) nsupp = random.choice(range(30,51)) nurules = 0 nxrules = random.choice((3,4,5,6,7)) return generate_scan_episode(nsupport=nsupp, nquery=10, nprims=nprims, nurules=nurules, nxrules=nxrules, input_lang=input_lang, output_lang=output_lang, prog_lang=prog_lang, tabu_list=tabu_episodes, u_type=tp) generate_episode_test = generate_episode_train elif episode_type in ['scan_simple_original', 'scan_jump_original', 'scan_around_right_original', 'scan_length_original']: dic = {'scan_simple_original':'simple', 'scan_jump_original': 'addprim_jump', 'scan_around_right_original':'template_around_right', 'scan_length_original': 'length' } scan_train = ge.load_scan_file( dic[episode_type],'train') scan_test = ge.load_scan_file( dic[episode_type],'test') #assert 0, "deal with langs" # input_symbols_scan = get_unique_words([c[0] for c in scan_train+scan_test]) # output_symbols_scan = get_unique_words([c[1] for c in scan_train+scan_test]) # input_lang = Lang(input_symbols_scan) # output_lang = Lang(output_symbols_scan) words = ['walk','look','run','jump','turn','left','right','opposite','around','twice','thrice','and','after'] + ['dax', 'blicket', 'lug', 'kiki'] cmds = ['I_WALK','I_LOOK','I_RUN','I_JUMP','I_TURN_LEFT','I_TURN_RIGHT'] + ['RED', 'BLUE', 'GREEN'] #assert set(words) == set(get_unique_words([c[0] for c in scan_train+scan_test])) #assert set(cmds) == set(get_unique_words([c[1] for c in scan_train+scan_test])) print("WARNING: vocab includes extra words, so beware") input_lang = Lang( words) output_lang = Lang( cmds ) prog_lang = Lang( words+cmds+ ['->', '\n', 'x1', 'u1', '[x1]', '[u1]', 'x2', '[x2]', 'u2', '[u2]', ""])#, '[', ']'] ) #"" generate_episode_train = lambda tabu_episodes : generate_traditional_synth_scan_episode( nsupport=100, nquery=500, input_lang=input_lang, output_lang=output_lang, train_tuples=scan_train, test_tuples=scan_test, tabu_list=tabu_episodes) generate_episode_test = lambda tabu_episodes : generate_traditional_synth_scan_episode( nsupport=100, nquery=500, input_lang=input_lang, output_lang=output_lang, train_tuples=scan_train, test_tuples=scan_test, tabu_list=tabu_episodes) else: raise Exception("episode_type is not valid" ) return generate_episode_train, generate_episode_test, input_lang, output_lang, prog_lang
parser.add_argument('--eval-interval', type=int, default=1000, help='evaluate model at this rate') parser.add_argument('--model-path', type=str, default='', help='model checkpoint to start from (if any)') args = parser.parse_args() print(args) # Set the random seed manually for better reproducibility torch.manual_seed(args.seed) scan_all = ge.load_scan_file('all', 'train') scan_all_var = ge.load_scan_var('all', 'train') input_symbols_scan = get_unique_words([c[0] for c in scan_all]) output_symbols_scan = get_unique_words([c[1] for c in scan_all]) all_symbols_scan = input_symbols_scan + output_symbols_scan all_lang = Lang(all_symbols_scan) ntoken = all_lang.n_symbols # set up transformer encoder-decoder model, loss, optimizer model = TransformerModel(ntoken=ntoken, emsize=args.emsize, nhead=args.nhead, nhid=args.nhid, nlayers=args.nlayers,