Exemple #1
0
    def test_oov(self):
        unknown_token = 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'
        texts = [
            unknown_token,
            unknown_token + ' the'
        ]

        augmenters = [
            naw.BertAug(action=Action.INSERT),
            naw.BertAug(action=Action.SUBSTITUTE)
        ]

        for aug in augmenters:
            for text in texts:
                self.assertLess(0, len(text))
                augmented_text = aug.augment(text)
                if aug.action == Action.INSERT:
                    self.assertLess(len(text.split(' ')), len(augmented_text.split(' ')))
                elif aug.action == Action.SUBSTITUTE:
                    self.assertEqual(len(text.split(' ')), len(augmented_text.split(' ')))
                else:
                    raise Exception('Augmenter is neither INSERT or SUBSTITUTE')

                self.assertNotEqual(text, augmented_text)
                self.assertTrue(nml.Bert.SUBWORD_PREFIX not in augmented_text)
Exemple #2
0
    def test_substitute_stopwords(self):
        texts = [
            'The quick brown fox jumps over the lazy dog'
        ]

        stopwords = [t.lower() for t in texts[0].split(' ')[:3]]
        aug_n = 3

        aug = naw.BertAug(action=Action.SUBSTITUTE, aug_n=3, stopwords=stopwords)

        for text in texts:
            self.assertLess(0, len(text))
            augmented_text = aug.augment(text)

            augmented_tokens = aug.tokenizer(augmented_text)
            tokens = aug.tokenizer(text)

            augmented_cnt = 0

            for token, augmented_token in zip(tokens, augmented_tokens):
                if token.lower() in stopwords and len(token) > aug_n:
                    self.assertEqual(token.lower(), augmented_token)
                else:
                    augmented_cnt += 1

            self.assertGreater(augmented_cnt, 0)

        self.assertLess(0, len(texts))
Exemple #3
0
    def test_empty_input_for_insert(self):
        text = ' '

        aug = naw.BertAug(action=Action.INSERT)
        augmented_text = aug.augment(text)

        self.assertEqual(augmented_text, '')
Exemple #4
0
 def __init__(self):
     self.aug = nafc.Sequential([
         #naw.BertAug(action=Action.INSERT),
         naw.BertAug(action=Action.SUBSTITUTE),
         #naw.GloVeAug(model_path=os.environ.get("MODEL_DIR") + 'glove.6B.50d.txt', action=Action.SUBSTITUTE), bad results
         #naw.WordNetAug(),
         #naw.RandomWordAug(), # Deletes randomly word
     ])
Exemple #5
0
    def test_empty_input_for_insert(self):
        texts = ['']
        aug = naw.BertAug(action=Action.INSERT)

        for text in texts:
            augmented_text = aug.augment(text)

            self.assertEqual(text, augmented_text)

        self.assertEqual(1, len(texts))
        self.assertEqual(0, len(texts[0]))
Exemple #6
0
    def test_substitute(self):
        texts = ['The quick brown fox jumps over the lazy dog']

        aug = naw.BertAug(action=Action.SUBSTITUTE)

        for text in texts:
            self.assertLess(0, len(text))
            augmented_text = aug.augment(text)

            self.assertNotEqual(text, augmented_text)

        self.assertLess(0, len(texts))
Exemple #7
0
    def test_empty_input_for_insert(self):
        text = ' '

        augs = [
            naw.BertAug(action=Action.INSERT),
            naw.TfIdfAug(model_path=os.environ.get("MODEL_DIR"),
                         action=Action.SUBSTITUTE)
        ]

        for aug in augs:
            augmented_text = aug.augment(text)
            # FIXME: standardize return
            is_equal = augmented_text == '' or augmented_text == ' '
            self.assertTrue(is_equal)
Exemple #8
0
    def test_insert(self):
        texts = ['The quick brown fox jumps over the lazy dog']

        aug = naw.BertAug(action=Action.INSERT)

        for text in texts:
            self.assertLess(0, len(text))
            augmented_text = aug.augment(text)

            self.assertLess(len(text.split(' ')),
                            len(augmented_text.split(' ')))
            self.assertNotEqual(text, augmented_text)

        self.assertLess(0, len(texts))