예제 #1
0
def generate_prim_augmentation(shuffle,nsupport,nquery,input_lang,output_lang,scan_var_tuples,nextra,tabu_list=[]):
    # Generate a SCAN episode with primitive augmentation,
    #  The tabu list identifier is only determined based on the assignment of the "jump" primitive 
    #
    # Input
    #  shuffle: permute how the input primitives map to the output actions? (true/false)
    #  scan_var_tuples : scan input/output patterns with placeholder replacement
    #  nextra: number of abstract input/output primitives to add to the set of possibilities
    #
    special_prim = 'jump'
    count = 0
    while True:
        D_support, D_query, D_primitive = ge.sample_augment_scan(nsupport,nquery,scan_var_tuples,shuffle,nextra,inc_support_in_query=use_resconstruct_loss)
        input_prim_list = [s[0] for s in D_primitive]
        try:
            index_prim = input_prim_list.index(special_prim)
            D_str = D_primitive[index_prim][0] + ' -> ' + D_primitive[index_prim][1]
        except ValueError:
            D_str = 'no jump'    
        identifier = D_str
        if not shuffle: # ignore tabu list if we aren't shuffling primitive assignments
            break
        if identifier not in tabu_list:
            break
        count += 1
        if count > max_try_novel:
            raise Exception('We were unable to generate an episode that is not on the tabu list')
    x_support = [d[0].split(' ') for d in D_support]
    y_support = [d[1].split(' ') for d in D_support]
    x_query = [d[0].split(' ') for d in D_query]
    y_query = [d[1].split(' ') for d in D_query]
    return build_sample(x_support,y_support,x_query,y_query,input_lang,output_lang,identifier)
예제 #2
0
def generate_prim_permutation(shuffle,nsupport,nquery,input_lang,output_lang,scan_var_tuples,nextra,tabu_list=[]):
    # Generate a SCAN episode with primitive permutation.
    #  The tabu list identifier is based on the permutation of primitive inputs to primitive actions.
    #
    # Input
    #  shuffle: permute how the input primitives map to the output actions? (true/false)
    #  scan_var_tuples : scan input/output sequences with placeholder replacement
    #  nextra: number of abstract input/output primitives to add to the set of possibilities
    #
    count = 0
    while True:
        D_support, D_query, D_primitive = ge.sample_augment_scan(nsupport,nquery,scan_var_tuples,shuffle,nextra,inc_support_in_query=use_resconstruct_loss)
        D_str = '\n'.join([s[0] + ' -> ' + s[1] for s in D_primitive])
        identifier = make_hashable(D_str)
        if not shuffle: # ignore tabu list if we aren't shuffling primitive assignments
            break
        if identifier not in tabu_list:
            break
        count += 1
        if count > max_try_novel:
            raise Exception('We were unable to generate an episode that is not on the tabu list')
    x_support = [d[0].split(' ') for d in D_support]
    y_support = [d[1].split(' ') for d in D_support]
    x_query = [d[0].split(' ') for d in D_query]
    y_query = [d[1].split(' ') for d in D_query]
    return build_sample(x_support,y_support,x_query,y_query,input_lang,output_lang,identifier)    
예제 #3
0
def scan_evaluation_prim_only(mytype, split, input_lang, output_lang):
    # Load an entire SCAN split as the query set.
    #   Use the isolated primitives as the support set
    #
    # Input
    #  mytype : type of SCAN experiment
    #  split : 'train' or 'test'
    #  ... other inputs are language objects
    D_query = ge.load_scan_file(mytype, split)
    _, _, D_primitive = ge.sample_augment_scan(0,
                                               0, [],
                                               shuffle=False,
                                               inc_support_in_query=False)
    D_support = D_primitive  # support set only includes the primitive mappings...
    random.shuffle(D_support)
    x_support = [d[0].split(' ') for d in D_support]
    y_support = [d[1].split(' ') for d in D_support]
    x_query = [d[0].split(' ') for d in D_query]
    y_query = [d[1].split(' ') for d in D_query]
    return build_sample(x_support, y_support, x_query, y_query, input_lang,
                        output_lang, '')