示例#1
0
def get_episode_generator(episode_type):
    #  Returns function that generates episodes,
    #   and language class for the input and output language
    #
    # Input
    #  episode_type : string specifying type of episode
    #
    # Output
    #  generate_episode: function handle for generating episodes
    #  input_lang: Language object for input sequence
    #  output_lang: Language object for output sequence
    input_symbols_list_default = [
        'dax', 'lug', 'wif', 'zup', 'fep', 'blicket', 'kiki', 'tufa', 'gazzer'
    ]
    output_symbols_list_default = [
        'RED', 'YELLOW', 'GREEN', 'BLUE', 'PURPLE', 'PINK'
    ]
    input_lang = Lang(input_symbols_list_default)
    output_lang = Lang(output_symbols_list_default)
    if episode_type == 'ME':  # NeurIPS Exp 1 : Mutual exclusivity
        input_lang = Lang(input_symbols_list_default[:4])
        output_lang = Lang(output_symbols_list_default[:4])
        generate_episode_train = lambda tabu_episodes: generate_ME(
            nquery=20,
            nprims=len(input_lang.symbols),
            input_lang=input_lang,
            output_lang=output_lang,
            tabu_list=tabu_episodes)
        generate_episode_test = generate_episode_train
    elif episode_type == 'scan_prim_permutation':  # NeurIPS Exp 2 : Adding a new primitive through permutation meta-training
        scan_all = ge.load_scan_file('all', 'train')
        scan_all_var = ge.load_scan_var('all', 'train')
        input_symbols_scan = get_unique_words([c[0] for c in scan_all])
        output_symbols_scan = get_unique_words([c[1] for c in scan_all])
        input_lang = Lang(input_symbols_scan)
        output_lang = Lang(output_symbols_scan)
        generate_episode_train = lambda tabu_episodes: generate_prim_permutation(
            shuffle=True,
            nsupport=20,
            nquery=20,
            input_lang=input_lang,
            output_lang=output_lang,
            scan_var_tuples=scan_all_var,
            nextra=0,
            tabu_list=tabu_episodes)
        generate_episode_test = lambda tabu_episodes: generate_prim_permutation(
            shuffle=False,
            nsupport=20,
            nquery=20,
            input_lang=input_lang,
            output_lang=output_lang,
            scan_var_tuples=scan_all_var,
            nextra=0,
            tabu_list=tabu_episodes)
    elif episode_type == 'scan_prim_augmentation':  # NeurIPS Exp 3 : Adding a new primitive through augmentation meta-training
        nextra_prims = 20
        scan_all = ge.load_scan_file('all', 'train')
        scan_all_var = ge.load_scan_var('all', 'train')
        input_symbols_scan = get_unique_words(
            [c[0]
             for c in scan_all] + [str(i) for i in range(1, nextra_prims + 1)])
        output_symbols_scan = get_unique_words(
            [c[1] for c in scan_all] +
            ['I_' + str(i) for i in range(1, nextra_prims + 1)])
        input_lang = Lang(input_symbols_scan)
        output_lang = Lang(output_symbols_scan)
        generate_episode_train = lambda tabu_episodes: generate_prim_augmentation(
            shuffle=True,
            nextra=nextra_prims,
            nsupport=20,
            nquery=20,
            input_lang=input_lang,
            output_lang=output_lang,
            scan_var_tuples=scan_all_var,
            tabu_list=tabu_episodes)
        generate_episode_test = lambda tabu_episodes: generate_prim_augmentation(
            shuffle=False,
            nextra=0,
            nsupport=20,
            nquery=20,
            input_lang=input_lang,
            output_lang=output_lang,
            scan_var_tuples=scan_all_var,
            tabu_list=tabu_episodes)
    elif episode_type == 'scan_around_right':  # NeurIPS Exp 4 : Combining familiar concepts through meta-training
        nextra_prims = 2
        scan_all = ge.load_scan_file('all', 'train')
        scan_all_var = ge.load_scan_dir_var('all', 'train')
        input_symbols_scan = get_unique_words(
            [c[0]
             for c in scan_all] + [str(i) for i in range(1, nextra_prims + 1)])
        output_symbols_scan = get_unique_words(
            [c[1] for c in scan_all] +
            ['I_' + str(i) for i in range(1, nextra_prims + 1)])
        input_lang = Lang(input_symbols_scan)
        output_lang = Lang(output_symbols_scan)
        generate_episode_train = lambda tabu_episodes: generate_right_augmentation(
            shuffle=True,
            nextra=nextra_prims,
            nsupport=20,
            nquery=20,
            input_lang=input_lang,
            output_lang=output_lang,
            scan_var_tuples=scan_all_var,
            tabu_list=tabu_episodes)
        generate_episode_test = lambda tabu_episodes: generate_right_augmentation(
            shuffle=False,
            nextra=0,
            nsupport=20,
            nquery=20,
            input_lang=input_lang,
            output_lang=output_lang,
            scan_var_tuples=scan_all_var,
            tabu_list=tabu_episodes)
    elif episode_type == 'scan_length':  # NeurIPS Exp 5 : Generalizing to longer instructions through meta-training
        nextra_prims = 20  # number of additional primitives to augment the episodes with
        support_threshold = 12  # items with action length less than this belong in the support,
        # and greater than or equal to this length belong in the query
        scan_length_train = ge.load_scan_file('length', 'train')
        scan_length_test = ge.load_scan_file('length', 'test')
        scan_all = scan_length_train + scan_length_test
        scan_length_train_var = ge.load_scan_var('length', 'train')
        scan_length_test_var = ge.load_scan_var('length', 'test')
        input_symbols_scan = get_unique_words(
            [c[0]
             for c in scan_all] + [str(i) for i in range(1, nextra_prims + 1)])
        output_symbols_scan = get_unique_words(
            [c[1] for c in scan_all] +
            ['I_' + str(i) for i in range(1, nextra_prims + 1)])
        input_lang = Lang(input_symbols_scan)
        output_lang = Lang(output_symbols_scan)
        scan_length_support_var = [
            pair for pair in scan_length_train_var
            if len(pair[1].split(' ')) < support_threshold
        ]  # partition based on number of output actions
        scan_length_query_var = [
            pair for pair in scan_length_train_var
            if len(pair[1].split(' ')) >= support_threshold
        ]  # long sequences
        generate_episode_train = lambda tabu_episodes: generate_length(
            shuffle=True,
            nextra=nextra_prims,
            nsupport=100,
            nquery=20,
            input_lang=input_lang,
            output_lang=output_lang,
            scan_tuples_support_variable=scan_length_support_var,
            scan_tuples_query_variable=scan_length_query_var,
            tabu_list=tabu_episodes)
        generate_episode_test = lambda tabu_episodes: generate_length(
            shuffle=False,
            nextra=0,
            nsupport=100,
            nquery=20,
            input_lang=input_lang,
            output_lang=output_lang,
            scan_tuples_support_variable=scan_length_train_var,
            scan_tuples_query_variable=scan_length_test_var,
            tabu_list=tabu_episodes)
    else:
        raise Exception("episode_type is not valid")
    return generate_episode_train, generate_episode_test, input_lang, output_lang
                        type=int,
                        default=1000,
                        help='evaluate model at this rate')
    parser.add_argument('--model-path',
                        type=str,
                        default='',
                        help='model checkpoint to start from (if any)')

    args = parser.parse_args()
    print(args)

    # Set the random seed manually for better reproducibility
    torch.manual_seed(args.seed)

    scan_all = ge.load_scan_file('all', 'train')
    scan_all_var = ge.load_scan_var('all', 'train')

    input_symbols_scan = get_unique_words([c[0] for c in scan_all])
    output_symbols_scan = get_unique_words([c[1] for c in scan_all])

    all_symbols_scan = input_symbols_scan + output_symbols_scan
    all_lang = Lang(all_symbols_scan)
    ntoken = all_lang.n_symbols

    # set up transformer encoder-decoder model, loss, optimizer
    model = TransformerModel(ntoken=ntoken,
                             emsize=args.emsize,
                             nhead=args.nhead,
                             nhid=args.nhid,
                             nlayers=args.nlayers,
                             dropout=args.dropout)