コード例 #1
0
    def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
        """
        每次执行验证集的evaluation后会调用。

        :param Dict[str: Dict[str: float]] eval_result: , evaluation的结果。一个例子为{'AccuracyMetric':{'acc':1.0}},即
            传入的dict是有两层,第一层是metric的名称,第二层是metric的具体指标。
        :param str metric_key: 初始化Trainer时传入的metric_key。
        :param torch.Optimizer optimizer: Trainer中使用的优化器。
        :param bool is_better_eval: 当前dev结果是否比之前的好。
        :return:
        """
        logger.warning(
            '======epoch : {} , early stopping : {}/{}======'.format(
                self.epoch_no, self.wait, self.patience))
        metric_value = list(eval_result.values())[0].get(metric_key, None)
        logger.info('metric_key : {}, metric_value : {}'.format(
            metric_key, metric_value))
        logger.info('eval_result : \n{}'.format(eval_result))
        self.epoch_no += 1
        if not is_better_eval:
            # current result is getting worse
            if self.wait == self.patience:
                logger.info('reach early stopping patience, stop training.')
                raise EarlyStopError("Early stopping raised.")
            else:
                self.wait += 1
        else:
            self.wait = 0
コード例 #2
0
from myClue.tools.serialize import save_serialize_obj  # noqa
from myClue.loader.classification import THUCNewsLoader  # noqa
from myClue.pipe.classification import THUCNewsPipe  # noqa
from myClue.core.metrics import ClassifyFPreRecMetric  # noqa
from myClue.core.callback import EarlyStopCallback  # noqa
from myClue.tools.serialize import save_serialize_obj  # noqa
from myClue.tools.file import init_file_path  # noqa
from myClue.models import CNNText  # noqa

if __name__ == "__main__":
    train_file_config = {
        'train': './data/UCAS_NLP_TC/example.txt',
        'dev': './data/UCAS_NLP_TC/example.txt',
        'test': './data/UCAS_NLP_TC/example.txt',
    }
    logger.info('数据加载')
    data_loader = THUCNewsLoader()
    data_bundle = data_loader.load(train_file_config)
    print_data_bundle(data_bundle)
    logger.info('数据预处理')
    data_pipe = THUCNewsPipe()
    data_bundle = data_pipe.process(data_bundle)
    data_bundle.rename_field(field_name=Const.CHAR_INPUT,
                             new_field_name=Const.INPUT,
                             ignore_miss_dataset=True,
                             rename_vocab=True)
    print_data_bundle(data_bundle)
    model_path = './data/UCAS_NLP_TC/model_textcnn_topk'
    init_file_path(model_path)
    logger.add_file_handler(
        os.path.join(
コード例 #3
0
from fastNLP.core import Const

sys.path.insert(0, './')  # 定义搜索路径的优先顺序,序号从0开始,表示最大优先级

import myClue  # noqa
print('myClue module path :{}'.format(myClue.__file__))  # 输出测试模块文件位置
from myClue.core import logger  # noqa
from myClue.core.utils import print_data_bundle  # noqa
from myClue.loader.classification import THUCNewsLoader  # noqa
from myClue.pipe.classification import THUCNewsPipe  # noqa
from myClue.tools.serialize import save_serialize_obj  # noqa

if __name__ == "__main__":
    train_file_config = {
        'train': './data/UCAS_NLP_TC/data_01_shuffle/traindata.txt',
        'dev': './data/UCAS_NLP_TC/data_01_shuffle/devdata.txt',
    }
    logger.info('数据加载')
    data_loader = THUCNewsLoader()
    data_bundle = data_loader.load(train_file_config)
    print_data_bundle(data_bundle)
    logger.info('数据预处理')
    data_pipe = THUCNewsPipe(tokenizer='white_space')
    data_bundle = data_pipe.process(data_bundle)
    data_bundle.rename_field(field_name=Const.CHAR_INPUT,
                             new_field_name=Const.INPUT,
                             ignore_miss_dataset=True,
                             rename_vocab=True)
    print_data_bundle(data_bundle)
コード例 #4
0
 model_path = './data/tnews_public/model_textcnn'
 test_data_json_file_name = './data/tnews_public/test.json'
 label_json_file_name = './data/tnews_public/labels.json'
 char_vocab_pkl_file = os.path.join(model_path, 'vocab_char.pkl')
 target_vocab_pkl_file = os.path.join(model_path, 'target_char.pkl')
 model_name = os.path.join(model_path, 'best_CNNText_f_2020-05-14-23-33-55')
 predict_output_json_file_name = os.path.join(
     model_path, 'pred_2020-05-14-23-33-55.json')
 predict_output_file_name = os.path.join(model_path,
                                         'pred_2020-05-14-23-33-55.txt')
 logger.warn('加载标签映射关系')
 json_file_iter = read_json_file_iter(label_json_file_name)
 label_link_dict = dict()
 for row_json in json_file_iter:
     label_link_dict[row_json['label_desc']] = row_json['label']
 logger.info(label_link_dict)
 logger.warn('开始加载模型')
 model = torch.load(model_name)
 model.eval()
 logger.info('模型加载完毕:\n{}'.format(model))
 logger.warn('获取词典')
 char_vocab = load_serialize_obj(char_vocab_pkl_file)
 logger.info('char_vocab:{}'.format(char_vocab))
 target_vocab = load_serialize_obj(target_vocab_pkl_file)
 logger.info('target_vocab:{}'.format(target_vocab))
 logger.warn('加载测试数据')
 json_file_iter = read_json_file_iter(test_data_json_file_name)
 predictor = Predictor(model)
 with codecs.open(
         predict_output_json_file_name, mode='w',
         encoding='utf8') as fw_json, codecs.open(predict_output_file_name,
コード例 #5
0
"""
import codecs
import os
import sys
from tqdm import tqdm

sys.path.insert(0, './')  # 定义搜索路径的优先顺序,序号从0开始,表示最大优先级

import myClue  # noqa
print('myClue module path :{}'.format(myClue.__file__))  # 输出测试模块文件位置
from myClue.core import logger  # noqa
from myClue.tools.file import read_file_texts  # noqa
from myClue.tools.file import init_file_path  # noqa

stopwords = set(read_file_texts('./data/stopwords/stopwords_mix.txt'))
logger.info('stopwords len:{}, example:{}'.format(len(stopwords),
                                                  list(stopwords)[:20]))


def news_content_process(news_content):
    """"数据转换处理"""
    words = news_content.split(' ')
    words = [word for word in words if word not in stopwords]
    return ' '.join(words)


if __name__ == "__main__":
    train_file_config = {
        'train': './data/UCAS_NLP_TC/data_01_shuffle/traindata.txt',
        'dev': './data/UCAS_NLP_TC/data_01_shuffle/devdata.txt',
        'test': './data/UCAS_NLP_TC/testdata.txt',
    }
コード例 #6
0
 def on_exception(self, exception):
     if isinstance(exception, EarlyStopError):
         logger.info("Early Stopping triggered in epoch {}!".format(
             self.epoch))
     else:
         raise exception  # 抛出陌生Error
コード例 #7
0
print('myClue module path :{}'.format(myClue.__file__))  # 输出测试模块文件位置
from myClue.core import logger  # noqa
from myClue.tools.file import read_file_texts  # noqa
from myClue.tools.file import init_file_path  # noqa


if __name__ == "__main__":
    train_file_config = {
        'train': './data/UCAS_NLP_TC/data_baidu_cws/train_cws.json',
        'dev': './data/UCAS_NLP_TC/data_baidu_cws/dev_cws.json',
        'test': './data/UCAS_NLP_TC/data_baidu_cws/test_cws.json',
    }
    output_path = './data/UCAS_NLP_TC/data_11_baidu_nerwords'
    init_file_path(output_path)
    for file_label, file_name in train_file_config.items():
        logger.info('开始处理:{}'.format(file_label))
        texts = read_file_texts(file_name)
        output_file_name = os.path.join(output_path, '{}data.txt'.format(file_label))
        with codecs.open(output_file_name, mode='w', encoding='utf8') as fw:
            for text in tqdm(texts):
                row_data = json.loads(text)
                label = row_data['label']
                cws_items = row_data['cws_items']
                words = list()
                item_filter = set()
                original_words = list()
                for cws_item in cws_items:
                    item_text = cws_item['item']
                    original_words.extend(cws_item['basic_words'])
                    if cws_item['ne'] in {'ORG', 'PER', 'LOC', 'nr', 'ns', 'nt', 'nw', 'nz'}:
                        if item_text in item_filter:
コード例 #8
0
sys.path.insert(0, './')  # 定义搜索路径的优先顺序,序号从0开始,表示最大优先级

import myClue  # noqa
print('myClue module path :{}'.format(myClue.__file__))  # 输出测试模块文件位置
from myClue.core import logger  # noqa
from myClue.core.utils import print_data_bundle  # noqa
from myClue.tools.serialize import save_serialize_obj  # noqa

if __name__ == "__main__":
    train_file_config = {
        'train': './data/peopledaily/train.txt',
        'dev': './data/peopledaily/dev.txt',
        'test': './data/peopledaily/test.txt',
    }
    train_data_bundle_pkl_file = './data/peopledaily/train_data_bundle.pkl'
    logger.info('数据加载')
    data_loader = PeopleDailyNERLoader()
    data_bundle = data_loader.load(train_file_config)
    print_data_bundle(data_bundle)
    logger.info('数据预处理')
    data_pipe = PeopleDailyPipe()
    data_bundle = data_pipe.process(data_bundle)
    data_bundle.rename_field(field_name=Const.CHAR_INPUT,
                             new_field_name=Const.INPUT,
                             ignore_miss_dataset=True,
                             rename_vocab=True)
    print_data_bundle(data_bundle)
    save_serialize_obj(data_bundle, train_data_bundle_pkl_file)
    logger.info('数据预处理后进行序列化:{}'.format(train_data_bundle_pkl_file))
コード例 #9
0
from myClue.tools.serialize import save_serialize_obj  # noqa
from myClue.tools.file import init_file_path  # noqa


if __name__ == "__main__":
    train_data_bundle_pkl_file = './data/weibo_NER/train_data_bundle.pkl'
    model_path = './data/weibo_NER/model_bilstm_crf_bert_embed'
    init_file_path(model_path)
    logger.add_file_handler(os.path.join(model_path, 'log_{}.txt'.format(time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()))))  # 日志写入文件
    char_vocab_pkl_file = os.path.join(model_path, 'vocab_char.pkl')
    target_vocab_pkl_file = os.path.join(model_path, 'target_char.pkl')
    logger.warn('加载数据集')
    data_bundle = load_serialize_obj(train_data_bundle_pkl_file)
    logger.warn('获取词典')
    char_vocab = data_bundle.get_vocab('words')
    logger.info('char_vocab:{}'.format(char_vocab))
    target_vocab = data_bundle.get_vocab('target')
    logger.info('target_vocab:{}'.format(target_vocab))
    save_serialize_obj(char_vocab, char_vocab_pkl_file)
    save_serialize_obj(target_vocab, target_vocab_pkl_file)
    logger.info('词典序列化:{}'.format(char_vocab_pkl_file))
    logger.warn('选择预训练词向量')
    # model_dir_or_name = 'cn-wwm'
    model_dir_or_name = './data/embed/ERNIE_1.0_max-len-512-pytorch'
    bert_embed = BertEmbedding(vocab=char_vocab, model_dir_or_name=model_dir_or_name, requires_grad=False)
    logger.warn('神经网络模型')
    model = BiLSTMCRF(embed=bert_embed, num_classes=len(target_vocab), num_layers=1, hidden_size=200, dropout=0.5,
                      target_vocab=target_vocab)
    logger.info(model)
    logger.warn('训练超参数设定')
    loss = LossInForward()
コード例 #10
0
 train_data.extend(
     random.sample(augmentation_list,
                   k=int(len(data_bundle.datasets['train']) / 2)))
 random.shuffle(train_data)
 dev_data = list()
 dev_data.extend([[datarow[Const.RAW_CHAR], datarow[Const.TARGET]]
                  for datarow in data_bundle.datasets['dev']])
 augmentation_list = [[
     datarow[Const.RAW_CHAR], datarow[Const.TARGET]
 ] for datarow in data_bundle_augmentation.datasets['dev']]
 dev_data.extend(
     random.sample(augmentation_list,
                   k=int(len(data_bundle.datasets['dev']) / 2)))
 random.shuffle(dev_data)
 # 数据集切分
 logger.info('train:{}, val:{}'.format(len(train_data), len(dev_data)))
 # 数据输出
 train_augmentated_file = './data/weibo_NER/train_augmentated.conll'
 dev_augmentated_file = './data/weibo_NER/dev_augmentated.conll'
 with codecs.open(train_augmentated_file, mode='w', encoding='utf8') as fw:
     for row_id, (row_chars, target) in enumerate(train_data):
         # 输出
         for char, label in zip(row_chars, target):
             fw.write('{}\t{}\n'.format(char, label))
         fw.write('\n')
     # fw.write('\n')
 logger.info('train_augmentated_file{}'.format(train_augmentated_file))
 with codecs.open(dev_augmentated_file, mode='w', encoding='utf8') as fw:
     for row_id, (row_chars, target) in enumerate(dev_data):
         # 输出
         for char, label in zip(row_chars, target):
コード例 #11
0
ファイル: predict.py プロジェクト: q759729997/qyt_clue
 model_path = './data/weibo_NER/model_bilstm_crf_random_embed'
 model_file = os.path.join(model_path,
                           'best_BiLSTMCRF_f_2020-05-20-11-08-16-221138')
 train_file = './data/weibo_NER/example.conll'
 predict_output_file = './data/weibo_NER/example_BiLSTMCRF_predict.conll'
 char_vocab_pkl_file = os.path.join(model_path, 'vocab_char.pkl')
 target_vocab_pkl_file = os.path.join(model_path, 'target_char.pkl')
 # 加载数据
 data_loader = PeopleDailyNERLoader()
 data_bundle = data_loader.load({'train': train_file})
 print_data_bundle(data_bundle)
 dataset = data_bundle.datasets['train']
 dataset_original = copy.deepcopy(dataset)
 # 加载词表
 char_vocab = load_serialize_obj(char_vocab_pkl_file)
 logger.info('char_vocab:{}'.format(char_vocab))
 target_vocab = load_serialize_obj(target_vocab_pkl_file)
 logger.info('target_vocab:{}'.format(target_vocab))
 # 加载模型
 model = torch.load(model_file)
 if torch.cuda.is_available():
     model = model.cuda()
     logger.info('use cuda')
 model.eval()
 logger.info('模型加载完毕:\n{}'.format(model))
 # 数据预处理
 dataset.rename_field(field_name=Const.RAW_CHAR, new_field_name=Const.INPUT)
 dataset.add_seq_len(field_name=Const.INPUT)
 dataset.set_input(Const.INPUT, Const.INPUT_LEN)
 dataset.set_target(Const.TARGET, Const.INPUT_LEN)
 char_vocab.index_dataset(dataset, field_name=Const.INPUT)
コード例 #12
0
print('myClue module path :{}'.format(myClue.__file__))  # 输出测试模块文件位置
from myClue.core import logger  # noqa
from myClue.core.utils import print_data_bundle  # noqa
from myClue.tools.serialize import load_serialize_obj  # noqa
from myClue.tools.decoder import decode_ner_tags  # noqa

if __name__ == "__main__":
    """通过word2vec中最相似的词进行数据增强"""
    random.seed(2020)
    # train_file = './data/weibo_NER/example.conll'
    # augmentation_file = './data/weibo_NER/example_augmentation.conll'
    # train_file = './data/weibo_NER/train.conll'
    # augmentation_file = './data/weibo_NER/train_augmentation.conll'
    train_file = './data/weibo_NER/dev.conll'
    augmentation_file = './data/weibo_NER/dev_augmentation.conll'
    logger.info('加载word2vec')
    word2vec_model_file = './data/embed/sgns.weibo.word/sgns.weibo.word'
    model = KeyedVectors.load_word2vec_format(word2vec_model_file,
                                              binary=False)
    logger.info('word2vec加载完毕,测试一下:{}'.format(
        model.most_similar('扎克伯格', topn=5)))
    # 加载数据
    data_loader = PeopleDailyNERLoader()
    data_bundle = data_loader.load({'train': train_file})
    print_data_bundle(data_bundle)
    dataset = data_bundle.datasets['train']
    # 数据处理
    with codecs.open(augmentation_file, mode='w', encoding='utf8') as fw:
        for row_id, datarow in enumerate(dataset):
            if row_id % 10 == 0:
                print('row_id:{}'.format(row_id))
コード例 #13
0
def print_data_bundle(data_bundle: DataBundle, title: str = None):
    """ 打印输出data_bundle的信息.

    @params:
        data_bundle - 数据集DataBundle.
        title - 打印输出的标题信息.
    """
    if title:
        logger.warning(title)
    for name, dataset in data_bundle.iter_datasets():
        logger.info('dataset name : {}'.format(name))
        logger.info('dataset len : {}'.format(len(dataset)))
        logger.info('dataset example : ')
        logger.info('\n{}'.format(dataset[:5]))
        logger.info('dataset 输出各个field的被设置成input和target的情况 : ')
        logger.info('\n{}'.format(dataset.print_field_meta()))
コード例 #14
0
import myClue  # noqa
print('myClue module path :{}'.format(myClue.__file__))  # 输出测试模块文件位置
from myClue.core import logger  # noqa
from myClue.tools.serialize import load_serialize_obj  # noqa
from myClue.tools.file import read_json_file_iter  # noqa

if __name__ == "__main__":
    model_path = './data/UCAS_NLP_TC/model_textcnn_topk'
    char_vocab_pkl_file = os.path.join(model_path, 'vocab_char.pkl')
    target_vocab_pkl_file = os.path.join(model_path, 'target_char.pkl')
    model_name = os.path.join(model_path,
                              'best_CNNText_f_2020-07-04-23-18-38-341248')
    logger.warn('开始加载模型')
    model = torch.load(model_name)
    model.eval()
    logger.info('模型加载完毕:\n{}'.format(model))
    logger.warn('获取词典')
    char_vocab = load_serialize_obj(char_vocab_pkl_file)
    logger.info('char_vocab:{}'.format(char_vocab))
    target_vocab = load_serialize_obj(target_vocab_pkl_file)
    logger.info('target_vocab:{}'.format(target_vocab))
    logger.warn('加载测试数据')
    text = "世界贸易组织(WTO)17日对美国进行贸易政策审议。在当天的会议上,包括中国、欧"
    test_data = [list(text)]
    dataset = DataSet({Const.INPUT: test_data})
    dataset.add_seq_len(field_name=Const.INPUT)
    dataset.set_input(Const.INPUT, Const.INPUT_LEN)
    char_vocab.index_dataset(dataset, field_name=Const.INPUT)
    predictor = Predictor(model)
    batch_output = predictor.predict(data=dataset,
                                     seq_len_field_name=Const.INPUT_LEN)
コード例 #15
0
ファイル: data_resplit.py プロジェクト: q759729997/qyt_clue
 data_loader = PeopleDailyNERLoader()
 data_bundle = data_loader.load({'train': train_file, 'dev': dev_file})
 print_data_bundle(data_bundle)
 # 组装数据
 data_list = list()
 data_list.extend([[datarow[Const.RAW_CHAR], datarow[Const.TARGET]]
                   for datarow in data_bundle.datasets['train']])
 data_list.extend([[datarow[Const.RAW_CHAR], datarow[Const.TARGET]]
                   for datarow in data_bundle.datasets['dev']])
 # 数据集切分
 train_data, dev_data = train_test_split(data_list,
                                         test_size=len(
                                             data_bundle.datasets['dev']),
                                         shuffle=True,
                                         random_state=2020)
 logger.info('数据切分结果: all:{}, train:{}, val:{}'.format(
     len(data_list), len(train_data), len(dev_data)))
 # 数据输出
 train_resplit_file = './data/weibo_NER/train_resplit.conll'
 dev_resplit_file = './data/weibo_NER/dev_resplit.conll'
 with codecs.open(train_resplit_file, mode='w', encoding='utf8') as fw:
     for row_id, (row_chars, target) in enumerate(train_data):
         # 输出
         for char, label in zip(row_chars, target):
             fw.write('{}\t{}\n'.format(char, label))
         fw.write('\n')
     # fw.write('\n')
 logger.info('train_resplit_file{}'.format(train_resplit_file))
 with codecs.open(dev_resplit_file, mode='w', encoding='utf8') as fw:
     for row_id, (row_chars, target) in enumerate(dev_data):
         # 输出
         for char, label in zip(row_chars, target):
コード例 #16
0
import os
import sys

sys.path.insert(0, './')  # 定义搜索路径的优先顺序,序号从0开始,表示最大优先级

import myClue  # noqa
print('myClue module path :{}'.format(myClue.__file__))  # 输出测试模块文件位置
from myClue.core import logger  # noqa
from myClue.tools.file import read_json_file_iter  # noqa
from myClue.tools.text import remove_blank  # noqa

if __name__ == "__main__":
    file_path = './data/tnews_public'
    file_types = ('train', 'dev', 'test')
    for file_type in file_types:
        logger.info('开始处理:{}'.format(file_type))
        json_file_name = os.path.join(file_path, '{}.json'.format(file_type))
        logger.info('json文件名:{}'.format(json_file_name))
        json_file_iter = read_json_file_iter(json_file_name)
        txt_file_name = os.path.join(file_path, '{}.txt'.format(file_type))
        row_count = 0
        with codecs.open(txt_file_name, mode='w', encoding='utf8') as fw:
            for row_json in json_file_iter:
                label = row_json.get('label_desc', None)
                sentence = row_json.get('sentence', '')
                keywords = row_json.get('keywords', '')
                text = remove_blank('{}{}'.format(sentence, keywords))
                if label is None:
                    fw.write('{}\n'.format(text))
                else:
                    fw.write('{}\t{}\n'.format(label, text))