def trainer(data_folder, write2model, write2vocab):
    data_bundle = PeopleDailyNERLoader().load(
        data_folder)  # 这一行代码将从{data_dir}处读取数据至DataBundle
    data_bundle = PeopleDailyPipe().process(data_bundle)
    data_bundle.rename_field('chars', 'words')
    # 存储vocab
    targetVocab = dict(data_bundle.vocabs["target"])
    wordsVocab = dict(data_bundle.vocabs["words"])
    targetWc = dict(data_bundle.vocabs['target'].word_count)
    wordsWc = dict(data_bundle.vocabs['words'].word_count)
    with open(write2vocab, "w", encoding="utf-8") as VocabOut:
        VocabOut.write(
            json.dumps(
                {
                    "targetVocab": targetVocab,
                    "wordsVocab": wordsVocab,
                    "targetWc": targetWc,
                    "wordsWc": wordsWc
                },
                ensure_ascii=False))

    embed = BertEmbedding(vocab=data_bundle.get_vocab('words'),
                          model_dir_or_name='cn',
                          requires_grad=False,
                          auto_truncate=True)
    model = BiLSTMCRF(embed=embed,
                      num_classes=len(data_bundle.get_vocab('target')),
                      num_layers=1,
                      hidden_size=100,
                      dropout=0.5,
                      target_vocab=data_bundle.get_vocab('target'))

    metric = SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))
    optimizer = Adam(model.parameters(), lr=2e-5)
    loss = LossInForward()
    device = 0 if torch.cuda.is_available() else 'cpu'
    # device = "cpu"
    trainer = Trainer(data_bundle.get_dataset('train'),
                      model,
                      loss=loss,
                      optimizer=optimizer,
                      batch_size=8,
                      dev_data=data_bundle.get_dataset('dev'),
                      metrics=metric,
                      device=device,
                      n_epochs=1)
    trainer.train()
    tester = Tester(data_bundle.get_dataset('test'), model, metrics=metric)
    tester.test()
    saver = ModelSaver(write2model)
    saver.save_pytorch(model, param_only=False)
 def __init__(self,
              init_embedding=None,
              hidden_size=100,
              dropout=0.5,
              tag_vocab=None,
              model_path=None):
     super().__init__()
     self.emb_layer = init_embedding
     self.model = BiLSTMCRF(self.emb_layer,
                            num_classes=len(tag_vocab),
                            hidden_size=hidden_size,
                            dropout=dropout,
                            target_vocab=tag_vocab)
     if model_path is not None:
         self.load_from_disk(model_path)
     self.to(device)
Exemple #3
0
 logger.warn('加载数据集')
 data_bundle = load_serialize_obj(train_data_bundle_pkl_file)
 logger.warn('获取词典')
 char_vocab = data_bundle.get_vocab('words')
 logger.info('char_vocab:{}'.format(char_vocab))
 target_vocab = data_bundle.get_vocab('target')
 logger.info('target_vocab:{}'.format(target_vocab))
 save_serialize_obj(char_vocab, char_vocab_pkl_file)
 save_serialize_obj(target_vocab, target_vocab_pkl_file)
 logger.info('词典序列化:{}'.format(char_vocab_pkl_file))
 logger.warn('选择预训练词向量')
 # model_dir_or_name = 'cn-wwm'
 model_dir_or_name = './data/embed/ERNIE_1.0_max-len-512-pytorch'
 bert_embed = BertEmbedding(vocab=char_vocab, model_dir_or_name=model_dir_or_name, requires_grad=False)
 logger.warn('神经网络模型')
 model = BiLSTMCRF(embed=bert_embed, num_classes=len(target_vocab), num_layers=1, hidden_size=200, dropout=0.5,
                   target_vocab=target_vocab)
 logger.info(model)
 logger.warn('训练超参数设定')
 loss = LossInForward()
 optimizer = Adam([param for param in model.parameters() if param.requires_grad])
 # metric = AccuracyMetric()
 metric = SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab(Const.TARGET), only_gross=False)  # 若only_gross=False, 即还会返回各个label的metric统计值
 device = 'cuda' if torch.cuda.is_available() else 'cpu'  # 如果有gpu的话在gpu上运行,训练速度会更快
 logger.info('device:{}'.format(device))
 batch_size = 32
 n_epochs = 10
 early_stopping = 10
 trainer = Trainer(
     save_path=model_path,
     train_data=data_bundle.get_dataset('train'),
     model=model,