示例#1
0
def recognize(test_file_name: str, split_by_paragraphs: bool,
              recognizer: BERT_NER, results_file_name: str):
    X_test, y_test = load_dataset_from_bio(
        test_file_name,
        paragraph_separators=({'-DOCSTART-'} if split_by_paragraphs else None),
        stopwords={'-DOCSTART-'})
    print('The CoNLL-2003 data for final testing have been loaded...')
    print('Number of samples is {0}.'.format(len(y_test)))
    print('')
    y_pred = recognizer.predict(X_test)
    f1, precision, recall, quality_by_entities = calculate_prediction_quality(
        y_test, y_pred, classes_list=recognizer.classes_list_)
    print('All entities:')
    print('    F1-score is {0:.2%}.'.format(f1))
    print('    Precision is {0:.2%}.'.format(precision))
    print('    Recall is {0:.2%}.'.format(recall))
    for ne_type in sorted(list(quality_by_entities.keys())):
        print('  {0}'.format(ne_type))
        print('    F1-score is {0:.2%}.'.format(
            quality_by_entities[ne_type][0]))
        print('    Precision is {0:.2%}.'.format(
            quality_by_entities[ne_type][1]))
        print('    Recall is {0:.2%}.'.format(quality_by_entities[ne_type][2]))
    print('')
    save_dataset_as_bio(test_file_name,
                        X_test,
                        y_pred,
                        results_file_name,
                        stopwords={'-DOCSTART-'})
示例#2
0
 def test_calculate_prediction_quality(self):
     base_dir = os.path.join(os.path.dirname(__file__), 'testdata')
     X_true, y_true = load_dataset(
         os.path.join(base_dir, 'true_named_entities.json'))
     X_pred, y_pred = load_dataset(
         os.path.join(base_dir, 'predicted_named_entities.json'))
     self.assertEqual(X_true, X_pred)
     f1, precision, recall, quality_by_entities = calculate_prediction_quality(
         y_true, y_pred, ('LOCATION', 'PERSON', 'ORG'))
     self.assertIsInstance(f1, float)
     self.assertIsInstance(precision, float)
     self.assertIsInstance(recall, float)
     self.assertAlmostEqual(f1, 0.842037, places=3)
     self.assertAlmostEqual(precision, 0.908352, places=3)
     self.assertAlmostEqual(recall, 0.784746, places=3)
     self.assertIsInstance(quality_by_entities, dict)
     self.assertEqual({'LOCATION', 'PERSON', 'ORG'},
                      set(quality_by_entities.keys()))
     f1_macro = 0.0
     precision_macro = 0.0
     recall_macro = 0.0
     for ne_type in quality_by_entities:
         self.assertIsInstance(quality_by_entities[ne_type], tuple)
         self.assertEqual(len(quality_by_entities[ne_type]), 3)
         self.assertIsInstance(quality_by_entities[ne_type][0], float)
         self.assertIsInstance(quality_by_entities[ne_type][1], float)
         self.assertIsInstance(quality_by_entities[ne_type][2], float)
         self.assertLess(quality_by_entities[ne_type][0], 1.0)
         self.assertGreater(quality_by_entities[ne_type][0], 0.0)
         self.assertLess(quality_by_entities[ne_type][1], 1.0)
         self.assertGreater(quality_by_entities[ne_type][1], 0.0)
         self.assertLess(quality_by_entities[ne_type][2], 1.0)
         self.assertGreater(quality_by_entities[ne_type][2], 0.0)
         f1_macro += quality_by_entities[ne_type][0]
         precision_macro += quality_by_entities[ne_type][1]
         recall_macro += quality_by_entities[ne_type][2]
     f1_macro /= float(len(quality_by_entities))
     precision_macro /= float(len(quality_by_entities))
     recall_macro /= float(len(quality_by_entities))
     for ne_type in quality_by_entities:
         self.assertGreater(abs(quality_by_entities[ne_type][0] - f1_macro),
                            1e-4)
         self.assertGreater(
             abs(quality_by_entities[ne_type][1] - precision_macro), 1e-4)
         self.assertGreater(
             abs(quality_by_entities[ne_type][2] - recall_macro), 1e-4)
示例#3
0
def train(train_file_name: str, valid_file_name: str,
          split_by_paragraphs: bool, bert_will_be_tuned: bool,
          lstm_layer_size: Union[int, None], l2: float, max_epochs: int,
          batch_size: int, gpu_memory_frac: float,
          model_name: str) -> BERT_NER:
    if os.path.isfile(model_name):
        with open(model_name, 'rb') as fp:
            recognizer = pickle.load(fp)
        assert isinstance(recognizer, BERT_NER)
        print('The NER has been successfully loaded from the file `{0}`...'.
              format(model_name))
        print('')
    else:
        X_train, y_train = load_dataset_from_bio(
            train_file_name,
            paragraph_separators=({'-DOCSTART-'}
                                  if split_by_paragraphs else None),
            stopwords={'-DOCSTART-'})
        X_val, y_val = load_dataset_from_bio(
            valid_file_name,
            paragraph_separators=({'-DOCSTART-'}
                                  if split_by_paragraphs else None),
            stopwords={'-DOCSTART-'})
        print(
            'The CoNLL-2003 data for training and validation have been loaded...'
        )
        print('Number of samples for training is {0}.'.format(len(y_train)))
        print('Number of samples for validation is {0}.'.format(len(y_val)))
        print('')
        if BERT_NER.PATH_TO_BERT is None:
            bert_hub_module_handle = 'https://tfhub.dev/google/bert_cased_L-12_H-768_A-12/1'
        else:
            bert_hub_module_handle = None
        recognizer = BERT_NER(finetune_bert=bert_will_be_tuned,
                              batch_size=batch_size,
                              l2_reg=l2,
                              bert_hub_module_handle=bert_hub_module_handle,
                              lstm_units=lstm_layer_size,
                              max_epochs=max_epochs,
                              patience=5,
                              gpu_memory_frac=gpu_memory_frac,
                              verbose=True,
                              random_seed=42,
                              lr=1e-6 if bert_will_be_tuned else 1e-4)
        recognizer.fit(X_train, y_train, validation_data=(X_val, y_val))
        print('')
        print(
            'The NER has been successfully fitted and saved into the file `{0}`...'
            .format(model_name))
        y_pred = recognizer.predict(X_val)
        f1, precision, recall, quality_by_entities = calculate_prediction_quality(
            y_val, y_pred, classes_list=recognizer.classes_list_)
        print('All entities:')
        print('    F1-score is {0:.2%}.'.format(f1))
        print('    Precision is {0:.2%}.'.format(precision))
        print('    Recall is {0:.2%}.'.format(recall))
        for ne_type in sorted(list(quality_by_entities.keys())):
            print('  {0}'.format(ne_type))
            print('    F1-score is {0:.2%}.'.format(
                quality_by_entities[ne_type][0]))
            print('    Precision is {0:.2%}.'.format(
                quality_by_entities[ne_type][1]))
            print('    Recall is {0:.2%}.'.format(
                quality_by_entities[ne_type][2]))
        print('')
        with open(model_name, 'wb') as fp:
            pickle.dump(recognizer, fp)
    return recognizer
示例#4
0
def recognize(factrueval2016_testset_dir: str, split_by_paragraphs: bool,
              recognizer: ELMo_NER, results_dir: str):
    temp_json_name = tempfile.NamedTemporaryFile(mode='w').name
    try:
        factrueval2016_to_json(factrueval2016_testset_dir, temp_json_name,
                               split_by_paragraphs)
        with codecs.open(temp_json_name,
                         mode='r',
                         encoding='utf-8',
                         errors='ignore') as fp:
            data_for_testing = json.load(fp)
        _, true_entities = load_dataset(temp_json_name)
    finally:
        if os.path.isfile(temp_json_name):
            os.remove(temp_json_name)
    texts = []
    additional_info = []
    for cur_document in data_for_testing:
        base_name = os.path.join(results_dir,
                                 cur_document['base_name'] + '.task1')
        for cur_paragraph in cur_document['paragraph_bounds']:
            texts.append(
                cur_document['text'][cur_paragraph[0]:cur_paragraph[1]])
            additional_info.append((base_name, cur_paragraph))
    print('Data for final testing have been loaded...')
    print('Number of samples is {0}.'.format(len(true_entities)))
    print('')
    predicted_entities = recognizer.predict(texts)
    assert len(predicted_entities) == len(true_entities)
    f1, precision, recall, quality_by_entities = calculate_prediction_quality(
        true_entities, predicted_entities, recognizer.classes_list_)
    print('All entities:')
    print('    F1-score is {0:.2%}.'.format(f1))
    print('    Precision is {0:.2%}.'.format(precision))
    print('    Recall is {0:.2%}.'.format(recall))
    for ne_type in sorted(list(quality_by_entities.keys())):
        print('  {0}'.format(ne_type))
        print('    F1-score is {0:.2%}.'.format(
            quality_by_entities[ne_type][0]))
        print('    Precision is {0:.2%}.'.format(
            quality_by_entities[ne_type][1]))
        print('    Recall is {0:.2%}.'.format(quality_by_entities[ne_type][2]))
    results_for_factrueval_2016 = dict()
    for sample_idx, cur_result in enumerate(predicted_entities):
        base_name, paragraph_bounds = additional_info[sample_idx]
        for entity_type in cur_result:
            if entity_type == 'ORG':
                prepared_entity_type = 'org'
            elif entity_type == 'PERSON':
                prepared_entity_type = 'per'
            elif entity_type == 'LOCATION':
                prepared_entity_type = 'loc'
            else:
                prepared_entity_type = None
            if prepared_entity_type is None:
                raise ValueError(
                    '`{0}` is unknown entity type!'.format(entity_type))
            for entity_bounds in cur_result[entity_type]:
                postprocessed_entity = (prepared_entity_type,
                                        entity_bounds[0] + paragraph_bounds[0],
                                        entity_bounds[1] - entity_bounds[0])
                if base_name in results_for_factrueval_2016:
                    results_for_factrueval_2016[base_name].append(
                        postprocessed_entity)
                else:
                    results_for_factrueval_2016[base_name] = [
                        postprocessed_entity
                    ]
    for base_name in results_for_factrueval_2016:
        with codecs.open(base_name,
                         mode='w',
                         encoding='utf-8',
                         errors='ignore') as fp:
            for cur_entity in sorted(results_for_factrueval_2016[base_name],
                                     key=lambda it: (it[1], it[2], it[0])):
                fp.write('{0} {1} {2}\n'.format(cur_entity[0], cur_entity[1],
                                                cur_entity[2]))