Пример #1
0
    wlog('Vocabulary size: |source|={}, |target|={}'.format(
        n_src_vcb, n_trg_vcb))

    # wv = KeyedVectors.load('word_vector_en', mmap='r')
    # voc = list(wv.vocab)
    # weight = tc.zeros(n_trg_vcb, 100)
    # for i, k in trg_vocab.idx2key.items():
    #     if k not in wv:
    #         continue
    #     weight[i, :] = tc.Tensor(wv[k])
    # dic = {
    #     'w2v': weight
    # }
    # tc.save(dic, '../zh_en_w2v_embedding.pt')

    model_dict, e_idx, e_bidx, n_steps, optim = load_model(model_file)
    from models.embedding import WordEmbedding

    src_emb = WordEmbedding(n_src_vcb,
                            wargs.d_src_emb,
                            position_encoding=wargs.position_encoding,
                            prefix='Src')
    trg_emb = WordEmbedding(n_trg_vcb,
                            wargs.d_trg_emb,
                            position_encoding=wargs.position_encoding,
                            prefix='Trg')
    from models.model_builder import build_NMT

    nmtModel = build_NMT(src_emb, trg_emb)
    if args.gpu_ids is not None:
        wlog('push model onto GPU {} ... '.format(args.gpu_ids[0]), 0)
Пример #2
0
def main():

    #if wargs.ss_type is not None: assert wargs.model == 1, 'Only rnnsearch support schedule sample'
    init_dir(wargs.dir_model)
    init_dir(wargs.dir_valid)

    src = os.path.join(wargs.dir_data, '{}.{}'.format(wargs.train_prefix, wargs.train_src_suffix))
    trg = os.path.join(wargs.dir_data, '{}.{}'.format(wargs.train_prefix, wargs.train_trg_suffix))
    vocabs = {}
    wlog('\nPreparing source vocabulary from {} ... '.format(src))
    src_vocab = extract_vocab(src, wargs.src_vcb, wargs.n_src_vcb_plan,
                              wargs.max_seq_len, char=wargs.src_char)
    wlog('\nPreparing target vocabulary from {} ... '.format(trg))
    trg_vocab = extract_vocab(trg, wargs.trg_vcb, wargs.n_trg_vcb_plan, wargs.max_seq_len)
    n_src_vcb, n_trg_vcb = src_vocab.size(), trg_vocab.size()
    wlog('Vocabulary size: |source|={}, |target|={}'.format(n_src_vcb, n_trg_vcb))
    vocabs['src'], vocabs['trg'] = src_vocab, trg_vocab

    wlog('\nPreparing training set from {} and {} ... '.format(src, trg))
    trains = {}
    train_src_tlst, train_trg_tlst = wrap_data(wargs.dir_data, wargs.train_prefix,
                                               wargs.train_src_suffix, wargs.train_trg_suffix,
                                               src_vocab, trg_vocab, shuffle=True,
                                               sort_k_batches=wargs.sort_k_batches,
                                               max_seq_len=wargs.max_seq_len,
                                               char=wargs.src_char)
    '''
    list [torch.LongTensor (sentence), torch.LongTensor, torch.LongTensor, ...]
    no padding
    '''
    batch_train = Input(train_src_tlst, train_trg_tlst, wargs.batch_size,
                        batch_type=wargs.batch_type, bow=wargs.trg_bow, batch_sort=False)
    wlog('Sentence-pairs count in training data: {}'.format(len(train_src_tlst)))

    batch_valid = None
    if wargs.val_prefix is not None:
        val_src_file = os.path.join(wargs.val_tst_dir, '{}.{}'.format(wargs.val_prefix, wargs.val_src_suffix))
        val_trg_file = os.path.join(wargs.val_tst_dir, '{}.{}'.format(wargs.val_prefix, wargs.val_ref_suffix))
        wlog('\nPreparing validation set from {} and {} ... '.format(val_src_file, val_trg_file))
        valid_src_tlst, valid_trg_tlst = wrap_data(wargs.val_tst_dir, wargs.val_prefix,
                                                   wargs.val_src_suffix, wargs.val_ref_suffix,
                                                   src_vocab, trg_vocab, shuffle=False,
                                                   max_seq_len=wargs.dev_max_seq_len,
                                                   char=wargs.src_char)
        batch_valid = Input(valid_src_tlst, valid_trg_tlst, 1, batch_sort=False)

    batch_tests = None
    if wargs.tests_prefix is not None:
        assert isinstance(wargs.tests_prefix, list), 'Test files should be list.'
        init_dir(wargs.dir_tests)
        batch_tests = {}
        for prefix in wargs.tests_prefix:
            init_dir(wargs.dir_tests + '/' + prefix)
            test_file = '{}{}.{}'.format(wargs.val_tst_dir, prefix, wargs.val_src_suffix)
            wlog('\nPreparing test set from {} ... '.format(test_file))
            test_src_tlst, _ = wrap_tst_data(test_file, src_vocab, char=wargs.src_char)
            batch_tests[prefix] = Input(test_src_tlst, None, 1, batch_sort=False)
    wlog('\n## Finish to Prepare Dataset ! ##\n')

    src_emb = WordEmbedding(n_src_vcb, wargs.d_src_emb, wargs.input_dropout,
                            wargs.position_encoding, prefix='Src')
    trg_emb = WordEmbedding(n_trg_vcb, wargs.d_trg_emb, wargs.input_dropout,
                            wargs.position_encoding, prefix='Trg')
    # share the embedding matrix - preprocess with share_vocab required.
    if wargs.embs_share_weight:
        if n_src_vcb != n_trg_vcb:
            raise AssertionError('The `-share_vocab` should be set during '
                                 'preprocess if you use share_embeddings!')
        src_emb.we.weight = trg_emb.we.weight

    nmtModel = build_NMT(src_emb, trg_emb)

    if not wargs.copy_attn:
        classifier = Classifier(wargs.d_model if wargs.decoder_type == 'att' else 2 * wargs.d_enc_hid,
                                n_trg_vcb, trg_emb, loss_norm=wargs.loss_norm,
                                label_smoothing=wargs.label_smoothing,
                                emb_loss=wargs.emb_loss, bow_loss=wargs.bow_loss)
    nmtModel.decoder.classifier = classifier

    if wargs.gpu_id is not None:
        wlog('push model onto GPU {} ... '.format(wargs.gpu_id), 0)
        #nmtModel = nn.DataParallel(nmtModel, device_ids=wargs.gpu_id)
        nmtModel.to(tc.device('cuda'))
    else:
        wlog('push model onto CPU ... ', 0)
        nmtModel.to(tc.device('cpu'))
    wlog('done.')

    if wargs.pre_train is not None:
        assert os.path.exists(wargs.pre_train)
        from tools.utils import load_model
        _dict = load_model(wargs.pre_train)
        # initializing parameters of interactive attention model
        class_dict = None
        if len(_dict) == 5:
            model_dict, class_dict, eid, bid, optim = _dict
        elif len(_dict) == 4:
            model_dict, eid, bid, optim = _dict
        for name, param in nmtModel.named_parameters():
            if name in model_dict:
                param.requires_grad = not wargs.fix_pre_params
                param.data.copy_(model_dict[name])
                wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name))
            elif name.endswith('map_vocab.weight'):
                if class_dict is not None:
                    param.requires_grad = not wargs.fix_pre_params
                    param.data.copy_(class_dict['map_vocab.weight'])
                    wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name))
            elif name.endswith('map_vocab.bias'):
                if class_dict is not None:
                    param.requires_grad = not wargs.fix_pre_params
                    param.data.copy_(class_dict['map_vocab.bias'])
                    wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name))
            else: init_params(param, name, init_D=wargs.param_init_D, a=float(wargs.u_gain))

        wargs.start_epoch = eid + 1

    else:
        optim = Optim(wargs.opt_mode, wargs.learning_rate, wargs.max_grad_norm)
        #for n, p in nmtModel.named_parameters():
            # bias can not be initialized uniformly
            #if wargs.encoder_type != 'att' and wargs.decoder_type != 'att':
            #    init_params(p, n, init_D=wargs.param_init_D, a=float(wargs.u_gain))

    wlog(nmtModel)
    wlog(optim)
    pcnt1 = len([p for p in nmtModel.parameters()])
    pcnt2 = sum([p.nelement() for p in nmtModel.parameters()])
    wlog('parameters number: {}/{}'.format(pcnt1, pcnt2))

    wlog('\n' + '*' * 30 + ' trainable parameters ' + '*' * 30)
    for n, p in nmtModel.named_parameters():
        if p.requires_grad: wlog('{:60} : {}'.format(n, p.size()))

    optim.init_optimizer(nmtModel.parameters())

    trainer = Trainer(nmtModel, batch_train, vocabs, optim, batch_valid, batch_tests)

    trainer.train()
Пример #3
0
        logging.info('latency_cpu (batch 1): %.2fms' % latency_cpu)
        latency_gpu = utils.latency_measure(model, (3, 224, 224),
                                            32,
                                            5000,
                                            mode='gpu')
        logging.info('latency_gpu (batch 32): %.2fms' % latency_gpu)
    params = utils.count_parameters_in_MB(model)
    logging.info("Params = %.2fMB" % params)
    mult_adds = comp_multadds(model, input_size=config.data.input_size)
    logging.info("Mult-Adds = %.2fMB" % mult_adds)

    model = nn.DataParallel(model)

    # whether to resume from a checkpoint
    if config.optim.if_resume:
        utils.load_model(model, config.optim.resume.load_path)
        start_epoch = config.optim.resume.load_epoch + 1
    else:
        start_epoch = 0

    model = model.cuda()

    if config.optim.label_smooth:
        criterion = utils.cross_entropy_with_label_smoothing
    else:
        criterion = nn.CrossEntropyLoss()
        criterion = criterion.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                config.optim.init_lr,
                                momentum=config.optim.momentum,
Пример #4
0
def main():
    # if wargs.ss_type is not None: assert wargs.model == 1, 'Only rnnsearch support schedule sample'
    init_dir(wargs.dir_model)
    init_dir(wargs.dir_valid)

    src = os.path.join(
        wargs.dir_data, '{}.{}'.format(wargs.train_prefix,
                                       wargs.train_src_suffix))
    trg = os.path.join(
        wargs.dir_data, '{}.{}'.format(wargs.train_prefix,
                                       wargs.train_trg_suffix))
    src, trg = os.path.abspath(src), os.path.abspath(trg)
    vocabs = {}
    if wargs.share_vocab is False:
        wlog('\nPreparing source vocabulary from {} ... '.format(src))
        src_vocab = extract_vocab(src,
                                  wargs.src_vcb,
                                  wargs.n_src_vcb_plan,
                                  wargs.max_seq_len,
                                  char=wargs.src_char)
        wlog('\nPreparing target vocabulary from {} ... '.format(trg))
        trg_vocab = extract_vocab(trg, wargs.trg_vcb, wargs.n_trg_vcb_plan,
                                  wargs.max_seq_len)
        n_src_vcb, n_trg_vcb = src_vocab.size(), trg_vocab.size()
        wlog('Vocabulary size: |source|={}, |target|={}'.format(
            n_src_vcb, n_trg_vcb))
    else:
        wlog('\nPreparing the shared vocabulary from \n\t{}\n\t{}'.format(
            src, trg))
        trg_vocab = src_vocab = extract_vocab(src,
                                              wargs.src_vcb,
                                              wargs.n_src_vcb_plan,
                                              wargs.max_seq_len,
                                              share_vocab=True,
                                              trg_file=trg)
        n_src_vcb, n_trg_vcb = src_vocab.size(), trg_vocab.size()
        wlog('Shared vocabulary size: |vocab|={}'.format(src_vocab.size()))

    vocabs['src'], vocabs['trg'] = src_vocab, trg_vocab

    wlog('\nPreparing training set from {} and {} ... '.format(src, trg))
    trains = {}
    train_src_tlst, train_trg_tlst = wrap_data(
        wargs.dir_data,
        wargs.train_prefix,
        wargs.train_src_suffix,
        wargs.train_trg_suffix,
        src_vocab,
        trg_vocab,
        shuffle=True,
        sort_k_batches=wargs.sort_k_batches,
        max_seq_len=wargs.max_seq_len,
        char=wargs.src_char)
    '''
    list [torch.LongTensor (sentence), torch.LongTensor, torch.LongTensor, ...]
    no padding
    '''
    batch_train = Input(train_src_tlst,
                        train_trg_tlst,
                        wargs.batch_size,
                        batch_type=wargs.batch_type,
                        bow=wargs.trg_bow,
                        batch_sort=False,
                        gpu_ids=device_ids)
    wlog('Sentence-pairs count in training data: {}'.format(
        len(train_src_tlst)))

    batch_valid = None
    if wargs.val_prefix is not None:
        val_src_file = os.path.join(
            wargs.val_tst_dir, '{}.{}'.format(wargs.val_prefix,
                                              wargs.val_src_suffix))
        val_trg_file = os.path.join(
            wargs.val_tst_dir, '{}.{}'.format(wargs.val_prefix,
                                              wargs.val_ref_suffix))
        val_src_file, val_trg_file = os.path.abspath(
            val_src_file), os.path.abspath(val_trg_file)
        wlog('\nPreparing validation set from {} and {} ... '.format(
            val_src_file, val_trg_file))
        valid_src_tlst, valid_trg_tlst = wrap_data(
            wargs.val_tst_dir,
            wargs.val_prefix,
            wargs.val_src_suffix,
            wargs.val_ref_suffix,
            src_vocab,
            trg_vocab,
            shuffle=False,
            max_seq_len=wargs.dev_max_seq_len,
            char=wargs.src_char)
        batch_valid = Input(valid_src_tlst,
                            valid_trg_tlst,
                            batch_size=wargs.valid_batch_size,
                            batch_sort=False,
                            gpu_ids=device_ids)

    batch_tests = None
    if wargs.tests_prefix is not None:
        assert isinstance(wargs.tests_prefix,
                          list), 'Test files should be list.'
        init_dir(wargs.dir_tests)
        batch_tests = {}
        for prefix in wargs.tests_prefix:
            init_dir(wargs.dir_tests + '/' + prefix)
            test_file = '{}{}.{}'.format(wargs.val_tst_dir, prefix,
                                         wargs.val_src_suffix)
            test_file = os.path.abspath(test_file)
            wlog('\nPreparing test set from {} ... '.format(test_file))
            test_src_tlst, _ = wrap_tst_data(test_file,
                                             src_vocab,
                                             char=wargs.src_char)
            batch_tests[prefix] = Input(test_src_tlst,
                                        None,
                                        batch_size=wargs.test_batch_size,
                                        batch_sort=False,
                                        gpu_ids=device_ids)
    wlog('\n## Finish to Prepare Dataset ! ##\n')

    src_emb = WordEmbedding(n_src_vcb,
                            wargs.d_src_emb,
                            wargs.input_dropout,
                            wargs.position_encoding,
                            prefix='Src')
    trg_emb = WordEmbedding(n_trg_vcb,
                            wargs.d_trg_emb,
                            wargs.input_dropout,
                            wargs.position_encoding,
                            prefix='Trg')
    # share the embedding matrix between the source and target
    if wargs.share_vocab is True: src_emb.we.weight = trg_emb.we.weight

    nmtModel = build_NMT(src_emb, trg_emb)

    if device_ids is not None:
        wlog('push model onto GPU {} ... '.format(device_ids[0]), 0)
        nmtModel_par = nn.DataParallel(nmtModel, device_ids=device_ids)
        nmtModel_par.to(device)
    else:
        wlog('push model onto CPU ... ', 0)
        nmtModel.to(tc.device('cpu'))
    wlog('done.')

    if wargs.pre_train is not None:
        wlog(wargs.pre_train)
        assert os.path.exists(wargs.pre_train)
        from tools.utils import load_model
        _dict = load_model(wargs.pre_train)
        # initializing parameters of interactive attention model
        class_dict = None
        if len(_dict) == 5:
            # model_dict, e_idx, e_bidx, n_steps, optim = _dict['model'], _dict['epoch'], _dict['batch'], _dict['steps'], _dict['optim']
            model_dict, e_idx, e_bidx, n_steps, optim = _dict
        elif len(_dict) == 4:
            model_dict, e_idx, e_bidx, optim = _dict
        for name, param in nmtModel.named_parameters():
            if name in model_dict:
                param.requires_grad = not wargs.fix_pre_params
                param.data.copy_(model_dict[name])
                # wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name))
            elif name.endswith('map_vocab.weight'):
                if class_dict is not None:
                    param.requires_grad = not wargs.fix_pre_params
                    param.data.copy_(class_dict['map_vocab.weight'])
                    # wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name))
            elif name.endswith('map_vocab.bias'):
                if class_dict is not None:
                    param.requires_grad = not wargs.fix_pre_params
                    param.data.copy_(class_dict['map_vocab.bias'])
                    # wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name))
            else:
                init_params(param,
                            name,
                            init_D=wargs.param_init_D,
                            a=float(wargs.u_gain))

        # wargs.start_epoch = e_idx + 1
        # # 不重新开始
        # optim.n_current_steps = 0

    else:
        optim = Optim(wargs.opt_mode, wargs.learning_rate, wargs.max_grad_norm)
        for n, p in nmtModel.named_parameters():
            # bias can not be initialized uniformly
            if 'norm' in n:
                wlog('ignore layer norm init ...')
                continue
            if 'emb' in n:
                wlog('ignore word embedding weight init ...')
                continue
            if 'vcb_proj' in n:
                wlog('ignore vcb_proj weight init ...')
                continue
            init_params(p, n, init_D=wargs.param_init_D, a=float(wargs.u_gain))
            # if wargs.encoder_type != 'att' and wargs.decoder_type != 'att':
            #    init_params(p, n, init_D=wargs.param_init_D, a=float(wargs.u_gain))

    # wlog(nmtModel)
    wlog(optim)
    pcnt1 = len([p for p in nmtModel.parameters()])
    pcnt2 = sum([p.nelement() for p in nmtModel.parameters()])
    wlog('parameters number: {}/{}'.format(pcnt1, pcnt2))

    # wlog('\n' + '*' * 30 + ' trainable parameters ' + '*' * 30)
    # for n, p in nmtModel.named_parameters():
    #     if p.requires_grad: wlog('{:60} : {}'.format(n, p.size()))
    opt_state = None
    if wargs.pre_train:
        opt_state = optim.optimizer.state_dict()

    if wargs.use_reinfore_ce is False:
        criterion = LabelSmoothingCriterion(
            trg_emb.n_vocab, label_smoothing=wargs.label_smoothing)
    else:
        word2vec = tc.load(wargs.word2vec_weight)['w2v']
        # criterion = Word2VecDistanceCriterion(word2vec)
        criterion = CosineDistance(word2vec)

    if device_ids is not None:
        wlog('push criterion onto GPU {} ... '.format(device_ids[0]), 0)
        criterion = criterion.to(device)
        wlog('done.')
    # if wargs.reinfore_type == 0 or wargs.reinfore_type == 1:
    #     param = list(nmtModel.parameters())
    # else:
    #     param = list(nmtModel.parameters()) + list(criterion.parameters())
    param = list(nmtModel.parameters())
    optim.init_optimizer(param)

    lossCompute = MultiGPULossCompute(
        nmtModel.generator,
        criterion,
        wargs.d_model if wargs.decoder_type == 'att' else 2 * wargs.d_enc_hid,
        n_trg_vcb,
        trg_emb,
        nmtModel.bowMapper,
        loss_norm=wargs.loss_norm,
        chunk_size=wargs.chunk_size,
        device_ids=device_ids)

    trainer = Trainer(nmtModel_par, batch_train, vocabs, optim, lossCompute,
                      nmtModel, batch_valid, batch_tests, writer)

    trainer.train()
    writer.close()
Пример #5
0
    cudnn.benchmark = True
    cudnn.enabled = True
    
    logging.info("args = %s", args)
    logging.info('Training with config:')
    logging.info(pprint.pformat(config))

    config.net_config, net_type = utils.load_net_config(os.path.join(args.load_path, 'net_config'))

    derivedNetwork = getattr(model_derived, '%s_Net' % net_type.upper())
    model = derivedNetwork(config.net_config, config=config)
    
    logging.info("Network Structure: \n" + '\n'.join(map(str, model.net_config)))
    logging.info("Params = %.2fMB" % utils.count_parameters_in_MB(model))
    logging.info("Mult-Adds = %.2fMB" % comp_multadds(model, input_size=config.data.input_size))

    model = model.cuda()
    model = nn.DataParallel(model)
    utils.load_model(model, os.path.join(args.load_path, 'weights.pt'))

    imagenet = imagenet_data.ImageNet12(trainFolder=os.path.join(args.data_path, 'train'),
                            testFolder=os.path.join(args.data_path, 'val'),
                            num_workers=config.data.num_workers,
                            data_config=config.data)
    valid_queue = imagenet.getTestLoader(config.data.batch_size)
    trainer = Trainer(None, valid_queue, None, None, 
                        None, config, args.report_freq)

    with torch.no_grad():
        val_acc_top1, val_acc_top5, valid_obj, batch_time = trainer.infer(model)
Пример #6
0
    mergeWay = args.merge_way
    avgAtt = args.avg_att
    m_threshold = args.m_threshold
    switchs = [useBatch, vocabNorm, lenNorm, useMv, mergeWay, avgAtt]
    '''

    wlog('Starting load vocabularies ... ')
    assert os.path.exists(wargs.src_vcb) and os.path.exists(
        wargs.trg_vcb), 'need vocabulary ...'
    src_vocab = extract_vocab(None, wargs.src_vcb)
    trg_vocab = extract_vocab(None, wargs.trg_vcb)
    n_src_vcb, n_trg_vcb = src_vocab.size(), trg_vocab.size()
    wlog('Vocabulary size: |source|={}, |target|={}'.format(
        n_src_vcb, n_trg_vcb))

    _dict = load_model(model_file)
    if len(_dict) == 4: model_dict, eid, bid, optim = _dict
    elif len(_dict) == 5: model_dict, class_dict, eid, bid, optim = _dict
    from models.embedding import WordEmbedding
    src_emb = WordEmbedding(n_src_vcb,
                            wargs.d_src_emb,
                            wargs.position_encoding,
                            prefix='Src')
    trg_emb = WordEmbedding(n_trg_vcb,
                            wargs.d_trg_emb,
                            wargs.position_encoding,
                            prefix='Trg')
    from models.model_builder import build_NMT
    nmtModel = build_NMT(src_emb, trg_emb)
    classifier = Classifier(wargs.d_dec_hid,
                            n_trg_vcb,
Пример #7
0
                         train_cols=['Question', 'Dialogue', 'Report'],
                         test_path='input/AutoMaster_TestSet.csv',
                         stop_list=['|', '[', ']', '语音', '图片', ' '])
    # # sentences = MyCorpus1()

    model_path = 'output/w2v.wv'
    embed_path = 'output/embedding.mat'
    vocab_path = 'output/vocab.dict'
    model = build_model(sentences=sentences,
                        size=256,
                        skip_gram=1,
                        hs=1,
                        save_path=model_path)
    print(model['说'])
    save_embedding_vocab(model, embed_path, vocab_path)
    with open(vocab_path, 'rb') as f:
        vocab = load_model(f)
    print(vocab['说'])
    with open(embed_path, 'rb') as f:
        matrix = load_model(f)
    print(matrix[1])

    # wv = KeyedVectors.load_word2vec_format(model_path, binary=True)  # 以KeyedVectors实例形式加载词向量
    # print(wv['车主'])

    # model1 = KeyedVectors.load(model_path)  # 以Word2Vec形式加载模型
    # print('车主和技师的词向量相似度为:{}'.format(model1.similarity('技师', '车主')))
    # print(model1.wv.get_vector('语音'))
    # 问题
    # 1 为什么加载语料时,使用MyCorpus1比使用MyCorpus慢一倍多一点
distilgpt2, distilgpt2_tokenizer = make_pretrained_transformer_and_tokenizer(
    'distilgpt2')
xlmroberta, xlmroberta_tokenizer = make_pretrained_transformer_and_tokenizer(
    'xlm-roberta-base')
lstm, lstm_tokenizer = make_pretrained_lstm_and_tokenizer()

# %%
# Load the appropriate probing-models
path_to_lstm_POS_probe = 'storage/saved_models/pos_probes/full_run_monday_LSTM_pos_probe.pt'
path_to_distilgpt2_POS_probe = 'storage/saved_models/pos_probes/full_run_monday_distilgpt2_pos_probe.pt'
# TODO: Add XLM Roberta

POS_lstm_probe = SimpleProbe(650, 19)

POS_distilgpt2_probe = SimpleProbe(768, 19)
utils.load_model(path_to_lstm_POS_probe, POS_lstm_probe)
# utils.load_model(path_to_distilgpt2_POS_probe, POS_distilgpt2_probe)

# %%
POS_pairs = [('distilgpt2', distilgpt2, distilgpt2_tokenizer,
              POS_distilgpt2_probe),
             ('lstm', lstm, lstm_tokenizer, POS_lstm_probe)]

# %%
for pair in POS_pairs:
    feature_name, feature_model, feature_model_tokenizer, probe = pair

    eval_reps, eval_pos_tags, eval_pos_vocab = init_pos_data(
        config.path_to_data_test,
        feature_model,
        feature_model_tokenizer,