def dev_one_epoch(model, sess, dev): """ created by jma 对一个epoch进行验证 :param model: 运行的模型 :param sess: 训练的一次会话 :param dev: 验证数据 :return: """ label_list, seq_len_list = [], [] # 获取一个批次的句子中词的id以及标签 for seqs, labels in train_utils.batch_yield(dev, args.batch_size, word2id, tag2label, shuffle=False): feed_dict, seq_len_list_ = train_utils.get_feed_dict(model, seqs, drop_keep=1.0) log_its, transition_params = sess.run( [model.log_its, model.transition_params], feed_dict=feed_dict) label_list_ = [] for log_it, seq_len in zip(log_its, seq_len_list_): vtb_seq, _ = viterbi_decode(log_it[:seq_len], transition_params) label_list_.append(vtb_seq) label_list.extend(label_list_) seq_len_list.extend(seq_len_list_) return label_list, seq_len_list
def run_one_epoch(model, sess, train_corpus, dev, tag_label, epoch, saver): """ create by ljx 训练模型,训练一个批次 :param model: 模型 :param sess: 训练模型的一次会话 :param train_corpus: 训练数据 :param dev: 用来验证的数据 :param tag_label: 标注转换为label的字典 :param epoch: 批次的计数 :param saver: 保存训练参数 :return: """ num_batches = (len(train_corpus) + args.batch_size - 1) // args.batch_size start_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) batches = train_utils.batch_yield(train_corpus, args.batch_size, word2id, tag_label) for step, (seqs, labels) in enumerate(batches): sys.stdout.write(' processing: {} batch / {} batches.'.format( step + 1, num_batches) + '\r') step_num = epoch * num_batches + step + 1 feed_dict, _ = train_utils.get_feed_dict(model, seqs, labels, args.lr, args.dropout) _, loss_train, summary, step_num_ = sess.run( [model.train_op, model.loss, model.merged, model.global_step], feed_dict=feed_dict) if step + 1 == 1 or ( step + 1) % args.batch_size == 0 or step + 1 == num_batches: print('logger info') logger.info( '{} epoch {}, step {}, loss: {:.4}, total_step: {}'.format( start_time, epoch + 1, step + 1, loss_train, step_num)) if step + 1 == num_batches: saver.save(sess, params.store_path, global_step=step_num) logger.info('=============test==============') label_list_dev, seq_len_list_dev = dev_one_epoch(model, sess, dev) evaluate(label_list_dev, dev, epoch)
def demo_one(model, ses, sent, batch_size, vocab, shuffle, tag2label): """ Created by jty 输入句子,得到预测标签id,并转化为label :param model: 保存好的模型 :param ses: 使用会话 :param sent: 输入要进行实体抽取的句子 :param batch_size: 每次预测的句子数 :param vocab: word2id :param shuffle: 默认为False :return: tag 预测标签 """ # batch_yield就是把输入的句子每个字的id返回,以及每个标签转化为对应的tag2label的值 label_list = [] for seqs, labels in train_utils.batch_yield(sent, batch_size, vocab, tag2label, shuffle): label_list_, _ = predict_one_batch(model, ses, seqs) label_list.extend(label_list_) label2tag = {} for tag, label in tag2label.items(): label2tag[label] = tag if label != 0 else label tag = [label2tag[label] for label in label_list[0]] return tag