Exemple #1
0
    def test_case1(self):
        # 测试能否正确计算
        import torch
        metric = CMRC2018Metric()

        raw_chars = [list("abcsdef"), list("123456s789")]
        context_len = torch.LongTensor([3, 6])
        answers = [["abc", "abc", "abc"], ["12", "12", "12"]]
        pred_start = torch.randn(2, max(map(len, raw_chars)))
        pred_end = torch.randn(2, max(map(len, raw_chars)))
        pred_start[0, 0] = 1000  # 正好是abc
        pred_end[0, 2] = 1000
        pred_start[1, 1] = 1000  # 取出234
        pred_end[1, 3] = 1000

        metric.evaluate(answers, raw_chars, context_len, pred_start, pred_end)

        eval_res = metric.get_metric()
        self.assertDictEqual(eval_res, {'f1': 70.0, 'em': 50.0})
Exemple #2
0
from fastNLP.core.losses import CMRC2018Loss
from fastNLP.core.metrics import CMRC2018Metric
from fastNLP.io.pipe.qa import CMRC2018BertPipe
from fastNLP import Trainer, BucketSampler
from fastNLP import WarmupCallback, GradientClipCallback
from fastNLP.core.optimizer import AdamW


data_bundle = CMRC2018BertPipe().process_from_file()
data_bundle.rename_field('chars', 'words')

print(data_bundle)

embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn', requires_grad=True, include_cls_sep=False, auto_truncate=True,
                      dropout=0.5, word_dropout=0.01)
model = BertForQuestionAnswering(embed)
loss = CMRC2018Loss()
metric = CMRC2018Metric()

wm_callback = WarmupCallback(schedule='linear')
gc_callback = GradientClipCallback(clip_value=1, clip_type='norm')
callbacks = [wm_callback, gc_callback]

optimizer = AdamW(model.parameters(), lr=5e-5)

trainer = Trainer(data_bundle.get_dataset('train'), model, loss=loss, optimizer=optimizer,
                  sampler=BucketSampler(seq_len_field_name='context_len'),
                  dev_data=data_bundle.get_dataset('dev'), metrics=metric,
                  callbacks=callbacks, device=0, batch_size=6, num_workers=2, n_epochs=2, print_every=1,
                  test_use_tqdm=False, update_every=10)
trainer.train(load_best_model=False)