def test_batch_size(self):
        model_name = self.eng_model_names[0]

        # 1 per batch
        aug = naw.BackTranslationAug(
            from_model_name=model_name['from_model_name'],
            to_model_name=model_name['to_model_name'],
            batch_size=1)
        aug_data = aug.augment(self.texts)
        self.assertEqual(len(aug_data), len(self.texts))

        # batch size = input size
        aug = naw.BackTranslationAug(
            from_model_name=model_name['from_model_name'],
            to_model_name=model_name['to_model_name'],
            batch_size=len(self.texts))
        aug_data = aug.augment(self.texts)
        self.assertEqual(len(aug_data), len(self.texts))

        # batch size > input size
        aug = naw.BackTranslationAug(
            from_model_name=model_name['from_model_name'],
            to_model_name=model_name['to_model_name'],
            batch_size=len(self.texts) + 1)
        aug_data = aug.augment(self.texts)
        self.assertEqual(len(aug_data), len(self.texts))

        # input size > batch size
        aug = naw.BackTranslationAug(
            from_model_name=model_name['from_model_name'],
            to_model_name=model_name['to_model_name'],
            batch_size=2)
        aug_data = aug.augment(self.texts * 2)
        self.assertEqual(len(aug_data), len(self.texts) * 2)
Beispiel #2
0
    def test_back_translation(self):
        # From English
        text = 'The quick brown fox jumps over the lazy dog'
        for model_name in self.eng_model_names:
            aug = naw.BackTranslationAug(
                from_model_name=model_name['from_model_name'],
                from_model_checkpt=model_name['from_model_checkpt'],
                to_model_name=model_name['to_model_name'],
                to_model_checkpt=model_name['to_model_checkpt'])
            augmented_text = aug.augment(text)
            aug.clear_cache()
            self.assertNotEqual(text, augmented_text)

        self.assertTrue(len(self.eng_model_names) > 1)
    def test_load_from_local_path_inexist(self):
        from_model_dir = '/abc/'
        to_model_dir = '/def/'
        with self.assertRaises(ValueError) as error:
            aug = naw.BackTranslationAug(from_model_name=from_model_dir,
                                         from_model_checkpt='model1.pt',
                                         to_model_name=to_model_dir,
                                         to_model_checkpt='model1.pt',
                                         is_load_from_github=False)
        self.assertTrue(
            'Cannot load model from local path' in str(error.exception))

        base_model_dir = os.environ.get("MODEL_DIR")
        from_model_dir = os.path.join(base_model_dir, 'word', 'fairseq',
                                      'wmt19.en-de')
        to_model_dir = '/def/'
        with self.assertRaises(ValueError) as error:
            aug = naw.BackTranslationAug(from_model_name=from_model_dir,
                                         from_model_checkpt='model1.pt',
                                         to_model_name=to_model_dir,
                                         to_model_checkpt='model1.pt',
                                         is_load_from_github=False)
        self.assertTrue(
            'Cannot load model from local path' in str(error.exception))
    def test_load_from_local_path(self):
        base_model_dir = os.environ.get("MODEL_DIR")
        from_model_dir = os.path.join(base_model_dir, 'word', 'fairseq',
                                      'wmt19.en-de')
        to_model_dir = os.path.join(base_model_dir, 'word', 'fairseq',
                                    'wmt19.de-en', '')

        aug = naw.BackTranslationAug(from_model_name=from_model_dir,
                                     from_model_checkpt='model1.pt',
                                     to_model_name=to_model_dir,
                                     to_model_checkpt='model1.pt',
                                     is_load_from_github=False)

        augmented_text = aug.augment(self.text)
        aug.clear_cache()
        self.assertNotEqual(self.text, augmented_text)
    def sample_test_case(self, device):
        # From English
        for model_name in self.eng_model_names:
            aug = naw.BackTranslationAug(
                from_model_name=model_name['from_model_name'],
                to_model_name=model_name['to_model_name'],
                device=device)
            augmented_text = aug.augment(self.text)
            aug.clear_cache()
            self.assertNotEqual(self.text, augmented_text)

            augmented_texts = aug.augment(self.texts)
            aug.clear_cache()
            for d, a in zip(self.texts, augmented_texts):
                self.assertNotEqual(d, a)

            if device == 'cpu':
                self.assertTrue(device == aug.model.get_device())
            elif 'cuda' in device:
                self.assertTrue('cuda' in aug.model.get_device())
    def test_back_translation(self):
        # From English
        texts = [
            self.text,
            "Seeing all of the negative reviews for this movie, I figured that it could be yet another comic masterpiece that wasn't quite meant to be."
        ]
        for model_name in self.eng_model_names:
            aug = naw.BackTranslationAug(
                from_model_name=model_name['from_model_name'],
                from_model_checkpt=model_name['from_model_checkpt'],
                to_model_name=model_name['to_model_name'],
                to_model_checkpt=model_name['to_model_checkpt'])
            augmented_text = aug.augment(self.text)
            aug.clear_cache()
            self.assertNotEqual(self.text, augmented_text)

            augmented_texts = aug.augment(texts)
            aug.clear_cache()
            for d, a in zip(texts, augmented_texts):
                self.assertNotEqual(d, a)

        self.assertTrue(len(self.eng_model_names) > 1)
Beispiel #7
0
def augmented_all(
    use_textcomp19=False,
    use_weebit=False,
    use_dw=False,
    backtrans=False,
    lemmatization=False,
    stemming=False,
    randword_swap=False,
    randword_del=False,
    test_size=0.1,
):
    """
    Returns the augmented training dataset
    and the test dataset of all specified data.

    backtrans : enables back and forth translation of the data
    lemmatization self explanatory
    stemming self explanatory
    randword_swap : enables randomly swapping words around sentences
    randword_del : enalbles randomly deleting words from sentences
    test_size : gives the ratio of test to train set

    train_set, test_set = augmented_all()

    """

    # Perform a Train-Test Split keeping dataset proportions the same
    print("perform train-test split keeping dataset proportions the same")

    all_dataset = all_data(use_textcomp19, use_weebit, use_dw)
    print("#####################", all_dataset[all_dataset["source"] == 1])

    if use_textcomp19:
        text_comp_train, text_comp_test = train_test_split(
            all_dataset[all_dataset["source"] == 0], test_size=test_size)

    if use_weebit:
        weebit_train = all_dataset[all_dataset["source"] == 1]
        print(weebit_train)

    if use_dw:
        dw_train = all_dataset[all_dataset["source"] == 2]

    if use_textcomp19 and not use_weebit and not use_dw:  #0
        all_dataset_train = text_comp_train
        all_dataset_test = text_comp_test

    if use_weebit and not use_textcomp19 and not use_dw:  #1
        print("No weebit test set available!")
        all_dataset_train = weebit_train

    if use_dw and not use_weebit and not use_textcomp19:  #2
        print("No dw test set available!")
        all_dataset_train = dw_train
        all_dataset_test = dw_train  # added so that dataset with only dw can be created

    if use_textcomp19 and use_weebit and not use_dw:  #01
        all_dataset_train = text_comp_train.append(weebit_train,
                                                   ignore_index=True)
        all_dataset_test = text_comp_test

    if use_weebit and use_dw and not use_textcomp19:  #12
        print("No weebit and dw test set available!")
        all_dataset_train = weebit_train.append(dw_train, ignore_index=True)

    if use_textcomp19 and use_dw and not use_weebit:  #02
        all_dataset_train = text_comp_train.append(dw_train, ignore_index=True)
        all_dataset_test = text_comp_test

    if use_textcomp19 and use_weebit and use_dw:  # 012
        all_dataset_train = text_comp_train.append(weebit_train,
                                                   ignore_index=True)
        all_dataset_train = all_dataset_train.append(dw_train,
                                                     ignore_index=True)

        all_dataset_test = text_comp_test

    ## Augmentation of data
    print("Start augmenting Data...")

    # Back and forth translation of data
    if backtrans == True:

        print("Back and forth translation...")
        back_translation_aug = naw.BackTranslationAug(
            from_model_name="transformer.wmt19.de-en",
            to_model_name="transformer.wmt19.en-de",
        )
        if use_weebit:
            translated = all_dataset_train[all_dataset_train["source"] != 1]
        else:
            translated = all_dataset_train

        translated["raw_text"] = translated["raw_text"].apply(
            lambda x: back_translation_aug.augment(x))

        all_dataset_train = all_dataset_train.append(translated,
                                                     ignore_index=True)

    # Random word swap
    if randword_swap == True:
        print("Random word swap")
        aug1 = naw.RandomWordAug(action="swap")
        swapped_data = all_dataset_train
        swapped_data["raw_text"] = all_dataset_train["raw_text"].apply(
            lambda x: aug1.augment(x))
        all_dataset_train = all_dataset_train.append(swapped_data,
                                                     ignore_index=True)

    # Random word deletion
    if randword_del == True:

        print("Random word deletion")
        aug2 = naw.RandomWordAug()
        rand_deleted_data = all_dataset_train
        rand_deleted_data["raw_text"] = all_dataset_train["raw_text"].apply(
            lambda x: aug2.augment(x))
        all_dataset_train = all_dataset_train.append(rand_deleted_data,
                                                     ignore_index=True)

    # Lemmatization using spacy
    if lemmatization == True:

        print("lemmatizing")
        nlp = spacy.load("de_core_news_sm")
        all_dataset_train["raw_text"] = all_dataset_train["raw_text"].apply(
            lambda x: " ".join([y.lemma_ for y in nlp(x)]))

        all_dataset_test["raw_text"] = all_dataset_test["raw_text"].apply(
            lambda x: " ".join([y.lemma_ for y in nlp(x)]))

    # Stemming using
    if stemming == True:

        print("stemming")
        stemmer = SnowballStemmer("german")
        all_dataset_train["raw_text"] = all_dataset_train["raw_text"].apply(
            lambda x: stemmer.stem(x))

        all_dataset_test["raw_text"] = all_dataset_test["raw_text"].apply(
            lambda x: stemmer.stem(x))

    return all_dataset_train, all_dataset_test
Beispiel #8
0
def create_back_aug(language):
    return naw.BackTranslationAug(
        from_model_name=f'transformer.wmt19.en-{language}',
        to_model_name=f'transformer.wmt19.{language}-en',
        device='cuda')