class Evaluator: def __init__(self, model: BiLSTMCRF, dataset: NestedNERDataset, model_path: str = None, use_gpu: bool = True): self.device_str = 'cuda' if torch.cuda.is_available( ) and use_gpu else 'cpu' self.device = torch.device(self.device_str) model.load_state_dict(torch.load(model_path)) self.model = model.to(self.device) self.dataset = dataset def evaluate(self, batch_size: int = 64): sentences = [] predicted_labels = [] answer_labels = [] test_iterator = self.dataset.get_batch(batch_size, 'test') self.model.eval() self.dataset.char_encoder.eval() with torch.no_grad(): for data in tqdm(test_iterator): mask = data.label0 != 0 mask = mask.float().to(self.device) vecs = self.dataset.to_vectors(data.word, data.char, data.pos, data.subpos, device=self.device_str) sentences.extend([ self.dataset.wordid_to_sentence(sentence) for sentence in data.word ]) input_embed = self.model.concat_embedding(list(vecs.values())) predicted_labels.extend(self.model.predict(input_embed, mask)) answer_label = self.dataset.get_batch_true_label( data, 0, self.device_str) answer_label = answer_label.numpy() mask = mask.cpu().numpy() > 0 answer_label = [ answer_label[i, mask[i]].tolist() for i in range(len(answer_label)) ] answer_labels.extend(answer_label) answer_labels = [ self.dataset.labelid_to_labels(answer) for answer in answer_labels ] predicted_labels = [ self.dataset.labelid_to_labels(pred) for pred in predicted_labels ] self.miner = Miner(answer_labels, predicted_labels, sentences) self.miner.default_report(True) self.model.train() self.dataset.char_encoder.train()
def report(self, show_flag: bool = True)\ -> Dict[str, Dict[str, Dict[str, float]]]: """ :param show_flag: if True, output reports to a console :return: reports for all, unknown only and known only NEs {'all': {'NELABEL1': {'precision': 0.000, 'recall': 0.000, 'f1_score': 0.00}, 'NELABEL2': {'precision': ...}, ... 'unknown': ..., 'known': ..., 'misses': ... }} """ sentences = self._dataset.get_sentences('test') known_words = self._dataset.known_NEs() predicts = self._model.predict_all(self.test_data[0]) miner = Miner(self.test_data[1], predicts, sentences, known_words) return { 'all': miner.default_report(show_flag), 'unknown': miner.unknown_only_report(show_flag), 'known': miner.known_only_report(show_flag), 'misses': miner.return_miss_labelings(), 'seg': { type_: miner.segmentation_score(type_, show_flag) for type_ in ['all', 'unknown', 'known'] } }
class TestMiner(unittest.TestCase): def setUp(self): # 花子さんは東京に行きました(IOB2) # 山田太郎くんは東京駅に向かいました(IOB2) # 花子さんとボブくんは東京スカイツリーに行きました(BIOES) self.answers = [ 'B-PSN O O B-LOC O O O O'.split(' '), 'B-PSN I-PSN O O B-LOC I-LOC O O O O'.split(' '), 'S-PSN O O S-PSN O O B-LOC I-LOC E-LOC O O O O'.split(' ') ] self.predicts = [ 'B-PSN O O B-LOC O O O O'.split(' '), 'B-PSN B-PSN O O B-LOC I-LOC O O O O'.split(' '), 'S-PSN O O O O O B-LOC I-LOC E-LOC O O O O'.split(' ') ] self.sentences = [ '花子 さん は 東京 に 行き まし た'.split(' '), '山田 太郎 君 は 東京 駅 に 向かい まし た'.split(' '), '花子 さん と ボブ くん は 東京 スカイ ツリー に 行き まし た'.split(' '), ] self.knowns = {'PSN': ['花子'], 'LOC': ['東京']} self.miner = Miner(self.answers, self.predicts, self.sentences, self.knowns) def test_initialize(self): self.assertEqual(self.miner.answers, self.answers) self.assertEqual(self.miner.predicts, self.predicts) self.assertEqual(self.miner.sentences, self.sentences) self.assertEqual(self.miner.known_words, self.knowns) def test_default_report(self): result = self.miner.default_report(False) self.assertTrue(all([k in ['PSN', 'LOC'] for k, v in result.items()])) self.assertEqual([k for k, v in result['PSN'].items()], ['precision', 'recall', 'f1_score', 'num']) self.assertEqual( { 'PSN': { 'precision': 0.5, 'recall': 0.5, 'f1_score': 0.5, 'num': 4 }, 'LOC': { 'precision': 1.0, 'recall': 1.0, 'f1_score': 1.0, 'num': 3 } }, result) def test_known_only_report(self): result = self.miner.known_only_report(False) self.assertTrue(all([k in ['PSN', 'LOC'] for k, v in result.items()])) self.assertEqual([k for k, v in result['PSN'].items()], ['precision', 'recall', 'f1_score', 'num']) self.assertEqual( { 'PSN': { 'precision': 1.0, 'recall': 1.0, 'f1_score': 1.0, 'num': 2 }, 'LOC': { 'precision': 1.0, 'recall': 1.0, 'f1_score': 1.0, 'num': 1 } }, result) def test_unknown_only_report(self): result = self.miner.unknown_only_report(False) self.assertTrue(all([k in ['PSN', 'LOC'] for k, v in result.items()])) self.assertEqual([k for k, v in result['PSN'].items()], ['precision', 'recall', 'f1_score', 'num']) self.assertEqual( { 'PSN': { 'precision': 0, 'recall': 0, 'f1_score': 0, 'num': 2 }, 'LOC': { 'precision': 1.0, 'recall': 1.0, 'f1_score': 1.0, 'num': 2 } }, result) def test__return_entities_indexes(self): result = self.miner._return_entities_indexes(self.answers, 'PSN') expect = [('PSN', 0, 0), ('PSN', 9, 10), ('PSN', 20, 20), ('PSN', 23, 23)] self.assertEqual(result, expect) result = self.miner._return_entities_indexes(self.answers, 'LOC') expect = [('LOC', 3, 3), ('LOC', 13, 14), ('LOC', 26, 28)] self.assertEqual(result, expect) def test__return_named_entities(self): result = self.miner._return_named_entities(self.answers) expect = { 'known': { 'PSN': ['花子'], 'LOC': ['東京'] }, 'unknown': { 'PSN': ['山田太郎', 'ボブ'], 'LOC': ['東京スカイツリー', '東京駅'] } } for (rk, rv), (ek, ev) in zip(result.items(), expect.items()): self.assertTrue(set(rv['PSN']) & set(ev['PSN'])) self.assertTrue(set(rv['LOC']) & set(ev['LOC'])) def test_return_answer_named_entities(self): result = self.miner.return_answer_named_entities() expect = self.miner._return_named_entities(self.answers) for (rk, rv), (ek, ev) in zip(result.items(), expect.items()): self.assertTrue(set(rv['PSN']) & set(ev['PSN'])) self.assertTrue(set(rv['LOC']) & set(ev['LOC'])) def test_return_predict_named_entities(self): result = self.miner.return_predict_named_entities() expect = self.miner._return_named_entities(self.predicts) for (rk, rv), (ek, ev) in zip(result.items(), expect.items()): self.assertTrue(set(rv['PSN']) & set(ev['PSN'])) self.assertTrue(set(rv['LOC']) & set(ev['LOC'])) def test__is_end_of_label(self): labels = ['B', 'I', 'O', 'S', 'B', 'I', 'I', 'E', 'O', 'O', 'B', 'B'] self.assertFalse( self.miner._is_end_of_label(labels[0], labels[1], 'a', 'a')) self.assertTrue( self.miner._is_end_of_label(labels[1], labels[2], 'a', 'a')) self.assertFalse( self.miner._is_end_of_label(labels[2], labels[3], 'a', 'a')) self.assertTrue( self.miner._is_end_of_label(labels[3], labels[4], 'a', 'a')) self.assertFalse( self.miner._is_end_of_label(labels[4], labels[5], 'a', 'a')) self.assertFalse( self.miner._is_end_of_label(labels[5], labels[6], 'a', 'a')) self.assertFalse( self.miner._is_end_of_label(labels[6], labels[7], 'a', 'a')) self.assertTrue( self.miner._is_end_of_label(labels[7], labels[8], 'a', 'a')) self.assertFalse( self.miner._is_end_of_label(labels[8], labels[9], 'a', 'a')) self.assertFalse( self.miner._is_end_of_label(labels[9], labels[10], 'a', 'a')) self.assertTrue( self.miner._is_end_of_label(labels[10], labels[11], 'a', 'a')) self.assertTrue(self.miner._is_end_of_label(labels[11], '', 'a', '')) self.assertTrue(self.miner._is_end_of_label('B', 'I', 'a', 'b')) def test__is_begin_of_label(self): labels = ['B', 'I', 'O', 'S', 'B', 'I', 'I', 'E', 'O', 'O', 'B', 'B'] self.assertTrue(self.miner._is_begin_of_label('', labels[0], 'a', 'a')) self.assertFalse( self.miner._is_begin_of_label(labels[1], labels[2], 'a', 'a')) self.assertTrue( self.miner._is_begin_of_label(labels[2], labels[3], 'a', 'a')) self.assertTrue( self.miner._is_begin_of_label(labels[3], labels[4], 'a', 'a')) self.assertFalse( self.miner._is_begin_of_label(labels[4], labels[5], 'a', 'a')) self.assertFalse( self.miner._is_begin_of_label(labels[5], labels[6], 'a', 'a')) self.assertFalse( self.miner._is_begin_of_label(labels[6], labels[7], 'a', 'a')) self.assertFalse( self.miner._is_begin_of_label(labels[7], labels[8], 'a', 'a')) self.assertFalse( self.miner._is_begin_of_label(labels[8], labels[9], 'a', 'a')) self.assertTrue( self.miner._is_begin_of_label(labels[9], labels[10], 'a', 'a')) self.assertTrue( self.miner._is_begin_of_label(labels[10], labels[11], 'a', 'a')) self.assertTrue(self.miner._is_begin_of_label('B', 'I', 'a', 'b')) def test__check_add_entities(self): # なんでもOK self.miner.check_known = True self.miner.check_unknown = True self.assertTrue(self.miner._check_add_entities('', '')) self.assertTrue(self.miner._check_add_entities('花子', 'PSN')) # 既知のみOK self.miner.check_known = True self.miner.check_unknown = False self.assertTrue(self.miner._check_add_entities('花子', 'PSN')) self.assertTrue(self.miner._check_add_entities('東京', 'LOC')) self.assertFalse(self.miner._check_add_entities('ボブ', 'PSN')) # 未知のみOK self.miner.check_known = False self.miner.check_unknown = True self.assertTrue(self.miner._check_add_entities('ボブ', 'PSN')) self.assertTrue(self.miner._check_add_entities('東京スカイツリー', 'LOC')) self.assertFalse(self.miner._check_add_entities('東京', 'LOC'))