Exemplo n.º 1
0
    def test_add(self):
        new_collection = TargetTextCollection()

        assert len(new_collection) == 0
        new_collection.add(self._target_text_example())
        assert len(new_collection) == 1

        assert '2' in new_collection
Exemplo n.º 2
0
def test__statistics_to_dataframe():
    # Test with just one collection
    target_stats = dataset_target_extraction_statistics([TRAIN_COLLECTION])
    tl_1 = round((17 / 19.0) * 100, 2)
    tl_2 = round((2 / 19.0) * 100, 2)
    true_stats = {
        'Name': 'train',
        'No. Sentences': 6,
        'No. Sentences(t)': 5,
        'No. Targets': 19,
        'No. Uniq Targets': 13,
        'ATS': round(19 / 6.0, 2),
        'ATS(t)': round(19 / 5.0, 2),
        'TL 1 %': tl_1,
        'TL 2 %': tl_2,
        'TL 3+ %': 0,
        'Mean Sentence Length': 15.33,
        'Mean Sentence Length(t)': 16.6
    }
    true_stats_list = {key: [value] for key, value in true_stats.items()}
    true_stats_df = pd.DataFrame(true_stats_list)
    test_stats_df = _statistics_to_dataframe(target_stats)
    pd.testing.assert_frame_equal(true_stats_df,
                                  test_stats_df,
                                  check_less_precise=2)
    # Test with two collections
    subcollection = TargetTextCollection(name='sub')
    subcollection.add(TRAIN_COLLECTION["81207500773427072"])
    subcollection.add(TRAIN_COLLECTION["78522643479064576"])
    target_stats = dataset_target_extraction_statistics(
        [subcollection, TRAIN_COLLECTION])
    tl_1 = round((6 / 7.0) * 100, 2)
    tl_2 = round((1 / 7.0) * 100, 2)
    sub_stats = {
        'Name': 'sub',
        'No. Sentences': 2,
        'No. Sentences(t)': 2,
        'No. Targets': 7,
        'No. Uniq Targets': 7,
        'ATS': round(7 / 2.0, 2),
        'ATS(t)': round(7 / 2.0, 2),
        'TL 1 %': tl_1,
        'TL 2 %': tl_2,
        'TL 3+ %': 0,
        'Mean Sentence Length': 13,
        'Mean Sentence Length(t)': 13
    }
    true_stats_list = {
        key: [value, true_stats[key]]
        for key, value in sub_stats.items()
    }
    true_stats_df = pd.DataFrame(true_stats_list)
    test_stats_df = _statistics_to_dataframe(target_stats)
    pd.testing.assert_frame_equal(true_stats_df,
                                  test_stats_df,
                                  check_less_precise=2)
def test_metric_df(metric_function_name: str, metric_name: str, 
                   include_run_number: bool):
    model_1_collection = passable_example_multiple_preds('true_sentiments', 'model_1')
    model_2_collection = passable_example_multiple_preds('true_sentiments', 'model_2')
    combined_collection = TargetTextCollection()
    for key, value in model_1_collection.items():
        combined_collection.add(value)
        combined_collection[key]['model_2'] = model_2_collection[key]['model_2']
    metric_function = getattr(sentiment_metrics, metric_function_name)
    # Test the array score version first
    model_1_scores = metric_function(model_1_collection, 'true_sentiments', 'model_1', 
                                     average=False, array_scores=True)
    model_2_scores = metric_function(model_2_collection, 'true_sentiments', 'model_2', 
                                     average=False, array_scores=True)
    test_df = util.metric_df(combined_collection, metric_function, 'true_sentiments', 
                             predicted_sentiment_keys=['model_1', 'model_2'], 
                             average=False, array_scores=True, metric_name=metric_name,
                             include_run_number=include_run_number)
    get_metric_name = 'metric' if None else metric_name
    if include_run_number:
        assert (4, 3) == test_df.shape
    else:
        assert (4, 2) == test_df.shape
    for model_name, true_model_scores in [('model_1', model_1_scores), 
                                          ('model_2', model_2_scores)]:
        test_model_scores = test_df.loc[test_df['prediction key']==f'{model_name}'][f'{get_metric_name}']
        assert true_model_scores == test_model_scores.to_list()
        if include_run_number:
            test_run_numbers = test_df.loc[test_df['prediction key']==f'{model_name}']['run number']
            test_run_numbers = test_run_numbers.to_list()
            assert [0, 1] == test_run_numbers
    # Test the average version
    model_1_scores = metric_function(model_1_collection, 'true_sentiments', 'model_1', 
                                     average=True, array_scores=False)
    model_2_scores = metric_function(model_2_collection, 'true_sentiments', 'model_2', 
                                     average=True, array_scores=False)
    if include_run_number:
        with pytest.raises(ValueError):
            util.metric_df(combined_collection, metric_function, 'true_sentiments', 
                           predicted_sentiment_keys=['model_1', 'model_2'], 
                           average=True, array_scores=False, metric_name=metric_name,
                           include_run_number=include_run_number)
    else:
        test_df = util.metric_df(combined_collection, metric_function, 'true_sentiments', 
                                predicted_sentiment_keys=['model_1', 'model_2'], 
                                average=True, array_scores=False, metric_name=metric_name,
                                include_run_number=include_run_number)
        get_metric_name = 'metric' if None else metric_name
        assert (2,2) == test_df.shape
        for model_name, true_model_scores in [('model_1', model_1_scores), 
                                            ('model_2', model_2_scores)]:
            test_model_scores = test_df.loc[test_df['prediction key']==f'{model_name}'][f'{get_metric_name}']
            test_model_scores = test_model_scores.to_list()
            assert 1 == len(test_model_scores)
            assert true_model_scores == test_model_scores[0]
Exemplo n.º 4
0
    def test_target_count(self):
        # Start with an empty collection
        test_collection = TargetTextCollection()
        nothing = test_collection.target_count()
        assert len(nothing) == 0
        assert not nothing

        # Collection that contains TargetText instances but with no targets
        test_collection.add(TargetText(text='some text', text_id='1'))
        assert len(test_collection) == 1
        nothing = test_collection.target_count()
        assert len(nothing) == 0
        assert not nothing

        # Collection now contains at least one target
        test_collection.add(
            TargetText(text='another item today',
                       text_id='2',
                       spans=[Span(0, 12)],
                       targets=['another item']))
        assert len(test_collection) == 2
        one = test_collection.target_count()
        assert len(one) == 1
        assert one == {'another item': 1}

        # Collection now contains 3 targets but 2 are the same
        test_collection.add(
            TargetText(text='another item today',
                       text_id='3',
                       spans=[Span(0, 12)],
                       targets=['another item']))
        test_collection.add(
            TargetText(text='item today',
                       text_id='4',
                       spans=[Span(0, 4)],
                       targets=['item']))
        assert len(test_collection) == 4
        two = test_collection.target_count()
        assert len(two) == 2
        assert two == {'another item': 2, 'item': 1}
Exemplo n.º 5
0
    def test_samples_with_targets(self):
        # Test the case where all of the TargetTextCollection contain targets
        test_collection = TargetTextCollection(self._target_text_examples())
        sub_collection = test_collection.samples_with_targets()
        assert test_collection == sub_collection
        assert len(test_collection) == 3

        # Test the case where none of the TargetTextCollection contain targets
        for sample_id in list(test_collection.keys()):
            del test_collection[sample_id]
            test_collection.add(TargetText(text='nothing', text_id=sample_id))
        assert len(test_collection) == 3
        sub_collection = test_collection.samples_with_targets()
        assert len(sub_collection) == 0
        assert sub_collection != test_collection

        # Test the case where only a 2 of the the three TargetTextCollection
        # contain targets.
        test_collection = TargetTextCollection(self._target_text_examples())
        del test_collection['another_id']
        easy_case = TargetText(text='something else', text_id='another_id')
        test_collection.add(easy_case)
        sub_collection = test_collection.samples_with_targets()
        assert len(sub_collection) == 2
        assert sub_collection != test_collection

        # Test the case where the targets are just an empty list rather than
        # None
        test_collection = TargetTextCollection(self._target_text_examples())
        del test_collection['another_id']
        edge_case = TargetText(text='something else',
                               text_id='another_id',
                               targets=[],
                               spans=[])
        test_collection.add(edge_case)
        sub_collection = test_collection.samples_with_targets()
        assert len(sub_collection) == 2
        assert sub_collection != test_collection
Exemplo n.º 6
0
def test_dataset_target_sentiment_statistics(lower: bool):
    if lower is not None:
        target_stats = dataset_target_sentiment_statistics([TRAIN_COLLECTION],
                                                           lower_target=lower)
    else:
        target_stats = dataset_target_sentiment_statistics([TRAIN_COLLECTION])

    pos_percent = round(
        get_sentiment_counts(TRAIN_COLLECTION, SENTIMENT_KEY)['positive'] *
        100, 2)
    pos_count = get_sentiment_counts(TRAIN_COLLECTION,
                                     SENTIMENT_KEY,
                                     normalised=False)['positive']
    pos_count_percent = f'{pos_count} ({pos_percent})'

    neu_percent = round(
        get_sentiment_counts(TRAIN_COLLECTION, SENTIMENT_KEY)['neutral'] * 100,
        2)
    neu_count = get_sentiment_counts(TRAIN_COLLECTION,
                                     SENTIMENT_KEY,
                                     normalised=False)['neutral']
    neu_count_percent = f'{neu_count} ({neu_percent})'

    neg_percent = round(
        get_sentiment_counts(TRAIN_COLLECTION, SENTIMENT_KEY)['negative'] *
        100, 2)
    neg_count = get_sentiment_counts(TRAIN_COLLECTION,
                                     SENTIMENT_KEY,
                                     normalised=False)['negative']
    neg_count_percent = f'{neg_count} ({neg_percent})'

    tl_1 = round((17 / 19.0) * 100, 2)
    tl_2 = round((2 / 19.0) * 100, 2)
    true_stats = {
        'Name': 'train',
        'No. Sentences': 6,
        'No. Sentences(t)': 5,
        'No. Targets': 19,
        'No. Uniq Targets': 13,
        'ATS': round(19 / 6.0, 2),
        'ATS(t)': round(19 / 5.0, 2),
        'POS (%)': pos_count_percent,
        'NEG (%)': neg_count_percent,
        'NEU (%)': neu_count_percent,
        'TL 1 %': tl_1,
        'TL 2 %': tl_2,
        'TL 3+ %': 0.0,
        'Mean Sentence Length': 15.33,
        'Mean Sentence Length(t)': 16.6
    }
    if lower == False:
        true_stats['No. Uniq Targets'] = 14
    print(target_stats)
    assert 1 == len(target_stats)
    target_stats = target_stats[0]
    assert len(true_stats) == len(target_stats)
    for stat_name, stat in true_stats.items():
        if re.search(r'^TL', stat_name):
            assert math.isclose(stat, target_stats[stat_name], rel_tol=0.001)
        else:
            assert stat == target_stats[stat_name], stat_name

    # Multiple collections, where one collection is just the subset of the other
    subcollection = TargetTextCollection(name='sub')
    subcollection.add(TRAIN_COLLECTION["81207500773427072"])
    subcollection.add(TRAIN_COLLECTION["78522643479064576"])
    if lower is not None:
        target_stats = dataset_target_sentiment_statistics(
            [subcollection, TRAIN_COLLECTION], lower_target=lower)
    else:
        target_stats = dataset_target_sentiment_statistics(
            [subcollection, TRAIN_COLLECTION])

    pos_percent = round(
        get_sentiment_counts(subcollection, SENTIMENT_KEY)['positive'] * 100,
        2)
    pos_count = get_sentiment_counts(subcollection,
                                     SENTIMENT_KEY,
                                     normalised=False)['positive']
    pos_count_percent = f'{pos_count} ({pos_percent})'

    neu_percent = round(
        get_sentiment_counts(subcollection, SENTIMENT_KEY)['neutral'] * 100, 2)
    neu_count = get_sentiment_counts(subcollection,
                                     SENTIMENT_KEY,
                                     normalised=False)['neutral']
    neu_count_percent = f'{neu_count} ({neu_percent})'

    neg_percent = round(
        get_sentiment_counts(subcollection, SENTIMENT_KEY)['negative'] * 100,
        2)
    neg_count = get_sentiment_counts(subcollection,
                                     SENTIMENT_KEY,
                                     normalised=False)['negative']
    neg_count_percent = f'{neg_count} ({neg_percent})'

    tl_1 = round((6 / 7.0) * 100, 2)
    tl_2 = round((1 / 7.0) * 100, 2)
    sub_stats = {
        'Name': 'sub',
        'No. Sentences': 2,
        'No. Sentences(t)': 2,
        'No. Targets': 7,
        'No. Uniq Targets': 7,
        'ATS': round(7 / 2.0, 2),
        'ATS(t)': round(7 / 2.0, 2),
        'POS (%)': pos_count_percent,
        'NEG (%)': neg_count_percent,
        'NEU (%)': neu_count_percent,
        'TL 1 %': tl_1,
        'TL 2 %': tl_2,
        'TL 3+ %': 0.0,
        'Mean Sentence Length': 13,
        'Mean Sentence Length(t)': 13
    }
    true_stats = [sub_stats, true_stats]
    assert len(true_stats) == len(target_stats)
    for stat_index, stat in enumerate(true_stats):
        test_stat = target_stats[stat_index]
        assert len(stat) == len(test_stat)
        for stat_name, stat_value in stat.items():
            if re.search(r'^TL', stat_name):
                assert math.isclose(stat_value,
                                    test_stat[stat_name],
                                    rel_tol=0.001)
            else:
                assert stat_value == test_stat[stat_name], stat_name
Exemplo n.º 7
0
def test_dataset_target_extraction_statistics(lower: bool,
                                              incl_sentence_statistics: bool):
    if lower is not None:
        target_stats = dataset_target_extraction_statistics(
            [TRAIN_COLLECTION],
            lower_target=lower,
            incl_sentence_statistics=incl_sentence_statistics)
    else:
        target_stats = dataset_target_extraction_statistics(
            [TRAIN_COLLECTION],
            incl_sentence_statistics=incl_sentence_statistics)
    tl_1 = round((17 / 19.0) * 100, 2)
    tl_2 = round((2 / 19.0) * 100, 2)
    true_stats = {
        'Name': 'train',
        'No. Sentences': 6,
        'No. Sentences(t)': 5,
        'No. Targets': 19,
        'No. Uniq Targets': 13,
        'ATS': round(19 / 6.0, 2),
        'ATS(t)': round(19 / 5.0, 2),
        'TL 1 %': tl_1,
        'TL 2 %': tl_2,
        'TL 3+ %': 0.0,
        'Mean Sentence Length': 15.33,
        'Mean Sentence Length(t)': 16.6
    }
    if not incl_sentence_statistics:
        del true_stats['Mean Sentence Length(t)']
        del true_stats['Mean Sentence Length']
    if lower == False:
        true_stats['No. Uniq Targets'] = 14
    assert 1 == len(target_stats)
    target_stats = target_stats[0]
    assert len(true_stats) == len(target_stats)
    for stat_name, stat in true_stats.items():
        if re.search(r'^TL', stat_name):
            assert math.isclose(stat, target_stats[stat_name], rel_tol=0.001)
        else:
            assert stat == target_stats[stat_name], stat_name

    # Multiple collections, where one collection is just the subset of the other
    subcollection = TargetTextCollection(name='sub')
    subcollection.add(TRAIN_COLLECTION["81207500773427072"])
    subcollection.add(TRAIN_COLLECTION["78522643479064576"])
    long_target = TargetText(
        text='some text that contains a long target or two',
        spans=[Span(0, 14), Span(15, 37)],
        targets=['some text that', 'contains a long target'],
        target_sentiments=['positive', 'negative'],
        text_id='100')
    subcollection.add(long_target)
    subcollection.tokenize(whitespace())
    if lower is not None:
        target_stats = dataset_target_extraction_statistics(
            [subcollection, TRAIN_COLLECTION],
            lower_target=lower,
            incl_sentence_statistics=incl_sentence_statistics)
    else:
        target_stats = dataset_target_extraction_statistics(
            [subcollection, TRAIN_COLLECTION],
            incl_sentence_statistics=incl_sentence_statistics)

    tl_1 = round((6 / 9.0) * 100, 2)
    tl_2 = round((1 / 9.0) * 100, 2)
    tl_3 = round((2 / 9.0) * 100, 2)
    sub_stats = {
        'Name': 'sub',
        'No. Sentences': 3,
        'No. Sentences(t)': 3,
        'No. Targets': 9,
        'No. Uniq Targets': 9,
        'ATS': round(9 / 3.0, 2),
        'ATS(t)': round(9 / 3.0, 2),
        'TL 1 %': tl_1,
        'TL 2 %': tl_2,
        'TL 3+ %': tl_3,
        'Mean Sentence Length': 11.67,
        'Mean Sentence Length(t)': 11.67
    }
    if not incl_sentence_statistics:
        del sub_stats['Mean Sentence Length(t)']
        del sub_stats['Mean Sentence Length']
    true_stats = [sub_stats, true_stats]
    assert len(true_stats) == len(target_stats)
    for stat_index, stat in enumerate(true_stats):
        test_stat = target_stats[stat_index]
        assert len(stat) == len(test_stat)
        for stat_name, stat_value in stat.items():
            if re.search(r'^TL', stat_name):
                assert math.isclose(stat_value,
                                    test_stat[stat_name],
                                    rel_tol=0.001)
            else:
                assert stat_value == test_stat[stat_name], stat_name
def test_overall_metric_results(true_sentiment_key: str, 
                                include_metadata: bool,
                                strict_accuracy_metrics: bool):
    if true_sentiment_key is None:
        true_sentiment_key = 'target_sentiments'
    model_1_collection = passable_example_multiple_preds(true_sentiment_key, 'model_1')
    model_1_collection.add(TargetText(text='a', text_id='200', spans=[Span(0,1)], targets=['a'],
                                      model_1=[['pos'], ['neg']], **{f'{true_sentiment_key}': ['pos']}))
    model_1_collection.add(TargetText(text='a', text_id='201', spans=[Span(0,1)], targets=['a'],
                                      model_1=[['pos'], ['neg']], **{f'{true_sentiment_key}': ['neg']}))
    model_1_collection.add(TargetText(text='a', text_id='202', spans=[Span(0,1)], targets=['a'],
                                      model_1=[['pos'], ['neg']], **{f'{true_sentiment_key}': ['neu']}))
    print(true_sentiment_key)
    print(model_1_collection['1'])
    print(model_1_collection['200'])
    model_2_collection = passable_example_multiple_preds(true_sentiment_key, 'model_2')
    combined_collection = TargetTextCollection()
    
    standard_columns = ['Dataset', 'Macro F1', 'Accuracy', 'run number', 
                        'prediction key']
    if strict_accuracy_metrics:
        standard_columns = standard_columns + ['STAC', 'STAC 1', 'STAC Multi']
    if include_metadata:
        metadata = {'predicted_target_sentiment_key': {'model_1': {'CWR': True},
                                                       'model_2': {'CWR': False}}}
        combined_collection.name = 'name'
        combined_collection.metadata = metadata
        standard_columns.append('CWR')
    number_df_columns = len(standard_columns)

    for key, value in model_1_collection.items():
        if key in ['200', '201', '202']:
            combined_collection.add(value)
            combined_collection[key]['model_2'] = [['neg'], ['pos']]
            continue
        combined_collection.add(value)
        combined_collection[key]['model_2'] = model_2_collection[key]['model_2']
    if true_sentiment_key is None:
        result_df = util.overall_metric_results(combined_collection, 
                                                ['model_1', 'model_2'], 
                                                strict_accuracy_metrics=strict_accuracy_metrics)
    else:
        result_df = util.overall_metric_results(combined_collection, 
                                                ['model_1', 'model_2'],
                                                true_sentiment_key, 
                                                strict_accuracy_metrics=strict_accuracy_metrics)
    assert (4, number_df_columns) == result_df.shape
    assert set(standard_columns) == set(result_df.columns)
    if include_metadata:
        assert ['name'] * 4 == result_df['Dataset'].tolist()
    else:
        assert [''] * 4 == result_df['Dataset'].tolist()
    # Test the case where only one model is used
    if true_sentiment_key is None:
        result_df = util.overall_metric_results(combined_collection, 
                                                ['model_1'], 
                                                strict_accuracy_metrics=strict_accuracy_metrics)
    else:
        result_df = util.overall_metric_results(combined_collection, 
                                                ['model_1'],
                                                true_sentiment_key, 
                                                strict_accuracy_metrics=strict_accuracy_metrics)
    assert (2, number_df_columns) == result_df.shape
    # Test the case where the model names come from the metadata
    if include_metadata:
        result_df = util.overall_metric_results(combined_collection, 
                                                true_sentiment_key=true_sentiment_key, 
                                                strict_accuracy_metrics=strict_accuracy_metrics)
        assert (4, number_df_columns) == result_df.shape
    else:
        with pytest.raises(KeyError):
            util.overall_metric_results(combined_collection, 
                                        true_sentiment_key=true_sentiment_key, 
                                        strict_accuracy_metrics=strict_accuracy_metrics)
def test_add_metadata_to_df():
    model_1_collection = passable_example_multiple_preds('true_sentiments', 'model_1')
    model_2_collection = passable_example_multiple_preds('true_sentiments', 'model_2')
    combined_collection = TargetTextCollection()
    for key, value in model_1_collection.items():
        combined_collection.add(value)
        combined_collection[key]['model_2'] = model_2_collection[key]['model_2']
    # get test metric_df
    combined_collection.metadata = None
    metric_df = util.metric_df(combined_collection, sentiment_metrics.accuracy, 
                               'true_sentiments', 
                               predicted_sentiment_keys=['model_1', 'model_2'], 
                               average=False, array_scores=True, metric_name='metric')
    # Test the case where the TargetTextCollection has no metadata, should 
    # just return the dataframe as is without change
    test_df = util.add_metadata_to_df(metric_df.copy(deep=True), combined_collection, 
                                      'non-existing-key')
    assert metric_df.equals(test_df)
    assert combined_collection.metadata is None
    # Testing the case where the metadata is not None but does not contain the 
    # `metadata_prediction_key` = `non-existing-key`
    combined_collection.name = 'combined collection'
    assert combined_collection.metadata is not None
    test_df = util.add_metadata_to_df(metric_df.copy(deep=True), combined_collection, 
                                      'non-existing-key')
    assert metric_df.equals(test_df)

    # Test the normal cases where we are adding metadata
    key_metadata_normal = {'model_1': {'embedding': True, 'value': '10'}, 
                           'model_2': {'embedding': False, 'value': '5'}}
    key_metadata_alt = {'model_1': {'embedding': True, 'value': '10'}, 
                        'model_2': {'embedding': False, 'value': '5'},
                        'model_3': {'embedding': 'low', 'value': '12'}}
    key_metadata_diff = {'model_1': {'embedding': True, 'value': '10', 'special': 12}, 
                         'model_2': {'embedding': False, 'value': '5', 'diff': 30.0}}
    key_metadataer = [key_metadata_normal, key_metadata_alt, key_metadata_diff]
    for key_metadata in key_metadataer:
        combined_collection.metadata['non-existing-key'] = key_metadata
        if 'special' in key_metadata['model_1']:
            test_df = util.add_metadata_to_df(metric_df.copy(deep=True), combined_collection, 
                                              'non-existing-key', 
                                              metadata_keys=['embedding', 'value'])
        else:
            test_df = util.add_metadata_to_df(metric_df.copy(deep=True), combined_collection, 
                                              'non-existing-key')
        assert not metric_df.equals(test_df)
        assert (4, 4) == test_df.shape
        assert [True, True] == test_df.loc[test_df['prediction key']=='model_1']['embedding'].to_list()
        assert [False, False] == test_df.loc[test_df['prediction key']=='model_2']['embedding'].to_list()
        assert ['10', '10'] == test_df.loc[test_df['prediction key']=='model_1']['value'].to_list()
        assert ['5', '5'] == test_df.loc[test_df['prediction key']=='model_2']['value'].to_list()
    # Test the case where some of the metadata exists for some of the models in
    # the collection but not all of them
    combined_collection.metadata['non-existing-key'] = key_metadata_diff
    test_df = util.add_metadata_to_df(metric_df.copy(deep=True), combined_collection, 
                                      'non-existing-key')
    assert not metric_df.equals(test_df)
    assert (4, 6) == test_df.shape
    assert [True, True] == test_df.loc[test_df['prediction key']=='model_1']['embedding'].to_list()
    assert [False, False] == test_df.loc[test_df['prediction key']=='model_2']['embedding'].to_list()
    assert ['10', '10'] == test_df.loc[test_df['prediction key']=='model_1']['value'].to_list()
    assert ['5', '5'] == test_df.loc[test_df['prediction key']=='model_2']['value'].to_list()
    assert [12, 12] == test_df.loc[test_df['prediction key']=='model_1']['special'].to_list()
    nan_values = test_df.loc[test_df['prediction key']=='model_2']['special'].to_list()
    assert 2 == len(nan_values)
    for test_value in nan_values:
        assert math.isnan(test_value)
    nan_values = test_df.loc[test_df['prediction key']=='model_1']['diff'].to_list()
    assert 2 == len(nan_values)
    for test_value in nan_values:
        assert math.isnan(test_value)
    assert [30.0, 30.0] == test_df.loc[test_df['prediction key']=='model_2']['diff'].to_list()

    # Test the KeyError cases
    # Prediction keys that exist in the metric df but not in the collection
    metric_copy_df = metric_df.copy(deep=True)
    alt_metric_df = pd.DataFrame({'prediction key': ['model_3', 'model_3'], 
                                  'metric': [0.4, 0.5]})
    metric_copy_df = metric_copy_df.append(alt_metric_df)
    assert (6, 2) == metric_copy_df.shape
    with pytest.raises(KeyError):
        util.add_metadata_to_df(metric_copy_df, combined_collection, 'non-existing-key')
    # Prediction keys exist in the dataframe and target texts but not in the 
    # metadata
    combined_collection.metadata['non-existing-key'] = {'model_1': {'embedding': True, 'value': '10'}}
    with pytest.raises(KeyError):
        util.add_metadata_to_df(metric_df.copy(deep=True), combined_collection, 'non-existing-key')
Exemplo n.º 10
0
    def test_exact_match_score(self):
        # Simple case where it should get perfect score
        test_collection = TargetTextCollection([self._target_text_example()])
        test_collection.tokenize(spacy_tokenizer())
        test_collection.sequence_labels()
        measures = test_collection.exact_match_score('sequence_labels')
        for index, measure in enumerate(measures):
            if index == 3:
                assert measure['FP'] == []
                assert measure['FN'] == []
                assert measure['TP'] == [('2', Span(4, 15)),
                                         ('2', Span(30, 35))]
            else:
                assert measure == 1.0

        # Something that has perfect precision but misses one therefore does
        # not have perfect recall nor f1
        test_collection = TargetTextCollection(
            self._target_text_measure_examples())
        test_collection.tokenize(str.split)
        # text = 'The laptop case was great and cover was rubbish'
        sequence_labels_0 = ['O', 'O', 'O', 'O', 'O', 'O', 'B', 'O', 'O']
        test_collection['0']['sequence_labels'] = sequence_labels_0
        # text = 'The laptop price was awful'
        sequence_labels_1 = ['O', 'B', 'I', 'O', 'O']
        test_collection['1']['sequence_labels'] = sequence_labels_1
        recall, precision, f1, error_analysis = test_collection.exact_match_score(
            'sequence_labels')
        assert precision == 1.0
        assert recall == 2.0 / 3.0
        assert f1 == 0.8
        assert error_analysis['FP'] == []
        assert error_analysis['FN'] == [('0', Span(4, 15))]
        assert error_analysis['TP'] == [('0', Span(30, 35)), ('1', Span(4,
                                                                        16))]

        # Something that has perfect recall but not precision as it over
        # predicts
        sequence_labels_0 = ['O', 'B', 'I', 'B', 'O', 'O', 'B', 'O', 'O']
        test_collection['0']['sequence_labels'] = sequence_labels_0
        sequence_labels_1 = ['O', 'B', 'I', 'O', 'O']
        test_collection['1']['sequence_labels'] = sequence_labels_1
        recall, precision, f1, error_analysis = test_collection.exact_match_score(
            'sequence_labels')
        assert precision == 3 / 4
        assert recall == 1.0
        assert round(f1, 3) == 0.857
        assert error_analysis['FP'] == [('0', Span(16, 19))]
        assert error_analysis['FN'] == []
        assert error_analysis['TP'] == [('0', Span(4, 15)), ('0', Span(30,
                                                                       35)),
                                        ('1', Span(4, 16))]

        # Does not predict anything for a whole sentence therefore will have
        # perfect precision but bad recall (mainly testing the if not
        # getting anything for a sentence matters)
        sequence_labels_0 = ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
        test_collection['0']['sequence_labels'] = sequence_labels_0
        sequence_labels_1 = ['O', 'B', 'I', 'O', 'O']
        test_collection['1']['sequence_labels'] = sequence_labels_1
        recall, precision, f1, error_analysis = test_collection.exact_match_score(
            'sequence_labels')
        assert precision == 1.0
        assert recall == 1 / 3
        assert f1 == 0.5
        assert error_analysis['FP'] == []
        fn_error = sorted(error_analysis['FN'], key=lambda x: x[1].start)
        assert fn_error == [('0', Span(4, 15)), ('0', Span(30, 35))]
        assert error_analysis['TP'] == [('1', Span(4, 16))]

        # Handle the edge case of not getting anything
        sequence_labels_0 = ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
        test_collection['0']['sequence_labels'] = sequence_labels_0
        sequence_labels_1 = ['O', 'O', 'O', 'O', 'O']
        test_collection['1']['sequence_labels'] = sequence_labels_1
        recall, precision, f1, error_analysis = test_collection.exact_match_score(
            'sequence_labels')
        assert precision == 0.0
        assert recall == 0.0
        assert f1 == 0.0
        assert error_analysis['FP'] == []
        fn_error = sorted(error_analysis['FN'], key=lambda x: x[1].start)
        assert fn_error == [('0', Span(4, 15)), ('1', Span(4, 16)),
                            ('0', Span(30, 35))]
        assert error_analysis['TP'] == []

        # The case where the tokens and the text do not align
        not_align_example = self._target_text_not_align_example()
        # text = 'The laptop case; was awful'
        sequence_labels_align = ['O', 'B', 'I', 'O', 'O']
        test_collection.add(not_align_example)
        test_collection.tokenize(str.split)
        test_collection['inf']['sequence_labels'] = sequence_labels_align
        sequence_labels_0 = ['O', 'B', 'I', 'O', 'O', 'O', 'B', 'O', 'O']
        test_collection['0']['sequence_labels'] = sequence_labels_0
        sequence_labels_1 = ['O', 'B', 'I', 'O', 'O']
        test_collection['1']['sequence_labels'] = sequence_labels_1
        recall, precision, f1, error_analysis = test_collection.exact_match_score(
            'sequence_labels')
        assert recall == 3 / 4
        assert precision == 3 / 4
        assert f1 == 0.75
        assert error_analysis['FP'] == [('inf', Span(4, 16))]
        assert error_analysis['FN'] == [('inf', Span(4, 15))]
        tp_error = sorted(error_analysis['TP'], key=lambda x: x[1].start)
        assert tp_error == [('0', Span(4, 15)), ('1', Span(4, 16)),
                            ('0', Span(30, 35))]

        # This time it can get a perfect score as the token alignment will be
        # perfect
        test_collection.tokenize(spacy_tokenizer())
        sequence_labels_align = ['O', 'B', 'I', 'O', 'O', 'O']
        test_collection['inf']['sequence_labels'] = sequence_labels_align
        recall, precision, f1, error_analysis = test_collection.exact_match_score(
            'sequence_labels')
        assert recall == 1.0
        assert precision == 1.0
        assert f1 == 1.0
        assert error_analysis['FP'] == []
        assert error_analysis['FN'] == []
        tp_error = sorted(error_analysis['TP'], key=lambda x: x[1].end)
        assert tp_error == [('0', Span(4, 15)), ('inf', Span(4, 15)),
                            ('1', Span(4, 16)), ('0', Span(30, 35))]

        # Handle the case where one of the samples has no spans
        test_example = TargetText(text="I've had a bad day", text_id='50')
        other_examples = self._target_text_measure_examples()
        other_examples.append(test_example)
        test_collection = TargetTextCollection(other_examples)
        test_collection.tokenize(str.split)
        test_collection.sequence_labels()
        measures = test_collection.exact_match_score('sequence_labels')
        for index, measure in enumerate(measures):
            if index == 3:
                assert measure['FP'] == []
                assert measure['FN'] == []
                tp_error = sorted(measure['TP'], key=lambda x: x[1].end)
                assert tp_error == [('0', Span(4, 15)), ('1', Span(4, 16)),
                                    ('0', Span(30, 35))]
            else:
                assert measure == 1.0
        # Handle the case where on the samples has no spans but has predicted
        # there is a span there
        test_collection['50']['sequence_labels'] = ['B', 'I', 'O', 'O', 'O']
        recall, precision, f1, error_analysis = test_collection.exact_match_score(
            'sequence_labels')
        assert recall == 1.0
        assert precision == 3 / 4
        assert round(f1, 3) == 0.857
        assert error_analysis['FP'] == [('50', Span(start=0, end=8))]
        assert error_analysis['FN'] == []
        tp_error = sorted(error_analysis['TP'], key=lambda x: x[1].end)
        assert tp_error == [('0', Span(4, 15)), ('1', Span(4, 16)),
                            ('0', Span(30, 35))]
        # See if it can handle a collection that only contains no spans
        test_example = TargetText(text="I've had a bad day", text_id='50')
        test_collection = TargetTextCollection([test_example])
        test_collection.tokenize(str.split)
        test_collection.sequence_labels()
        measures = test_collection.exact_match_score('sequence_labels')
        for index, measure in enumerate(measures):
            if index == 3:
                assert measure['FP'] == []
                assert measure['FN'] == []
                assert measure['TP'] == []
            else:
                assert measure == 0.0
        # Handle the case the collection contains one spans but a mistake
        test_collection['50']['sequence_labels'] = ['B', 'I', 'O', 'O', 'O']
        measures = test_collection.exact_match_score('sequence_labels')
        for index, measure in enumerate(measures):
            if index == 3:
                assert measure['FP'] == [('50', Span(0, 8))]
                assert measure['FN'] == []
                assert measure['TP'] == []
            else:
                assert measure == 0.0
        # Should raise a KeyError if one of the TargetText instances does
        # not have a Span key
        del test_collection['50']._storage['spans']
        with pytest.raises(KeyError):
            test_collection.exact_match_score('sequence_labels')
        # should raise a KeyError if one of the TargetText instances does
        # not have a predicted sequence key
        test_collection = TargetTextCollection([self._target_text_example()])
        test_collection.tokenize(spacy_tokenizer())
        test_collection.sequence_labels()
        with pytest.raises(KeyError):
            measures = test_collection.exact_match_score('nothing')

        # Should raise a ValueError if there are multiple same true spans
        a = TargetText(text='hello how are you I am good',
                       text_id='1',
                       targets=['hello', 'hello'],
                       spans=[Span(0, 5), Span(0, 5)])
        test_collection = TargetTextCollection([a])
        test_collection.tokenize(str.split)
        test_collection['1']['sequence_labels'] = [
            'B', 'O', 'O', 'O', 'O', 'O', 'O'
        ]
        with pytest.raises(ValueError):
            test_collection.exact_match_score('sequence_labels')
Exemplo n.º 11
0
def multi_aspect_multi_sentiment_acsa(
        dataset: str,
        cache_dir: Optional[Path] = None) -> TargetTextCollection:
    '''
    The data for this function when downloaded is stored within: 
    `Path(cache_dir, 'Jiang 2019 MAMS ACSA')

    :NOTE: That as each sentence/`TargetText` object has to have 
           a `text_id`, as no ids exist in this dataset the ids are created 
           based on when the sentence occurs in the dataset e.g. the first 
           sentence/`TargetText` object id is '0'

    For reference this dataset has 8 different aspect categories.

    :param dataset: Either `train`, `val` or `test`, determines the dataset that 
                    is returned.
    :param cache_dir: The directory where all of the data is stored for 
                      this code base. If None then the cache directory is
                      `dataset_parsers.CACHE_DIRECTORY`
    :returns: The `train`, `val`, or `test` dataset from the 
              Multi-Aspect-Multi-Sentiment dataset (MAMS) ACSA version. 
              Dataset came from the `A Challenge Dataset and Effective Models  
              for Aspect-Based Sentiment Analysis, EMNLP 2019 
              <https://www.aclweb.org/anthology/D19-1654.pdf>`_
    :raises ValueError: If the `dataset` value is not `train`, `val`, or `test`
    '''
    accepted_datasets = {'train', 'val', 'test'}
    if dataset not in accepted_datasets:
        raise ValueError('dataset has to be one of these values '
                         f'{accepted_datasets}, not {dataset}')
    if cache_dir is None:
        cache_dir = CACHE_DIRECTORY
    data_folder = Path(cache_dir, 'Jiang 2019 MAMS ACSA')
    data_folder.mkdir(parents=True, exist_ok=True)

    dataset_url = {
        'train':
        'https://github.com/siat-nlp/MAMS-for-ABSA/raw/master/data/MAMS-ACSA/raw/train.xml',
        'val':
        'https://github.com/siat-nlp/MAMS-for-ABSA/raw/master/data/MAMS-ACSA/raw/val.xml',
        'test':
        'https://github.com/siat-nlp/MAMS-for-ABSA/raw/master/data/MAMS-ACSA/raw/test.xml'
    }
    url = dataset_url[dataset]
    data_fp = Path(cached_path(url, cache_dir=data_folder))

    # Parsing the data
    category_text_collection = TargetTextCollection()
    tree = ET.parse(data_fp)
    sentences = tree.getroot()
    for sentence_id, sentence in enumerate(sentences):
        categories: List[str] = []
        category_sentiments: List[Union[str, int]] = []

        for data in sentence:
            if data.tag == 'text':
                text = data.text
                text = text.replace(u'\xa0', u' ')
            elif data.tag == 'aspectCategories':
                for category in data:
                    category_sentiment = category.attrib['polarity']
                    category_sentiments.append(category_sentiment)
                    categories.append(category.attrib['category'].replace(
                        u'\xa0', u' '))
            else:
                raise ValueError(f'This tag {data.tag} should not occur '
                                 'within a sentence tag')
        category_text_kwargs = {
            'targets': None,
            'spans': None,
            'text_id': f'{dataset}${str(sentence_id)}',
            'target_sentiments': None,
            'categories': categories,
            'text': text,
            'category_sentiments': category_sentiments
        }
        for key in category_text_kwargs:
            if not category_text_kwargs[key]:
                category_text_kwargs[key] = None
        category_text = TargetText(**category_text_kwargs)
        category_text_collection.add(category_text)
    return category_text_collection
Exemplo n.º 12
0
def multi_aspect_multi_sentiment_atsa(
        dataset: str,
        cache_dir: Optional[Path] = None,
        original: bool = True) -> TargetTextCollection:
    '''
    The data for this function when downloaded is stored within: 
    `Path(cache_dir, 'Jiang 2019 MAMS ATSA')

    :NOTE: That as each sentence/`TargetText` object has to have 
           a `text_id`, as no ids exist in this dataset the ids are created 
           based on when the sentence occurs in the dataset e.g. the first 
           sentence/`TargetText` object id is '0'

    :param dataset: Either `train`, `val` or `test`, determines the dataset that 
                    is returned.
    :param cache_dir: The directory where all of the data is stored for 
                      this code base. If None then the cache directory is
                      `dataset_parsers.CACHE_DIRECTORY`
    :param original: This does not affect `val` or `test`. If True then it will 
                     download the original training data from the `original paper 
                     <https://www.aclweb.org/anthology/D19-1654.pdf>`_ . Else 
                     it will download the cleaned Training dataset version. The 
                     cleaned version only contains a few sample differences 
                     but these differences are with respect to overlapping 
                     targets. See this `notebook for full differences 
                     <https://github.com/apmoore1/target-extraction/blob/master/tutorials/Difference_between_MAMS_ATSA_original_and_MAMS_ATSA_cleaned.ipynb>`_:
                     
    :returns: The `train`, `val`, or `test` dataset from the 
              Multi-Aspect-Multi-Sentiment dataset (MAMS) ATSA version. 
              Dataset came from the `A Challenge Dataset and Effective Models  
              for Aspect-Based Sentiment Analysis, EMNLP 2019 
              <https://www.aclweb.org/anthology/D19-1654.pdf>`_
    :raises ValueError: If the `dataset` value is not `train`, `val`, or `test`
    '''
    accepted_datasets = {'train', 'val', 'test'}
    if dataset not in accepted_datasets:
        raise ValueError('dataset has to be one of these values '
                         f'{accepted_datasets}, not {dataset}')
    if cache_dir is None:
        cache_dir = CACHE_DIRECTORY
    data_folder = Path(cache_dir, 'Jiang 2019 MAMS ATSA')
    data_folder.mkdir(parents=True, exist_ok=True)

    dataset_url = {
        'train':
        'https://github.com/siat-nlp/MAMS-for-ABSA/raw/master/data/MAMS-ATSA/raw/train.xml',
        'val':
        'https://github.com/siat-nlp/MAMS-for-ABSA/raw/master/data/MAMS-ATSA/raw/val.xml',
        'test':
        'https://github.com/siat-nlp/MAMS-for-ABSA/raw/master/data/MAMS-ATSA/raw/test.xml'
    }
    url = dataset_url[dataset]
    if dataset == 'train' and not original:
        url = 'https://raw.githubusercontent.com/apmoore1/target-extraction/master/data/MAMS/MAMS_ATSA_cleaned_train.xml'
    data_fp = Path(cached_path(url, cache_dir=data_folder))

    # Parsing the data
    target_text_collection = TargetTextCollection()
    tree = ET.parse(data_fp)
    sentences = tree.getroot()
    for sentence_id, sentence in enumerate(sentences):
        targets: List[str] = []
        target_sentiments: List[Union[str, int]] = []
        spans: List[Span] = []

        for data in sentence:
            if data.tag == 'text':
                text = data.text
                text = text.replace(u'\xa0', u' ')
            elif data.tag == 'aspectTerms':
                for target in data:
                    target_sentiment = target.attrib['polarity']
                    target_sentiments.append(target_sentiment)
                    targets.append(target.attrib['term'].replace(
                        u'\xa0', u' '))
                    span_from = int(target.attrib['from'])
                    span_to = int(target.attrib['to'])
                    spans.append(Span(span_from, span_to))
            else:
                raise ValueError(f'This tag {data.tag} should not occur '
                                 'within a sentence tag')
        target_text_kwargs = {
            'targets': targets,
            'spans': spans,
            'text_id': f'{dataset}${str(sentence_id)}',
            'target_sentiments': target_sentiments,
            'categories': None,
            'text': text,
            'category_sentiments': None
        }
        for key in target_text_kwargs:
            if not target_text_kwargs[key]:
                target_text_kwargs[key] = None
        target_text = TargetText(**target_text_kwargs)
        target_text_collection.add(target_text)
    return target_text_collection
Exemplo n.º 13
0
def _semeval_extract_data(sentence_tree: Element,
                          conflict: bool) -> TargetTextCollection:
    '''
    :param sentence_tree: The root element of the XML tree that has come 
                          from a SemEval XML formatted XML File.
    :param conflict: Whether or not to include targets or categories that 
                     have the `conflict` sentiment value. True is to include 
                     conflict targets and categories.
    :returns: The SemEval data formatted into a 
              `target_extraction.data_types.TargetTextCollection` object.
    '''
    target_text_collection = TargetTextCollection()
    for sentence in sentence_tree:
        text_id = sentence.attrib['id']

        targets: List[str] = []
        target_sentiments: List[Union[str, int]] = []
        spans: List[Span] = []

        category_sentiments: List[Union[str, int]] = []
        categories: List[str] = []

        for data in sentence:
            if data.tag == 'text':
                text = data.text
                text = text.replace(u'\xa0', u' ')
            elif data.tag == 'aspectTerms':
                for target in data:
                    # If it is a conflict sentiment and conflict argument True
                    # skip this target
                    target_sentiment = target.attrib['polarity']
                    if not conflict and target_sentiment == 'conflict':
                        continue
                    targets.append(target.attrib['term'].replace(
                        u'\xa0', u' '))
                    target_sentiments.append(target_sentiment)
                    span_from = int(target.attrib['from'])
                    span_to = int(target.attrib['to'])
                    spans.append(Span(span_from, span_to))
            elif data.tag == 'aspectCategories':
                for category in data:
                    # If it is a conflict sentiment and conflict argument True
                    # skip this category
                    category_sentiment = category.attrib['polarity']
                    if not conflict and category_sentiment == 'conflict':
                        continue
                    categories.append(category.attrib['category'])
                    category_sentiments.append(category.attrib['polarity'])
            elif data.tag == 'Opinions':
                for opinion in data:
                    category_target_sentiment = opinion.attrib['polarity']
                    if not conflict and category_target_sentiment == 'conflict':
                        continue
                    # Handle the case where some of the SemEval 16 files do
                    # not contain targets and are only category sentiment files
                    if 'target' in opinion.attrib:
                        # Handle the case where there is a category but no
                        # target
                        target_text = opinion.attrib['target'].replace(
                            u'\xa0', u' ')
                        span_from = int(opinion.attrib['from'])
                        span_to = int(opinion.attrib['to'])
                        # Special cases for poor annotation in SemEval 2016
                        # task 5 subtask 1 Restaurant dataset
                        if text_id == 'DBG#2:15' and target_text == 'NULL':
                            span_from = 0
                            span_to = 0
                        if text_id == "en_Patsy'sPizzeria_478231878:2"\
                           and target_text == 'NULL':
                            span_to = 0
                        if text_id == "en_MercedesRestaurant_478010602:1" \
                           and target_text == 'NULL':
                            span_to = 0
                        if text_id == "en_MiopostoCaffe_479702043:9" \
                           and target_text == 'NULL':
                            span_to = 0
                        if text_id == "en_MercedesRestaurant_478010600:1" \
                           and target_text == 'NULL':
                            span_from = 0
                            span_to = 0
                        if target_text == 'NULL':
                            target_text = None
                            # Special cases for poor annotation in SemEval 2016
                            # task 5 subtask 1 Restaurant dataset
                            if text_id == '1490757:0':
                                target_text = 'restaurant'
                            if text_id == 'TR#1:0' and span_from == 27:
                                target_text = 'spot'
                            if text_id == 'TFS#5:26':
                                target_text = "environment"
                            if text_id == 'en_SchoonerOrLater_477965850:10':
                                target_text = 'Schooner or Later'
                        targets.append(target_text)
                        spans.append(Span(span_from, span_to))
                    categories.append(opinion.attrib['category'])
                    target_sentiments.append(category_target_sentiment)
        target_text_kwargs = {
            'targets': targets,
            'spans': spans,
            'text_id': text_id,
            'target_sentiments': target_sentiments,
            'categories': categories,
            'text': text,
            'category_sentiments': category_sentiments
        }
        for key in target_text_kwargs:
            if not target_text_kwargs[key]:
                target_text_kwargs[key] = None
        target_text = TargetText(**target_text_kwargs)
        target_text_collection.add(target_text)
    return target_text_collection