Exemple #1
0
def scan_evaluation_dir_only(mytype, split, input_lang, output_lang):
    # Load an entire SCAN pattern file as the query set
    #  Just use the isolated directions as the support set
    #
    # Input
    #  mytype : type of SCAN experiment
    #  split : 'train' or 'test'
    #  ... other inputs are language objects
    D_query = ge.load_scan_file(mytype, split)
    D_support = [('turn left', 'I_TURN_LEFT'), ('turn right', 'I_TURN_RIGHT')]
    random.shuffle(D_support)
    x_support = [d[0].split(' ') for d in D_support]
    y_support = [d[1].split(' ') for d in D_support]
    x_query = [d[0].split(' ') for d in D_query]
    y_query = [d[1].split(' ') for d in D_query]
    return build_sample(x_support, y_support, x_query, y_query, input_lang,
                        output_lang, '')
Exemple #2
0
def scan_evaluation_val_support(mytype, split, input_lang, output_lang,
                                samples_val):
    # Use the pre-generated in the validation episodes as the support set.
    #  Replace the validation episodes' query sets as the rest of the SCAN split (e.g., the entire length test set)
    #
    # Input
    #  mytype : type of SCAN experiment
    #  split : 'train' or 'test'
    #  ... other inputs are language objects
    #  samples_val : list of pre-generated validation episodes
    D_query = ge.load_scan_file(
        mytype, split)  # e.g., we can load in the entire "length" test set
    x_query = [d[0].split(' ') for d in D_query]
    y_query = [d[1].split(' ') for d in D_query]
    for idx in range(len(samples_val)):
        samples = samples_val[idx]
        samples_val[idx] = build_sample(samples['xs'], samples['ys'],
                                        deepcopy(x_query), deepcopy(y_query),
                                        input_lang, output_lang, '')
    return samples_val
Exemple #3
0
def scan_evaluation_prim_only(mytype, split, input_lang, output_lang):
    # Load an entire SCAN split as the query set.
    #   Use the isolated primitives as the support set
    #
    # Input
    #  mytype : type of SCAN experiment
    #  split : 'train' or 'test'
    #  ... other inputs are language objects
    D_query = ge.load_scan_file(mytype, split)
    _, _, D_primitive = ge.sample_augment_scan(0,
                                               0, [],
                                               shuffle=False,
                                               inc_support_in_query=False)
    D_support = D_primitive  # support set only includes the primitive mappings...
    random.shuffle(D_support)
    x_support = [d[0].split(' ') for d in D_support]
    y_support = [d[1].split(' ') for d in D_support]
    x_query = [d[0].split(' ') for d in D_query]
    y_query = [d[1].split(' ') for d in D_query]
    return build_sample(x_support, y_support, x_query, y_query, input_lang,
                        output_lang, '')
Exemple #4
0
def get_episode_generator(episode_type):
    #  Returns function that generates episodes,
    #   and language class for the input and output language
    #
    # Input
    #  episode_type : string specifying type of episode
    #
    # Output
    #  generate_episode: function handle for generating episodes
    #  input_lang: Language object for input sequence
    #  output_lang: Language object for output sequence
    input_symbols_list_default = [
        'dax', 'lug', 'wif', 'zup', 'fep', 'blicket', 'kiki', 'tufa', 'gazzer'
    ]
    output_symbols_list_default = [
        'RED', 'YELLOW', 'GREEN', 'BLUE', 'PURPLE', 'PINK'
    ]
    input_lang = Lang(input_symbols_list_default)
    output_lang = Lang(output_symbols_list_default)
    if episode_type == 'ME':  # NeurIPS Exp 1 : Mutual exclusivity
        input_lang = Lang(input_symbols_list_default[:4])
        output_lang = Lang(output_symbols_list_default[:4])
        generate_episode_train = lambda tabu_episodes: generate_ME(
            nquery=20,
            nprims=len(input_lang.symbols),
            input_lang=input_lang,
            output_lang=output_lang,
            tabu_list=tabu_episodes)
        generate_episode_test = generate_episode_train
    elif episode_type == 'scan_prim_permutation':  # NeurIPS Exp 2 : Adding a new primitive through permutation meta-training
        scan_all = ge.load_scan_file('all', 'train')
        scan_all_var = ge.load_scan_var('all', 'train')
        input_symbols_scan = get_unique_words([c[0] for c in scan_all])
        output_symbols_scan = get_unique_words([c[1] for c in scan_all])
        input_lang = Lang(input_symbols_scan)
        output_lang = Lang(output_symbols_scan)
        generate_episode_train = lambda tabu_episodes: generate_prim_permutation(
            shuffle=True,
            nsupport=20,
            nquery=20,
            input_lang=input_lang,
            output_lang=output_lang,
            scan_var_tuples=scan_all_var,
            nextra=0,
            tabu_list=tabu_episodes)
        generate_episode_test = lambda tabu_episodes: generate_prim_permutation(
            shuffle=False,
            nsupport=20,
            nquery=20,
            input_lang=input_lang,
            output_lang=output_lang,
            scan_var_tuples=scan_all_var,
            nextra=0,
            tabu_list=tabu_episodes)
    elif episode_type == 'scan_prim_augmentation':  # NeurIPS Exp 3 : Adding a new primitive through augmentation meta-training
        nextra_prims = 20
        scan_all = ge.load_scan_file('all', 'train')
        scan_all_var = ge.load_scan_var('all', 'train')
        input_symbols_scan = get_unique_words(
            [c[0]
             for c in scan_all] + [str(i) for i in range(1, nextra_prims + 1)])
        output_symbols_scan = get_unique_words(
            [c[1] for c in scan_all] +
            ['I_' + str(i) for i in range(1, nextra_prims + 1)])
        input_lang = Lang(input_symbols_scan)
        output_lang = Lang(output_symbols_scan)
        generate_episode_train = lambda tabu_episodes: generate_prim_augmentation(
            shuffle=True,
            nextra=nextra_prims,
            nsupport=20,
            nquery=20,
            input_lang=input_lang,
            output_lang=output_lang,
            scan_var_tuples=scan_all_var,
            tabu_list=tabu_episodes)
        generate_episode_test = lambda tabu_episodes: generate_prim_augmentation(
            shuffle=False,
            nextra=0,
            nsupport=20,
            nquery=20,
            input_lang=input_lang,
            output_lang=output_lang,
            scan_var_tuples=scan_all_var,
            tabu_list=tabu_episodes)
    elif episode_type == 'scan_around_right':  # NeurIPS Exp 4 : Combining familiar concepts through meta-training
        nextra_prims = 2
        scan_all = ge.load_scan_file('all', 'train')
        scan_all_var = ge.load_scan_dir_var('all', 'train')
        input_symbols_scan = get_unique_words(
            [c[0]
             for c in scan_all] + [str(i) for i in range(1, nextra_prims + 1)])
        output_symbols_scan = get_unique_words(
            [c[1] for c in scan_all] +
            ['I_' + str(i) for i in range(1, nextra_prims + 1)])
        input_lang = Lang(input_symbols_scan)
        output_lang = Lang(output_symbols_scan)
        generate_episode_train = lambda tabu_episodes: generate_right_augmentation(
            shuffle=True,
            nextra=nextra_prims,
            nsupport=20,
            nquery=20,
            input_lang=input_lang,
            output_lang=output_lang,
            scan_var_tuples=scan_all_var,
            tabu_list=tabu_episodes)
        generate_episode_test = lambda tabu_episodes: generate_right_augmentation(
            shuffle=False,
            nextra=0,
            nsupport=20,
            nquery=20,
            input_lang=input_lang,
            output_lang=output_lang,
            scan_var_tuples=scan_all_var,
            tabu_list=tabu_episodes)
    elif episode_type == 'scan_length':  # NeurIPS Exp 5 : Generalizing to longer instructions through meta-training
        nextra_prims = 20  # number of additional primitives to augment the episodes with
        support_threshold = 12  # items with action length less than this belong in the support,
        # and greater than or equal to this length belong in the query
        scan_length_train = ge.load_scan_file('length', 'train')
        scan_length_test = ge.load_scan_file('length', 'test')
        scan_all = scan_length_train + scan_length_test
        scan_length_train_var = ge.load_scan_var('length', 'train')
        scan_length_test_var = ge.load_scan_var('length', 'test')
        input_symbols_scan = get_unique_words(
            [c[0]
             for c in scan_all] + [str(i) for i in range(1, nextra_prims + 1)])
        output_symbols_scan = get_unique_words(
            [c[1] for c in scan_all] +
            ['I_' + str(i) for i in range(1, nextra_prims + 1)])
        input_lang = Lang(input_symbols_scan)
        output_lang = Lang(output_symbols_scan)
        scan_length_support_var = [
            pair for pair in scan_length_train_var
            if len(pair[1].split(' ')) < support_threshold
        ]  # partition based on number of output actions
        scan_length_query_var = [
            pair for pair in scan_length_train_var
            if len(pair[1].split(' ')) >= support_threshold
        ]  # long sequences
        generate_episode_train = lambda tabu_episodes: generate_length(
            shuffle=True,
            nextra=nextra_prims,
            nsupport=100,
            nquery=20,
            input_lang=input_lang,
            output_lang=output_lang,
            scan_tuples_support_variable=scan_length_support_var,
            scan_tuples_query_variable=scan_length_query_var,
            tabu_list=tabu_episodes)
        generate_episode_test = lambda tabu_episodes: generate_length(
            shuffle=False,
            nextra=0,
            nsupport=100,
            nquery=20,
            input_lang=input_lang,
            output_lang=output_lang,
            scan_tuples_support_variable=scan_length_train_var,
            scan_tuples_query_variable=scan_length_test_var,
            tabu_list=tabu_episodes)
    else:
        raise Exception("episode_type is not valid")
    return generate_episode_train, generate_episode_test, input_lang, output_lang
Exemple #5
0
def get_episode_generator(episode_type, model_in_lang=None, model_out_lang=None, model_prog_lang=None):
    #  Returns function that generates episodes, 
    #   and language class for the input and output language
    #
    # Input
    #  episode_type :
    #
    # Output
    #  generate_episode: function handle for generating episodes
    #  input_lang: Language object for input sequence
    #  output_lang: Language object for output sequence

    input_symbols_list_default = ['dax', 'lug', 'fep', 'blicket', 'kiki', 'tufa','gazzer', 'zup', 'wif'] #changed order for sorting
    output_symbols_list_default = ['RED', 'YELLOW', 'GREEN', 'BLUE', 'PURPLE', 'PINK', 'BLACK', 'WHITE']
    input_lang = Lang(input_symbols_list_default)
    output_lang = Lang(output_symbols_list_default)
    prog_symbols_list = input_symbols_list_default + output_symbols_list_default[:6] + ['->', '\n', 'x1', 'u1', '[x1]', '[u1]', 'x2', '[x2]', 'u2', '[u2]'] #TODO
    prog_lang = Lang(prog_symbols_list)

    if episode_type == 'rules_gen':
        
        input_lang = Lang(input_symbols_list_default + ['mup', 'dox', 'kleek'] ) #default has 9 symbols
        #output_lang defaults to 8 symbols, that works

        prog_lang = Lang (input_lang.symbols + output_lang.symbols + ['->', '\n', 'x1', 'u1', '[x1]', '[u1]', 'x2', '[x2]', 'u2', '[u2]']) 

        #what does it do to have unused query items?
        def generate_episode_train(tabu_episodes):
            nprims = random.choice((3,4))
            nsupp = random.choice(range(10,21))
            nrules = random.choice((2,3,4))
            return generate_rules_episode(nsupport=nsupp,nquery=10,nprims=nprims,nrules=nrules,input_lang=input_lang,output_lang=output_lang, prog_lang=prog_lang, tabu_list=tabu_episodes)      

        generate_episode_test = generate_episode_train

    elif 'rules_sup_' in episode_type:
        nSupp = int(episode_type.split('_')[-1])
        input_lang = Lang(input_symbols_list_default + ['mup', 'dox', 'kleek'] ) #default has 9 symbols
        #output_lang defaults to 8 symbols, that works
        prog_lang = Lang (input_lang.symbols + output_lang.symbols + ['->', '\n', 'x1', 'u1', '[x1]', '[u1]', 'x2', '[x2]', 'u2', '[u2]']) 
        #what does it do to have unused query items?
        def generate_episode_train(tabu_episodes):
            nprims = random.choice((3,4))
            nsupp = nSupp
            nrules = random.choice((2,3,4))
            return generate_rules_episode(nsupport=nsupp,nquery=10,nprims=nprims,nrules=nrules,input_lang=input_lang,output_lang=output_lang, prog_lang=prog_lang, tabu_list=tabu_episodes)      
        generate_episode_test = generate_episode_train

    elif 'rules_horules_' in episode_type:
        nHO = int(episode_type.split('_')[-1])
        input_lang = Lang(input_symbols_list_default + ['mup', 'dox', 'kleek'] ) #default has 9 symbols
        #output_lang defaults to 8 symbols, that works
        prog_lang = Lang (input_lang.symbols + output_lang.symbols + ['->', '\n', 'x1', 'u1', '[x1]', '[u1]', 'x2', '[x2]', 'u2', '[u2]']) 
        #what does it do to have unused query items?
        def generate_episode_train(tabu_episodes):
            nprims = random.choice((3,4))
            nsupp = 30 #random.choice(range(10,21))
            nrules = nHO
            return generate_rules_episode(nsupport=nsupp,nquery=10,nprims=nprims,nrules=nrules,input_lang=input_lang,output_lang=output_lang, prog_lang=prog_lang, tabu_list=tabu_episodes)      
        generate_episode_test = generate_episode_train

    elif 'rules_prims_' in episode_type:
        nPrims = int(episode_type.split('_')[-1])
        input_lang = Lang(input_symbols_list_default + ['mup', 'dox', 'kleek'] ) #default has 9 symbols
        #output_lang defaults to 8 symbols, that works
        prog_lang = Lang (input_lang.symbols + output_lang.symbols + ['->', '\n', 'x1', 'u1', '[x1]', '[u1]', 'x2', '[x2]', 'u2', '[u2]']) 
        #what does it do to have unused query items?
        def generate_episode_train(tabu_episodes):
            nprims = nPrims
            nsupp = 30 #random.choice(range(10,21))
            nrules = random.choice((2,3,4))
            return generate_rules_episode(nsupport=nsupp,nquery=10,nprims=nprims,nrules=nrules,input_lang=input_lang,output_lang=output_lang, prog_lang=prog_lang, tabu_list=tabu_episodes)      
        generate_episode_test = generate_episode_train

    elif 'lang_' in episode_type:
        lang = episode_type.split('_')[-1]
        from number_generate_model import generate_lang_test_episode
        from number_word_interpret_grammar import RHS_DICT
        tokens = ['token'+format(i, '02d') for i in range(1, 52)]
        input_lang = Lang(tokens)
        output_lang = Lang([str(i) for i in range(10)] )
        prog_symbols = ['1000000*','10000*', '1000*', '100*', '10*', '[x1]*10', '[x1]*100', '[x1]*1000',  '[x1]*10000',  '[x1]*1000000', '[x1]', '[u1]','[y1]', 'x1', 'u1', 'y1', '->', '\n'] + [str(i) for i in range(10)]
        prog_lang = Lang(prog_symbols+input_lang.symbols)
        nsupp = 25
        nquery = 100
        def generate_episode_train(tabu_examples):
            return generate_lang_test_episode(nsupp,
                                nquery,
                                input_lang,
                                output_lang,
                                prog_lang, 
                                tabu_examples,
                                lang=lang)
        generate_episode_test = generate_episode_train

    elif episode_type == 'wordToNumber':
        from number_generate_model import generate_wordToNumber_episode
        from number_word_interpret_grammar import RHS_DICT
        tokens = ['token'+format(i, '02d') for i in range(1, 52)]
        input_lang = Lang(tokens)
        output_lang = Lang([str(i) for i in range(10)] )
        prog_symbols = ['1000000*','10000*', '1000*', '100*', '10*', '[x1]*10', '[x1]*100', '[x1]*1000',  '[x1]*10000',  '[x1]*1000000', '[x1]', '[u1]','[y1]', 'x1', 'u1', 'y1', '->', '\n'] + [str(i) for i in range(10)]
        prog_lang = Lang(prog_symbols+input_lang.symbols)
        
        def generate_episode_train(tabu_examples):
            nsupp = random.choice(range(60,101)) #should vary this ...
            nquery = 10
            return generate_wordToNumber_episode(nsupp,
                                                nquery,
                                                input_lang,
                                                output_lang,
                                                prog_lang, 
                                                tabu_examples)

        generate_episode_test = generate_episode_train

    elif episode_type == 'scan_random':
        words = ['walk','look','run','jump','turn','left','right','opposite','around','twice','thrice','and','after'] + ['dax', 'blicket', 'lug', 'kiki']
        cmds = ['WALK','LOOK','RUN','JUMP','LTURN','RTURN'] + ['RED', 'BLUE', 'GREEN']
        input_lang = Lang( words )
        output_lang = Lang( cmds )
        prog_lang = Lang( words+cmds+ ['->', '\n', 'x1', 'u1', '[x1]', '[u1]', 'x2', '[x2]', 'u2', '[u2]', ""]) #, '[', ']'] ) #""

        tp = 'random'

        def generate_episode_train(tabu_episodes):
            nprims = random.choice(range(4,9))
            nsupp = random.choice(range(30,51)) 

            nurules = 0
            nxrules = random.choice((3,4,5,6,7))

            return generate_scan_episode(nsupport=nsupp,
                                        nquery=10,
                                        nprims=nprims,
                                        nurules=nurules,
                                        nxrules=nxrules,
                                        input_lang=input_lang,
                                        output_lang=output_lang, 
                                        prog_lang=prog_lang, 
                                        tabu_list=tabu_episodes, u_type=tp)      
        generate_episode_test = generate_episode_train
    
    elif episode_type in ['scan_simple_original', 'scan_jump_original', 'scan_around_right_original', 'scan_length_original']:

        dic = {'scan_simple_original':'simple',
                'scan_jump_original': 'addprim_jump',
                'scan_around_right_original':'template_around_right',
                'scan_length_original': 'length'  }

        scan_train = ge.load_scan_file( dic[episode_type],'train')
        scan_test = ge.load_scan_file( dic[episode_type],'test')
        #assert 0, "deal with langs"
        # input_symbols_scan = get_unique_words([c[0] for c in scan_train+scan_test])
        # output_symbols_scan = get_unique_words([c[1] for c in scan_train+scan_test])
        # input_lang = Lang(input_symbols_scan)
        # output_lang = Lang(output_symbols_scan)
        words = ['walk','look','run','jump','turn','left','right','opposite','around','twice','thrice','and','after'] + ['dax', 'blicket', 'lug', 'kiki']
        cmds = ['I_WALK','I_LOOK','I_RUN','I_JUMP','I_TURN_LEFT','I_TURN_RIGHT'] + ['RED', 'BLUE', 'GREEN']
        #assert set(words) == set(get_unique_words([c[0] for c in scan_train+scan_test]))
        #assert set(cmds) == set(get_unique_words([c[1] for c in scan_train+scan_test]))
        print("WARNING: vocab includes extra words, so beware")
        input_lang = Lang( words)
        output_lang = Lang( cmds )
        prog_lang = Lang( words+cmds+ ['->', '\n', 'x1', 'u1', '[x1]', '[u1]', 'x2', '[x2]', 'u2', '[u2]', ""])#, '[', ']'] ) #""


        generate_episode_train = lambda tabu_episodes : generate_traditional_synth_scan_episode(
                                                                nsupport=100, 
                                                                nquery=500, 
                                                                input_lang=input_lang, 
                                                                output_lang=output_lang, 
                                                                train_tuples=scan_train, 
                                                                test_tuples=scan_test, 
                                                                tabu_list=tabu_episodes)
        generate_episode_test = lambda tabu_episodes :  generate_traditional_synth_scan_episode(
                                                                nsupport=100, 
                                                                nquery=500, 
                                                                input_lang=input_lang, 
                                                                output_lang=output_lang, 
                                                                train_tuples=scan_train, 
                                                                test_tuples=scan_test, 
                                                                tabu_list=tabu_episodes)

    else:
        raise Exception("episode_type is not valid" )

    return generate_episode_train, generate_episode_test, input_lang, output_lang, prog_lang
    parser.add_argument('--eval-interval',
                        type=int,
                        default=1000,
                        help='evaluate model at this rate')
    parser.add_argument('--model-path',
                        type=str,
                        default='',
                        help='model checkpoint to start from (if any)')

    args = parser.parse_args()
    print(args)

    # Set the random seed manually for better reproducibility
    torch.manual_seed(args.seed)

    scan_all = ge.load_scan_file('all', 'train')
    scan_all_var = ge.load_scan_var('all', 'train')

    input_symbols_scan = get_unique_words([c[0] for c in scan_all])
    output_symbols_scan = get_unique_words([c[1] for c in scan_all])

    all_symbols_scan = input_symbols_scan + output_symbols_scan
    all_lang = Lang(all_symbols_scan)
    ntoken = all_lang.n_symbols

    # set up transformer encoder-decoder model, loss, optimizer
    model = TransformerModel(ntoken=ntoken,
                             emsize=args.emsize,
                             nhead=args.nhead,
                             nhid=args.nhid,
                             nlayers=args.nlayers,