def scan_evaluation_dir_only(mytype, split, input_lang, output_lang): # Load an entire SCAN pattern file as the query set # Just use the isolated directions as the support set # # Input # mytype : type of SCAN experiment # split : 'train' or 'test' # ... other inputs are language objects D_query = ge.load_scan_file(mytype, split) D_support = [('turn left', 'I_TURN_LEFT'), ('turn right', 'I_TURN_RIGHT')] random.shuffle(D_support) x_support = [d[0].split(' ') for d in D_support] y_support = [d[1].split(' ') for d in D_support] x_query = [d[0].split(' ') for d in D_query] y_query = [d[1].split(' ') for d in D_query] return build_sample(x_support, y_support, x_query, y_query, input_lang, output_lang, '')
def scan_evaluation_val_support(mytype, split, input_lang, output_lang, samples_val): # Use the pre-generated in the validation episodes as the support set. # Replace the validation episodes' query sets as the rest of the SCAN split (e.g., the entire length test set) # # Input # mytype : type of SCAN experiment # split : 'train' or 'test' # ... other inputs are language objects # samples_val : list of pre-generated validation episodes D_query = ge.load_scan_file( mytype, split) # e.g., we can load in the entire "length" test set x_query = [d[0].split(' ') for d in D_query] y_query = [d[1].split(' ') for d in D_query] for idx in range(len(samples_val)): samples = samples_val[idx] samples_val[idx] = build_sample(samples['xs'], samples['ys'], deepcopy(x_query), deepcopy(y_query), input_lang, output_lang, '') return samples_val
def scan_evaluation_prim_only(mytype, split, input_lang, output_lang): # Load an entire SCAN split as the query set. # Use the isolated primitives as the support set # # Input # mytype : type of SCAN experiment # split : 'train' or 'test' # ... other inputs are language objects D_query = ge.load_scan_file(mytype, split) _, _, D_primitive = ge.sample_augment_scan(0, 0, [], shuffle=False, inc_support_in_query=False) D_support = D_primitive # support set only includes the primitive mappings... random.shuffle(D_support) x_support = [d[0].split(' ') for d in D_support] y_support = [d[1].split(' ') for d in D_support] x_query = [d[0].split(' ') for d in D_query] y_query = [d[1].split(' ') for d in D_query] return build_sample(x_support, y_support, x_query, y_query, input_lang, output_lang, '')