def test_batch_size(self): model_name = self.eng_model_names[0] # 1 per batch aug = naw.BackTranslationAug( from_model_name=model_name['from_model_name'], to_model_name=model_name['to_model_name'], batch_size=1) aug_data = aug.augment(self.texts) self.assertEqual(len(aug_data), len(self.texts)) # batch size = input size aug = naw.BackTranslationAug( from_model_name=model_name['from_model_name'], to_model_name=model_name['to_model_name'], batch_size=len(self.texts)) aug_data = aug.augment(self.texts) self.assertEqual(len(aug_data), len(self.texts)) # batch size > input size aug = naw.BackTranslationAug( from_model_name=model_name['from_model_name'], to_model_name=model_name['to_model_name'], batch_size=len(self.texts) + 1) aug_data = aug.augment(self.texts) self.assertEqual(len(aug_data), len(self.texts)) # input size > batch size aug = naw.BackTranslationAug( from_model_name=model_name['from_model_name'], to_model_name=model_name['to_model_name'], batch_size=2) aug_data = aug.augment(self.texts * 2) self.assertEqual(len(aug_data), len(self.texts) * 2)
def test_back_translation(self): # From English text = 'The quick brown fox jumps over the lazy dog' for model_name in self.eng_model_names: aug = naw.BackTranslationAug( from_model_name=model_name['from_model_name'], from_model_checkpt=model_name['from_model_checkpt'], to_model_name=model_name['to_model_name'], to_model_checkpt=model_name['to_model_checkpt']) augmented_text = aug.augment(text) aug.clear_cache() self.assertNotEqual(text, augmented_text) self.assertTrue(len(self.eng_model_names) > 1)
def test_load_from_local_path_inexist(self): from_model_dir = '/abc/' to_model_dir = '/def/' with self.assertRaises(ValueError) as error: aug = naw.BackTranslationAug(from_model_name=from_model_dir, from_model_checkpt='model1.pt', to_model_name=to_model_dir, to_model_checkpt='model1.pt', is_load_from_github=False) self.assertTrue( 'Cannot load model from local path' in str(error.exception)) base_model_dir = os.environ.get("MODEL_DIR") from_model_dir = os.path.join(base_model_dir, 'word', 'fairseq', 'wmt19.en-de') to_model_dir = '/def/' with self.assertRaises(ValueError) as error: aug = naw.BackTranslationAug(from_model_name=from_model_dir, from_model_checkpt='model1.pt', to_model_name=to_model_dir, to_model_checkpt='model1.pt', is_load_from_github=False) self.assertTrue( 'Cannot load model from local path' in str(error.exception))
def test_load_from_local_path(self): base_model_dir = os.environ.get("MODEL_DIR") from_model_dir = os.path.join(base_model_dir, 'word', 'fairseq', 'wmt19.en-de') to_model_dir = os.path.join(base_model_dir, 'word', 'fairseq', 'wmt19.de-en', '') aug = naw.BackTranslationAug(from_model_name=from_model_dir, from_model_checkpt='model1.pt', to_model_name=to_model_dir, to_model_checkpt='model1.pt', is_load_from_github=False) augmented_text = aug.augment(self.text) aug.clear_cache() self.assertNotEqual(self.text, augmented_text)
def sample_test_case(self, device): # From English for model_name in self.eng_model_names: aug = naw.BackTranslationAug( from_model_name=model_name['from_model_name'], to_model_name=model_name['to_model_name'], device=device) augmented_text = aug.augment(self.text) aug.clear_cache() self.assertNotEqual(self.text, augmented_text) augmented_texts = aug.augment(self.texts) aug.clear_cache() for d, a in zip(self.texts, augmented_texts): self.assertNotEqual(d, a) if device == 'cpu': self.assertTrue(device == aug.model.get_device()) elif 'cuda' in device: self.assertTrue('cuda' in aug.model.get_device())
def test_back_translation(self): # From English texts = [ self.text, "Seeing all of the negative reviews for this movie, I figured that it could be yet another comic masterpiece that wasn't quite meant to be." ] for model_name in self.eng_model_names: aug = naw.BackTranslationAug( from_model_name=model_name['from_model_name'], from_model_checkpt=model_name['from_model_checkpt'], to_model_name=model_name['to_model_name'], to_model_checkpt=model_name['to_model_checkpt']) augmented_text = aug.augment(self.text) aug.clear_cache() self.assertNotEqual(self.text, augmented_text) augmented_texts = aug.augment(texts) aug.clear_cache() for d, a in zip(texts, augmented_texts): self.assertNotEqual(d, a) self.assertTrue(len(self.eng_model_names) > 1)
def augmented_all( use_textcomp19=False, use_weebit=False, use_dw=False, backtrans=False, lemmatization=False, stemming=False, randword_swap=False, randword_del=False, test_size=0.1, ): """ Returns the augmented training dataset and the test dataset of all specified data. backtrans : enables back and forth translation of the data lemmatization self explanatory stemming self explanatory randword_swap : enables randomly swapping words around sentences randword_del : enalbles randomly deleting words from sentences test_size : gives the ratio of test to train set train_set, test_set = augmented_all() """ # Perform a Train-Test Split keeping dataset proportions the same print("perform train-test split keeping dataset proportions the same") all_dataset = all_data(use_textcomp19, use_weebit, use_dw) print("#####################", all_dataset[all_dataset["source"] == 1]) if use_textcomp19: text_comp_train, text_comp_test = train_test_split( all_dataset[all_dataset["source"] == 0], test_size=test_size) if use_weebit: weebit_train = all_dataset[all_dataset["source"] == 1] print(weebit_train) if use_dw: dw_train = all_dataset[all_dataset["source"] == 2] if use_textcomp19 and not use_weebit and not use_dw: #0 all_dataset_train = text_comp_train all_dataset_test = text_comp_test if use_weebit and not use_textcomp19 and not use_dw: #1 print("No weebit test set available!") all_dataset_train = weebit_train if use_dw and not use_weebit and not use_textcomp19: #2 print("No dw test set available!") all_dataset_train = dw_train all_dataset_test = dw_train # added so that dataset with only dw can be created if use_textcomp19 and use_weebit and not use_dw: #01 all_dataset_train = text_comp_train.append(weebit_train, ignore_index=True) all_dataset_test = text_comp_test if use_weebit and use_dw and not use_textcomp19: #12 print("No weebit and dw test set available!") all_dataset_train = weebit_train.append(dw_train, ignore_index=True) if use_textcomp19 and use_dw and not use_weebit: #02 all_dataset_train = text_comp_train.append(dw_train, ignore_index=True) all_dataset_test = text_comp_test if use_textcomp19 and use_weebit and use_dw: # 012 all_dataset_train = text_comp_train.append(weebit_train, ignore_index=True) all_dataset_train = all_dataset_train.append(dw_train, ignore_index=True) all_dataset_test = text_comp_test ## Augmentation of data print("Start augmenting Data...") # Back and forth translation of data if backtrans == True: print("Back and forth translation...") back_translation_aug = naw.BackTranslationAug( from_model_name="transformer.wmt19.de-en", to_model_name="transformer.wmt19.en-de", ) if use_weebit: translated = all_dataset_train[all_dataset_train["source"] != 1] else: translated = all_dataset_train translated["raw_text"] = translated["raw_text"].apply( lambda x: back_translation_aug.augment(x)) all_dataset_train = all_dataset_train.append(translated, ignore_index=True) # Random word swap if randword_swap == True: print("Random word swap") aug1 = naw.RandomWordAug(action="swap") swapped_data = all_dataset_train swapped_data["raw_text"] = all_dataset_train["raw_text"].apply( lambda x: aug1.augment(x)) all_dataset_train = all_dataset_train.append(swapped_data, ignore_index=True) # Random word deletion if randword_del == True: print("Random word deletion") aug2 = naw.RandomWordAug() rand_deleted_data = all_dataset_train rand_deleted_data["raw_text"] = all_dataset_train["raw_text"].apply( lambda x: aug2.augment(x)) all_dataset_train = all_dataset_train.append(rand_deleted_data, ignore_index=True) # Lemmatization using spacy if lemmatization == True: print("lemmatizing") nlp = spacy.load("de_core_news_sm") all_dataset_train["raw_text"] = all_dataset_train["raw_text"].apply( lambda x: " ".join([y.lemma_ for y in nlp(x)])) all_dataset_test["raw_text"] = all_dataset_test["raw_text"].apply( lambda x: " ".join([y.lemma_ for y in nlp(x)])) # Stemming using if stemming == True: print("stemming") stemmer = SnowballStemmer("german") all_dataset_train["raw_text"] = all_dataset_train["raw_text"].apply( lambda x: stemmer.stem(x)) all_dataset_test["raw_text"] = all_dataset_test["raw_text"].apply( lambda x: stemmer.stem(x)) return all_dataset_train, all_dataset_test
def create_back_aug(language): return naw.BackTranslationAug( from_model_name=f'transformer.wmt19.en-{language}', to_model_name=f'transformer.wmt19.{language}-en', device='cuda')