def test_exact_match_score(self): # Simple case where it should get perfect score test_collection = TargetTextCollection([self._target_text_example()]) test_collection.tokenize(spacy_tokenizer()) test_collection.sequence_labels() measures = test_collection.exact_match_score('sequence_labels') for index, measure in enumerate(measures): if index == 3: assert measure['FP'] == [] assert measure['FN'] == [] assert measure['TP'] == [('2', Span(4, 15)), ('2', Span(30, 35))] else: assert measure == 1.0 # Something that has perfect precision but misses one therefore does # not have perfect recall nor f1 test_collection = TargetTextCollection( self._target_text_measure_examples()) test_collection.tokenize(str.split) # text = 'The laptop case was great and cover was rubbish' sequence_labels_0 = ['O', 'O', 'O', 'O', 'O', 'O', 'B', 'O', 'O'] test_collection['0']['sequence_labels'] = sequence_labels_0 # text = 'The laptop price was awful' sequence_labels_1 = ['O', 'B', 'I', 'O', 'O'] test_collection['1']['sequence_labels'] = sequence_labels_1 recall, precision, f1, error_analysis = test_collection.exact_match_score( 'sequence_labels') assert precision == 1.0 assert recall == 2.0 / 3.0 assert f1 == 0.8 assert error_analysis['FP'] == [] assert error_analysis['FN'] == [('0', Span(4, 15))] assert error_analysis['TP'] == [('0', Span(30, 35)), ('1', Span(4, 16))] # Something that has perfect recall but not precision as it over # predicts sequence_labels_0 = ['O', 'B', 'I', 'B', 'O', 'O', 'B', 'O', 'O'] test_collection['0']['sequence_labels'] = sequence_labels_0 sequence_labels_1 = ['O', 'B', 'I', 'O', 'O'] test_collection['1']['sequence_labels'] = sequence_labels_1 recall, precision, f1, error_analysis = test_collection.exact_match_score( 'sequence_labels') assert precision == 3 / 4 assert recall == 1.0 assert round(f1, 3) == 0.857 assert error_analysis['FP'] == [('0', Span(16, 19))] assert error_analysis['FN'] == [] assert error_analysis['TP'] == [('0', Span(4, 15)), ('0', Span(30, 35)), ('1', Span(4, 16))] # Does not predict anything for a whole sentence therefore will have # perfect precision but bad recall (mainly testing the if not # getting anything for a sentence matters) sequence_labels_0 = ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'] test_collection['0']['sequence_labels'] = sequence_labels_0 sequence_labels_1 = ['O', 'B', 'I', 'O', 'O'] test_collection['1']['sequence_labels'] = sequence_labels_1 recall, precision, f1, error_analysis = test_collection.exact_match_score( 'sequence_labels') assert precision == 1.0 assert recall == 1 / 3 assert f1 == 0.5 assert error_analysis['FP'] == [] fn_error = sorted(error_analysis['FN'], key=lambda x: x[1].start) assert fn_error == [('0', Span(4, 15)), ('0', Span(30, 35))] assert error_analysis['TP'] == [('1', Span(4, 16))] # Handle the edge case of not getting anything sequence_labels_0 = ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'] test_collection['0']['sequence_labels'] = sequence_labels_0 sequence_labels_1 = ['O', 'O', 'O', 'O', 'O'] test_collection['1']['sequence_labels'] = sequence_labels_1 recall, precision, f1, error_analysis = test_collection.exact_match_score( 'sequence_labels') assert precision == 0.0 assert recall == 0.0 assert f1 == 0.0 assert error_analysis['FP'] == [] fn_error = sorted(error_analysis['FN'], key=lambda x: x[1].start) assert fn_error == [('0', Span(4, 15)), ('1', Span(4, 16)), ('0', Span(30, 35))] assert error_analysis['TP'] == [] # The case where the tokens and the text do not align not_align_example = self._target_text_not_align_example() # text = 'The laptop case; was awful' sequence_labels_align = ['O', 'B', 'I', 'O', 'O'] test_collection.add(not_align_example) test_collection.tokenize(str.split) test_collection['inf']['sequence_labels'] = sequence_labels_align sequence_labels_0 = ['O', 'B', 'I', 'O', 'O', 'O', 'B', 'O', 'O'] test_collection['0']['sequence_labels'] = sequence_labels_0 sequence_labels_1 = ['O', 'B', 'I', 'O', 'O'] test_collection['1']['sequence_labels'] = sequence_labels_1 recall, precision, f1, error_analysis = test_collection.exact_match_score( 'sequence_labels') assert recall == 3 / 4 assert precision == 3 / 4 assert f1 == 0.75 assert error_analysis['FP'] == [('inf', Span(4, 16))] assert error_analysis['FN'] == [('inf', Span(4, 15))] tp_error = sorted(error_analysis['TP'], key=lambda x: x[1].start) assert tp_error == [('0', Span(4, 15)), ('1', Span(4, 16)), ('0', Span(30, 35))] # This time it can get a perfect score as the token alignment will be # perfect test_collection.tokenize(spacy_tokenizer()) sequence_labels_align = ['O', 'B', 'I', 'O', 'O', 'O'] test_collection['inf']['sequence_labels'] = sequence_labels_align recall, precision, f1, error_analysis = test_collection.exact_match_score( 'sequence_labels') assert recall == 1.0 assert precision == 1.0 assert f1 == 1.0 assert error_analysis['FP'] == [] assert error_analysis['FN'] == [] tp_error = sorted(error_analysis['TP'], key=lambda x: x[1].end) assert tp_error == [('0', Span(4, 15)), ('inf', Span(4, 15)), ('1', Span(4, 16)), ('0', Span(30, 35))] # Handle the case where one of the samples has no spans test_example = TargetText(text="I've had a bad day", text_id='50') other_examples = self._target_text_measure_examples() other_examples.append(test_example) test_collection = TargetTextCollection(other_examples) test_collection.tokenize(str.split) test_collection.sequence_labels() measures = test_collection.exact_match_score('sequence_labels') for index, measure in enumerate(measures): if index == 3: assert measure['FP'] == [] assert measure['FN'] == [] tp_error = sorted(measure['TP'], key=lambda x: x[1].end) assert tp_error == [('0', Span(4, 15)), ('1', Span(4, 16)), ('0', Span(30, 35))] else: assert measure == 1.0 # Handle the case where on the samples has no spans but has predicted # there is a span there test_collection['50']['sequence_labels'] = ['B', 'I', 'O', 'O', 'O'] recall, precision, f1, error_analysis = test_collection.exact_match_score( 'sequence_labels') assert recall == 1.0 assert precision == 3 / 4 assert round(f1, 3) == 0.857 assert error_analysis['FP'] == [('50', Span(start=0, end=8))] assert error_analysis['FN'] == [] tp_error = sorted(error_analysis['TP'], key=lambda x: x[1].end) assert tp_error == [('0', Span(4, 15)), ('1', Span(4, 16)), ('0', Span(30, 35))] # See if it can handle a collection that only contains no spans test_example = TargetText(text="I've had a bad day", text_id='50') test_collection = TargetTextCollection([test_example]) test_collection.tokenize(str.split) test_collection.sequence_labels() measures = test_collection.exact_match_score('sequence_labels') for index, measure in enumerate(measures): if index == 3: assert measure['FP'] == [] assert measure['FN'] == [] assert measure['TP'] == [] else: assert measure == 0.0 # Handle the case the collection contains one spans but a mistake test_collection['50']['sequence_labels'] = ['B', 'I', 'O', 'O', 'O'] measures = test_collection.exact_match_score('sequence_labels') for index, measure in enumerate(measures): if index == 3: assert measure['FP'] == [('50', Span(0, 8))] assert measure['FN'] == [] assert measure['TP'] == [] else: assert measure == 0.0 # Should raise a KeyError if one of the TargetText instances does # not have a Span key del test_collection['50']._storage['spans'] with pytest.raises(KeyError): test_collection.exact_match_score('sequence_labels') # should raise a KeyError if one of the TargetText instances does # not have a predicted sequence key test_collection = TargetTextCollection([self._target_text_example()]) test_collection.tokenize(spacy_tokenizer()) test_collection.sequence_labels() with pytest.raises(KeyError): measures = test_collection.exact_match_score('nothing') # Should raise a ValueError if there are multiple same true spans a = TargetText(text='hello how are you I am good', text_id='1', targets=['hello', 'hello'], spans=[Span(0, 5), Span(0, 5)]) test_collection = TargetTextCollection([a]) test_collection.tokenize(str.split) test_collection['1']['sequence_labels'] = [ 'B', 'O', 'O', 'O', 'O', 'O', 'O' ] with pytest.raises(ValueError): test_collection.exact_match_score('sequence_labels')
dataset.sequence_labels() sizes.append(len(dataset)) print(f'Lengths {sizes[0]}, {sizes[1]}, {sizes[2]}') save_dir = Path('.', 'models', 'glove_model') param_file = Path('.', 'training_configs', 'Target_Extraction', 'General_Domain', 'Glove_LSTM_CRF.jsonnet') model = AllenNLPModel('Glove', param_file, 'target-tagger', save_dir) if not save_dir.exists(): model.fit(train_data, val_data, test_data) else: model.load() import time start_time = time.time() val_iter = iter(val_data.values()) for val_predictions in model.predict_sequences(val_data.values()): relevant_val = next(val_iter) relevant_val['predicted_sequence_labels'] = val_predictions[ 'sequence_labels'] print(time.time() - start_time) another_time = time.time() for val_predictions in model.predict_sequences(val_data.values()): pass print(time.time() - another_time) print('done') print(val_data.exact_match_score('predicted_sequence_labels')[2]) test_iter = iter(test_data.values()) for test_pred in model.predict_sequences(test_data.values()): relevant_test = next(test_iter) relevant_test['predicted_sequence_labels'] = test_pred['sequence_labels'] print(test_data.exact_match_score('predicted_sequence_labels')[2])