def generate_prim_augmentation(shuffle,nsupport,nquery,input_lang,output_lang,scan_var_tuples,nextra,tabu_list=[]): # Generate a SCAN episode with primitive augmentation, # The tabu list identifier is only determined based on the assignment of the "jump" primitive # # Input # shuffle: permute how the input primitives map to the output actions? (true/false) # scan_var_tuples : scan input/output patterns with placeholder replacement # nextra: number of abstract input/output primitives to add to the set of possibilities # special_prim = 'jump' count = 0 while True: D_support, D_query, D_primitive = ge.sample_augment_scan(nsupport,nquery,scan_var_tuples,shuffle,nextra,inc_support_in_query=use_resconstruct_loss) input_prim_list = [s[0] for s in D_primitive] try: index_prim = input_prim_list.index(special_prim) D_str = D_primitive[index_prim][0] + ' -> ' + D_primitive[index_prim][1] except ValueError: D_str = 'no jump' identifier = D_str if not shuffle: # ignore tabu list if we aren't shuffling primitive assignments break if identifier not in tabu_list: break count += 1 if count > max_try_novel: raise Exception('We were unable to generate an episode that is not on the tabu list') x_support = [d[0].split(' ') for d in D_support] y_support = [d[1].split(' ') for d in D_support] x_query = [d[0].split(' ') for d in D_query] y_query = [d[1].split(' ') for d in D_query] return build_sample(x_support,y_support,x_query,y_query,input_lang,output_lang,identifier)
def generate_prim_permutation(shuffle,nsupport,nquery,input_lang,output_lang,scan_var_tuples,nextra,tabu_list=[]): # Generate a SCAN episode with primitive permutation. # The tabu list identifier is based on the permutation of primitive inputs to primitive actions. # # Input # shuffle: permute how the input primitives map to the output actions? (true/false) # scan_var_tuples : scan input/output sequences with placeholder replacement # nextra: number of abstract input/output primitives to add to the set of possibilities # count = 0 while True: D_support, D_query, D_primitive = ge.sample_augment_scan(nsupport,nquery,scan_var_tuples,shuffle,nextra,inc_support_in_query=use_resconstruct_loss) D_str = '\n'.join([s[0] + ' -> ' + s[1] for s in D_primitive]) identifier = make_hashable(D_str) if not shuffle: # ignore tabu list if we aren't shuffling primitive assignments break if identifier not in tabu_list: break count += 1 if count > max_try_novel: raise Exception('We were unable to generate an episode that is not on the tabu list') x_support = [d[0].split(' ') for d in D_support] y_support = [d[1].split(' ') for d in D_support] x_query = [d[0].split(' ') for d in D_query] y_query = [d[1].split(' ') for d in D_query] return build_sample(x_support,y_support,x_query,y_query,input_lang,output_lang,identifier)
def scan_evaluation_prim_only(mytype, split, input_lang, output_lang): # Load an entire SCAN split as the query set. # Use the isolated primitives as the support set # # Input # mytype : type of SCAN experiment # split : 'train' or 'test' # ... other inputs are language objects D_query = ge.load_scan_file(mytype, split) _, _, D_primitive = ge.sample_augment_scan(0, 0, [], shuffle=False, inc_support_in_query=False) D_support = D_primitive # support set only includes the primitive mappings... random.shuffle(D_support) x_support = [d[0].split(' ') for d in D_support] y_support = [d[1].split(' ') for d in D_support] x_query = [d[0].split(' ') for d in D_query] y_query = [d[1].split(' ') for d in D_query] return build_sample(x_support, y_support, x_query, y_query, input_lang, output_lang, '')