def execute_by_device(self, device): for model_path in self.model_paths: aug = nas.AbstSummAug(model_path=model_path, device=device) self.empty_input(aug) self.substitute(aug) self.assertLess(0, len(self.model_paths))
def test_batch_size(self): # 1 per batch aug = nas.AbstSummAug(model_path='t5-small', batch_size=1) aug_data = aug.augment(self.texts) self.assertEqual(len(aug_data), len(self.texts)) # batch size = input size aug = nas.AbstSummAug(model_path='t5-small', batch_size=len(self.texts)) aug_data = aug.augment(self.texts) self.assertEqual(len(aug_data), len(self.texts)) # batch size > input size aug = nas.AbstSummAug(model_path='t5-small', batch_size=len(self.texts)+1) aug_data = aug.augment(self.texts) self.assertEqual(len(aug_data), len(self.texts)) # input size > batch size aug = nas.AbstSummAug(model_path='t5-small', batch_size=2) aug_data = aug.augment(self.texts * 2) self.assertEqual(len(aug_data), len(self.texts)*2)
def execute_by_device(self, device): for model_path in self.model_paths: aug = nas.AbstSummAug(model_path=model_path, device=device) self.empty_input(aug) for data in [self.text, self.texts]: self.substitute(aug, data) if device == 'cpu': self.assertTrue(device == aug.model.get_device()) elif 'cuda' in device: self.assertTrue('cuda' in aug.model.get_device()) self.assertLess(0, len(self.model_paths))
import nlpaug.augmenter.word as naw import nlpaug.augmenter.sentence as nas import nlpaug.flow as nafc from nlpaug.util import Action import os os.environ["MODEL_DIR"] = '../model' #初始化进度条 bar = progressbar #停用词集合 stop_words = set(stopwords.words('english')) for w in ['!',',','.','?','-s','-ly','</s>','s','nan','mac']: stop_words.add(w) # aug = naw.RandomWordAug(action='crop') aug = nas.AbstSummAug(model_path='t5-base', num_beam=3, device='cuda:2') #过滤stopwords方法 def filter_stopwords(text): # text = text.values[0] word_tokens = word_tokenize(text) # filtered_list = # for w in word_tokens: # if w not in stop_words: # filtered_list.append(w) # filtered_sentence return " ".join(list(filter(lambda x: x not in stop_words, word_tokens))) #定义路径 path_train = '/ssd/zhouhcData/deepmatcherData/Textual/Company/train.csv' path_valid = '/ssd/zhouhcData/deepmatcherData/Textual/Company/valid.csv' path_test = '/ssd/zhouhcData/deepmatcherData/Textual/Company/test.csv'