from myClue.core.metrics import ClassifyFPreRecMetric # noqa from myClue.core.callback import EarlyStopCallback # noqa from myClue.tools.serialize import save_serialize_obj # noqa from myClue.tools.file import init_file_path # noqa from myClue.models import CNNText # noqa if __name__ == "__main__": train_file_config = { 'train': './data/UCAS_NLP_TC/example.txt', 'dev': './data/UCAS_NLP_TC/example.txt', 'test': './data/UCAS_NLP_TC/example.txt', } logger.info('数据加载') data_loader = THUCNewsLoader() data_bundle = data_loader.load(train_file_config) print_data_bundle(data_bundle) logger.info('数据预处理') data_pipe = THUCNewsPipe() data_bundle = data_pipe.process(data_bundle) data_bundle.rename_field(field_name=Const.CHAR_INPUT, new_field_name=Const.INPUT, ignore_miss_dataset=True, rename_vocab=True) print_data_bundle(data_bundle) model_path = './data/UCAS_NLP_TC/model_textcnn_topk' init_file_path(model_path) logger.add_file_handler( os.path.join( model_path, 'log_{}.txt'.format( time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())))) # 日志写入文件
import myClue # noqa print('myClue module path :{}'.format(myClue.__file__)) # 输出测试模块文件位置 from myClue.core import logger # noqa from myClue.core.utils import print_data_bundle # noqa from myClue.tools.serialize import load_serialize_obj # noqa from myClue.tools.decoder import decode_ner_tags # noqa if __name__ == "__main__": """使用同义词增强的数据扩充训练与开发数据""" random.seed(2020) train_file = './data/weibo_NER/train.conll' dev_file = './data/weibo_NER/dev.conll' # 加载数据 data_loader = PeopleDailyNERLoader() data_bundle = data_loader.load({'train': train_file, 'dev': dev_file}) print_data_bundle(data_bundle) # 加载扩充的数据 train_file = './data/weibo_NER/train_augmentation.conll' dev_file = './data/weibo_NER/dev_augmentation.conll' # 加载数据 data_loader_augmentation = PeopleDailyNERLoader() data_bundle_augmentation = data_loader_augmentation.load({ 'train': train_file, 'dev': dev_file }) print_data_bundle(data_bundle_augmentation) # 组装数据 train_data = list() train_data.extend([[datarow[Const.RAW_CHAR], datarow[Const.TARGET]] for datarow in data_bundle.datasets['train']]) augmentation_list = [[