def get_episode_generator(episode_type): # Returns function that generates episodes, # and language class for the input and output language # # Input # episode_type : string specifying type of episode # # Output # generate_episode: function handle for generating episodes # input_lang: Language object for input sequence # output_lang: Language object for output sequence input_symbols_list_default = [ 'dax', 'lug', 'wif', 'zup', 'fep', 'blicket', 'kiki', 'tufa', 'gazzer' ] output_symbols_list_default = [ 'RED', 'YELLOW', 'GREEN', 'BLUE', 'PURPLE', 'PINK' ] input_lang = Lang(input_symbols_list_default) output_lang = Lang(output_symbols_list_default) if episode_type == 'ME': # NeurIPS Exp 1 : Mutual exclusivity input_lang = Lang(input_symbols_list_default[:4]) output_lang = Lang(output_symbols_list_default[:4]) generate_episode_train = lambda tabu_episodes: generate_ME( nquery=20, nprims=len(input_lang.symbols), input_lang=input_lang, output_lang=output_lang, tabu_list=tabu_episodes) generate_episode_test = generate_episode_train elif episode_type == 'scan_prim_permutation': # NeurIPS Exp 2 : Adding a new primitive through permutation meta-training scan_all = ge.load_scan_file('all', 'train') scan_all_var = ge.load_scan_var('all', 'train') input_symbols_scan = get_unique_words([c[0] for c in scan_all]) output_symbols_scan = get_unique_words([c[1] for c in scan_all]) input_lang = Lang(input_symbols_scan) output_lang = Lang(output_symbols_scan) generate_episode_train = lambda tabu_episodes: generate_prim_permutation( shuffle=True, nsupport=20, nquery=20, input_lang=input_lang, output_lang=output_lang, scan_var_tuples=scan_all_var, nextra=0, tabu_list=tabu_episodes) generate_episode_test = lambda tabu_episodes: generate_prim_permutation( shuffle=False, nsupport=20, nquery=20, input_lang=input_lang, output_lang=output_lang, scan_var_tuples=scan_all_var, nextra=0, tabu_list=tabu_episodes) elif episode_type == 'scan_prim_augmentation': # NeurIPS Exp 3 : Adding a new primitive through augmentation meta-training nextra_prims = 20 scan_all = ge.load_scan_file('all', 'train') scan_all_var = ge.load_scan_var('all', 'train') input_symbols_scan = get_unique_words( [c[0] for c in scan_all] + [str(i) for i in range(1, nextra_prims + 1)]) output_symbols_scan = get_unique_words( [c[1] for c in scan_all] + ['I_' + str(i) for i in range(1, nextra_prims + 1)]) input_lang = Lang(input_symbols_scan) output_lang = Lang(output_symbols_scan) generate_episode_train = lambda tabu_episodes: generate_prim_augmentation( shuffle=True, nextra=nextra_prims, nsupport=20, nquery=20, input_lang=input_lang, output_lang=output_lang, scan_var_tuples=scan_all_var, tabu_list=tabu_episodes) generate_episode_test = lambda tabu_episodes: generate_prim_augmentation( shuffle=False, nextra=0, nsupport=20, nquery=20, input_lang=input_lang, output_lang=output_lang, scan_var_tuples=scan_all_var, tabu_list=tabu_episodes) elif episode_type == 'scan_around_right': # NeurIPS Exp 4 : Combining familiar concepts through meta-training nextra_prims = 2 scan_all = ge.load_scan_file('all', 'train') scan_all_var = ge.load_scan_dir_var('all', 'train') input_symbols_scan = get_unique_words( [c[0] for c in scan_all] + [str(i) for i in range(1, nextra_prims + 1)]) output_symbols_scan = get_unique_words( [c[1] for c in scan_all] + ['I_' + str(i) for i in range(1, nextra_prims + 1)]) input_lang = Lang(input_symbols_scan) output_lang = Lang(output_symbols_scan) generate_episode_train = lambda tabu_episodes: generate_right_augmentation( shuffle=True, nextra=nextra_prims, nsupport=20, nquery=20, input_lang=input_lang, output_lang=output_lang, scan_var_tuples=scan_all_var, tabu_list=tabu_episodes) generate_episode_test = lambda tabu_episodes: generate_right_augmentation( shuffle=False, nextra=0, nsupport=20, nquery=20, input_lang=input_lang, output_lang=output_lang, scan_var_tuples=scan_all_var, tabu_list=tabu_episodes) elif episode_type == 'scan_length': # NeurIPS Exp 5 : Generalizing to longer instructions through meta-training nextra_prims = 20 # number of additional primitives to augment the episodes with support_threshold = 12 # items with action length less than this belong in the support, # and greater than or equal to this length belong in the query scan_length_train = ge.load_scan_file('length', 'train') scan_length_test = ge.load_scan_file('length', 'test') scan_all = scan_length_train + scan_length_test scan_length_train_var = ge.load_scan_var('length', 'train') scan_length_test_var = ge.load_scan_var('length', 'test') input_symbols_scan = get_unique_words( [c[0] for c in scan_all] + [str(i) for i in range(1, nextra_prims + 1)]) output_symbols_scan = get_unique_words( [c[1] for c in scan_all] + ['I_' + str(i) for i in range(1, nextra_prims + 1)]) input_lang = Lang(input_symbols_scan) output_lang = Lang(output_symbols_scan) scan_length_support_var = [ pair for pair in scan_length_train_var if len(pair[1].split(' ')) < support_threshold ] # partition based on number of output actions scan_length_query_var = [ pair for pair in scan_length_train_var if len(pair[1].split(' ')) >= support_threshold ] # long sequences generate_episode_train = lambda tabu_episodes: generate_length( shuffle=True, nextra=nextra_prims, nsupport=100, nquery=20, input_lang=input_lang, output_lang=output_lang, scan_tuples_support_variable=scan_length_support_var, scan_tuples_query_variable=scan_length_query_var, tabu_list=tabu_episodes) generate_episode_test = lambda tabu_episodes: generate_length( shuffle=False, nextra=0, nsupport=100, nquery=20, input_lang=input_lang, output_lang=output_lang, scan_tuples_support_variable=scan_length_train_var, scan_tuples_query_variable=scan_length_test_var, tabu_list=tabu_episodes) else: raise Exception("episode_type is not valid") return generate_episode_train, generate_episode_test, input_lang, output_lang
type=int, default=1000, help='evaluate model at this rate') parser.add_argument('--model-path', type=str, default='', help='model checkpoint to start from (if any)') args = parser.parse_args() print(args) # Set the random seed manually for better reproducibility torch.manual_seed(args.seed) scan_all = ge.load_scan_file('all', 'train') scan_all_var = ge.load_scan_var('all', 'train') input_symbols_scan = get_unique_words([c[0] for c in scan_all]) output_symbols_scan = get_unique_words([c[1] for c in scan_all]) all_symbols_scan = input_symbols_scan + output_symbols_scan all_lang = Lang(all_symbols_scan) ntoken = all_lang.n_symbols # set up transformer encoder-decoder model, loss, optimizer model = TransformerModel(ntoken=ntoken, emsize=args.emsize, nhead=args.nhead, nhid=args.nhid, nlayers=args.nlayers, dropout=args.dropout)