Пример #1
0
def main(my_arg):
    log_dir = 'ner_logs' + str(my_arg)
    logger = Logger(log_dir)
    emb = LoadEmbedding('res/embedding.txt')
    if config['label_emb'] or config['question_alone']:
        onto_emb = LoadEmbedding('res/onto_embedding.txt')
    print('finish loading embedding')
    # batch_getter = BatchGetter('data/train', 'GPE_NAM', config['batch_size'])
    batch_getter_lst = []
    if config['bioes']:
        if config['data'] == 'conll':
            # pernam_batch_getter = ConllBatchGetter('data/conll2003/bioes_eng.train', 'PER', 1, True)
            pernam_batch_getter = ConllBatchGetter(
                'data/conll2003/bioes_eng.train', 'PER', 1, True)
            batch_getter_lst.append(pernam_batch_getter)

            # loc_batch_getter = ConllBatchGetter('data/conll2003/bioes_eng.train', 'LOC', 1, True)
            loc_batch_getter = ConllBatchGetter(
                'data/conll2003/bioes_eng.train', 'LOC', 1, True)
            batch_getter_lst.append(loc_batch_getter)

            if not config['drop_misc']:
                # misc_batch_getter = ConllBatchGetter('data/conll2003/bioes_eng.train', 'MISC', 1, True)
                misc_batch_getter = ConllBatchGetter(
                    'data/conll2003/bioes_eng.train', 'MISC', 1, True)
                batch_getter_lst.append(misc_batch_getter)

            # org_batch_getter = ConllBatchGetter('data/conll2003/bioes_eng.train', 'ORG', 1, True)
            org_batch_getter = ConllBatchGetter(
                'data/conll2003/bioes_eng.train', 'ORG', 1, True)
            batch_getter_lst.append(org_batch_getter)
        elif config['data'] == 'OntoNotes':
            # onto_notes_data = TrainOntoNotesGetter('data/OntoNotes/train.json', 1, True)
            onto_notes_data = OntoNotesGetter(
                'data/OntoNotes/leaf_train.json',
                ['/person', '/organization', '/location', '/other'], 1, True)
            batch_getter_lst.append(onto_notes_data)
    else:
        pernam_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.train',
                                               'PER', 1, True)
        batch_getter_lst.append(pernam_batch_getter)

        loc_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.train',
                                            'LOC', 1, True)
        batch_getter_lst.append(loc_batch_getter)

        if not config['drop_misc']:
            misc_batch_getter = ConllBatchGetter(
                'data/conll2003/bio_eng.train', 'MISC', 1, True)
            batch_getter_lst.append(misc_batch_getter)

        org_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.train',
                                            'ORG', 1, True)
        batch_getter_lst.append(org_batch_getter)

    batch_getter = MergeBatchGetter(batch_getter_lst,
                                    config['batch_size'],
                                    True,
                                    data_name=config['data'])
    print('finish loading train data')
    # if config['data'] == 'OntoNotes':
    #     emb_onto = True
    # else:
    #     emb_onto = False
    embedding_layer = EmbeddingLayer(emb)
    if config['label_emb']:
        q_word_embedding = nn.Embedding(onto_emb.get_voc_size(),
                                        onto_emb.get_emb_size())
        q_word_embedding.weight.data.copy_(onto_emb.get_embedding_tensor())
        q_word_embedding.weight.requires_grad = False
    else:
        q_word_embedding = None
    d = config['hidden_size']
    if config['question_alone']:
        q_emb_layer = QLabel(onto_emb)
    else:
        q_emb_layer = None
    att_layer = AttentionFlowLayer(2 * d)

    model_layer = ModelingLayer(8 * d, d, 2)
    ner_hw_layer = NerHighway(2 * d, 8 * d, 1)
    ner_out_layer = NerOutLayer(10 * d, len(config['Tags']))
    crf = CRF(config, config['Tags'], 10 * d)

    if config['USE_CUDA']:
        att_layer.cuda(config['cuda_num'])
        embedding_layer.cuda(config['cuda_num'])
        if config['label_emb']:
            q_word_embedding.cuda(config['cuda_num'])
        model_layer.cuda(config['cuda_num'])
        ner_hw_layer.cuda(config['cuda_num'])
        ner_out_layer.cuda(config['cuda_num'])
        crf.cuda(config['cuda_num'])
        if config['question_alone']:
            q_emb_layer.cuda(config['cuda_num'])

    squad_model_dir = 'mr_model1'

    if not config['not_pretrain']:
        att_layer.load_state_dict(
            torch.load(squad_model_dir + '/early_att_layer.pkl',
                       map_location=lambda storage, loc: storage))
        model_layer.load_state_dict(
            torch.load(squad_model_dir + '/early_model_layer.pkl',
                       map_location=lambda storage, loc: storage))
        embedding_layer.load_state_dict(
            torch.load(squad_model_dir + '/early_embedding_layer.pkl',
                       map_location=lambda storage, loc: storage))

    if config['freeze']:
        for param in att_layer.parameters():
            param.requires_grad = False
        for param in model_layer.parameters():
            param.requires_grad = False
        for param in embedding_layer.parameters():
            param.requires_grad = False
        embedding_layer.eval()
        model_layer.eval()
        att_layer.eval()
        emb_opt = None
        att_opt = None
        model_opt = None
    else:
        if config['not_pretrain']:
            emb_opt = torch.optim.Adam(
                filter(lambda param: param.requires_grad,
                       embedding_layer.parameters()))
            att_opt = torch.optim.Adam(
                filter(lambda param: param.requires_grad,
                       att_layer.parameters()))
            model_opt = torch.optim.Adam(
                filter(lambda param: param.requires_grad,
                       model_layer.parameters()))
        else:
            emb_opt = torch.optim.Adam(filter(
                lambda param: param.requires_grad,
                embedding_layer.parameters()),
                                       lr=1e-4)
            att_opt = torch.optim.Adam(filter(
                lambda param: param.requires_grad, att_layer.parameters()),
                                       lr=1e-4)
            model_opt = torch.optim.Adam(filter(
                lambda param: param.requires_grad, model_layer.parameters()),
                                         lr=1e-4)

    # model_opt = torch.optim.Adam(filter(lambda param: param.requires_grad, model_layer.parameters()))
    ner_hw_opt = torch.optim.Adam(
        filter(lambda param: param.requires_grad, ner_hw_layer.parameters()))
    ner_out_opt = torch.optim.Adam(
        filter(lambda param: param.requires_grad, ner_out_layer.parameters()))
    crf_opt = torch.optim.Adam(
        filter(lambda param: param.requires_grad, crf.parameters()))
    if config['question_alone']:
        q_emb_opt = torch.optim.Adam(
            filter(lambda param: param.requires_grad,
                   q_emb_layer.parameters()))
    else:
        q_emb_opt = None

    log_file = open('{}/log_file'.format(log_dir), 'w')
    f_max = 0
    low_epoch = 0
    ex_iterations = 0
    model_dir = 'ner_model' + str(my_arg)
    time0 = time.time()
    for epoch in range(config['max_epoch']):
        embedding_layer.train()
        att_layer.train()
        model_layer.train()
        ner_hw_layer.train()
        ner_out_layer.train()
        crf.train()
        if config['question_alone']:
            q_emb_layer.train()
        # f, p, r = evaluate_all(my_arg, False)
        for iteration, this_batch in enumerate(batch_getter):
            if (ex_iterations + iteration) % 100 == 0:
                print('epoch: {}, iteraton: {}'.format(
                    epoch, ex_iterations + iteration))

            train_iteration(logger, ex_iterations + iteration, embedding_layer,
                            q_word_embedding, q_emb_layer, att_layer,
                            model_layer, ner_hw_layer, ner_out_layer, crf,
                            emb_opt, q_emb_opt, att_opt, model_opt, ner_hw_opt,
                            ner_out_opt, crf_opt, this_batch)
            if (ex_iterations + iteration) % 100 == 0:
                time1 = time.time()
                print('this iteration time: ', time1 - time0, '\n')
                time0 = time1
            if (ex_iterations + iteration) % config['save_freq'] == 0:
                torch.save(embedding_layer.state_dict(),
                           model_dir + '/embedding_layer.pkl')
                torch.save(att_layer.state_dict(),
                           model_dir + '/att_layer.pkl')
                torch.save(model_layer.state_dict(),
                           model_dir + '/model_layer.pkl')
                torch.save(ner_hw_layer.state_dict(),
                           model_dir + '/ner_hw_layer.pkl')
                torch.save(ner_out_layer.state_dict(),
                           model_dir + '/ner_out_layer.pkl')
                torch.save(crf.state_dict(), model_dir + '/crf.pkl')
                if config['question_alone']:
                    torch.save(q_emb_layer.state_dict(),
                               model_dir + '/q_emb_layer.pkl')

        ex_iterations += iteration + 1
        batch_getter.reset()
        config['use_dropout'] = False
        f, p, r = evaluate_all(my_arg, False)
        config['use_dropout'] = True
        log_file.write('epoch: {} f: {} p: {} r: {}\n'.format(epoch, f, p, r))
        log_file.flush()
        if f >= f_max:
            f_max = f
            low_epoch = 0
            os.system('cp {}/embedding_layer.pkl {}/early_embedding_layer.pkl'.
                      format(model_dir, model_dir))
            os.system('cp {}/att_layer.pkl {}/early_att_layer.pkl'.format(
                model_dir, model_dir))
            os.system('cp {}/model_layer.pkl {}/early_model_layer.pkl'.format(
                model_dir, model_dir))
            os.system(
                'cp {}/ner_hw_layer.pkl {}/early_ner_hw_layer.pkl'.format(
                    model_dir, model_dir))
            os.system(
                'cp {}/ner_out_layer.pkl {}/early_ner_out_layer.pkl'.format(
                    model_dir, model_dir))
            os.system('cp {}/crf.pkl {}/early_crf.pkl'.format(
                model_dir, model_dir))
            if config['question_alone']:
                os.system(
                    'cp {}/q_emb_layer.pkl {}/early_q_emb_layer.pkl'.format(
                        model_dir, model_dir))

        else:
            low_epoch += 1
            log_file.write('low' + str(low_epoch) + '\n')
            log_file.flush()
        if low_epoch >= config['early_stop']:
            break
    log_file.close()
Пример #2
0
    def free_evaluate_all(my_arg, pr=True):
        emb = LoadEmbedding('res/emb.txt')
        if config['label_emb'] or config['question_alone']:
            onto_emb = LoadEmbedding('res/onto_embedding.txt')
        print('finish loading embedding')
        batch_getter_lst = []
        if my_arg == 0:
            # pernam_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.testa', 'PER', 1, False)
            # batch_getter_lst.append(pernam_batch_getter)

            # loc_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.testa', 'LOC', 1, False)
            # batch_getter_lst.append(loc_batch_getter)

            misc_batch_getter = ConllBatchGetter(
                'data/conll2003/bio_eng.testa', 'MISC', 1, False)
            batch_getter_lst.append(misc_batch_getter)

            # org_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.testa', 'ORG', 1, False)
            # batch_getter_lst.append(org_batch_getter)

        if my_arg == 1:
            # pernam_batch_getter = ConllBatchGetter('data/ttt', 'PER', 1, False)
            # batch_getter_lst.append(pernam_batch_getter)
            # pernam_batch_getter = ConllBatchGetter('data/ttt', 'singer', 1, False)
            # batch_getter_lst.append(pernam_batch_getter)
            pernam_batch_getter = ConllBatchGetter(
                'data/conll2003/bio_eng.testb', 'PER', 1, False)
            batch_getter_lst.append(pernam_batch_getter)

            loc_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.testb',
                                                'LOC', 1, False)
            batch_getter_lst.append(loc_batch_getter)
            #
            misc_batch_getter = ConllBatchGetter(
                'data/conll2003/bio_eng.testb', 'MISC', 1, False)
            batch_getter_lst.append(misc_batch_getter)

            org_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.testb',
                                                'ORG', 1, False)
            batch_getter_lst.append(org_batch_getter)
        if my_arg == 2:
            # pernam_batch_getter = ConllBatchGetter('data/conll2003/bioes_eng.testb', 'food', 1, False)
            # batch_getter_lst.append(pernam_batch_getter)
            pernam_batch_getter = ConllBatchGetter(
                'data/conll2003/bioes_eng.testb', 'PER', 1, False)
            batch_getter_lst.append(pernam_batch_getter)

            loc_batch_getter = ConllBatchGetter(
                'data/conll2003/bioes_eng.testb', 'LOC', 1, False)
            batch_getter_lst.append(loc_batch_getter)

            misc_batch_getter = ConllBatchGetter(
                'data/conll2003/bioes_eng.testb', 'MISC', 1, False)
            batch_getter_lst.append(misc_batch_getter)

            org_batch_getter = ConllBatchGetter(
                'data/conll2003/bioes_eng.testb', 'ORG', 1, False)
            batch_getter_lst.append(org_batch_getter)
        if my_arg == 3:
            # onto_notes = OntoNotesGetter('data/OntoNotes/test.json', '/person', 1, False)
            # batch_getter_lst.append(onto_notes)
            onto_notes_data = OntoNotesGetter('data/OntoNotes/test.json',
                                              utils.get_ontoNotes_type_lst(),
                                              1, False)
            batch_getter_lst.append(onto_notes_data)
        batch_size = 100
        batch_getter = MergeBatchGetter(batch_getter_lst,
                                        batch_size,
                                        False,
                                        data_name=config['data'])
        print('finish loading dev data')
        # if config['data'] == 'OntoNotes':
        #     emb_onto = True
        # else:
        #     emb_onto = False
        embedding_layer = EmbeddingLayer(emb, 0)
        if config['label_emb']:
            q_word_embedding = nn.Embedding(onto_emb.get_voc_size(),
                                            onto_emb.get_emb_size())
            q_word_embedding.weight.data.copy_(onto_emb.get_embedding_tensor())
            q_word_embedding.weight.requires_grad = False
        else:
            q_word_embedding = None
        d = config['hidden_size']
        if config['question_alone']:
            q_emb_layer = QLabel(onto_emb, 0)
        else:
            q_emb_layer = None
        att_layer = AttentionFlowLayer(2 * d)
        model_layer = ModelingLayer(8 * d, d, 2, 0)
        ner_hw_layer = NerHighway(2 * d, 8 * d, 1)
        ner_out_layer = NerOutLayer(10 * d, len(config['Tags']), 0)
        crf = CRF(config, config['Tags'], len(config['Tags']))
        if config['USE_CUDA']:
            att_layer.cuda(config['cuda_num'])
            embedding_layer.cuda(config['cuda_num'])
            if config['label_emb']:
                q_word_embedding.cuda(config['cuda_num'])
            model_layer.cuda(config['cuda_num'])
            ner_hw_layer.cuda(config['cuda_num'])
            ner_out_layer.cuda(config['cuda_num'])
            crf.cuda(config['cuda_num'])
            if config['question_alone']:
                q_emb_layer.cuda(config['cuda_num'])
        model_dir = 'ner_model8'

        att_layer.load_state_dict(
            torch.load(model_dir + '/early_att_layer.pkl',
                       map_location=lambda storage, loc: storage))
        model_layer.load_state_dict(
            torch.load(model_dir + '/early_model_layer.pkl',
                       map_location=lambda storage, loc: storage))
        ner_hw_layer.load_state_dict(
            torch.load(model_dir + '/early_ner_hw_layer.pkl',
                       map_location=lambda storage, loc: storage))
        ner_out_layer.load_state_dict(
            torch.load(model_dir + '/early_ner_out_layer.pkl',
                       map_location=lambda storage, loc: storage))
        crf.load_state_dict(
            torch.load(model_dir + '/early_crf.pkl',
                       map_location=lambda storage, loc: storage))
        embedding_layer.load_state_dict(
            torch.load(model_dir + '/early_embedding_layer.pkl',
                       map_location=lambda storage, loc: storage))
        if config['question_alone']:
            q_emb_layer.load_state_dict(
                torch.load(model_dir + '/q_emb_layer.pkl',
                           map_location=lambda storage, loc: storage))
        else:
            q_emb_layer = None
        if config['question_alone']:
            q_emb_layer.eval()
        embedding_layer.eval()
        att_layer.eval()
        model_layer.eval()
        ner_hw_layer.eval()
        ner_out_layer.eval()
        crf.eval()

        ner_tag = Vocab('res/ner_xx',
                        unk_id=config['UNK_token'],
                        pad_id=config['PAD_token'])
        if my_arg == 3:
            evaluator = ConllBoundaryPerformance(ner_tag, onto_notes_data)
        else:
            evaluator = ConllBoundaryPerformance(ner_tag)
        evaluator.reset()
        out_file = codecs.open('data/eva_result' + str(my_arg),
                               mode='wb',
                               encoding='utf-8')
        # writer.add_embedding(embedding_layer.word_embedding.weight.data.cpu())
        # return
        all_emb = None
        all_metadata = []
        ex_iterations = 0
        summary_emb = False
        for iteration, this_batch in enumerate(batch_getter):
            # if iteration >= 15:
            #     break

            if summary_emb:
                top_path, all_emb, all_metadata, q = evaluate_one(
                    ex_iterations + iteration, embedding_layer,
                    q_word_embedding, q_emb_layer, att_layer, model_layer,
                    ner_hw_layer, ner_out_layer, crf, this_batch, summary_emb,
                    all_emb, all_metadata)
            else:
                top_path = evaluate_one(ex_iterations + iteration,
                                        embedding_layer, q_word_embedding,
                                        q_emb_layer, att_layer, model_layer,
                                        ner_hw_layer, ner_out_layer, crf,
                                        this_batch)
            for batch_no, path in enumerate(top_path):
                evaluator.evaluate(
                    iteration * batch_size + batch_no,
                    remove_end_tag(
                        this_batch[1].numpy()[batch_no, :].tolist()), path,
                    out_file, pr)
            if (iteration + 1) * batch_size % 100 == 0:
                print('{} sentences processed'.format(
                    (iteration + 1) * batch_size))
                evaluator.get_performance()
        if summary_emb:
            writer.add_embedding(torch.cat([q, all_emb], 0),
                                 metadata=['question'] + all_metadata)
        return evaluator.get_performance()