Beispiel #1
0
    def test_eq(self):
        examples = self._regular_examples()

        example_0 = examples[0]
        example_1 = examples[1]
        example_2 = examples[2]
        assert example_0 != example_1

        assert example_2 == example_2
        assert example_0 == example_0

        del example_2['2']
        assert example_2 == example_1

        # Ensure that it only relies on the text_id and not the content of the
        # target text
        example_2['2'] = TargetText('hello how', '2')
        example_1['2'] = TargetText('another test', '2')
        assert example_2 == example_1

        # Testing instance, if not TargetTextCollection should return False
        text = 'today was good'
        text_id = '1'
        dict_version = [{'text_id': text_id, 'text': text}]
        collection_version = TargetTextCollection(
            [TargetText(**dict_version[0])])
        assert dict_version != collection_version

        # Should return False as they have different text_id's but have the
        # content
        alt_collection_version = TargetTextCollection(
            [TargetText(text_id='2', text=text)])
        assert collection_version != alt_collection_version
        assert len(collection_version) == len(alt_collection_version)
Beispiel #2
0
    def test_one_sample_per_span(self, remove_empty: bool):
        # Case where nothing should change with respect to the number of spans
        # but will change the values target_sentiments to None etc
        target_text = TargetText(
            text_id='0',
            spans=[Span(4, 15)],
            text='The laptop case was great and cover was rubbish',
            target_sentiments=[0],
            targets=['laptop case'])
        collection = TargetTextCollection([target_text])
        new_collection = collection.one_sample_per_span(
            remove_empty=remove_empty)
        assert new_collection == collection
        assert new_collection['0']['spans'] == [Span(4, 15)]
        assert new_collection['0']['target_sentiments'] == None
        assert collection['0']['target_sentiments'] == [0]

        # Should change the number of Spans.
        assert target_text['target_sentiments'] == [0]
        target_text._storage['spans'] = [Span(4, 15), Span(4, 15)]
        target_text._storage['targets'] = ['laptop case', 'laptop case']
        target_text._storage['target_sentiments'] = [0, 1]
        diff_collection = TargetTextCollection([target_text])
        new_collection = diff_collection.one_sample_per_span(
            remove_empty=remove_empty)
        assert new_collection == collection
        assert new_collection['0']['spans'] == [Span(4, 15)]
        assert new_collection['0']['target_sentiments'] == None
        assert diff_collection['0']['target_sentiments'] == [0, 1]
        assert diff_collection['0']['spans'] == [Span(4, 15), Span(4, 15)]
Beispiel #3
0
    def test_force_targets(self):
        text = 'The laptop casewas great and cover was rubbish'
        spans = [Span(4, 15), Span(29, 34)]
        targets = ['laptop case', 'cover']
        target_text = TargetText(text=text,
                                 text_id='0',
                                 targets=targets,
                                 spans=spans)
        text_1 = 'The laptop casewas great andcover was rubbish'
        spans_1 = [Span(4, 15), Span(28, 33)]
        target_text_1 = TargetText(text=text_1,
                                   text_id='1',
                                   targets=targets,
                                   spans=spans_1)

        perfect_text = 'The laptop case was great and cover was rubbish'
        perfect_spans = [Span(4, 15), Span(30, 35)]

        # Test the single case
        test_collection = TargetTextCollection([target_text])
        test_collection.force_targets()
        assert test_collection['0']['text'] == perfect_text
        assert test_collection['0']['spans'] == perfect_spans

        # Test the multiple case
        test_collection = TargetTextCollection([target_text, target_text_1])
        test_collection.force_targets()
        for target_key in ['0', '1']:
            assert test_collection[target_key]['text'] == perfect_text
            assert test_collection[target_key]['spans'] == perfect_spans
def no_pred_values(true_sentiment_key: str, predicted_sentiment_key: str
                   ) -> Tuple[TargetTextCollection, List[str], List[List[str]]]:
    example_1 = TargetText(text_id='1', text='some text')
    example_1[true_sentiment_key] = []
    example_1[predicted_sentiment_key] = []
    example_2 = TargetText(text_id='2', text='some text')
    example_2[true_sentiment_key] = []
    example_2[predicted_sentiment_key] = []

    true_labels = []
    pred_labels = []
    return TargetTextCollection([example_1, example_2]), true_labels, pred_labels
def passable_diff_num_labels(true_sentiment_key: str, 
                             predicted_sentiment_key: str
                             ) -> Tuple[TargetTextCollection, List[str], List[List[str]]]:
    example_1 = TargetText(text_id='1', text='some text')
    example_1[true_sentiment_key] = ['pos', 'neg']
    example_1[predicted_sentiment_key] = [['neg', 'pos']]
    example_2 = TargetText(text_id='2', text='some text')
    example_2[true_sentiment_key] = ['pos', 'neg', 'neu']
    example_2[predicted_sentiment_key] = [['neg', 'neg', 'pos']]

    true_labels = ['pos', 'neg', 'pos', 'neg', 'neu']
    pred_labels = [['neg', 'pos', 'neg', 'neg', 'pos']]
    return TargetTextCollection([example_1, example_2]), true_labels, pred_labels
def empty_preds_examples(true_sentiment_key: str, 
                         predicted_sentiment_key: str,
                         labels_per_text: bool = False
                         ) -> Tuple[TargetTextCollection, List[str], List[List[str]]]:
    example_1 = TargetText(text_id='1', text='some text')
    example_1[true_sentiment_key] = ['pos', 'neg']
    example_1[predicted_sentiment_key] = [[]]
    example_2 = TargetText(text_id='2', text='some text')
    example_2[true_sentiment_key] = ['pos', 'neg', 'neu']
    example_2[predicted_sentiment_key] = [[]]

    true_labels = ['pos', 'neg', 'pos', 'neg', 'neu']
    pred_labels = [[]]
    return TargetTextCollection([example_1, example_2]), true_labels, pred_labels
def passable_example_multiple_preds(true_sentiment_key: str, 
                                    predicted_sentiment_key: str
                                    ) -> TargetTextCollection:
    example_1 = TargetText(text_id='1', text='some text', targets=['some', 'text'], 
                           spans=[Span(0,4), Span(5, 9)])
    example_1[true_sentiment_key] = ['pos', 'neg']
    example_1[predicted_sentiment_key] = [['pos', 'neg'], ['pos', 'neg']]
    example_2 = TargetText(text_id='2', text='some text is', targets=['some', 'text', 'is'], 
                           spans=[Span(0,4), Span(5, 9), Span(10,12)])
    example_2[true_sentiment_key] = ['pos', 'neg', 'neu']
    if predicted_sentiment_key == 'model_2':
        example_2[predicted_sentiment_key] = [['pos', 'neg', 'neu'], ['neg', 'neg', 'pos']]
    else:
        example_2[predicted_sentiment_key] = [['neu', 'neg', 'pos'], ['neg', 'neg', 'pos']]
    return TargetTextCollection([example_1, example_2])
def test_strict_text_accuracy(true_sentiment_key: str, 
                              predicted_sentiment_key: str):
    # Test macro F1 works as should on one set of predictions
    example, _, _ = passable_example(true_sentiment_key, predicted_sentiment_key)
    score = strict_text_accuracy(example, true_sentiment_key, 
                                 predicted_sentiment_key, False, False, None)
    assert 0.0 == score
    # Test it works on multiple predictions
    example, _, _ = passable_example_multiple_preds(true_sentiment_key, 
                                                    predicted_sentiment_key)
    score = strict_text_accuracy(example, true_sentiment_key, 
                                 predicted_sentiment_key, True, False, None)
    assert 0.25 == score
    # Test it works on multiple predictions
    example, _, _ = passable_example_multiple_preds(true_sentiment_key, 
                                                    predicted_sentiment_key)
    score = strict_text_accuracy(example, true_sentiment_key, 
                                 predicted_sentiment_key, False, True, None)
    assert [(0.0), (0.5)] == score
    # Test the case where the TargetCollection has a sentence/text with no 
    # Targets/Predictions. This should raise a ValueError.
    no_target = TargetText(text='hello how are you', text_id='10', 
                           target_sentiments=[], targets=[], spans=[])
    no_target[true_sentiment_key] = []
    no_target[predicted_sentiment_key] = []
    target_examples, _, _ = passable_example_multiple_preds(true_sentiment_key, 
                                                            predicted_sentiment_key)
    all_targets = list(target_examples.values())
    all_targets.append(no_target)      
    test_collection = TargetTextCollection(all_targets)
    assert 3 == len(test_collection)
    with pytest.raises(ValueError):
        strict_text_accuracy(test_collection, true_sentiment_key, 
                             predicted_sentiment_key, True, False, None)
Beispiel #9
0
    def test_sanitize(self):
        # The normal case where no errors should be raised.
        target_text = TargetText(
            text_id='0',
            spans=[Span(4, 15)],
            text='The laptop case was great and cover was rubbish',
            target_sentiments=[0],
            targets=['laptop case'])
        collection = TargetTextCollection([target_text])
        collection.sanitize()

        # The case where an error should be raised
        with pytest.raises(ValueError):
            target_text._storage['spans'] = [Span(3, 15)]
            collection = TargetTextCollection([target_text])
            collection.sanitize()
Beispiel #10
0
 def _target_text_not_align_example(self) -> TargetText:
     text = 'The laptop case; was awful'
     text_id = 'inf'
     spans = [Span(4, 15)]
     targets = ['laptop case']
     return TargetText(text=text,
                       text_id=text_id,
                       spans=spans,
                       targets=targets)
def passable_subset_multiple_preds(true_sentiment_key: str, 
                                   predicted_sentiment_key: str,
                                   labels_per_text: bool = False
                                   ) -> Tuple[TargetTextCollection, List[str], List[List[str]]]:
    example_1 = TargetText(text_id='1', text='some text')
    example_1[true_sentiment_key] = ['pos', 'neg']
    example_1[predicted_sentiment_key] = [['neg', 'neg'], ['neg', 'neg']]
    example_2 = TargetText(text_id='2', text='some text')
    example_2[true_sentiment_key] = ['pos', 'neg', 'neu']
    example_2[predicted_sentiment_key] = [['neu', 'neg', 'neu'], ['neg', 'neg', 'neu']]

    true_labels = ['pos', 'neg', 'pos', 'neg', 'neu']
    pred_labels = [['neg', 'neg', 'neu', 'neg', 'neu'], 
                   ['neg', 'neg', 'neg', 'neg', 'neu']]
    if labels_per_text:
        true_labels = [['pos', 'neg'], ['pos', 'neg', 'neu']]
        pred_labels = [[['neg', 'neg'], ['neu', 'neg', 'neu']], 
                       [['neg', 'neg'], ['neg', 'neg', 'neu']]]
    return TargetTextCollection([example_1, example_2]), true_labels, pred_labels
Beispiel #12
0
    def _target_text_measure_examples(self) -> List[TargetText]:
        text = 'The laptop case was great and cover was rubbish'
        text_id = '0'
        spans = [Span(4, 15), Span(30, 35)]
        targets = ['laptop case', 'cover']
        target_text_0 = TargetText(text=text,
                                   text_id=text_id,
                                   spans=spans,
                                   targets=targets)
        text = 'The laptop price was awful'
        text_id = '1'
        spans = [Span(4, 16)]
        targets = ['laptop price']
        target_text_1 = TargetText(text=text,
                                   text_id=text_id,
                                   spans=spans,
                                   targets=targets)

        return [target_text_0, target_text_1]
Beispiel #13
0
    def test_get_item(self):
        examples = self._regular_examples()
        example_2 = examples[2]

        last_target_text = self._target_text_example()
        assert example_2['2'] == last_target_text

        assert example_2['0'] == TargetText(
            'can be any text as long as id is correct', '0')

        with pytest.raises(KeyError):
            example_2['any key']
Beispiel #14
0
    def test_target_count(self):
        # Start with an empty collection
        test_collection = TargetTextCollection()
        nothing = test_collection.target_count()
        assert len(nothing) == 0
        assert not nothing

        # Collection that contains TargetText instances but with no targets
        test_collection.add(TargetText(text='some text', text_id='1'))
        assert len(test_collection) == 1
        nothing = test_collection.target_count()
        assert len(nothing) == 0
        assert not nothing

        # Collection now contains at least one target
        test_collection.add(
            TargetText(text='another item today',
                       text_id='2',
                       spans=[Span(0, 12)],
                       targets=['another item']))
        assert len(test_collection) == 2
        one = test_collection.target_count()
        assert len(one) == 1
        assert one == {'another item': 1}

        # Collection now contains 3 targets but 2 are the same
        test_collection.add(
            TargetText(text='another item today',
                       text_id='3',
                       spans=[Span(0, 12)],
                       targets=['another item']))
        test_collection.add(
            TargetText(text='item today',
                       text_id='4',
                       spans=[Span(0, 4)],
                       targets=['item']))
        assert len(test_collection) == 4
        two = test_collection.target_count()
        assert len(two) == 2
        assert two == {'another item': 2, 'item': 1}
Beispiel #15
0
    def test_samples_with_targets(self):
        # Test the case where all of the TargetTextCollection contain targets
        test_collection = TargetTextCollection(self._target_text_examples())
        sub_collection = test_collection.samples_with_targets()
        assert test_collection == sub_collection
        assert len(test_collection) == 3

        # Test the case where none of the TargetTextCollection contain targets
        for sample_id in list(test_collection.keys()):
            del test_collection[sample_id]
            test_collection.add(TargetText(text='nothing', text_id=sample_id))
        assert len(test_collection) == 3
        sub_collection = test_collection.samples_with_targets()
        assert len(sub_collection) == 0
        assert sub_collection != test_collection

        # Test the case where only a 2 of the the three TargetTextCollection
        # contain targets.
        test_collection = TargetTextCollection(self._target_text_examples())
        del test_collection['another_id']
        easy_case = TargetText(text='something else', text_id='another_id')
        test_collection.add(easy_case)
        sub_collection = test_collection.samples_with_targets()
        assert len(sub_collection) == 2
        assert sub_collection != test_collection

        # Test the case where the targets are just an empty list rather than
        # None
        test_collection = TargetTextCollection(self._target_text_examples())
        del test_collection['another_id']
        edge_case = TargetText(text='something else',
                               text_id='another_id',
                               targets=[],
                               spans=[])
        test_collection.add(edge_case)
        sub_collection = test_collection.samples_with_targets()
        assert len(sub_collection) == 2
        assert sub_collection != test_collection
Beispiel #16
0
    def parse_tweet(tweet_data: Dict[str, Any], annotation_data: Dict[str,
                                                                      Any],
                    tweet_id: str) -> TargetText:
        '''
        :params tweet_data: Data containing the Tweet information
        :params annotation_data: Data containing the annotation data on the 
                                 Tweet
        :params tweet_id: ID of the Tweet
        :returns: The Tweet data in 
                  :class:`target_extraction.data_types.TargetText` format
        :raises ValueError: If the Target offset cannot be found.
        '''
        def get_offsets(from_offset: int, tweet_text: str,
                        target: str) -> Span:
            offset_shifts = [0, -1, 1]
            for offset_shift in offset_shifts:
                from_offset_shift = from_offset + offset_shift
                to_offset = from_offset_shift + len(target)
                offsets = Span(from_offset_shift, to_offset)
                offset_text = tweet_text[from_offset_shift:to_offset].lower()
                if offset_text == target.lower():
                    return offsets
            raise ValueError(
                f'Offset {from_offset} does not match target text'
                f' {target}. Full text {tweet_text}\nid {tweet_id}')

        target_id = str(tweet_id)
        target_text = tweet_data['content']
        target_categories = None
        target_category_sentiments = None
        targets = []
        target_spans = []
        target_sentiments = []
        for entity in tweet_data['entities']:
            target_sentiment = annotation_data['items'][str(entity['id'])]
            if target_sentiment == 'doesnotapply':
                continue

            target = entity['entity']
            target_span = get_offsets(entity['offset'], target_text, target)
            # Take the target from the text as sometimes the original label
            # is lower cased when it should not be according to the text.
            target = target_text[target_span.start:target_span.end]

            targets.append(target)
            target_spans.append(target_span)
            target_sentiments.append(target_sentiment)
        return TargetText(target_text, target_id, targets, target_spans,
                          target_sentiments, target_categories,
                          target_category_sentiments)
Beispiel #17
0
    def test_set_item(self):
        new_collection = TargetTextCollection()

        # Full target text.
        new_collection['2'] = self._target_text_example()
        # Minimum target text
        new_collection['2'] = TargetText('minimum example', '2')
        example_args = {'text': 'minimum example', 'text_id': '2'}
        new_collection['2'] = TargetText(**example_args)

        with pytest.raises(ValueError):
            new_collection['2'] = TargetText('minimum example', '3')
        with pytest.raises(TypeError):
            new_collection['2'] = example_args

        # Ensure that if the given TargetText changes it does not change in
        # the collection
        example_instance = TargetText(**example_args)
        example_collection = TargetTextCollection()
        example_collection['2'] = example_instance

        example_instance['target_sentiments'] = [0]
        assert example_instance['target_sentiments'] is not None
        assert example_collection['2']['target_sentiments'] is None
Beispiel #18
0
    def _target_text_examples(self) -> List[TargetText]:
        text = 'The laptop case was great and cover was rubbish'
        text_ids = ['0', 'another_id', '2']
        spans = [[Span(4, 15)], [Span(30, 35)], [Span(4, 15), Span(30, 35)]]
        target_sentiments = [[0], [1], [0, 1]]
        targets = [['laptop case'], ['cover'], ['laptop case', 'cover']]
        categories = [['LAPTOP#CASE'], ['LAPTOP'], ['LAPTOP#CASE', 'LAPTOP']]

        target_text_examples = []
        for i in range(3):
            example = TargetText(text,
                                 text_ids[i],
                                 targets=targets[i],
                                 spans=spans[i],
                                 target_sentiments=target_sentiments[i],
                                 categories=categories[i])
            target_text_examples.append(example)
        return target_text_examples
Beispiel #19
0
def test_target_length_plot():
    # standard/normal case
    ax = target_length_plot([TRAIN_COLLECTION], 'targets', whitespace())
    del ax
    # cumulative percentage True
    ax = target_length_plot([TRAIN_COLLECTION],
                            'targets',
                            whitespace(),
                            cumulative_percentage=True)
    del ax
    # Max target length
    ax = target_length_plot([TRAIN_COLLECTION],
                            'targets',
                            whitespace(),
                            cumulative_percentage=True,
                            max_target_length=1)
    del ax
    # Can take consume an axes
    fig, alt_ax = plt.subplots(1, 1)
    ax = target_length_plot([TRAIN_COLLECTION],
                            'targets',
                            whitespace(),
                            cumulative_percentage=True,
                            max_target_length=1,
                            ax=alt_ax)
    assert alt_ax == ax
    del ax
    plt.close(fig)

    # Can take more than one collection
    alt_collection = copy.deepcopy(list(TRAIN_COLLECTION.dict_iterator()))
    alt_collection = [TargetText(**v) for v in alt_collection]
    alt_collection = TargetTextCollection(alt_collection)
    alt_collection.name = 'Another'
    ax = target_length_plot([TRAIN_COLLECTION, alt_collection],
                            'targets',
                            whitespace(),
                            cumulative_percentage=True,
                            max_target_length=1)
    assert alt_ax != ax
    del alt_ax
    del ax
def test_overall_metric_results(true_sentiment_key: str, 
                                include_metadata: bool,
                                strict_accuracy_metrics: bool):
    if true_sentiment_key is None:
        true_sentiment_key = 'target_sentiments'
    model_1_collection = passable_example_multiple_preds(true_sentiment_key, 'model_1')
    model_1_collection.add(TargetText(text='a', text_id='200', spans=[Span(0,1)], targets=['a'],
                                      model_1=[['pos'], ['neg']], **{f'{true_sentiment_key}': ['pos']}))
    model_1_collection.add(TargetText(text='a', text_id='201', spans=[Span(0,1)], targets=['a'],
                                      model_1=[['pos'], ['neg']], **{f'{true_sentiment_key}': ['neg']}))
    model_1_collection.add(TargetText(text='a', text_id='202', spans=[Span(0,1)], targets=['a'],
                                      model_1=[['pos'], ['neg']], **{f'{true_sentiment_key}': ['neu']}))
    print(true_sentiment_key)
    print(model_1_collection['1'])
    print(model_1_collection['200'])
    model_2_collection = passable_example_multiple_preds(true_sentiment_key, 'model_2')
    combined_collection = TargetTextCollection()
    
    standard_columns = ['Dataset', 'Macro F1', 'Accuracy', 'run number', 
                        'prediction key']
    if strict_accuracy_metrics:
        standard_columns = standard_columns + ['STAC', 'STAC 1', 'STAC Multi']
    if include_metadata:
        metadata = {'predicted_target_sentiment_key': {'model_1': {'CWR': True},
                                                       'model_2': {'CWR': False}}}
        combined_collection.name = 'name'
        combined_collection.metadata = metadata
        standard_columns.append('CWR')
    number_df_columns = len(standard_columns)

    for key, value in model_1_collection.items():
        if key in ['200', '201', '202']:
            combined_collection.add(value)
            combined_collection[key]['model_2'] = [['neg'], ['pos']]
            continue
        combined_collection.add(value)
        combined_collection[key]['model_2'] = model_2_collection[key]['model_2']
    if true_sentiment_key is None:
        result_df = util.overall_metric_results(combined_collection, 
                                                ['model_1', 'model_2'], 
                                                strict_accuracy_metrics=strict_accuracy_metrics)
    else:
        result_df = util.overall_metric_results(combined_collection, 
                                                ['model_1', 'model_2'],
                                                true_sentiment_key, 
                                                strict_accuracy_metrics=strict_accuracy_metrics)
    assert (4, number_df_columns) == result_df.shape
    assert set(standard_columns) == set(result_df.columns)
    if include_metadata:
        assert ['name'] * 4 == result_df['Dataset'].tolist()
    else:
        assert [''] * 4 == result_df['Dataset'].tolist()
    # Test the case where only one model is used
    if true_sentiment_key is None:
        result_df = util.overall_metric_results(combined_collection, 
                                                ['model_1'], 
                                                strict_accuracy_metrics=strict_accuracy_metrics)
    else:
        result_df = util.overall_metric_results(combined_collection, 
                                                ['model_1'],
                                                true_sentiment_key, 
                                                strict_accuracy_metrics=strict_accuracy_metrics)
    assert (2, number_df_columns) == result_df.shape
    # Test the case where the model names come from the metadata
    if include_metadata:
        result_df = util.overall_metric_results(combined_collection, 
                                                true_sentiment_key=true_sentiment_key, 
                                                strict_accuracy_metrics=strict_accuracy_metrics)
        assert (4, number_df_columns) == result_df.shape
    else:
        with pytest.raises(KeyError):
            util.overall_metric_results(combined_collection, 
                                        true_sentiment_key=true_sentiment_key, 
                                        strict_accuracy_metrics=strict_accuracy_metrics)
Beispiel #21
0
def multi_aspect_multi_sentiment_atsa(
        dataset: str,
        cache_dir: Optional[Path] = None,
        original: bool = True) -> TargetTextCollection:
    '''
    The data for this function when downloaded is stored within: 
    `Path(cache_dir, 'Jiang 2019 MAMS ATSA')

    :NOTE: That as each sentence/`TargetText` object has to have 
           a `text_id`, as no ids exist in this dataset the ids are created 
           based on when the sentence occurs in the dataset e.g. the first 
           sentence/`TargetText` object id is '0'

    :param dataset: Either `train`, `val` or `test`, determines the dataset that 
                    is returned.
    :param cache_dir: The directory where all of the data is stored for 
                      this code base. If None then the cache directory is
                      `dataset_parsers.CACHE_DIRECTORY`
    :param original: This does not affect `val` or `test`. If True then it will 
                     download the original training data from the `original paper 
                     <https://www.aclweb.org/anthology/D19-1654.pdf>`_ . Else 
                     it will download the cleaned Training dataset version. The 
                     cleaned version only contains a few sample differences 
                     but these differences are with respect to overlapping 
                     targets. See this `notebook for full differences 
                     <https://github.com/apmoore1/target-extraction/blob/master/tutorials/Difference_between_MAMS_ATSA_original_and_MAMS_ATSA_cleaned.ipynb>`_:
                     
    :returns: The `train`, `val`, or `test` dataset from the 
              Multi-Aspect-Multi-Sentiment dataset (MAMS) ATSA version. 
              Dataset came from the `A Challenge Dataset and Effective Models  
              for Aspect-Based Sentiment Analysis, EMNLP 2019 
              <https://www.aclweb.org/anthology/D19-1654.pdf>`_
    :raises ValueError: If the `dataset` value is not `train`, `val`, or `test`
    '''
    accepted_datasets = {'train', 'val', 'test'}
    if dataset not in accepted_datasets:
        raise ValueError('dataset has to be one of these values '
                         f'{accepted_datasets}, not {dataset}')
    if cache_dir is None:
        cache_dir = CACHE_DIRECTORY
    data_folder = Path(cache_dir, 'Jiang 2019 MAMS ATSA')
    data_folder.mkdir(parents=True, exist_ok=True)

    dataset_url = {
        'train':
        'https://github.com/siat-nlp/MAMS-for-ABSA/raw/master/data/MAMS-ATSA/raw/train.xml',
        'val':
        'https://github.com/siat-nlp/MAMS-for-ABSA/raw/master/data/MAMS-ATSA/raw/val.xml',
        'test':
        'https://github.com/siat-nlp/MAMS-for-ABSA/raw/master/data/MAMS-ATSA/raw/test.xml'
    }
    url = dataset_url[dataset]
    if dataset == 'train' and not original:
        url = 'https://raw.githubusercontent.com/apmoore1/target-extraction/master/data/MAMS/MAMS_ATSA_cleaned_train.xml'
    data_fp = Path(cached_path(url, cache_dir=data_folder))

    # Parsing the data
    target_text_collection = TargetTextCollection()
    tree = ET.parse(data_fp)
    sentences = tree.getroot()
    for sentence_id, sentence in enumerate(sentences):
        targets: List[str] = []
        target_sentiments: List[Union[str, int]] = []
        spans: List[Span] = []

        for data in sentence:
            if data.tag == 'text':
                text = data.text
                text = text.replace(u'\xa0', u' ')
            elif data.tag == 'aspectTerms':
                for target in data:
                    target_sentiment = target.attrib['polarity']
                    target_sentiments.append(target_sentiment)
                    targets.append(target.attrib['term'].replace(
                        u'\xa0', u' '))
                    span_from = int(target.attrib['from'])
                    span_to = int(target.attrib['to'])
                    spans.append(Span(span_from, span_to))
            else:
                raise ValueError(f'This tag {data.tag} should not occur '
                                 'within a sentence tag')
        target_text_kwargs = {
            'targets': targets,
            'spans': spans,
            'text_id': f'{dataset}${str(sentence_id)}',
            'target_sentiments': target_sentiments,
            'categories': None,
            'text': text,
            'category_sentiments': None
        }
        for key in target_text_kwargs:
            if not target_text_kwargs[key]:
                target_text_kwargs[key] = None
        target_text = TargetText(**target_text_kwargs)
        target_text_collection.add(target_text)
    return target_text_collection
Beispiel #22
0
def multi_aspect_multi_sentiment_acsa(
        dataset: str,
        cache_dir: Optional[Path] = None) -> TargetTextCollection:
    '''
    The data for this function when downloaded is stored within: 
    `Path(cache_dir, 'Jiang 2019 MAMS ACSA')

    :NOTE: That as each sentence/`TargetText` object has to have 
           a `text_id`, as no ids exist in this dataset the ids are created 
           based on when the sentence occurs in the dataset e.g. the first 
           sentence/`TargetText` object id is '0'

    For reference this dataset has 8 different aspect categories.

    :param dataset: Either `train`, `val` or `test`, determines the dataset that 
                    is returned.
    :param cache_dir: The directory where all of the data is stored for 
                      this code base. If None then the cache directory is
                      `dataset_parsers.CACHE_DIRECTORY`
    :returns: The `train`, `val`, or `test` dataset from the 
              Multi-Aspect-Multi-Sentiment dataset (MAMS) ACSA version. 
              Dataset came from the `A Challenge Dataset and Effective Models  
              for Aspect-Based Sentiment Analysis, EMNLP 2019 
              <https://www.aclweb.org/anthology/D19-1654.pdf>`_
    :raises ValueError: If the `dataset` value is not `train`, `val`, or `test`
    '''
    accepted_datasets = {'train', 'val', 'test'}
    if dataset not in accepted_datasets:
        raise ValueError('dataset has to be one of these values '
                         f'{accepted_datasets}, not {dataset}')
    if cache_dir is None:
        cache_dir = CACHE_DIRECTORY
    data_folder = Path(cache_dir, 'Jiang 2019 MAMS ACSA')
    data_folder.mkdir(parents=True, exist_ok=True)

    dataset_url = {
        'train':
        'https://github.com/siat-nlp/MAMS-for-ABSA/raw/master/data/MAMS-ACSA/raw/train.xml',
        'val':
        'https://github.com/siat-nlp/MAMS-for-ABSA/raw/master/data/MAMS-ACSA/raw/val.xml',
        'test':
        'https://github.com/siat-nlp/MAMS-for-ABSA/raw/master/data/MAMS-ACSA/raw/test.xml'
    }
    url = dataset_url[dataset]
    data_fp = Path(cached_path(url, cache_dir=data_folder))

    # Parsing the data
    category_text_collection = TargetTextCollection()
    tree = ET.parse(data_fp)
    sentences = tree.getroot()
    for sentence_id, sentence in enumerate(sentences):
        categories: List[str] = []
        category_sentiments: List[Union[str, int]] = []

        for data in sentence:
            if data.tag == 'text':
                text = data.text
                text = text.replace(u'\xa0', u' ')
            elif data.tag == 'aspectCategories':
                for category in data:
                    category_sentiment = category.attrib['polarity']
                    category_sentiments.append(category_sentiment)
                    categories.append(category.attrib['category'].replace(
                        u'\xa0', u' '))
            else:
                raise ValueError(f'This tag {data.tag} should not occur '
                                 'within a sentence tag')
        category_text_kwargs = {
            'targets': None,
            'spans': None,
            'text_id': f'{dataset}${str(sentence_id)}',
            'target_sentiments': None,
            'categories': categories,
            'text': text,
            'category_sentiments': category_sentiments
        }
        for key in category_text_kwargs:
            if not category_text_kwargs[key]:
                category_text_kwargs[key] = None
        category_text = TargetText(**category_text_kwargs)
        category_text_collection.add(category_text)
    return category_text_collection
Beispiel #23
0
    parser.add_argument("augmented_training_dataset", type=parse_path, 
                        help='File path to the augmented training dataset')
    parser.add_argument("expanded_targets_fp", type=parse_path, 
                        help='File path to the expanded targets json file')
    args = parser.parse_args()

    with args.expanded_targets_fp.open('r') as expanded_targets_file:
        targets_equivalents: Dict[str, str] = json.load(expanded_targets_file)
    assert len(targets_equivalents) > 1

    expanded_target_counts = Counter()
    number_training_samples = 0
    number_targets_expanded = 0
    with args.augmented_training_dataset.open('r') as training_file:
        for line in training_file:
            training_sample = TargetText.from_json(line)
            number_targets = len(training_sample['targets'])
            number_training_samples += number_targets
            for target_index in range(number_targets):
                original_target = training_sample['targets'][target_index]
                if original_target.lower() not in targets_equivalents:
                    continue 
                number_targets_expanded += 1

                expanded_target_key = f'target {target_index}'
                expanded_targets = training_sample[expanded_target_key]
                assert original_target in expanded_targets
                number_expanded_targets = len(expanded_targets) - 1
                assert len(expanded_targets) == len(set(expanded_targets))
                expanded_target_counts.update([number_expanded_targets])
Beispiel #24
0
    def test_exact_match_score(self):
        # Simple case where it should get perfect score
        test_collection = TargetTextCollection([self._target_text_example()])
        test_collection.tokenize(spacy_tokenizer())
        test_collection.sequence_labels()
        measures = test_collection.exact_match_score('sequence_labels')
        for index, measure in enumerate(measures):
            if index == 3:
                assert measure['FP'] == []
                assert measure['FN'] == []
                assert measure['TP'] == [('2', Span(4, 15)),
                                         ('2', Span(30, 35))]
            else:
                assert measure == 1.0

        # Something that has perfect precision but misses one therefore does
        # not have perfect recall nor f1
        test_collection = TargetTextCollection(
            self._target_text_measure_examples())
        test_collection.tokenize(str.split)
        # text = 'The laptop case was great and cover was rubbish'
        sequence_labels_0 = ['O', 'O', 'O', 'O', 'O', 'O', 'B', 'O', 'O']
        test_collection['0']['sequence_labels'] = sequence_labels_0
        # text = 'The laptop price was awful'
        sequence_labels_1 = ['O', 'B', 'I', 'O', 'O']
        test_collection['1']['sequence_labels'] = sequence_labels_1
        recall, precision, f1, error_analysis = test_collection.exact_match_score(
            'sequence_labels')
        assert precision == 1.0
        assert recall == 2.0 / 3.0
        assert f1 == 0.8
        assert error_analysis['FP'] == []
        assert error_analysis['FN'] == [('0', Span(4, 15))]
        assert error_analysis['TP'] == [('0', Span(30, 35)), ('1', Span(4,
                                                                        16))]

        # Something that has perfect recall but not precision as it over
        # predicts
        sequence_labels_0 = ['O', 'B', 'I', 'B', 'O', 'O', 'B', 'O', 'O']
        test_collection['0']['sequence_labels'] = sequence_labels_0
        sequence_labels_1 = ['O', 'B', 'I', 'O', 'O']
        test_collection['1']['sequence_labels'] = sequence_labels_1
        recall, precision, f1, error_analysis = test_collection.exact_match_score(
            'sequence_labels')
        assert precision == 3 / 4
        assert recall == 1.0
        assert round(f1, 3) == 0.857
        assert error_analysis['FP'] == [('0', Span(16, 19))]
        assert error_analysis['FN'] == []
        assert error_analysis['TP'] == [('0', Span(4, 15)), ('0', Span(30,
                                                                       35)),
                                        ('1', Span(4, 16))]

        # Does not predict anything for a whole sentence therefore will have
        # perfect precision but bad recall (mainly testing the if not
        # getting anything for a sentence matters)
        sequence_labels_0 = ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
        test_collection['0']['sequence_labels'] = sequence_labels_0
        sequence_labels_1 = ['O', 'B', 'I', 'O', 'O']
        test_collection['1']['sequence_labels'] = sequence_labels_1
        recall, precision, f1, error_analysis = test_collection.exact_match_score(
            'sequence_labels')
        assert precision == 1.0
        assert recall == 1 / 3
        assert f1 == 0.5
        assert error_analysis['FP'] == []
        fn_error = sorted(error_analysis['FN'], key=lambda x: x[1].start)
        assert fn_error == [('0', Span(4, 15)), ('0', Span(30, 35))]
        assert error_analysis['TP'] == [('1', Span(4, 16))]

        # Handle the edge case of not getting anything
        sequence_labels_0 = ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
        test_collection['0']['sequence_labels'] = sequence_labels_0
        sequence_labels_1 = ['O', 'O', 'O', 'O', 'O']
        test_collection['1']['sequence_labels'] = sequence_labels_1
        recall, precision, f1, error_analysis = test_collection.exact_match_score(
            'sequence_labels')
        assert precision == 0.0
        assert recall == 0.0
        assert f1 == 0.0
        assert error_analysis['FP'] == []
        fn_error = sorted(error_analysis['FN'], key=lambda x: x[1].start)
        assert fn_error == [('0', Span(4, 15)), ('1', Span(4, 16)),
                            ('0', Span(30, 35))]
        assert error_analysis['TP'] == []

        # The case where the tokens and the text do not align
        not_align_example = self._target_text_not_align_example()
        # text = 'The laptop case; was awful'
        sequence_labels_align = ['O', 'B', 'I', 'O', 'O']
        test_collection.add(not_align_example)
        test_collection.tokenize(str.split)
        test_collection['inf']['sequence_labels'] = sequence_labels_align
        sequence_labels_0 = ['O', 'B', 'I', 'O', 'O', 'O', 'B', 'O', 'O']
        test_collection['0']['sequence_labels'] = sequence_labels_0
        sequence_labels_1 = ['O', 'B', 'I', 'O', 'O']
        test_collection['1']['sequence_labels'] = sequence_labels_1
        recall, precision, f1, error_analysis = test_collection.exact_match_score(
            'sequence_labels')
        assert recall == 3 / 4
        assert precision == 3 / 4
        assert f1 == 0.75
        assert error_analysis['FP'] == [('inf', Span(4, 16))]
        assert error_analysis['FN'] == [('inf', Span(4, 15))]
        tp_error = sorted(error_analysis['TP'], key=lambda x: x[1].start)
        assert tp_error == [('0', Span(4, 15)), ('1', Span(4, 16)),
                            ('0', Span(30, 35))]

        # This time it can get a perfect score as the token alignment will be
        # perfect
        test_collection.tokenize(spacy_tokenizer())
        sequence_labels_align = ['O', 'B', 'I', 'O', 'O', 'O']
        test_collection['inf']['sequence_labels'] = sequence_labels_align
        recall, precision, f1, error_analysis = test_collection.exact_match_score(
            'sequence_labels')
        assert recall == 1.0
        assert precision == 1.0
        assert f1 == 1.0
        assert error_analysis['FP'] == []
        assert error_analysis['FN'] == []
        tp_error = sorted(error_analysis['TP'], key=lambda x: x[1].end)
        assert tp_error == [('0', Span(4, 15)), ('inf', Span(4, 15)),
                            ('1', Span(4, 16)), ('0', Span(30, 35))]

        # Handle the case where one of the samples has no spans
        test_example = TargetText(text="I've had a bad day", text_id='50')
        other_examples = self._target_text_measure_examples()
        other_examples.append(test_example)
        test_collection = TargetTextCollection(other_examples)
        test_collection.tokenize(str.split)
        test_collection.sequence_labels()
        measures = test_collection.exact_match_score('sequence_labels')
        for index, measure in enumerate(measures):
            if index == 3:
                assert measure['FP'] == []
                assert measure['FN'] == []
                tp_error = sorted(measure['TP'], key=lambda x: x[1].end)
                assert tp_error == [('0', Span(4, 15)), ('1', Span(4, 16)),
                                    ('0', Span(30, 35))]
            else:
                assert measure == 1.0
        # Handle the case where on the samples has no spans but has predicted
        # there is a span there
        test_collection['50']['sequence_labels'] = ['B', 'I', 'O', 'O', 'O']
        recall, precision, f1, error_analysis = test_collection.exact_match_score(
            'sequence_labels')
        assert recall == 1.0
        assert precision == 3 / 4
        assert round(f1, 3) == 0.857
        assert error_analysis['FP'] == [('50', Span(start=0, end=8))]
        assert error_analysis['FN'] == []
        tp_error = sorted(error_analysis['TP'], key=lambda x: x[1].end)
        assert tp_error == [('0', Span(4, 15)), ('1', Span(4, 16)),
                            ('0', Span(30, 35))]
        # See if it can handle a collection that only contains no spans
        test_example = TargetText(text="I've had a bad day", text_id='50')
        test_collection = TargetTextCollection([test_example])
        test_collection.tokenize(str.split)
        test_collection.sequence_labels()
        measures = test_collection.exact_match_score('sequence_labels')
        for index, measure in enumerate(measures):
            if index == 3:
                assert measure['FP'] == []
                assert measure['FN'] == []
                assert measure['TP'] == []
            else:
                assert measure == 0.0
        # Handle the case the collection contains one spans but a mistake
        test_collection['50']['sequence_labels'] = ['B', 'I', 'O', 'O', 'O']
        measures = test_collection.exact_match_score('sequence_labels')
        for index, measure in enumerate(measures):
            if index == 3:
                assert measure['FP'] == [('50', Span(0, 8))]
                assert measure['FN'] == []
                assert measure['TP'] == []
            else:
                assert measure == 0.0
        # Should raise a KeyError if one of the TargetText instances does
        # not have a Span key
        del test_collection['50']._storage['spans']
        with pytest.raises(KeyError):
            test_collection.exact_match_score('sequence_labels')
        # should raise a KeyError if one of the TargetText instances does
        # not have a predicted sequence key
        test_collection = TargetTextCollection([self._target_text_example()])
        test_collection.tokenize(spacy_tokenizer())
        test_collection.sequence_labels()
        with pytest.raises(KeyError):
            measures = test_collection.exact_match_score('nothing')

        # Should raise a ValueError if there are multiple same true spans
        a = TargetText(text='hello how are you I am good',
                       text_id='1',
                       targets=['hello', 'hello'],
                       spans=[Span(0, 5), Span(0, 5)])
        test_collection = TargetTextCollection([a])
        test_collection.tokenize(str.split)
        test_collection['1']['sequence_labels'] = [
            'B', 'O', 'O', 'O', 'O', 'O', 'O'
        ]
        with pytest.raises(ValueError):
            test_collection.exact_match_score('sequence_labels')
Beispiel #25
0
def _semeval_extract_data(sentence_tree: Element,
                          conflict: bool) -> TargetTextCollection:
    '''
    :param sentence_tree: The root element of the XML tree that has come 
                          from a SemEval XML formatted XML File.
    :param conflict: Whether or not to include targets or categories that 
                     have the `conflict` sentiment value. True is to include 
                     conflict targets and categories.
    :returns: The SemEval data formatted into a 
              `target_extraction.data_types.TargetTextCollection` object.
    '''
    target_text_collection = TargetTextCollection()
    for sentence in sentence_tree:
        text_id = sentence.attrib['id']

        targets: List[str] = []
        target_sentiments: List[Union[str, int]] = []
        spans: List[Span] = []

        category_sentiments: List[Union[str, int]] = []
        categories: List[str] = []

        for data in sentence:
            if data.tag == 'text':
                text = data.text
                text = text.replace(u'\xa0', u' ')
            elif data.tag == 'aspectTerms':
                for target in data:
                    # If it is a conflict sentiment and conflict argument True
                    # skip this target
                    target_sentiment = target.attrib['polarity']
                    if not conflict and target_sentiment == 'conflict':
                        continue
                    targets.append(target.attrib['term'].replace(
                        u'\xa0', u' '))
                    target_sentiments.append(target_sentiment)
                    span_from = int(target.attrib['from'])
                    span_to = int(target.attrib['to'])
                    spans.append(Span(span_from, span_to))
            elif data.tag == 'aspectCategories':
                for category in data:
                    # If it is a conflict sentiment and conflict argument True
                    # skip this category
                    category_sentiment = category.attrib['polarity']
                    if not conflict and category_sentiment == 'conflict':
                        continue
                    categories.append(category.attrib['category'])
                    category_sentiments.append(category.attrib['polarity'])
            elif data.tag == 'Opinions':
                for opinion in data:
                    category_target_sentiment = opinion.attrib['polarity']
                    if not conflict and category_target_sentiment == 'conflict':
                        continue
                    # Handle the case where some of the SemEval 16 files do
                    # not contain targets and are only category sentiment files
                    if 'target' in opinion.attrib:
                        # Handle the case where there is a category but no
                        # target
                        target_text = opinion.attrib['target'].replace(
                            u'\xa0', u' ')
                        span_from = int(opinion.attrib['from'])
                        span_to = int(opinion.attrib['to'])
                        # Special cases for poor annotation in SemEval 2016
                        # task 5 subtask 1 Restaurant dataset
                        if text_id == 'DBG#2:15' and target_text == 'NULL':
                            span_from = 0
                            span_to = 0
                        if text_id == "en_Patsy'sPizzeria_478231878:2"\
                           and target_text == 'NULL':
                            span_to = 0
                        if text_id == "en_MercedesRestaurant_478010602:1" \
                           and target_text == 'NULL':
                            span_to = 0
                        if text_id == "en_MiopostoCaffe_479702043:9" \
                           and target_text == 'NULL':
                            span_to = 0
                        if text_id == "en_MercedesRestaurant_478010600:1" \
                           and target_text == 'NULL':
                            span_from = 0
                            span_to = 0
                        if target_text == 'NULL':
                            target_text = None
                            # Special cases for poor annotation in SemEval 2016
                            # task 5 subtask 1 Restaurant dataset
                            if text_id == '1490757:0':
                                target_text = 'restaurant'
                            if text_id == 'TR#1:0' and span_from == 27:
                                target_text = 'spot'
                            if text_id == 'TFS#5:26':
                                target_text = "environment"
                            if text_id == 'en_SchoonerOrLater_477965850:10':
                                target_text = 'Schooner or Later'
                        targets.append(target_text)
                        spans.append(Span(span_from, span_to))
                    categories.append(opinion.attrib['category'])
                    target_sentiments.append(category_target_sentiment)
        target_text_kwargs = {
            'targets': targets,
            'spans': spans,
            'text_id': text_id,
            'target_sentiments': target_sentiments,
            'categories': categories,
            'text': text,
            'category_sentiments': category_sentiments
        }
        for key in target_text_kwargs:
            if not target_text_kwargs[key]:
                target_text_kwargs[key] = None
        target_text = TargetText(**target_text_kwargs)
        target_text_collection.add(target_text)
    return target_text_collection
Beispiel #26
0
def test_dataset_target_extraction_statistics(lower: bool,
                                              incl_sentence_statistics: bool):
    if lower is not None:
        target_stats = dataset_target_extraction_statistics(
            [TRAIN_COLLECTION],
            lower_target=lower,
            incl_sentence_statistics=incl_sentence_statistics)
    else:
        target_stats = dataset_target_extraction_statistics(
            [TRAIN_COLLECTION],
            incl_sentence_statistics=incl_sentence_statistics)
    tl_1 = round((17 / 19.0) * 100, 2)
    tl_2 = round((2 / 19.0) * 100, 2)
    true_stats = {
        'Name': 'train',
        'No. Sentences': 6,
        'No. Sentences(t)': 5,
        'No. Targets': 19,
        'No. Uniq Targets': 13,
        'ATS': round(19 / 6.0, 2),
        'ATS(t)': round(19 / 5.0, 2),
        'TL 1 %': tl_1,
        'TL 2 %': tl_2,
        'TL 3+ %': 0.0,
        'Mean Sentence Length': 15.33,
        'Mean Sentence Length(t)': 16.6
    }
    if not incl_sentence_statistics:
        del true_stats['Mean Sentence Length(t)']
        del true_stats['Mean Sentence Length']
    if lower == False:
        true_stats['No. Uniq Targets'] = 14
    assert 1 == len(target_stats)
    target_stats = target_stats[0]
    assert len(true_stats) == len(target_stats)
    for stat_name, stat in true_stats.items():
        if re.search(r'^TL', stat_name):
            assert math.isclose(stat, target_stats[stat_name], rel_tol=0.001)
        else:
            assert stat == target_stats[stat_name], stat_name

    # Multiple collections, where one collection is just the subset of the other
    subcollection = TargetTextCollection(name='sub')
    subcollection.add(TRAIN_COLLECTION["81207500773427072"])
    subcollection.add(TRAIN_COLLECTION["78522643479064576"])
    long_target = TargetText(
        text='some text that contains a long target or two',
        spans=[Span(0, 14), Span(15, 37)],
        targets=['some text that', 'contains a long target'],
        target_sentiments=['positive', 'negative'],
        text_id='100')
    subcollection.add(long_target)
    subcollection.tokenize(whitespace())
    if lower is not None:
        target_stats = dataset_target_extraction_statistics(
            [subcollection, TRAIN_COLLECTION],
            lower_target=lower,
            incl_sentence_statistics=incl_sentence_statistics)
    else:
        target_stats = dataset_target_extraction_statistics(
            [subcollection, TRAIN_COLLECTION],
            incl_sentence_statistics=incl_sentence_statistics)

    tl_1 = round((6 / 9.0) * 100, 2)
    tl_2 = round((1 / 9.0) * 100, 2)
    tl_3 = round((2 / 9.0) * 100, 2)
    sub_stats = {
        'Name': 'sub',
        'No. Sentences': 3,
        'No. Sentences(t)': 3,
        'No. Targets': 9,
        'No. Uniq Targets': 9,
        'ATS': round(9 / 3.0, 2),
        'ATS(t)': round(9 / 3.0, 2),
        'TL 1 %': tl_1,
        'TL 2 %': tl_2,
        'TL 3+ %': tl_3,
        'Mean Sentence Length': 11.67,
        'Mean Sentence Length(t)': 11.67
    }
    if not incl_sentence_statistics:
        del sub_stats['Mean Sentence Length(t)']
        del sub_stats['Mean Sentence Length']
    true_stats = [sub_stats, true_stats]
    assert len(true_stats) == len(target_stats)
    for stat_index, stat in enumerate(true_stats):
        test_stat = target_stats[stat_index]
        assert len(stat) == len(test_stat)
        for stat_name, stat_value in stat.items():
            if re.search(r'^TL', stat_name):
                assert math.isclose(stat_value,
                                    test_stat[stat_name],
                                    rel_tol=0.001)
            else:
                assert stat_value == test_stat[stat_name], stat_name
Beispiel #27
0
    train_targets = set(list(train_data.target_count(lower=True).keys()))

    acceptable_confidence = args.confidence_score
    all_targets: List[TargetText] = []
    with args.predicted_target_data_fp.open('r') as predicted_file:
        for index, line in enumerate(predicted_file):
            target_data = json.loads(line)
            target_id = str(index)
            target_data_dict = {
                'text': target_data['text'],
                'text_id': target_id,
                'confidences': target_data['confidence'],
                'sequence_labels': target_data['sequence_labels'],
                'tokenized_text': target_data['tokens']
            }
            target_data = TargetText.target_text_from_prediction(
                **target_data_dict, confidence=acceptable_confidence)
            if target_data['targets']:
                all_targets.append(target_data)
    print(len(all_targets))
    all_targets: TargetTextCollection = TargetTextCollection(all_targets)
    pred_target_dict = all_targets.target_count(lower=True)
    pred_targets = set(list(pred_target_dict.keys()))

    print(f'Number of unique targets in training dataset {len(train_targets)}')
    print(f'Number of unique predicted targets {len(pred_targets)}')
    pred_difference = pred_targets.difference(train_targets)
    print(
        f'Number of unique targets in predicted but not in training {len(pred_difference)}'
    )
    train_difference = train_targets.difference(pred_targets)
    print(
Beispiel #28
0
    def text_to_instance(self, text: str, 
                         targets: Optional[List[str]] = None,
                         target_sentiments: Optional[List[Union[str, int]]] = None,
                         spans: Optional[List[List[int]]] = None,
                         categories: Optional[List[str]] = None,
                         category_sentiments: Optional[List[Union[str, int]]] = None,
                         **kwargs) -> Instance:
        '''
        The original text, text tokens as well as the targets and target 
        tokens are stored in the MetadataField.

        :NOTE: At least targets and/or categories must be present.
        :NOTE: That the left and right contexts returned in the instance are 
               a List of a List of tokens. A list for each Target.

        :param text: The text that contains the target(s) and/or categories.
        :param targets: The targets that are within the text
        :param target_sentiments: The sentiment of the targets. To be used if 
                                  training the classifier
        :param spans: The spans that represent the character offsets for each 
                      of the targets given in the targets list.
        :param categories: The categories that are within the text
        :param category_sentiments: The sentiment of the categories
        :returns: An Instance object with all of the above encoded for a
                  PyTorch model.
        :raises ValueError: If either targets and categories are both None
        :raises ValueError: If `self._target_sequences` is True and the passed 
                            `spans` argument is None.
        :raises ValueError: If `self._left_right_contexts` is True and the 
                            passed `spans` argument is None.
        '''
        if targets is None and categories is None:
            raise ValueError('Either targets or categories must be given if you '
                             'want to be predict the sentiment of a target '
                             'or a category')

        instance_fields: Dict[str, Field] = {}
        

        # Metadata field
        metadata_dict = {}

        if targets is not None:
            # need to change this so that it takes into account the case where 
            # the positions are True but not the target sequences.
            if self._target_sequences or self._position_embeddings or self._position_weights:
                if spans is None:
                    raise ValueError('To create target sequences requires `spans`')
                spans = [Span(span[0], span[1]) for span in spans]
                target_text_object = TargetText(text=text, spans=spans, 
                                                targets=targets, text_id='anything')
                target_text_object.force_targets()
                text = target_text_object['text']
                allen_tokens = self._tokenizer.tokenize(text)
                tokens = [x.text for x in allen_tokens]
                target_text_object['tokenized_text'] = tokens
                target_text_object.sequence_labels(per_target=True)
                target_sequences = target_text_object['sequence_labels']
                # Need to add the target sequences to the instances
                in_label = {'B', 'I'}
                number_targets = len(targets)
                all_target_tokens: List[List[Token]] = [[] for _ in range(number_targets)]
                target_sequence_fields = []
                target_indicators: List[List[int]] = []
                for target_index in range(number_targets):
                    one_values = []
                    target_ones = [0] * len(allen_tokens)
                    for token_index, token in enumerate(allen_tokens):
                        target_sequence_value = target_sequences[target_index][token_index]
                        in_target = 1 if target_sequence_value in in_label else 0
                        if in_target:
                            all_target_tokens[target_index].append(allen_tokens[token_index])
                            one_value_list = [0] * len(allen_tokens)
                            one_value_list[token_index] = 1
                            one_values.append(one_value_list)
                            target_ones[token_index] = 1
                    one_values = np.array(one_values)
                    target_sequence_fields.append(ArrayField(one_values, dtype=np.int32))
                    target_indicators.append(target_ones)
                if self._position_embeddings:
                    target_distances = self._target_indicators_to_distances(target_indicators, 
                                                                            max_distance=self._max_position_distance, 
                                                                            as_string=True)
                    target_text_distances = []
                    for target_distance in target_distances:
                        token_distances = [Token(distance) for distance in target_distance]
                        token_distances = TextField(token_distances, self._position_indexers)
                        target_text_distances.append(token_distances)
                    instance_fields['position_embeddings'] = ListField(target_text_distances)
                if self._position_weights:
                    target_distances = self._target_indicators_to_distances(target_indicators, 
                                                                            max_distance=self._max_position_distance, 
                                                                            as_string=False)
                    target_distances = np.array(target_distances)
                    instance_fields['position_weights'] = ArrayField(target_distances, 
                                                                     dtype=np.int32)
                if self._target_sequences:
                    instance_fields['target_sequences'] = ListField(target_sequence_fields)
                instance_fields['tokens'] = TextField(allen_tokens, self._token_indexers)
                metadata_dict['text words'] = tokens
                metadata_dict['text'] = text
                # update target variable as the targets could have changed due 
                # to the force_targets function
                targets = target_text_object['targets']
            else:
                all_target_tokens = [self._tokenizer.tokenize(target) 
                                     for target in targets]
            target_fields = [TextField(target_tokens, self._token_indexers)  
                            for target_tokens in all_target_tokens]
            target_fields = ListField(target_fields)
            instance_fields['targets'] = target_fields
            # Add the targets and the tokenised targets to the metadata
            metadata_dict['targets'] = [target for target in targets]
            metadata_dict['target words'] = [[x.text for x in target_tokens] 
                                             for target_tokens in all_target_tokens]

            # Target sentiment if it exists
            if target_sentiments is not None:
                target_sentiments_field = SequenceLabelField(target_sentiments, 
                                                             target_fields,
                                                             label_namespace='target-sentiment-labels')
                instance_fields['target_sentiments'] = target_sentiments_field

        if categories is not None and self._use_categories:
            category_fields = TextField([Token(category) for category in categories], 
                                        self._token_indexers)
            instance_fields['categories'] = category_fields
            # Category sentiment if it exists
            if category_sentiments is not None:
                category_sentiments_field = SequenceLabelField(category_sentiments, 
                                                               category_fields,
                                                               label_namespace='category-sentiment-labels')
                instance_fields['category_sentiments'] = category_sentiments_field
            # Add the categories to the metadata
            metadata_dict['categories'] = [category for category in categories]

        if 'tokens' not in instance_fields:
            tokens = self._tokenizer.tokenize(text)
            instance_fields['tokens'] = TextField(tokens, self._token_indexers)
            metadata_dict['text'] = text
            metadata_dict['text words'] = [x.text for x in tokens]

        # If required processes the left and right contexts
        left_contexts = None
        right_contexts = None
        if self._left_right_contexts:
            if spans is None:
                raise ValueError('To create left, right, target contexts requires'
                                 ' the `spans` of the targets which is None')
            spans = [Span(span[0], span[1]) for span in spans]
            target_text_object = TargetText(text=text, spans=spans, 
                                            targets=targets, text_id='anything')
            # left, right, and target contexts for each target in the 
            # the text
            left_right_targets = target_text_object.left_right_target_contexts(incl_target=self._incl_target)
            left_contexts: List[str] = []
            right_contexts: List[str] = []
            for left_right_target in left_right_targets:
                left, right, _ = left_right_target
                left_contexts.append(left)
                if self._reverse_right_context:
                    right_tokens = self._tokenizer.tokenize(right)
                    reversed_right_tokens = []
                    for token in reversed(right_tokens):
                        reversed_right_tokens.append(token.text)
                    right = ' '.join(reversed_right_tokens)
                right_contexts.append(right)
        
        if left_contexts is not None:
            left_field = self._add_context_field(left_contexts)
            instance_fields["left_contexts"] = left_field
        if right_contexts is not None:
            right_field = self._add_context_field(right_contexts)
            instance_fields["right_contexts"] = right_field

        instance_fields["metadata"] = MetadataField(metadata_dict)
        
        return Instance(instance_fields)