Ejemplo n.º 1
0
def main(my_arg):
    log_dir = 'mr_logs' + str(my_arg)
    logger = Logger(log_dir)
    emb = LoadEmbedding('res/embedding.txt')
    print('finish loading embedding')
    batch_getter = SquadLoader('data/SQuAD/train-v1.1.json',
                               config['batch_size'], True)
    # batch_getter = SquadLoader('data/SQuAD/dev-v1.1.json', config['batch_size'], True)
    print('finish loading train data')
    embedding_layer = EmbeddingLayer(emb)
    d = config['hidden_size']
    att_layer = AttentionFlowLayer(2 * d)
    model_layer = ModelingLayer(8 * d, d, 2)
    start_layer = StartProbLayer(10 * d)
    end_layer = EndProbLayer(2 * d, d)

    if config['USE_CUDA']:
        att_layer.cuda(config['cuda_num'])
        embedding_layer.cuda(config['cuda_num'])
        model_layer.cuda(config['cuda_num'])
        start_layer.cuda(config['cuda_num'])
        end_layer.cuda(config['cuda_num'])

    emb_opt = torch.optim.Adam(embedding_layer.parameters())
    att_opt = torch.optim.Adam(att_layer.parameters())
    model_opt = torch.optim.Adam(model_layer.parameters())
    start_opt = torch.optim.Adam(start_layer.parameters())
    end_opt = torch.optim.Adam(end_layer.parameters())

    model_dir = 'mr_model' + str(my_arg)
    check_epoch = 0
    check_ex_iteration = 0

    if config['resume']:
        check = torch.load(model_dir + '/opt.pkl')
        emb_opt.load_state_dict(check['emb_opt'])
        att_opt.load_state_dict(check['att_opt'])
        model_opt.load_state_dict(check['model_opt'])
        start_opt.load_state_dict(check['start_opt'])
        end_opt.load_state_dict(check['end_opt'])
        check_epoch = check['epoch']
        check_ex_iteration = check['iteration']

        embedding_layer.load_state_dict(
            torch.load(model_dir + '/embedding_layer.pkl'))
        att_layer.load_state_dict(torch.load(model_dir + '/att_layer.pkl'))
        model_layer.load_state_dict(torch.load(model_dir + '/model_layer.pkl'))
        start_layer.load_state_dict(torch.load(model_dir + '/start_layer.pkl'))
        end_layer.load_state_dict(torch.load(model_dir + '/end_layer.pkl'))

    log_file = open('{}/log_file'.format(log_dir), 'w')
    f_max = 0
    low_epoch = 0
    ex_iterations = check_ex_iteration + 1

    for epoch in range(check_epoch, config['epochs']):
        embedding_layer.train()
        att_layer.train()
        model_layer.train()
        start_layer.train()
        end_layer.train()
        # exact_match, f = evaluate_all(my_arg, False)
        for iteration, this_batch in enumerate(batch_getter):
            time0 = time.time()
            print('epoch: {}, iteraton: {}'.format(epoch,
                                                   ex_iterations + iteration))
            train_iteration(logger, ex_iterations + iteration, embedding_layer,
                            att_layer, model_layer, start_layer, end_layer,
                            emb_opt, att_opt, model_opt, start_opt, end_opt,
                            this_batch)
            time1 = time.time()
            print('this iteration time: ', time1 - time0, '\n')
            if (ex_iterations + iteration) % config['save_freq'] == 0:
                torch.save(embedding_layer.state_dict(),
                           model_dir + '/embedding_layer.pkl')
                torch.save(att_layer.state_dict(),
                           model_dir + '/att_layer.pkl')
                torch.save(model_layer.state_dict(),
                           model_dir + '/model_layer.pkl')
                torch.save(start_layer.state_dict(),
                           model_dir + '/start_layer.pkl')
                torch.save(end_layer.state_dict(),
                           model_dir + '/end_layer.pkl')
                check_point = {
                    'epoch': epoch,
                    'iteration': ex_iterations + iteration,
                    'emb_opt': emb_opt.state_dict(),
                    'att_opt': att_opt.state_dict(),
                    'model_opt': model_opt.state_dict(),
                    'start_opt': start_opt.state_dict(),
                    'end_opt': end_opt.state_dict()
                }
                torch.save(check_point, model_dir + '/opt.pkl')
        if epoch == 11:
            torch.save(embedding_layer.state_dict(),
                       model_dir + '/12_embedding_layer.pkl')
            torch.save(att_layer.state_dict(), model_dir + '/12_att_layer.pkl')
            torch.save(model_layer.state_dict(),
                       model_dir + '/12_model_layer.pkl')
            torch.save(start_layer.state_dict(),
                       model_dir + '/12_start_layer.pkl')
            torch.save(end_layer.state_dict(), model_dir + '/12_end_layer.pkl')
            check_point = {
                'epoch': epoch,
                'iteration': ex_iterations + iteration,
                'emb_opt': emb_opt.state_dict(),
                'att_opt': att_opt.state_dict(),
                'model_opt': model_opt.state_dict(),
                'start_opt': start_opt.state_dict(),
                'end_opt': end_opt.state_dict()
            }
            torch.save(check_point, model_dir + '/opt.pkl')

        ex_iterations += iteration + 1
        batch_getter.reset()
        config['use_dropout'] = False
        exact_match, f = evaluate_all(my_arg, False)
        config['use_dropout'] = True
        log_file.write('epoch: {} exact_match: {} f: {}\n'.format(
            epoch, exact_match, f))
        log_file.flush()
        if f >= f_max:
            f_max = f
            low_epoch = 0
            os.system('cp {}/embedding_layer.pkl {}/early_embedding_layer.pkl'.
                      format(model_dir, model_dir))
            os.system('cp {}/att_layer.pkl {}/early_att_layer.pkl'.format(
                model_dir, model_dir))
            os.system('cp {}/model_layer.pkl {}/early_model_layer.pkl'.format(
                model_dir, model_dir))
            os.system('cp {}/start_layer.pkl {}/early_start_layer.pkl'.format(
                model_dir, model_dir))
            os.system('cp {}/end_layer.pkl {}/early_end_layer.pkl'.format(
                model_dir, model_dir))
            os.system('cp {}/opt.pkl {}/early_opt.pkl'.format(
                model_dir, model_dir))

        else:
            low_epoch += 1
            log_file.write('low' + str(low_epoch) + '\n')
            log_file.flush()
        if low_epoch >= config['early_stop']:
            break
    log_file.close()
Ejemplo n.º 2
0
def evaluate_all(my_arg, pr=True):
    emb = LoadEmbedding('res/emb.txt')
    print 'finish loading embedding'
    # batch_getter = BatchGetter('data/dev', 'GPE_NAM', 1, False)
    batch_getter_lst = []
    if my_arg == 0:
        pernam_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.testa',
                                               'PER', 1, False)
        batch_getter_lst.append(pernam_batch_getter)

        loc_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.testa',
                                            'LOC', 1, False)
        batch_getter_lst.append(loc_batch_getter)

        misc_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.testa',
                                             'MISC', 1, False)
        batch_getter_lst.append(misc_batch_getter)

        org_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.testa',
                                            'ORG', 1, False)
        batch_getter_lst.append(org_batch_getter)

    if my_arg == 1:
        pernam_batch_getter = BatchGetter('data/dev', 'PER_NAM', 1, False)
        batch_getter_lst.append(pernam_batch_getter)

        fac_batch_getter = BatchGetter('data/dev', 'FAC_NAM', 1, False)
        batch_getter_lst.append(fac_batch_getter)

        loc_batch_getter = BatchGetter('data/dev', 'LOC_NAM', 1, False)
        batch_getter_lst.append(loc_batch_getter)

        gpe_batch_getter = BatchGetter('data/dev', 'GPE_NAM', 1, False)
        batch_getter_lst.append(gpe_batch_getter)

        org_batch_getter = BatchGetter('data/dev', 'ORG_NAM', 1, False)
        batch_getter_lst.append(org_batch_getter)
    if my_arg == 2:
        pernam_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.testa',
                                               'PER', 1, False)
        batch_getter_lst.append(pernam_batch_getter)

        loc_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.testa',
                                            'LOC', 1, False)
        batch_getter_lst.append(loc_batch_getter)

        misc_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.testa',
                                             'MISC', 1, False)
        batch_getter_lst.append(misc_batch_getter)

        org_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.testa',
                                            'ORG', 1, False)
        batch_getter_lst.append(org_batch_getter)

    batch_getter = MergeBatchGetter(batch_getter_lst, 1, False)
    print 'finish loading dev data'
    embedding_layer = EmbeddingLayer(emb, 0)
    d = embedding_layer.get_out_dim()
    att_layer = AttentionFlowLayer(2 * d)
    # if my_arg == 2:
    model_out_layer = ModelingOutLayer(8 * d, d, 2, 3, 0)
    # else:
    #     model_out_layer = ModelingOutLayer(8*d, d, 2, 2, 0)
    model_dir = 'model' + str(my_arg)

    embedding_layer.load_state_dict(
        torch.load(model_dir + '/embedding_layer.pkl'))
    att_layer.load_state_dict(torch.load(model_dir + '/att_layer.pkl'))
    model_out_layer.load_state_dict(
        torch.load(model_dir + '/model_out_layer.pkl'))

    # models = [embedding_layer, att_layer, model_out_layer]
    # opts = [emb_opt, att_opt, model_out_opt]
    ner_tag = Vocab('res/ner_xx',
                    unk_id=config['UNK_token'],
                    pad_id=config['PAD_token'])
    # if my_arg == 2:
    evaluator = ConllBoundaryPerformance(ner_tag)
    # else:
    #     evaluator = BoundaryPerformance(ner_tag)
    evaluator.reset()

    if config['USE_CUDA']:
        att_layer.cuda(config['cuda_num'])
        embedding_layer.cuda(config['cuda_num'])
        model_out_layer.cuda(config['cuda_num'])

    emb_opt = torch.optim.Adam(embedding_layer.parameters())
    att_opt = torch.optim.Adam(att_layer.parameters())
    model_out_opt = torch.optim.Adam(model_out_layer.parameters())
    out_file = codecs.open('data/eva_result' + str(my_arg),
                           mode='wb',
                           encoding='utf-8')

    ex_iterations = 0
    for iteration, this_batch in enumerate(batch_getter):
        target, rec = evaluate_one(ex_iterations + iteration, embedding_layer,
                                   att_layer, model_out_layer, emb_opt,
                                   att_opt, model_out_opt, this_batch)
        evaluator.evaluate(iteration,
                           target.numpy().tolist(),
                           rec.numpy().tolist(),
                           out_file,
                           pr=pr)
        if iteration % 100 == 0:
            print '{} sentences processed'.format(iteration)
            evaluator.get_performance()
    return evaluator.get_performance()
Ejemplo n.º 3
0
def main(my_arg):
    log_dir = 'ner_logs' + str(my_arg)
    logger = Logger(log_dir)
    emb = LoadEmbedding('res/embedding.txt')
    if config['label_emb'] or config['question_alone']:
        onto_emb = LoadEmbedding('res/onto_embedding.txt')
    print('finish loading embedding')
    # batch_getter = BatchGetter('data/train', 'GPE_NAM', config['batch_size'])
    batch_getter_lst = []
    if config['bioes']:
        if config['data'] == 'conll':
            # pernam_batch_getter = ConllBatchGetter('data/conll2003/bioes_eng.train', 'PER', 1, True)
            pernam_batch_getter = ConllBatchGetter(
                'data/conll2003/bioes_eng.train', 'PER', 1, True)
            batch_getter_lst.append(pernam_batch_getter)

            # loc_batch_getter = ConllBatchGetter('data/conll2003/bioes_eng.train', 'LOC', 1, True)
            loc_batch_getter = ConllBatchGetter(
                'data/conll2003/bioes_eng.train', 'LOC', 1, True)
            batch_getter_lst.append(loc_batch_getter)

            if not config['drop_misc']:
                # misc_batch_getter = ConllBatchGetter('data/conll2003/bioes_eng.train', 'MISC', 1, True)
                misc_batch_getter = ConllBatchGetter(
                    'data/conll2003/bioes_eng.train', 'MISC', 1, True)
                batch_getter_lst.append(misc_batch_getter)

            # org_batch_getter = ConllBatchGetter('data/conll2003/bioes_eng.train', 'ORG', 1, True)
            org_batch_getter = ConllBatchGetter(
                'data/conll2003/bioes_eng.train', 'ORG', 1, True)
            batch_getter_lst.append(org_batch_getter)
        elif config['data'] == 'OntoNotes':
            # onto_notes_data = TrainOntoNotesGetter('data/OntoNotes/train.json', 1, True)
            onto_notes_data = OntoNotesGetter(
                'data/OntoNotes/leaf_train.json',
                ['/person', '/organization', '/location', '/other'], 1, True)
            batch_getter_lst.append(onto_notes_data)
    else:
        pernam_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.train',
                                               'PER', 1, True)
        batch_getter_lst.append(pernam_batch_getter)

        loc_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.train',
                                            'LOC', 1, True)
        batch_getter_lst.append(loc_batch_getter)

        if not config['drop_misc']:
            misc_batch_getter = ConllBatchGetter(
                'data/conll2003/bio_eng.train', 'MISC', 1, True)
            batch_getter_lst.append(misc_batch_getter)

        org_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.train',
                                            'ORG', 1, True)
        batch_getter_lst.append(org_batch_getter)

    batch_getter = MergeBatchGetter(batch_getter_lst,
                                    config['batch_size'],
                                    True,
                                    data_name=config['data'])
    print('finish loading train data')
    # if config['data'] == 'OntoNotes':
    #     emb_onto = True
    # else:
    #     emb_onto = False
    embedding_layer = EmbeddingLayer(emb)
    if config['label_emb']:
        q_word_embedding = nn.Embedding(onto_emb.get_voc_size(),
                                        onto_emb.get_emb_size())
        q_word_embedding.weight.data.copy_(onto_emb.get_embedding_tensor())
        q_word_embedding.weight.requires_grad = False
    else:
        q_word_embedding = None
    d = config['hidden_size']
    if config['question_alone']:
        q_emb_layer = QLabel(onto_emb)
    else:
        q_emb_layer = None
    att_layer = AttentionFlowLayer(2 * d)

    model_layer = ModelingLayer(8 * d, d, 2)
    ner_hw_layer = NerHighway(2 * d, 8 * d, 1)
    ner_out_layer = NerOutLayer(10 * d, len(config['Tags']))
    crf = CRF(config, config['Tags'], 10 * d)

    if config['USE_CUDA']:
        att_layer.cuda(config['cuda_num'])
        embedding_layer.cuda(config['cuda_num'])
        if config['label_emb']:
            q_word_embedding.cuda(config['cuda_num'])
        model_layer.cuda(config['cuda_num'])
        ner_hw_layer.cuda(config['cuda_num'])
        ner_out_layer.cuda(config['cuda_num'])
        crf.cuda(config['cuda_num'])
        if config['question_alone']:
            q_emb_layer.cuda(config['cuda_num'])

    squad_model_dir = 'mr_model1'

    if not config['not_pretrain']:
        att_layer.load_state_dict(
            torch.load(squad_model_dir + '/early_att_layer.pkl',
                       map_location=lambda storage, loc: storage))
        model_layer.load_state_dict(
            torch.load(squad_model_dir + '/early_model_layer.pkl',
                       map_location=lambda storage, loc: storage))
        embedding_layer.load_state_dict(
            torch.load(squad_model_dir + '/early_embedding_layer.pkl',
                       map_location=lambda storage, loc: storage))

    if config['freeze']:
        for param in att_layer.parameters():
            param.requires_grad = False
        for param in model_layer.parameters():
            param.requires_grad = False
        for param in embedding_layer.parameters():
            param.requires_grad = False
        embedding_layer.eval()
        model_layer.eval()
        att_layer.eval()
        emb_opt = None
        att_opt = None
        model_opt = None
    else:
        if config['not_pretrain']:
            emb_opt = torch.optim.Adam(
                filter(lambda param: param.requires_grad,
                       embedding_layer.parameters()))
            att_opt = torch.optim.Adam(
                filter(lambda param: param.requires_grad,
                       att_layer.parameters()))
            model_opt = torch.optim.Adam(
                filter(lambda param: param.requires_grad,
                       model_layer.parameters()))
        else:
            emb_opt = torch.optim.Adam(filter(
                lambda param: param.requires_grad,
                embedding_layer.parameters()),
                                       lr=1e-4)
            att_opt = torch.optim.Adam(filter(
                lambda param: param.requires_grad, att_layer.parameters()),
                                       lr=1e-4)
            model_opt = torch.optim.Adam(filter(
                lambda param: param.requires_grad, model_layer.parameters()),
                                         lr=1e-4)

    # model_opt = torch.optim.Adam(filter(lambda param: param.requires_grad, model_layer.parameters()))
    ner_hw_opt = torch.optim.Adam(
        filter(lambda param: param.requires_grad, ner_hw_layer.parameters()))
    ner_out_opt = torch.optim.Adam(
        filter(lambda param: param.requires_grad, ner_out_layer.parameters()))
    crf_opt = torch.optim.Adam(
        filter(lambda param: param.requires_grad, crf.parameters()))
    if config['question_alone']:
        q_emb_opt = torch.optim.Adam(
            filter(lambda param: param.requires_grad,
                   q_emb_layer.parameters()))
    else:
        q_emb_opt = None

    log_file = open('{}/log_file'.format(log_dir), 'w')
    f_max = 0
    low_epoch = 0
    ex_iterations = 0
    model_dir = 'ner_model' + str(my_arg)
    time0 = time.time()
    for epoch in range(config['max_epoch']):
        embedding_layer.train()
        att_layer.train()
        model_layer.train()
        ner_hw_layer.train()
        ner_out_layer.train()
        crf.train()
        if config['question_alone']:
            q_emb_layer.train()
        # f, p, r = evaluate_all(my_arg, False)
        for iteration, this_batch in enumerate(batch_getter):
            if (ex_iterations + iteration) % 100 == 0:
                print('epoch: {}, iteraton: {}'.format(
                    epoch, ex_iterations + iteration))

            train_iteration(logger, ex_iterations + iteration, embedding_layer,
                            q_word_embedding, q_emb_layer, att_layer,
                            model_layer, ner_hw_layer, ner_out_layer, crf,
                            emb_opt, q_emb_opt, att_opt, model_opt, ner_hw_opt,
                            ner_out_opt, crf_opt, this_batch)
            if (ex_iterations + iteration) % 100 == 0:
                time1 = time.time()
                print('this iteration time: ', time1 - time0, '\n')
                time0 = time1
            if (ex_iterations + iteration) % config['save_freq'] == 0:
                torch.save(embedding_layer.state_dict(),
                           model_dir + '/embedding_layer.pkl')
                torch.save(att_layer.state_dict(),
                           model_dir + '/att_layer.pkl')
                torch.save(model_layer.state_dict(),
                           model_dir + '/model_layer.pkl')
                torch.save(ner_hw_layer.state_dict(),
                           model_dir + '/ner_hw_layer.pkl')
                torch.save(ner_out_layer.state_dict(),
                           model_dir + '/ner_out_layer.pkl')
                torch.save(crf.state_dict(), model_dir + '/crf.pkl')
                if config['question_alone']:
                    torch.save(q_emb_layer.state_dict(),
                               model_dir + '/q_emb_layer.pkl')

        ex_iterations += iteration + 1
        batch_getter.reset()
        config['use_dropout'] = False
        f, p, r = evaluate_all(my_arg, False)
        config['use_dropout'] = True
        log_file.write('epoch: {} f: {} p: {} r: {}\n'.format(epoch, f, p, r))
        log_file.flush()
        if f >= f_max:
            f_max = f
            low_epoch = 0
            os.system('cp {}/embedding_layer.pkl {}/early_embedding_layer.pkl'.
                      format(model_dir, model_dir))
            os.system('cp {}/att_layer.pkl {}/early_att_layer.pkl'.format(
                model_dir, model_dir))
            os.system('cp {}/model_layer.pkl {}/early_model_layer.pkl'.format(
                model_dir, model_dir))
            os.system(
                'cp {}/ner_hw_layer.pkl {}/early_ner_hw_layer.pkl'.format(
                    model_dir, model_dir))
            os.system(
                'cp {}/ner_out_layer.pkl {}/early_ner_out_layer.pkl'.format(
                    model_dir, model_dir))
            os.system('cp {}/crf.pkl {}/early_crf.pkl'.format(
                model_dir, model_dir))
            if config['question_alone']:
                os.system(
                    'cp {}/q_emb_layer.pkl {}/early_q_emb_layer.pkl'.format(
                        model_dir, model_dir))

        else:
            low_epoch += 1
            log_file.write('low' + str(low_epoch) + '\n')
            log_file.flush()
        if low_epoch >= config['early_stop']:
            break
    log_file.close()
Ejemplo n.º 4
0
def evaluate_all(my_arg, pr=True):
    emb = LoadEmbedding('res/emb.txt')
    print('finish loading embedding')
    batch_size = 100
    batch_getter = SquadLoader('data/SQuAD/dev-v1.1.json', batch_size, False)
    print('finish loading dev data')
    embedding_layer = EmbeddingLayer(emb, dropout_p=0)
    d = config['hidden_size']
    att_layer = AttentionFlowLayer(2 * d)
    model_layer = ModelingLayer(8 * d, d, 2, dropout=0)
    start_layer = StartProbLayer(10 * d, dropout=0)
    end_layer = EndProbLayer(2 * d, d, dropout=0)

    if config['USE_CUDA']:
        att_layer.cuda(config['cuda_num'])
        embedding_layer.cuda(config['cuda_num'])
        model_layer.cuda(config['cuda_num'])
        start_layer.cuda(config['cuda_num'])
        end_layer.cuda(config['cuda_num'])
    model_dir = 'mr_model' + str(my_arg)

    embedding_layer.load_state_dict(
        torch.load(model_dir + '/embedding_layer.pkl'))
    att_layer.load_state_dict(torch.load(model_dir + '/att_layer.pkl'))
    model_layer.load_state_dict(torch.load(model_dir + '/model_layer.pkl'))
    start_layer.load_state_dict(torch.load(model_dir + '/start_layer.pkl'))
    end_layer.load_state_dict(torch.load(model_dir + '/end_layer.pkl'))

    embedding_layer.eval()
    att_layer.eval()
    model_layer.eval()
    start_layer.eval()
    end_layer.eval()

    result_json = {}

    ex_iterations = 0
    for iteration, this_batch in enumerate(batch_getter):
        start_index, end_index = evaluate_batch(ex_iterations + iteration,
                                                embedding_layer, att_layer,
                                                model_layer, start_layer,
                                                end_layer, this_batch)
        start_cpu = start_index.cpu().data.numpy()
        end_cpu = end_index.cpu().data.numpy()
        this_batch_size = len(this_batch['ids'])
        for i in range(this_batch_size):
            # start_num = this_batch['ans_start'][i]
            start_num = start_cpu[i]
            # end_num = this_batch['ans_end'][i]
            end_num = end_cpu[i]
            q_id = this_batch['ids'][i]
            art_id = this_batch['art_ids'][i]
            para_id = this_batch['para_ids'][i]
            context = batch_getter.dataset[art_id]['paragraphs'][para_id][
                'context']
            ans_word_lst = context.split()[start_num:end_num + 1]
            # ans_word_lst[-1] = remove_puctuation(ans_word_lst[-1])
            ans = ' '.join(ans_word_lst)
            result_json[q_id] = ans
        if (iteration + 1) * batch_size % 100 == 0:
            print('{} questions processed'.format(
                (iteration + 1) * batch_size))

    with open('data/squad_pred' + str(my_arg), mode='w') as out_f:
        json.dump(result_json, out_f)

    expected_version = '1.1'
    with open('data/SQuAD/dev-v1.1.json') as dataset_file:
        dataset_json = json.load(dataset_file)
        if dataset_json['version'] != expected_version:
            print('Evaluation expects v-' + expected_version +
                  ', but got dataset with v-' + dataset_json['version'],
                  file=sys.stderr)
        dataset = dataset_json['data']
    predictions = result_json
    r = evaluate(dataset, predictions)
    print(r)
    return r['exact_match'], r['f1']
Ejemplo n.º 5
0
    def free_evaluate_all(my_arg, pr=True):
        emb = LoadEmbedding('res/emb.txt')
        if config['label_emb'] or config['question_alone']:
            onto_emb = LoadEmbedding('res/onto_embedding.txt')
        print('finish loading embedding')
        batch_getter_lst = []
        if my_arg == 0:
            # pernam_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.testa', 'PER', 1, False)
            # batch_getter_lst.append(pernam_batch_getter)

            # loc_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.testa', 'LOC', 1, False)
            # batch_getter_lst.append(loc_batch_getter)

            misc_batch_getter = ConllBatchGetter(
                'data/conll2003/bio_eng.testa', 'MISC', 1, False)
            batch_getter_lst.append(misc_batch_getter)

            # org_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.testa', 'ORG', 1, False)
            # batch_getter_lst.append(org_batch_getter)

        if my_arg == 1:
            # pernam_batch_getter = ConllBatchGetter('data/ttt', 'PER', 1, False)
            # batch_getter_lst.append(pernam_batch_getter)
            # pernam_batch_getter = ConllBatchGetter('data/ttt', 'singer', 1, False)
            # batch_getter_lst.append(pernam_batch_getter)
            pernam_batch_getter = ConllBatchGetter(
                'data/conll2003/bio_eng.testb', 'PER', 1, False)
            batch_getter_lst.append(pernam_batch_getter)

            loc_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.testb',
                                                'LOC', 1, False)
            batch_getter_lst.append(loc_batch_getter)
            #
            misc_batch_getter = ConllBatchGetter(
                'data/conll2003/bio_eng.testb', 'MISC', 1, False)
            batch_getter_lst.append(misc_batch_getter)

            org_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.testb',
                                                'ORG', 1, False)
            batch_getter_lst.append(org_batch_getter)
        if my_arg == 2:
            # pernam_batch_getter = ConllBatchGetter('data/conll2003/bioes_eng.testb', 'food', 1, False)
            # batch_getter_lst.append(pernam_batch_getter)
            pernam_batch_getter = ConllBatchGetter(
                'data/conll2003/bioes_eng.testb', 'PER', 1, False)
            batch_getter_lst.append(pernam_batch_getter)

            loc_batch_getter = ConllBatchGetter(
                'data/conll2003/bioes_eng.testb', 'LOC', 1, False)
            batch_getter_lst.append(loc_batch_getter)

            misc_batch_getter = ConllBatchGetter(
                'data/conll2003/bioes_eng.testb', 'MISC', 1, False)
            batch_getter_lst.append(misc_batch_getter)

            org_batch_getter = ConllBatchGetter(
                'data/conll2003/bioes_eng.testb', 'ORG', 1, False)
            batch_getter_lst.append(org_batch_getter)
        if my_arg == 3:
            # onto_notes = OntoNotesGetter('data/OntoNotes/test.json', '/person', 1, False)
            # batch_getter_lst.append(onto_notes)
            onto_notes_data = OntoNotesGetter('data/OntoNotes/test.json',
                                              utils.get_ontoNotes_type_lst(),
                                              1, False)
            batch_getter_lst.append(onto_notes_data)
        batch_size = 100
        batch_getter = MergeBatchGetter(batch_getter_lst,
                                        batch_size,
                                        False,
                                        data_name=config['data'])
        print('finish loading dev data')
        # if config['data'] == 'OntoNotes':
        #     emb_onto = True
        # else:
        #     emb_onto = False
        embedding_layer = EmbeddingLayer(emb, 0)
        if config['label_emb']:
            q_word_embedding = nn.Embedding(onto_emb.get_voc_size(),
                                            onto_emb.get_emb_size())
            q_word_embedding.weight.data.copy_(onto_emb.get_embedding_tensor())
            q_word_embedding.weight.requires_grad = False
        else:
            q_word_embedding = None
        d = config['hidden_size']
        if config['question_alone']:
            q_emb_layer = QLabel(onto_emb, 0)
        else:
            q_emb_layer = None
        att_layer = AttentionFlowLayer(2 * d)
        model_layer = ModelingLayer(8 * d, d, 2, 0)
        ner_hw_layer = NerHighway(2 * d, 8 * d, 1)
        ner_out_layer = NerOutLayer(10 * d, len(config['Tags']), 0)
        crf = CRF(config, config['Tags'], len(config['Tags']))
        if config['USE_CUDA']:
            att_layer.cuda(config['cuda_num'])
            embedding_layer.cuda(config['cuda_num'])
            if config['label_emb']:
                q_word_embedding.cuda(config['cuda_num'])
            model_layer.cuda(config['cuda_num'])
            ner_hw_layer.cuda(config['cuda_num'])
            ner_out_layer.cuda(config['cuda_num'])
            crf.cuda(config['cuda_num'])
            if config['question_alone']:
                q_emb_layer.cuda(config['cuda_num'])
        model_dir = 'ner_model8'

        att_layer.load_state_dict(
            torch.load(model_dir + '/early_att_layer.pkl',
                       map_location=lambda storage, loc: storage))
        model_layer.load_state_dict(
            torch.load(model_dir + '/early_model_layer.pkl',
                       map_location=lambda storage, loc: storage))
        ner_hw_layer.load_state_dict(
            torch.load(model_dir + '/early_ner_hw_layer.pkl',
                       map_location=lambda storage, loc: storage))
        ner_out_layer.load_state_dict(
            torch.load(model_dir + '/early_ner_out_layer.pkl',
                       map_location=lambda storage, loc: storage))
        crf.load_state_dict(
            torch.load(model_dir + '/early_crf.pkl',
                       map_location=lambda storage, loc: storage))
        embedding_layer.load_state_dict(
            torch.load(model_dir + '/early_embedding_layer.pkl',
                       map_location=lambda storage, loc: storage))
        if config['question_alone']:
            q_emb_layer.load_state_dict(
                torch.load(model_dir + '/q_emb_layer.pkl',
                           map_location=lambda storage, loc: storage))
        else:
            q_emb_layer = None
        if config['question_alone']:
            q_emb_layer.eval()
        embedding_layer.eval()
        att_layer.eval()
        model_layer.eval()
        ner_hw_layer.eval()
        ner_out_layer.eval()
        crf.eval()

        ner_tag = Vocab('res/ner_xx',
                        unk_id=config['UNK_token'],
                        pad_id=config['PAD_token'])
        if my_arg == 3:
            evaluator = ConllBoundaryPerformance(ner_tag, onto_notes_data)
        else:
            evaluator = ConllBoundaryPerformance(ner_tag)
        evaluator.reset()
        out_file = codecs.open('data/eva_result' + str(my_arg),
                               mode='wb',
                               encoding='utf-8')
        # writer.add_embedding(embedding_layer.word_embedding.weight.data.cpu())
        # return
        all_emb = None
        all_metadata = []
        ex_iterations = 0
        summary_emb = False
        for iteration, this_batch in enumerate(batch_getter):
            # if iteration >= 15:
            #     break

            if summary_emb:
                top_path, all_emb, all_metadata, q = evaluate_one(
                    ex_iterations + iteration, embedding_layer,
                    q_word_embedding, q_emb_layer, att_layer, model_layer,
                    ner_hw_layer, ner_out_layer, crf, this_batch, summary_emb,
                    all_emb, all_metadata)
            else:
                top_path = evaluate_one(ex_iterations + iteration,
                                        embedding_layer, q_word_embedding,
                                        q_emb_layer, att_layer, model_layer,
                                        ner_hw_layer, ner_out_layer, crf,
                                        this_batch)
            for batch_no, path in enumerate(top_path):
                evaluator.evaluate(
                    iteration * batch_size + batch_no,
                    remove_end_tag(
                        this_batch[1].numpy()[batch_no, :].tolist()), path,
                    out_file, pr)
            if (iteration + 1) * batch_size % 100 == 0:
                print('{} sentences processed'.format(
                    (iteration + 1) * batch_size))
                evaluator.get_performance()
        if summary_emb:
            writer.add_embedding(torch.cat([q, all_emb], 0),
                                 metadata=['question'] + all_metadata)
        return evaluator.get_performance()