Beispiel #1
0
def dev_one_epoch(model, sess, dev):
    """
    created by jma
    对一个epoch进行验证
    :param model: 运行的模型
    :param sess: 训练的一次会话
    :param dev: 验证数据
    :return:
    """
    label_list, seq_len_list = [], []
    # 获取一个批次的句子中词的id以及标签
    for seqs, labels in train_utils.batch_yield(dev,
                                                args.batch_size,
                                                word2id,
                                                tag2label,
                                                shuffle=False):
        feed_dict, seq_len_list_ = train_utils.get_feed_dict(model,
                                                             seqs,
                                                             drop_keep=1.0)
        log_its, transition_params = sess.run(
            [model.log_its, model.transition_params], feed_dict=feed_dict)
        label_list_ = []
        for log_it, seq_len in zip(log_its, seq_len_list_):
            vtb_seq, _ = viterbi_decode(log_it[:seq_len], transition_params)
            label_list_.append(vtb_seq)

        label_list.extend(label_list_)
        seq_len_list.extend(seq_len_list_)
    return label_list, seq_len_list
Beispiel #2
0
def run_one_epoch(model, sess, train_corpus, dev, tag_label, epoch, saver):
    """
    create by ljx
    训练模型,训练一个批次
    :param model: 模型
    :param sess: 训练模型的一次会话
    :param train_corpus: 训练数据
    :param dev: 用来验证的数据
    :param tag_label: 标注转换为label的字典
    :param epoch: 批次的计数
    :param saver: 保存训练参数
    :return:
    """
    num_batches = (len(train_corpus) + args.batch_size - 1) // args.batch_size
    start_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
    batches = train_utils.batch_yield(train_corpus, args.batch_size, word2id,
                                      tag_label)

    for step, (seqs, labels) in enumerate(batches):
        sys.stdout.write(' processing: {} batch / {} batches.'.format(
            step + 1, num_batches) + '\r')
        step_num = epoch * num_batches + step + 1

        feed_dict, _ = train_utils.get_feed_dict(model, seqs, labels, args.lr,
                                                 args.dropout)

        _, loss_train, summary, step_num_ = sess.run(
            [model.train_op, model.loss, model.merged, model.global_step],
            feed_dict=feed_dict)

        if step + 1 == 1 or (
                step + 1) % args.batch_size == 0 or step + 1 == num_batches:
            print('logger info')
            logger.info(
                '{} epoch {}, step {}, loss: {:.4}, total_step: {}'.format(
                    start_time, epoch + 1, step + 1, loss_train, step_num))

        if step + 1 == num_batches:
            saver.save(sess, params.store_path, global_step=step_num)

    logger.info('=============test==============')
    label_list_dev, seq_len_list_dev = dev_one_epoch(model, sess, dev)
    evaluate(label_list_dev, dev, epoch)
Beispiel #3
0
def demo_one(model, ses, sent, batch_size, vocab, shuffle, tag2label):
    """
    Created by jty
    输入句子,得到预测标签id,并转化为label
    :param model: 保存好的模型
    :param ses: 使用会话
    :param sent: 输入要进行实体抽取的句子
    :param batch_size: 每次预测的句子数
    :param vocab:  word2id
    :param shuffle: 默认为False
    :return: tag 预测标签
    """

    # batch_yield就是把输入的句子每个字的id返回,以及每个标签转化为对应的tag2label的值
    label_list = []
    for seqs, labels in train_utils.batch_yield(sent, batch_size, vocab, tag2label, shuffle):
        label_list_, _ = predict_one_batch(model, ses, seqs)
        label_list.extend(label_list_)
    label2tag = {}
    for tag, label in tag2label.items():
        label2tag[label] = tag if label != 0 else label
    tag = [label2tag[label] for label in label_list[0]]
    return tag