Ejemplo n.º 1
0
def demotest(sentence):
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print(ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args,
                       embeddings,
                       tag2label,
                       word2id,
                       paths,
                       config=config)
    model.build_graph()
    saver = tf.train.Saver()
    with tf.Session(config=config) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)
        while (1):
            #print('Please input your sentence:')
            demo_sent = sentence
            if demo_sent == '' or demo_sent.isspace():
                print('语句为空')
                PER = ['']
                LOC = ['']
                ORG = ['']
                return (PER, LOC, ORG)
            else:
                demo_sent = list(demo_sent.strip())
                demo_data = [(demo_sent, ['O'] * len(demo_sent))]
                tag = model.demo_one(sess, demo_data)
                PER, LOC, ORG = get_entity(tag, demo_sent)
                print('PER: {}\nLOC: {}\nORG: {}'.format(PER, LOC, ORG))
                return (PER, LOC, ORG)
Ejemplo n.º 2
0
def getRest(input):
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print(ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args,
                       embeddings,
                       tag2label,
                       word2id,
                       paths,
                       config=config)
    model.build_graph()
    saver = tf.train.Saver()
    with tf.Session(config=config) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)
        demo_sent = input
        if demo_sent == '' or demo_sent.isspace():
            return {'status': 'fail'}
        else:
            demo_sent = list(demo_sent.strip())
            demo_data = [(demo_sent, ['O'] * len(demo_sent))]
            tag = model.demo_one(sess, demo_data)
            PER, LOC, ORG = get_entity(tag, demo_sent)
            result = {'status': 'success', 'PER': PER, 'LOC': LOC, 'ORG': ORG}
            return result
Ejemplo n.º 3
0
 def demo_one(self, model_path):
     '''
     输入句子
     :param model_path:
     input:武三思與韋後日夜譖敬暉等不已
     :return: [[0, 2, 'PER'], [4, 5, 'PER'], [9, 10, 'PER']]
     '''
     ckpt_file = tf.train.latest_checkpoint(model_path)
     print(ckpt_file)
     self.paths['model_path'] = ckpt_file
     model = BiLSTM_CRF(args,
                        self.embedding,
                        self.tag2id,
                        self.word2id,
                        self.paths,
                        config=config)
     model.build_graph()
     saver = tf.train.Saver()
     with tf.Session(config=config) as sess:
         print('begain to demo one sentence!')
         saver.restore(sess, ckpt_file)
         while (1):
             print('Please input your sentence:')
             demo_sent = input()
             if demo_sent == '' or demo_sent.isspace(
             ) or demo_sent == 'end':
                 print('See you next time!')
                 break
             else:
                 demo_sent = list(demo_sent.strip())
                 demo_data = [(demo_sent, ['O'] * len(demo_sent))]
                 tag = model.demo_one(sess, demo_data)
                 print(get_ner_demo(tag))
Ejemplo n.º 4
0
def predict_random(demo_sent):
    word2id, embeddings = getDicEmbed()
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print(ckpt_file)
    model = BiLSTM_CRF(batch_size=args.batch_size,
                       epoch_num=args.epoch,
                       hidden_dim=args.hidden_dim,
                       embeddings=embeddings,
                       dropout_keep=args.dropout,
                       optimizer=args.optimizer,
                       lr=args.lr,
                       clip_grad=args.clip,
                       tag2label=tag2label,
                       vocab=word2id,
                       shuffle=args.shuffle,
                       model_path=ckpt_file,
                       summary_path=summary_path,
                       log_path=log_path,
                       result_path=result_path,
                       CRF=args.CRF,
                       update_embedding=args.update_embedding)
    model.build_graph()
    saver = tf.train.Saver()
    with tf.Session() as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)
        demo_sent = list(demo_sent.strip())
        demo_data = [(demo_sent, ['M'] * len(demo_sent))]
        tag = model.demo_one(sess, demo_data)
        sess.close()
    res = segment(sent, tag)
    print(res)
Ejemplo n.º 5
0
def evaluate_words(lines):
    print("start evaluate_words")
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print(ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args,
                       embeddings,
                       tag2label,
                       word2id,
                       paths,
                       config=config)
    model.build_graph()
    saver = tf.train.Saver()
    with tf.Session(config=config) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)

        demo_sent = lines
        print(demo_sent)
        demo_sent = list(demo_sent.strip())
        print(demo_sent)
        demo_data = [(demo_sent, ['O'] * len(demo_sent))]
        tag = model.demo_one(sess, demo_data)
        PER, LOC, ORG = get_entity(tag, demo_sent)
        print('PER: {}\nLOC: {}\nORG: {}'.format(PER, LOC, ORG))
Ejemplo n.º 6
0
class NER_DEMO(object):
    def __init__(self, args):
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.gpu_options.per_process_gpu_memory_fraction = 0.2
        paths, model_path = get_paths(args)
        ckpt_file = tf.train.latest_checkpoint(model_path)

        paths['model_path'] = ckpt_file
        word2id = read_dictionary(
            os.path.join('.', args.train_data, 'word2id.pkl'))
        embeddings = random_embedding(word2id, args.embedding_dim)
        self.model = BiLSTM_CRF(args,
                                embeddings,
                                tag2label,
                                word2id,
                                paths,
                                config=config)
        self.model.build_graph()
        self.saver = tf.train.Saver()
        self.sess = tf.Session(config=config)
        self.saver.restore(self.sess, ckpt_file)

    def predict(self, demo_sent):
        if demo_sent == '' or demo_sent.isspace():
            print('See you next time!')
            return {}
        else:
            demo_sent = list(demo_sent.strip())
            demo_data = [(demo_sent, ['O'] * len(demo_sent))]
            tag = self.model.demo_one(self.sess, demo_data)
            entities = get_entity(tag, demo_sent)
            return entities
Ejemplo n.º 7
0
    model.test(test_data)

## demo
elif args.mode == 'demo':
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print(ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args,
                       embeddings,
                       tag2label,
                       word2id,
                       paths,
                       config=config)
    model.build_graph()
    saver = tf.train.Saver()
    with tf.Session(config=config) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)
        while (1):
            print('Please input your sentence:')
            demo_sent = input()
            if demo_sent == '' or demo_sent.isspace():
                print('See you next time!')
                break
            else:
                demo_sent = list(demo_sent.strip())
                demo_data = [(demo_sent, ['O'] * len(demo_sent))]
                tag = model.demo_one(sess, demo_data)
                PER, LOC, ORG = get_entity(tag, demo_sent)
                print('PER: {}\nLOC: {}\nORG: {}'.format(PER, LOC, ORG))
Ejemplo n.º 8
0
    print(ckpt_file)
    model = BiLSTM_CRF(batch_size=args.batch_size, epoch_num=args.epoch, hidden_dim=args.hidden_dim,
                       embeddings=embeddings,
                       dropout_keep=args.dropout, optimizer=args.optimizer, lr=args.lr, clip_grad=args.clip,
                       tag2label=tag2label, vocab=word2id, shuffle=args.shuffle,
                       model_path=ckpt_file, summary_path=summary_path, log_path=log_path, result_path=result_path,
                       CRF=args.CRF, update_embedding=args.update_embedding)
    model.build_graph()
    saver = tf.train.Saver()
    with tf.Session() as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)

        # 测试一份文件内容
        # with open('../content.txt', 'r') as fr:
        #     content = ''
        #     for line in fr:
        #         content += line.strip()
        while 1:
            sent = input("请输入症状描述:")
            sent = list(sent)
            sent_data = [(sent, ['O'] * len(sent))]
            tag = model.demo_one(sess, sent_data)
            for s, t in zip(sent, tag):
                print(s, t)
            # body, chec, cure, dise, symp = get_entity_keys(tag, sent, ['BODY', 'CHECK', 'TREATMENT', 'DISEASE', 'SIGNS'])
            # print('body:{}\nchec:{}\ncure:{}\ndise:{}\nsymp:{}\n'.format(body, chec, cure, dise, symp))
            symptom, disease, treatment, check = get_entity_keys(tag, sent,
                                                           ['SYMPTOM', 'DISEASE', 'TREATMENT', 'CHECK'])
            print('SYMPTOM:{}\nDISEASE:{}\nTREATMENT:{}\nCHECK:{}\n'.format(symptom, disease, treatment, check))
Ejemplo n.º 9
0
def trainAll(args):

    if args.mode == 'train':
        model = BiLSTM_CRF(args,
                           embeddings,
                           tag2label,
                           word2id,
                           paths,
                           config=config)
        model.build_graph()

        ## hyperparameters-tuning, split train/dev
        # dev_data = train_data[:5000]; dev_size = len(dev_data)
        # train_data = train_data[5000:]; train_size = len(train_data)
        # print("train data: {0}\ndev data: {1}".format(train_size, dev_size))
        # model.train(train=train_data, dev=dev_data)

        ## train model on the whole training data
        print("train data: {}".format(len(train_data)))
        model.train(
            train=train_data, dev=test_data
        )  # use test_data as the dev_data to see overfitting phenomena

    ## testing model
    elif args.mode == 'test':
        ckpt_file = tf.train.latest_checkpoint(model_path)
        print(ckpt_file)
        paths['model_path'] = ckpt_file
        model = BiLSTM_CRF(args,
                           embeddings,
                           tag2label,
                           word2id,
                           paths,
                           config=config)
        model.build_graph()
        print("test data: {}".format(test_size))
        model.test(test_data)

    ## demo
    elif args.mode == 'demo':
        ckpt_file = tf.train.latest_checkpoint(model_path)
        print(ckpt_file)
        paths['model_path'] = ckpt_file
        model = BiLSTM_CRF(args,
                           embeddings,
                           tag2label,
                           word2id,
                           paths,
                           config=config)
        model.build_graph()
        saver = tf.train.Saver()
        with tf.Session(config=config) as sess:
            print('============= demo =============')
            saver.restore(sess, ckpt_file)
            while (1):
                print('Please input your sentence:')
                demo_sent = input()
                if demo_sent == '' or demo_sent.isspace():
                    print('See you next time!')
                    break
                else:
                    demo_sent = list(demo_sent.strip())
                    demo_data = [(demo_sent, ['O'] * len(demo_sent))]
                    tag = model.demo_one(sess, demo_data)
                    PER, LOC, ORG = get_entity(tag, demo_sent)

                    print('PER: {}\nLOC: {}\nORG: {}'.format(PER, LOC, ORG))

    elif args.mode == 'savemodel':

        ckpt_file = tf.train.latest_checkpoint(model_path)
        print(ckpt_file)
        paths['model_path'] = ckpt_file
        model = BiLSTM_CRF(args,
                           embeddings,
                           tag2label,
                           word2id,
                           paths,
                           config=config)
        model.build_graph()
        saver = tf.train.Saver()
        with tf.Session(config=config) as sess:
            saver.restore(sess, ckpt_file)
            demo_sent = tf.placeholder(tf.string, name='input')
            demo_sent = list(str(demo_sent).strip())
            demo_data = [(demo_sent, ['O'] * len(demo_sent))]
            tag = model.demo_one(sess, demo_data)
            PER, LOC, ORG = get_entity(tag, demo_sent)
            result = {'PER': PER, 'LOC': LOC, 'ORG': ORG}
            print('PER: {}\nLOC: {}\nORG: {}'.format(PER, LOC, ORG))
            # #保存SavedModel模型
            builder = tf.saved_model.builder.SavedModelBuilder('./savemodels')
            signature = predict_signature_def(inputs={'input': demo_sent},
                                              outputs={'output': result})
            builder.add_meta_graph_and_variables(
                sess, [tf.saved_model.tag_constants.SERVING],
                signature_def_map={'predict': signature})
            builder.save()
            print('savemodel saves')
Ejemplo n.º 10
0
                       config=config)
    model.build_graph()
    saver = tf.train.Saver()
    with tf.Session(config=config) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)
        while (1):
            print('Please input your sentence:')
            demo_sent = input()
            if demo_sent == '' or demo_sent.isspace():
                print('See you next time!')
                break
            else:
                demo_sent = list(demo_sent.strip())
                demo_data = [(demo_sent, ['O'] * len(demo_sent))]
                tag = model.demo_one(sess, demo_data)
                PER, LOC, ORG = get_entity(tag, demo_sent)
                print('PER: {}\nLOC: {}\nORG: {}'.format(PER, LOC, ORG))

## demo
elif args.mode == 'demo1':
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print(ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args,
                       embeddings,
                       tag2label,
                       word2id,
                       paths,
                       config=config)
    model.build_graph()
Ejemplo n.º 11
0
    print(ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
    model.build_graph()
    print("test data: {}".format(test_size))
    model.test(test_data)

## demo
elif args.mode == 'demo':
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print(ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
    model.build_graph()
    saver = tf.train.Saver()
    with tf.Session(config=config) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)
        while(1):
            print('Please input your sentence:')
            demo_sent = input()
            if demo_sent == '' or demo_sent.isspace():
                print('See you next time!')
                break
            else:
                demo_sent = list(demo_sent.strip())
                demo_data = [(demo_sent, ['O'] * len(demo_sent))]
                tag = model.demo_one(sess, demo_data)
                PER, LOC, ORG = get_entity(tag, demo_sent)
                print('PER: {}\nLOC: {}\nORG: {}'.format(PER, LOC, ORG))
Ejemplo n.º 12
0
def main(_):
    print('start app')

    data_path = os.path.join(FLAGS.train_data, 'word2id.pkl')
    word2id = read_dictionary(data_path)
    if FLAGS.pretrain_embedding == 'random':
        data_embeddings = random_embedding(word2id, FLAGS.embedding_dim)
    else:
        embedding_path = 'pretrain_embedding.npy'
        data_embeddings = np.array(np.load(embedding_path), dtype='float32')

    if FLAGS.mode != 'demo':
        train_file = os.path.join(FLAGS.train_data, 'train_data')
        test_file = os.path.join(FLAGS.test_data, 'test_data')
        train_data = read_corpus(train_file)
        test_data = read_corpus(test_file)
        test_size = len(test_data)

    time_stamp = str(int(
        time.time())) if FLAGS.mode == 'train' else FLAGS.demo_model

    def generator_dir(file_path):
        if not os.path.exists(file_path):
            os.makedirs(file_path)
        return file_path

    output_path = generator_dir(
        os.path.join(FLAGS.train_data + '_save', time_stamp))
    summary_path = generator_dir(os.path.join(output_path, 'summary'))
    model_path = generator_dir(os.path.join(output_path, 'checkpoints'))
    ckpt_prefix = generator_dir(os.path.join(model_path, 'model'))
    result_path = generator_dir(os.path.join(output_path, 'results'))

    if FLAGS.mode == 'train':
        print('train ==================')
        """
        def __init__(self, batch_size, epoch, hidden_size, embeddings, crf, update_embedding, dropout_keepprob, optimizer,
         learning_rate, clip, tag2label, vocab, shuffle, model_p, summary_p, results_p, config):
        """
        model = BiLSTM_CRF(FLAGS.batch_size, FLAGS.epoch, FLAGS.hidden_size,
                           data_embeddings, FLAGS.CRF, FLAGS.update_embedding,
                           FLAGS.dropout, FLAGS.optimizer, FLAGS.learning_rate,
                           FLAGS.clipping, tag2label, word2id, FLAGS.shuffle,
                           model_path, summary_path, result_path, config)
        model.build_graph()
        model.train(train_data, test_data)
    elif FLAGS.mode == 'test':
        print('test ===============')
        ckpt_file = tf.train.latest_checkpoint(model_path)
        print('ckpt file {}'.format(ckpt_file))
        model = BiLSTM_CRF(FLAGS.batch_size, FLAGS.epoch, FLAGS.hidden_size,
                           data_embeddings, FLAGS.CRF, FLAGS.update_embedding,
                           FLAGS.dropout, FLAGS.optimizer, FLAGS.learning_rate,
                           FLAGS.clipping, tag2label, word2id, FLAGS.shuffle,
                           ckpt_file, summary_path, result_path, config)
        model.build_graph()
        print('test data {}'.format(test_size))
        model.test(test_data)
    elif FLAGS.mode == 'demo':
        print('demo ===========')
        ckpt_file = tf.train.latest_checkpoint(model_path)
        print('ckpt file {}'.format(ckpt_file))
        model = BiLSTM_CRF(FLAGS.batch_size, FLAGS.epoch, FLAGS.hidden_size,
                           data_embeddings, FLAGS.CRF, FLAGS.update_embedding,
                           FLAGS.dropout, FLAGS.optimizer, FLAGS.learning_rate,
                           FLAGS.clipping, tag2label, word2id, FLAGS.shuffle,
                           ckpt_file, summary_path, result_path, config)
        model.build_graph()
        saver = tf.train.Saver()
        with tf.Session(config=config) as sess:
            saver.restore(sess, ckpt_file)
            while True:
                print("please input you sentence:")
                demo_sentence = input()
                if not demo_sentence or demo_sentence.isspace():
                    print('bye')
                    break
                else:
                    demo_sent = list(demo_sentence.strip())
                    damo_data = [(demo_sent, [0] * len(demo_sent))]
                    tag = model.demo_one(sess, damo_data)
                    per, loc, org = get_entiry(tag, demo_sent)
                    print('per {0} loc {1} org {2}'.format(per, loc, org))
Ejemplo n.º 13
0
Archivo: ner.py Proyecto: Ma-Dan/NER
        print("train data: {}".format(len(train_data)))
        model.train(train_data=train_data, test_data=None)
    elif args.mode == 'demo':
        ckpt_file = tf.train.latest_checkpoint(model_path)
        print(ckpt_file)
        paths['model_path'] = ckpt_file
        model = BiLSTM_CRF(args, embedding, src_vocab, tgt_vocab, src_padding,
                           tgt_padding, paths)
        model.build_graph()
        saver = tf.train.Saver()
        with tf.Session() as sess:
            print('============= demo =============')
            saver.restore(sess, ckpt_file)
            while (1):
                print('Please input your sentence:')
                demo_sent = input()
                if demo_sent == '' or demo_sent.isspace():
                    print('See you next time!')
                    break
                else:
                    demo_data, demo_word = read_input(demo_sent, src_vocab,
                                                      src_unknown)
                    #print(demo_data)
                    label = model.demo_one(sess, demo_data)
                    #print(label)
                    index = 0
                    for word in demo_word:
                        print(word + '(' +
                              label2tag(tgt_vocab, label[0][index]) + ')')
                        index = index + 1
Ejemplo n.º 14
0
def ner(sent):
    config = tf.ConfigProto()
    parser = argparse.ArgumentParser(
        description='BiLSTM-CRF for Chinese NER task')
    parser.add_argument('--train_data',
                        type=str,
                        default='data_path',
                        help='train data source')
    parser.add_argument('--test_data',
                        type=str,
                        default='data_path',
                        help='test data source')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='#sample of each minibatch')
    parser.add_argument('--epoch',
                        type=int,
                        default=10,
                        help='#epoch of training')
    parser.add_argument('--hidden_dim',
                        type=int,
                        default=300,
                        help='#dim of hidden state')
    parser.add_argument('--optimizer',
                        type=str,
                        default='Adam',
                        help='Adam/Adadelta/Adagrad/RMSProp/Momentum/SGD')
    parser.add_argument('--CRF',
                        type=str2bool,
                        default=True,
                        help='use CRF at the top layer. if False, use Softmax')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate')
    parser.add_argument('--clip',
                        type=float,
                        default=5.0,
                        help='gradient clipping')
    parser.add_argument('--dropout',
                        type=float,
                        default=0.5,
                        help='dropout keep_prob')
    parser.add_argument('--update_embedding',
                        type=str2bool,
                        default=True,
                        help='update embedding during training')
    parser.add_argument(
        '--pretrain_embedding',
        type=str,
        default='random',
        help='use pretrained char embedding or init it randomly')
    parser.add_argument('--embedding_dim',
                        type=int,
                        default=300,
                        help='random init char embedding_dim')
    parser.add_argument('--shuffle',
                        type=str2bool,
                        default=False,
                        help='shuffle training data before each epoch')
    parser.add_argument('--mode',
                        type=str,
                        default='demo',
                        help='train/test/demo')
    parser.add_argument('--demo_model',
                        type=str,
                        default='1563773712',
                        help='model for test and demo')
    args = parser.parse_args()

    ## get char embeddings
    word2id = read_dictionary(os.path.join('.', args.train_data,
                                           'word2id.pkl'))
    if args.pretrain_embedding == 'random':
        embeddings = random_embedding(word2id, args.embedding_dim)
    else:
        embedding_path = 'pretrain_embedding.npy'
        embeddings = np.array(np.load(embedding_path), dtype='float32')

    ## paths setting
    paths = {}
    paths['summary_path'] = './'
    model_path = r'C:\Users\Houking\Desktop\web_api\ner\checkpoint'
    paths['model_path'] = os.path.join(model_path, "model")
    paths['result_path'] = './'
    paths['log_path'] = './'

    ckpt_file = tf.train.latest_checkpoint(model_path)
    print(ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args,
                       embeddings,
                       tag2label,
                       word2id,
                       paths,
                       config=config)
    model.build_graph()
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, ckpt_file)
        while (1):
            print('Please input your sentence:')
            demo_sent = input()
            if demo_sent == '' or demo_sent.isspace():
                print('See you next time!')
                break
            else:
                sent = list(sent)
                data = [(sent, ['O'] * len(sent))]
                tag = model.demo_one(sess, data)
                PER, SEX, TIT, REA = get_entity(tag, sent)
                print('PER: {}\nSEX: {}\nTIT: {}\nREA: {}'.format(
                    PER, SEX, TIT, REA))
Ejemplo n.º 15
0
def run(sentences):
    # 配置session的参数
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # 使用GPU 0
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # 日志级别设置
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.2  # need ~700MB GPU memory

    # hyperparameters超参数设置
    # 创建一个解析器对象,并告诉它将会有些什么参数
    # 那么当你的程序运行时,该解析器就可以用于处理命令行参数
    parser = argparse.ArgumentParser(
        description='BiLSTM-CRF for Chinese NER task')
    parser.add_argument('--train_data',
                        type=str,
                        default='data_path',
                        help='train data source')
    parser.add_argument('--test_data',
                        type=str,
                        default='data_path',
                        help='test data source')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='#sample of each minibatch')
    # batch :批次大小 在深度学习中,一般采用SGD训练,即每次训练在训练集中取batchsize个样本训练
    # iteration:中文翻译为迭代,1个iteration等于使用batchsize个样本训练一次
    # 一个迭代 = 一个正向通过+一个反向通过
    parser.add_argument('--epoch',
                        type=int,
                        default=40,
                        help='#epoch of training')
    # epoch:迭代次数,1个epoch等于使用训练集中的全部样本训练一次
    # 一个epoch = 所有训练样本的一个正向传递和一个反向传递 举个例子,训练集有1000个样本,batchsize=10,那么: 训练完整个样本集需要: 100次iteration,1次epoch。
    parser.add_argument('--hidden_dim',
                        type=int,
                        default=300,
                        help='#dim of hidden state')
    # 输出向量的维度:300维
    parser.add_argument('--optimizer',
                        type=str,
                        default='Adam',
                        help='Adam/Adadelta/Adagrad/RMSProp/Momentum/SGD')
    # 优化器用的Adam
    parser.add_argument('--CRF',
                        type=str2bool,
                        default=True,
                        help='use CRF at the top layer. if False, use Softmax')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate')
    parser.add_argument('--clip',
                        type=float,
                        default=5.0,
                        help='gradient clipping')
    parser.add_argument('--dropout',
                        type=float,
                        default=0.5,
                        help='dropout keep_prob')
    # dropout是指在深度学习网络的训练过程中,对于神经网络单元,按照一定的概率将其暂时从网络中丢弃
    parser.add_argument('--update_embedding',
                        type=str2bool,
                        default=True,
                        help='update embedding during training')
    parser.add_argument(
        '--pretrain_embedding',
        type=str,
        default='random',
        help='use pretrained char embedding or init it randomly')
    parser.add_argument('--embedding_dim',
                        type=int,
                        default=300,
                        help='random init char embedding_dim')
    parser.add_argument('--shuffle',
                        type=str2bool,
                        default=True,
                        help='shuffle training data before each epoch')
    parser.add_argument('--mode',
                        type=str,
                        default='demo',
                        help='train/test/demo')
    parser.add_argument('--demo_model',
                        type=str,
                        default='1559398699',
                        help='model for test and demo')
    # 传递参数送入模型中
    args = parser.parse_args()

    # 初始化embedding矩阵,读取词典
    word2id = read_dictionary(os.path.join('.', args.train_data,
                                           'word2id.pkl'))
    # 通过调用random_embedding函数返回一个len(vocab)*embedding_dim=3905*300的矩阵(矩阵元素均在-0.25到0.25之间)作为初始值
    if args.pretrain_embedding == 'random':
        embeddings = random_embedding(word2id, args.embedding_dim)
    else:
        embedding_path = 'pretrain_embedding.npy'
        embeddings = np.array(np.load(embedding_path), dtype='float32')

    # 读取训练集和测试集
    if args.mode != 'demo':
        train_path = os.path.join('.', args.train_data, 'train_data')
        test_path = os.path.join('.', args.test_data, 'test_data')
        train_data = read_corpus(train_path)
        test_data = read_corpus(test_path)
        test_size = len(test_data)

    # 设置路径
    paths = {}
    timestamp = str(int(
        time.time())) if args.mode == 'train' else args.demo_model
    output_path = os.path.join('.', args.train_data + "_save", timestamp)
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    summary_path = os.path.join(output_path, "summaries")
    paths['summary_path'] = summary_path
    if not os.path.exists(summary_path):
        os.makedirs(summary_path)
    model_path = os.path.join(output_path, "checkpoints/")
    if not os.path.exists(model_path):
        os.makedirs(model_path)
    ckpt_prefix = os.path.join(model_path, "model")
    paths['model_path'] = ckpt_prefix
    result_path = os.path.join(output_path, "results")
    paths['result_path'] = result_path
    if not os.path.exists(result_path):
        os.makedirs(result_path)
    log_path = os.path.join(result_path, "log.txt")
    paths['log_path'] = log_path
    get_logger(log_path).info(str(args))  # 将参数写入日志文件

    if args.mode == 'train':  # 训练模型
        model = BiLSTM_CRF(args,
                           embeddings,
                           tag2label,
                           word2id,
                           paths,
                           config=config)
        model.build_graph()
        model.train(train=train_data, dev=test_data)

    elif args.mode == 'test':  # 测试模型
        ckpt_file = tf.train.latest_checkpoint(model_path)
        print(ckpt_file)
        paths['model_path'] = ckpt_file
        model = BiLSTM_CRF(args,
                           embeddings,
                           tag2label,
                           word2id,
                           paths,
                           config=config)
        model.build_graph()
        print("test data: {}".format(test_size))
        model.test(test_data)

    # demo
    elif args.mode == 'demo':
        location = []
        ckpt_file = tf.train.latest_checkpoint(model_path)
        print("model path: ", ckpt_file)
        paths['model_path'] = ckpt_file  # 设置模型路径
        model = BiLSTM_CRF(args,
                           embeddings,
                           tag2label,
                           word2id,
                           paths,
                           config=config)
        model.build_graph()
        saver = tf.train.Saver()
        with tf.Session(config=config) as sess:
            saver.restore(sess, ckpt_file)
            for sentence in sentences:
                demo_sent = sentence
                demo_sent = list(demo_sent.strip())  # 删除空白符
                demo_data = [(demo_sent, ['O'] * len(demo_sent))]
                tag = model.demo_one(sess, demo_data)
                PER, LOC, ORG = get_entity(tag, demo_sent)  # 根据标注序列输出对应的字符
                new_LOC = list(set(LOC))  # 去重
                loc = ' '.join(new_LOC)
                location.append(loc)
            return location