def test_case(self): # Swap aug = naw.RandomWordAug(action='swap') self.assertEqual('bB aA', aug.augment('aA bB')) data = 'I love McDonalds' doc = Doc(data, aug.tokenizer(data)) augmented_tokens = aug.change_case(doc, 1, 0, 1).get_augmented_tokens() self.assertEqual(['Love', 'I', 'McDonalds'], augmented_tokens) doc = Doc(data, aug.tokenizer(data)) augmented_tokens = aug.change_case(doc, 0, 1, 1).get_augmented_tokens() self.assertEqual(['Love', 'I', 'McDonalds'], augmented_tokens) data = 'He loves McDonalds' doc = Doc(data, aug.tokenizer(data)) augmented_tokens = aug.change_case(doc, 1, 0, 1).get_augmented_tokens() self.assertEqual(['Loves', 'he', 'McDonalds'], augmented_tokens) doc = Doc(data, aug.tokenizer(data)) augmented_tokens = aug.change_case(doc, 0, 1, 1).get_augmented_tokens() self.assertEqual(['Loves', 'he', 'McDonalds'], augmented_tokens) doc = Doc(data, aug.tokenizer(data)) augmented_tokens = aug.change_case(doc, 2, 1, 1).get_augmented_tokens() self.assertEqual(['He', 'McDonalds', 'loves'], augmented_tokens) # Insert aug = naw.TfIdfAug(model_path=self.tfidf_model_path, action='insert') expected = False for i in range(10): augmented_text = aug.augment('Good') if 'good' in augmented_text and aug.get_word_case(augmented_text.split(' ')[0]) == 'capitalize': expected = True break self.assertTrue(expected) # Substitute aug = naw.RandomWordAug(action='substitute', target_words=['abc']) expected = False for i in range(10): augmented_text = aug.augment('I love') if augmented_text == 'Abc love': expected = True break self.assertTrue(expected) aug = naw.AntonymAug() self.assertEqual('Unhappy', aug.augment('Happy')) # Do not change if target word is non-lower aug = naw.SpellingAug() self.assertEqual('RE', aug.augment('Re')) # Delete case aug = naw.RandomWordAug(action='delete') expected = False for i in range(10): augmented_text = aug.augment('I love') if augmented_text == 'Love': expected = True break self.assertTrue(expected)
def substitute(self, data): if not data or not data.strip(): return data change_seq = 0 data = self.preprocess(data) doc = Doc(data, self.tokenizer(data)) aug_idxes = self._get_aug_idxes(doc.get_original_tokens()) aug_idxes.sort(reverse=True) tokens = doc.get_original_tokens() if aug_idxes is None or len(aug_idxes) == 0: if self.include_detail: return data, [] return data for aug_idx in aug_idxes: original_token = doc.get_token(aug_idx).orig_token.token if not self.case_sensitive: original_token = original_token.lower() if original_token in self.reserved_token_dict: candidate_tokens = [] for t in self.reserved_tokens[ self.reserved_token_dict[original_token]]: compare_token = t.lower() if not self.case_sensitive else t if compare_token != original_token: candidate_tokens.append(t) elif original_token in self.reserved_phrase_concats: candidate_tokens = [] for t in self.reserved_tokens[ self.reserved_phrase_dict[original_token]]: compare_token = t.replace(' ', self.CONNECT_TOKEN) compare_token = compare_token.lower( ) if not self.case_sensitive else compare_token if compare_token != original_token: candidate_tokens.append(t) new_token = self.sample(candidate_tokens, 1)[0] if aug_idx == 0: new_token = self.align_capitalization(original_token, new_token) change_seq += 1 doc.add_change_log(aug_idx, new_token=new_token, action=Action.SUBSTITUTE, change_seq=self.parent_change_seq + change_seq) if self.include_detail: return self.reverse_tokenizer( doc.get_augmented_tokens()), doc.get_change_logs() else: return self.reverse_tokenizer(doc.get_augmented_tokens())
def insert(self, data): change_seq = 0 doc = Doc(data, self.tokenizer(data)) aug_word_idxes = self._get_aug_idxes( doc.get_original_tokens(), self.aug_word_min, self.aug_word_max, self.aug_word_p, Method.WORD) if aug_word_idxes is None: return data for token_i, token in enumerate(doc.get_original_tokens()): if token_i not in aug_word_idxes: continue chars = self.token2char(token) aug_char_idxes = self._get_aug_idxes(chars, self.aug_char_min, self.aug_char_max, self.aug_char_p, Method.CHAR) if aug_char_idxes is None: continue aug_char_idxes.sort(reverse=True) for char_i in aug_char_idxes: chars.insert(char_i, self.sample(self.model, 1)[0]) # No capitalization alignment as this augmenter try to simulate random error new_token = ''.join(chars) change_seq += 1 doc.add_change_log(token_i, new_token=new_token, action=Action.INSERT, change_seq=self.parent_change_seq + change_seq) if self.include_detail: return self.reverse_tokenizer(doc.get_augmented_tokens()), doc.get_change_logs() else: return self.reverse_tokenizer(doc.get_augmented_tokens())
def substitute(self, data): change_seq = 0 doc = Doc(data, self.tokenizer(data)) aug_idxes = self._get_aug_idxes(doc.get_original_tokens()) if aug_idxes is None or len(aug_idxes) == 0: if self.include_detail: return data, [] return data for aug_idx in aug_idxes: original_token = doc.get_token(aug_idx).get_latest_token().token candidate_tokens = self.model.predict(original_token, n=1) substitute_token = self.sample(candidate_tokens, 1)[0] if aug_idx == 0: substitute_token = self.align_capitalization( original_token, substitute_token) change_seq += 1 doc.add_change_log(aug_idx, new_token=substitute_token, action=Action.SUBSTITUTE, change_seq=self.parent_change_seq + change_seq) if self.include_detail: return self.reverse_tokenizer( doc.get_augmented_tokens()), doc.get_change_logs() else: return self.reverse_tokenizer(doc.get_augmented_tokens())
def substitute(self, data): change_seq = 0 doc = Doc(data, self.tokenizer(data)) aug_idxes = self._get_random_aug_idxes(doc.get_original_tokens()) aug_idxes.sort(reverse=True) if aug_idxes is None or len(aug_idxes) == 0: if self.include_detail: return data, [] return data for aug_idx in aug_idxes: original_token = doc.get_token(aug_idx).orig_token.token new_token = self.sample(self.target_words, 1)[0] if aug_idx == 0: new_token = self.align_capitalization(original_token, new_token) change_seq += 1 doc.add_change_log(aug_idx, new_token=new_token, action=Action.SUBSTITUTE, change_seq=self.parent_change_seq + change_seq) if self.include_detail: return self.reverse_tokenizer( doc.get_augmented_tokens()), doc.get_change_logs() else: return self.reverse_tokenizer(doc.get_augmented_tokens())
def substitute(self, data): change_seq = 0 doc = Doc(data, self.tokenizer(data)) aug_word_idxes = self._get_aug_idxes( doc.get_original_tokens(), self.aug_word_min, self.aug_word_max, self.aug_word_p, Method.WORD) for token_i, token in enumerate(doc.get_original_tokens()): if token_i not in aug_word_idxes: continue substitute_token = '' chars = self.token2char(token) aug_char_idxes = self._get_aug_idxes(chars, self.aug_char_min, self.aug_char_max, self.aug_char_p, Method.CHAR) if aug_char_idxes is None: continue for char_i, char in enumerate(chars): if char_i not in aug_char_idxes: substitute_token += char continue substitute_token += self.sample(self.model.predict(chars[char_i]), 1)[0] # No capitalization alignment as this augmenter try to OCR engine error change_seq += 1 doc.add_change_log(token_i, new_token=substitute_token, action=Action.SUBSTITUTE, change_seq=self.parent_change_seq+change_seq) if self.include_detail: return self.reverse_tokenizer(doc.get_augmented_tokens()), doc.get_change_logs() else: return self.reverse_tokenizer(doc.get_augmented_tokens())
def substitute(self, data): if not data or not data.strip(): return data change_seq = 0 doc = Doc(data, self.tokenizer(data)) pos = self.model.pos_tag(doc.get_original_tokens()) aug_idxes = self._get_aug_idxes(pos) if aug_idxes is None or len(aug_idxes) == 0: if self.include_detail: return data, [] return data for aug_idx, original_token in enumerate(doc.get_original_tokens()): # Skip if no augment for word if aug_idx not in aug_idxes: continue word_poses = PartOfSpeech.constituent2pos(pos[aug_idx][1]) candidates = [] if word_poses is None or len(word_poses) == 0: # Use every possible words as the mapping does not defined correctly candidates.extend(self.model.predict(pos[aug_idx][0])) else: for word_pos in word_poses: candidates.extend( self.model.predict(pos[aug_idx][0], pos=word_pos)) candidates = [ c for c in candidates if c.lower() != original_token.lower() ] if len(candidates) > 0: candidate = self.sample(candidates, 1)[0] candidate = candidate.replace("_", " ").replace("-", " ").lower() substitute_token = self.align_capitalization( original_token, candidate) if aug_idx == 0: substitute_token = self.align_capitalization( original_token, substitute_token) change_seq += 1 doc.add_change_log(aug_idx, new_token=substitute_token, action=Action.SUBSTITUTE, change_seq=self.parent_change_seq + change_seq) if self.include_detail: return self.reverse_tokenizer( doc.get_augmented_tokens()), doc.get_change_logs() else: return self.reverse_tokenizer(doc.get_augmented_tokens())
def swap(self, data): if not data or not data.strip(): return data change_seq = 0 doc = Doc(data, self.tokenizer(data)) aug_idxes = self._get_random_aug_idxes(doc.get_original_tokens()) # https://github.com/makcedward/nlpaug/issues/76 if aug_idxes is None or len(aug_idxes) == 0 or doc.size() < 2: if self.include_detail: return data, [] return data for aug_idx in aug_idxes: swap_idx = self._get_swap_position(aug_idx, doc.size() - 1) change_seq += 1 doc = self.change_case(doc, aug_idx, swap_idx, change_seq) if self.include_detail: return self.reverse_tokenizer( doc.get_augmented_tokens()), doc.get_change_logs() else: return self.reverse_tokenizer(doc.get_augmented_tokens())
def insert(self, data): change_seq = 0 doc = Doc(data, self.tokenizer(data)) aug_idxes = self._get_random_aug_idxes(doc.get_original_tokens()) if aug_idxes is None or len(aug_idxes) == 0: if self.include_detail: return data, [] return data aug_idxes.sort(reverse=True) for aug_idx in aug_idxes: new_token = self.sample(self.model.get_vocab(), 1)[0] if self.n_gram_separator in new_token: new_token = new_token.split(self.n_gram_separator)[0] change_seq += 1 doc.add_token(aug_idx, token=new_token, action=Action.INSERT, change_seq=self.parent_change_seq + change_seq) if self.include_detail: return self.reverse_tokenizer( doc.get_augmented_tokens()), doc.get_change_logs() else: return self.reverse_tokenizer(doc.get_augmented_tokens())
def swap(self, data): if not data or not data.strip(): return data change_seq = 0 doc = Doc(data, self.tokenizer(data)) aug_word_idxes = self._get_aug_idxes( doc.get_original_tokens(), self.aug_word_min, self.aug_word_max, self.aug_word_p, Method.WORD) if aug_word_idxes is None: return data for token_i, token in enumerate(doc.get_original_tokens()): if token_i not in aug_word_idxes: continue swap_token = '' chars = self.token2char(token) aug_char_idxes = self._get_aug_idxes(chars, self.aug_char_min, self.aug_char_max, self.aug_char_p, Method.CHAR) if aug_char_idxes is None or len(aug_char_idxes) < 1: continue for char_i in aug_char_idxes: swap_position = self._get_swap_position(char_i, len(chars)-1, mode=self.swap_mode) is_original_upper, is_swap_upper = chars[char_i].isupper(), chars[swap_position].isupper() original_chars = chars.copy() chars[char_i], chars[swap_position] = original_chars[swap_position], original_chars[char_i] # Swap case if is_original_upper: chars[char_i] = chars[char_i].upper() else: chars[char_i] = chars[char_i].lower() if is_swap_upper: chars[swap_position] = chars[swap_position].upper() else: chars[swap_position] = chars[swap_position].lower() swap_token += self.sample(self.model, 1)[0] # No capitalization alignment as this augmenter try to simulate random error swap_token = ''.join(chars) change_seq += 1 doc.add_change_log(token_i, new_token=swap_token, action=Action.SWAP, change_seq=self.parent_change_seq + change_seq) if self.include_detail: return self.reverse_tokenizer(doc.get_augmented_tokens()), doc.get_change_logs() else: return self.reverse_tokenizer(doc.get_augmented_tokens())
def substitute(self, data): if not data or not data.strip(): return data change_seq = 0 doc = Doc(data, self.tokenizer(data)) pos = self.model.pos_tag(doc.get_original_tokens()) aug_candidates = self._get_aug_idxes(pos) if aug_candidates is None or len(aug_candidates) == 0: if self.include_detail: return data, [] return data aug_idxes, candidates = zip(*aug_candidates) if aug_idxes is None or len(aug_idxes) == 0: if self.include_detail: return data, [] return data for aug_idx, original_token in enumerate(doc.get_original_tokens()): # Skip if no augment for word if aug_idx not in aug_idxes: continue candidates = self.get_candidates(pos, aug_idx) if len(candidates) > 0: candidate = self.sample(candidates, 1)[0] candidate = candidate.replace("_", " ").replace("-", " ").lower() substitute_token = self.align_capitalization( original_token, candidate) if aug_idx == 0: substitute_token = self.align_capitalization( original_token, substitute_token) change_seq += 1 doc.add_change_log(aug_idx, new_token=substitute_token, action=Action.SUBSTITUTE, change_seq=self.parent_change_seq + change_seq) if self.include_detail: return self.reverse_tokenizer( doc.get_augmented_tokens()), doc.get_change_logs() else: return self.reverse_tokenizer(doc.get_augmented_tokens())
def insert(self, data): if data is None or data.strip() == '': if self.include_detail: return data, [] return data max_try = 30 # On average 30 should be enough to complete a sentence external_memory = None augmented_text = '' new_token = '' doc = Doc() change_seq = 0 aug_idx = 0 for _ in range(max_try): if external_memory is None: # First step or does not enable optimization text = data + augmented_text else: text = new_token # Mask token is needed for xlnet. No mask token for gpt2 if self.model_type in ['xlnet']: text += ' ' + self.model.MASK_TOKEN outputs = self.model.predict(text, n=1, external_memory=external_memory) results = outputs[0] if results is None: continue if self.model.optimize['external_memory']: external_memory = outputs[1] new_token, proba = results[0] change_seq += 1 doc.add_token(aug_idx, token='', action=Action.INSERT, change_seq=0) doc.update_change_log(aug_idx, token=self.model.clean(new_token), action=Action.INSERT, change_seq=self.parent_change_seq + change_seq) aug_idx += 1 if new_token in self.SENTENCE_SEPARATOR: augmented_text += new_token break augmented_text += ' ' + new_token augmented_text = data + augmented_text if self.include_detail: return augmented_text, doc.get_change_logs(start_pos=len(data) + 1) else: return augmented_text
def delete(self, data): if not data or not data.strip(): return data change_seq = 0 doc = Doc(data, self.tokenizer(data)) aug_word_idxes = self._get_aug_idxes(doc.get_original_tokens(), self.aug_word_min, self.aug_word_max, self.aug_word_p, Method.WORD) if aug_word_idxes is None: return data for token_i, token in enumerate(doc.get_original_tokens()): if token_i not in aug_word_idxes: continue chars = self.token2char(token) aug_char_idxes = self._get_aug_idxes(chars, self.aug_char_min, self.aug_char_max, self.aug_char_p, Method.CHAR) if aug_char_idxes is None or len(aug_char_idxes) < 1: continue aug_char_idxes.sort(reverse=True) for i in aug_char_idxes: del chars[i] # No capitalization alignment as this augmenter try to simulate random error delete_token = ''.join(chars) change_seq += 1 doc.add_change_log(token_i, new_token=delete_token, action=Action.DELETE, change_seq=self.parent_change_seq + change_seq) if self.include_detail: return self.reverse_tokenizer( doc.get_augmented_tokens()), doc.get_change_logs() else: return self.reverse_tokenizer(doc.get_augmented_tokens())
def substitute(self, data): change_seq = 0 doc = Doc(data, self.tokenizer(data)) aug_idxes = self._get_aug_idxes(doc.get_original_tokens()) aug_idxes.sort(reverse=True) tokens = doc.get_original_tokens() if aug_idxes is None or len(aug_idxes) == 0: if self.include_detail: return data, [] return data for aug_idx in aug_idxes: original_token = doc.get_token(aug_idx).orig_token.token candidate_tokens = self.reserved_tokens[ self.reserved_token_dict[original_token]] if not self.allow_original: candidate_tokens = [ t for t in candidate_tokens if t != original_token ] new_token = self.sample(candidate_tokens, 1)[0] if aug_idx == 0: new_token = self.align_capitalization(original_token, new_token) change_seq += 1 doc.add_change_log(aug_idx, new_token=new_token, action=Action.SUBSTITUTE, change_seq=self.parent_change_seq + change_seq) if self.include_detail: return self.reverse_tokenizer( doc.get_augmented_tokens()), doc.get_change_logs() else: return self.reverse_tokenizer(doc.get_augmented_tokens())
def substitute(self, data): if not data or not data.strip(): return data change_seq = 0 doc = Doc(data, self.tokenizer(data)) aug_idxes = self._get_aug_idxes(doc.get_original_tokens()) if aug_idxes is None or len(aug_idxes) == 0: if self.include_detail: return data, [] return data for aug_idx, original_token in enumerate(doc.get_original_tokens()): # Skip if no augment for word if aug_idx not in aug_idxes: continue candidate_words = self.model.predict(original_token) substitute_token = '' if candidate_words: substitute_token = self.sample(candidate_words, 1)[0] else: # Unexpected scenario. Adding original token substitute_token = original_token if aug_idx == 0: substitute_token = self.align_capitalization( original_token, substitute_token) change_seq += 1 doc.add_change_log(aug_idx, new_token=substitute_token, action=Action.SUBSTITUTE, change_seq=self.parent_change_seq + change_seq) if self.include_detail: return self.reverse_tokenizer( doc.get_augmented_tokens()), doc.get_change_logs() else: return self.reverse_tokenizer(doc.get_augmented_tokens())
def insert(self, data): if not data: return data if isinstance(data, list): all_data = data else: if data.strip() == '': return data all_data = [data] max_try = 30 # On average 30 should be enough to complete a sentence external_memories = [None] * len(all_data) augmented_texts = [''] * len(all_data) docs = [Doc()] * len(all_data) early_stops = [0] * len(all_data) change_seq = 0 aug_idx = 0 for _ in range(max_try): if sum(early_stops) == len(all_data): break aug_input_poses = [ ] # store which input augmented. No augmentation if genrated a sentence texts = [] for i, d in enumerate(all_data): if early_stops[i] == 1: continue aug_input_poses.append(i) augmented_text = augmented_texts[i] external_memory = external_memories[i] if external_memory is None: # First step or does not enable optimization text = d + augmented_text else: text = '' # Mask token is needed for xlnet. No mask token for gpt2 if self.model_type in ['xlnet']: text += ' ' + self.model.MASK_TOKEN texts.append(text) outputs = self.model.predict(texts, n=1, external_memory=external_memory, include_punctuation=True) for i, output in enumerate(outputs): aug_input_pos = aug_input_poses[i] # TODO: # if self.model.optimize['external_memory']: # external_memory = outputs[1] # TODO: Alternative method better than dropout candidate = '' if len(output) == 1: candidate = output[0] elif len(output) > 1: candidate = self.sample(output, 1)[0] change_seq += 1 docs[aug_input_pos].add_token(aug_idx, token='', action=Action.INSERT, change_seq=0) docs[aug_input_pos].update_change_log( aug_idx, token=self.model.clean(candidate), action=Action.INSERT, change_seq=self.parent_change_seq + change_seq) aug_idx += 1 # early stop if all input generated a sentence. if candidate in text_tokenizer.SENTENCE_SEPARATOR: if self.model_type in ['gpt2']: augmented_texts[aug_input_pos] += ' ' augmented_texts[aug_input_pos] += candidate early_stops[aug_input_pos] = 1 else: if self.model_type in ['gpt2']: augmented_texts[aug_input_pos] += ' ' augmented_texts[aug_input_pos] += candidate if self.model_type in ['gpt2']: results = [d + a for d, a in zip(all_data, augmented_texts)] elif self.model_type in ['xlnet']: results = [ d + ' ' + self.model.tokenizer.convert_tokens_to_string(a) for d, a in zip(all_data, augmented_texts) ] if isinstance(data, list): return results else: return results[0]
def crop(self, data): change_seq = 1 doc = Doc(data, self.tokenizer(data)) aug_idxes = self._get_aug_range_idxes(doc.get_original_tokens()) aug_idxes.sort(reverse=True) # https://github.com/makcedward/nlpaug/issues/76 if aug_idxes is None or len(aug_idxes) == 0 or doc.size() < 2: if self.include_detail: return data, [] return data for aug_idx in aug_idxes: original_token = doc.get_token(aug_idx).orig_token.token doc.add_change_log(aug_idx, new_token='', action=Action.CROP, change_seq=self.parent_change_seq + change_seq) if aug_idx == 0: new_token = self.align_capitalization( original_token, doc.get_token(1).orig_token.token) doc.add_change_log(1, new_token=new_token, action=Action.ALIGN, change_seq=self.parent_change_seq + change_seq) if self.include_detail: return self.reverse_tokenizer( doc.get_augmented_tokens()), doc.get_change_logs() else: return self.reverse_tokenizer(doc.get_augmented_tokens())
def insert(self, data): if not data: return data if isinstance(data, list): all_data = data else: if data.strip() == '': return data all_data = [data] # If length of input is larger than max allowed input, only augment heading part split_results = [] # head_text, tail_text, head_tokens, tail_tokens reserved_stopwords = [] for d in all_data: split_result, reserved_stopword = self.split_text(d) split_results.append(split_result) reserved_stopwords.append(reserved_stopword) change_seq = 0 # Pick target word for augmentation for i, (split_result, reserved_stopword_tokens) in enumerate( zip(split_results, reserved_stopwords)): head_text, tail_text, head_tokens, tail_tokens = split_result if self.model_type in ['xlnet', 'roberta', 'bart']: # xlent and roberta tokens include prefix (e.g. ▁ or Ġ') cleaned_head_tokens = [ t.replace(self.model.get_subword_prefix(), '') for t in head_tokens ] else: cleaned_head_tokens = head_tokens head_doc = Doc(head_text, head_tokens) aug_idxes = self._get_aug_idxes(head_tokens) aug_idxes.sort(reverse=True) if reserved_stopword_tokens: head_doc, change_seq = self.substitute_back_reserved_stopwords( head_doc, reserved_stopword_tokens, change_seq) split_results[i] += ( cleaned_head_tokens, head_doc, aug_idxes, ) # Pad aug_idxes max_aug_size = max( [len(split_result[6]) for split_result in split_results]) for split_result in split_results: aug_idxes = split_result[6] for _ in range(max_aug_size - len(aug_idxes)): aug_idxes.append(-1) token_placeholder = self.model.get_mask_token() if self.model_type in ['xlnet', 'roberta', 'bart']: token_placeholder = self.model.get_subword_prefix( ) + token_placeholder # Adding prefix for # Augment same index of aug by batch for i in range(max_aug_size): masked_texts = [] aug_input_poses = [ ] # store which input augmented. No record if padding change_seq += 1 for j, split_result in enumerate(split_results): head_doc, aug_idx = split_result[5], split_result[6][i] # -1 if it is padding if aug_idx == -1: continue head_doc.add_token(aug_idx, token=token_placeholder, action=Action.INSERT, change_seq=self.parent_change_seq + change_seq) aug_input_poses.append(j) # some tokenizers handle special charas (e.g. don't can merge after decode) if self.model_type in ['bert', 'electra']: ids = self.model.get_tokenizer().convert_tokens_to_ids( head_doc.get_augmented_tokens()) masked_text = self.model.get_tokenizer().decode( ids).strip() elif self.model_type in ['xlnet', 'roberta', 'bart']: masked_text = self.model.get_tokenizer( ).convert_tokens_to_string( head_doc.get_augmented_tokens()).strip() masked_texts.append(masked_text) if not len(masked_texts): continue outputs = self.model.predict(masked_texts, target_words=None, n=2) # Update doc for aug_input_pos, output, masked_text in zip( aug_input_poses, outputs, masked_texts): split_result = split_results[aug_input_pos] head_doc = split_result[5] aug_idx = split_result[6][i] # augment position in text # TODO: Alternative method better than dropout candidate = '' if len(output) == 0: # TODO: no result? pass elif len(output) == 1: candidate = output[0] elif len(output) > 1: candidate = self.sample(output, 1)[0] # # In XLNet, it can be the first word of sentence which does not come with space. E.g. Zombine (ID:29110) # if self.model_type in ['xlnet']: # if candidate != '' and not candidate.startswith(self.model.get_subword_prefix()): # candidate = self.model.get_subword_prefix() + candidate # if self.model_type in ['roberta', 'bart']: # if candidate != '' and not candidate.startswith(self.model.get_subword_prefix()) and candidate.strip() != candidate: # candidate = self.model.get_subword_prefix() + candidate.strip() # no candidate if candidate == '': head_doc.add_change_log(aug_idx, new_token='', action=Action.DELETE, change_seq=self.parent_change_seq + change_seq) continue head_doc.update_change_log(aug_idx, token=candidate) # Early stop if number of token exceed max number if head_doc.size() > self.max_num_token: for j in range(i + 1, max_aug_size): split_results[aug_input_pos][6][j] = -1 augmented_texts = [] for split_result, reserved_stopword_tokens in zip( split_results, reserved_stopwords): tail_text, head_doc = split_result[1], split_result[5] head_tokens = head_doc.get_augmented_tokens() # if self.model_type in ['xlnet', 'roberta']: # # xlent and roberta tokens include prefix (e.g. ▁ or Ġ') # head_tokens = [self.model.get_subword_prefix() + t if self.model.get_subword_prefix() not in t and i != 0 else t for i, t in enumerate(head_tokens)] ids = self.model.get_tokenizer().convert_tokens_to_ids(head_tokens) augmented_text = self.model.get_tokenizer().decode(ids) if tail_text: augmented_text += ' ' + tail_text augmented_texts.append(augmented_text) if isinstance(data, list): return augmented_texts else: return augmented_texts[0]
def substitute(self, data): if not data or not data.strip(): return data change_seq = 0 preprocessed_data = self.preprocess(data) doc = Doc(preprocessed_data, self.tokenizer(preprocessed_data)) data_lower = data.lower() aug_idxes = self._get_aug_idxes(doc.get_original_tokens()) aug_idxes.sort(reverse=True) tokens = doc.get_original_tokens() if not aug_idxes: if self.include_detail: return data, [] return data if(self.generate_all_combinations): assert self.aug_p == 1, "Augmentation probability has to be 1 to genenerate all combinations. Set aug_p=1 in constructor." candidate_token_list = [] for aug_idx in aug_idxes: original_token = doc.get_token(aug_idx).orig_token.token if not self.case_sensitive: original_token = original_token.lower() if original_token in self.reserved_token_dict: candidate_tokens = [] for t in self.reserved_tokens[self.reserved_token_dict[original_token]]: candidate_tokens.append(t) elif original_token in self.reserved_phrase_concats: candidate_tokens = [] for t in self.reserved_tokens[self.reserved_phrase_dict[original_token]]: candidate_tokens.append(t) change_seq += 1 candidate_token_list.append((aug_idx,change_seq,candidate_tokens)) generated_combinations = [] for tokens in self.generate_combinations(candidate_token_list): inp_doc = doc for token in tokens: aug_idx,seq,new_token = token if aug_idx == 0: new_token = self.align_capitalization(original_token, new_token) inp_doc.add_change_log(aug_idx, new_token=new_token, action=Action.SUBSTITUTE, change_seq=self.parent_change_seq+seq) augmented_text = self.reverse_tokenizer(doc.get_augmented_tokens()) same_as_original = False if self.case_sensitive: same_as_original = augmented_text == data else: same_as_original = augmented_text.lower() == data_lower if not same_as_original: if self.include_detail: generated_combinations.append((augmented_text, doc.get_change_logs())) else: generated_combinations.append(augmented_text) return generated_combinations # return sorted(generated_combinations) # not sorting to speed up else: for aug_idx in aug_idxes: original_token = doc.get_token(aug_idx).orig_token.token if not self.case_sensitive: original_token = original_token.lower() if original_token in self.reserved_token_dict: candidate_tokens = [] for t in self.reserved_tokens[self.reserved_token_dict[original_token]]: compare_token = t.lower() if not self.case_sensitive else t if compare_token != original_token: candidate_tokens.append(t) elif original_token in self.reserved_phrase_concats: candidate_tokens = [] for t in self.reserved_tokens[self.reserved_phrase_dict[original_token]]: compare_token = t.replace(' ', self.CONNECT_TOKEN) compare_token = compare_token.lower() if not self.case_sensitive else compare_token if compare_token != original_token: candidate_tokens.append(t) new_token = self.sample(candidate_tokens, 1)[0] if aug_idx == 0: new_token = self.align_capitalization(original_token, new_token) change_seq += 1 doc.add_change_log(aug_idx, new_token=new_token, action=Action.SUBSTITUTE, change_seq=self.parent_change_seq+change_seq) if self.include_detail: return self.reverse_tokenizer(doc.get_augmented_tokens()), doc.get_change_logs() else: return self.reverse_tokenizer(doc.get_augmented_tokens())
def substitute(self, data): change_seq = 0 # If length of input is larger than max allowed input, only augment heading part head_text, tail_text, head_tokens, tail_tokens = self.split_text(data) # Pick target word for augmentation if self.model_type in ['xlnet', 'roberta']: # xlent and roberta tokens include prefix (e.g. ▁ or Ġ') cleaned_head_tokens = [ t.replace(self.model.SUBWORD_PREFIX, '') for t in head_tokens ] else: cleaned_head_tokens = head_tokens head_doc = Doc(head_text, head_tokens) aug_idxes = self._get_aug_idxes(cleaned_head_tokens) if aug_idxes is None or len(aug_idxes) == 0: if self.include_detail: return data, [] return data aug_idxes.sort(reverse=True) for i, aug_idx in enumerate(aug_idxes): original_word = head_doc.get_token( aug_idx).get_latest_token().token token_placeholder = self.model.MASK_TOKEN if self.model_type in ['xlnet', 'roberta']: token_placeholder = self.model.SUBWORD_PREFIX + token_placeholder # Adding prefix for space change_seq += 1 head_doc.add_change_log(aug_idx, new_token=token_placeholder, action=Action.SUBSTITUTE, change_seq=self.parent_change_seq + change_seq) # remove continuous sub-word to_remove_idxes = [] for j in range(aug_idx + 1, head_doc.size()): subword_token = head_doc.get_token(j).orig_token.token if self.model_type in [ 'bert', 'distilbert' ] and self.model.SUBWORD_PREFIX in subword_token: to_remove_idxes.append(j) elif self.model_type in [ 'xlnet', 'roberta' ] and self.model.SUBWORD_PREFIX not in subword_token: to_remove_idxes.append(j) else: break for j in reversed(to_remove_idxes): head_doc.add_change_log(j, new_token='', action=Action.SUBSTITUTE, change_seq=self.parent_change_seq + change_seq) masked_text = self.model.tokenizer.convert_tokens_to_string( head_doc.get_augmented_tokens()).strip() substitute_word, prob = None, None # https://github.com/makcedward/nlpaug/pull/51 retry_cnt = 3 for _ in range(retry_cnt): outputs = self.model.predict(masked_text, target_word=original_word, n=1 + _) candidates = outputs[0] if candidates is None: continue # Filter out unused candidates (transfomers may return [unused123]) candidates = [ c for c in candidates if '[unused' not in c[0] and ']' not in c[0] ] if len(candidates) > 0: substitute_word, prob = self.sample(candidates, 1)[0] break # TODO: Alternative method better than dropout if substitute_word is None: substitute_word = '' if self.model_type in ['xlnet', 'roberta']: substitute_word = self.model.SUBWORD_PREFIX + substitute_word # Adding prefix for space head_doc.update_change_log(aug_idx, token=substitute_word, action=Action.SUBSTITUTE, change_seq=self.parent_change_seq + change_seq) # Early stop if number of token exceed max number if head_doc.size() > self.max_num_token: break head_tokens = head_doc.get_augmented_tokens() if self.model_type in ['xlnet', 'roberta']: # xlent and roberta tokens include prefix (e.g. ▁ or Ġ') head_tokens = [ self.model.SUBWORD_PREFIX + t if self.model.SUBWORD_PREFIX not in t and i != 0 else t for t in head_tokens ] augmented_text = self.model.tokenizer.convert_tokens_to_string( head_tokens) if tail_text is not None: augmented_text += ' ' + tail_text if self.include_detail: return augmented_text, head_doc.get_change_logs() else: return augmented_text
def insert(self, data): change_seq = 0 # If length of input is larger than max allowed input, only augment heading part head_text, tail_text, head_tokens, tail_tokens = self.split_text(data) if self.model_type in ['xlnet', 'roberta']: # xlent and roberta tokens include prefix (e.g. ▁ or Ġ') cleaned_head_tokens = [ t.replace(self.model.SUBWORD_PREFIX, '') for t in head_tokens ] else: cleaned_head_tokens = head_tokens head_doc = Doc(head_text, head_tokens) aug_idxes = self._get_aug_idxes(cleaned_head_tokens) if aug_idxes is None or len(aug_idxes) == 0: if self.include_detail: return data, [] return data aug_idxes.sort(reverse=True) for i, aug_idx in enumerate(aug_idxes): token_placeholder = self.model.MASK_TOKEN if self.model_type in ['xlnet', 'roberta']: token_placeholder = self.model.SUBWORD_PREFIX + token_placeholder # Adding prefix for space change_seq += 1 head_doc.add_token(aug_idx, token=token_placeholder, action=Action.INSERT, change_seq=self.parent_change_seq + change_seq) masked_text = self.model.tokenizer.convert_tokens_to_string( head_doc.get_augmented_tokens()).strip() # https://github.com/makcedward/nlpaug/issues/68 retry_cnt = 3 new_word, prob = None, None for _ in range(retry_cnt): outputs = self.model.predict(masked_text, target_word=None, n=1) candidates = outputs[0] if candidates is None: continue if len(candidates) > 0: new_word, prob = self.sample(candidates, 1)[0] break # TODO: Alternative method better than dropout if new_word is None: new_word = '' if self.model_type in ['xlnet', 'roberta']: new_word = self.model.SUBWORD_PREFIX + new_word # Adding prefix for space head_doc.update_change_log(aug_idx, token=new_word) # Early stop if number of token exceed max number if head_doc.size() > self.max_num_token: break head_tokens = head_doc.get_augmented_tokens() if self.model_type in ['xlnet', 'roberta']: # xlent and roberta tokens include prefix (e.g. ▁ or Ġ') head_tokens = [ self.model.SUBWORD_PREFIX + t if self.model.SUBWORD_PREFIX not in t and i != 0 else t for t in head_tokens ] augmented_text = self.model.tokenizer.convert_tokens_to_string( head_tokens) if tail_text is not None: augmented_text += ' ' + tail_text if self.include_detail: return augmented_text, head_doc.get_change_logs() else: return augmented_text
def substitute(self, data): if not data: return data if isinstance(data, list): all_data = data else: if data.strip() == '': return data all_data = [data] # If length of input is larger than max allowed input, only augment heading part split_results = [self.split_text(d) for d in all_data] # head_text, tail_text, head_tokens, tail_tokens # Pick target word for augmentation for i, split_result in enumerate(split_results): head_text, tail_text, head_tokens, tail_tokens = split_result if self.model_type in ['xlnet', 'roberta']: # xlent and roberta tokens include prefix (e.g. ▁ or Ġ') cleaned_head_tokens = [t.replace(self.model.SUBWORD_PREFIX, '') for t in head_tokens] else: cleaned_head_tokens = head_tokens head_doc = Doc(head_text, head_tokens) aug_idxes = self._get_aug_idxes(head_tokens) aug_idxes.sort(reverse=True) head_tokens = head_doc.get_augmented_tokens() split_results[i] += (cleaned_head_tokens, head_doc, aug_idxes, ) # Pad aug_idxes max_aug_size = max([len(split_result[6]) for split_result in split_results]) for split_result in split_results: aug_idxes = split_result[6] for _ in range(max_aug_size - len(aug_idxes)): aug_idxes.append(-1) token_placeholder = self.model.MASK_TOKEN if self.model_type in ['xlnet', 'roberta']: token_placeholder = self.model.SUBWORD_PREFIX + token_placeholder # Adding prefix for # Augment same index of aug by batch change_seq = 0 for i in range(max_aug_size): original_tokens = [] masked_texts = [] aug_input_poses = [] # store which input augmented. No record if padding change_seq += 1 for j, split_result in enumerate(split_results): head_doc, aug_idx = split_result[5], split_result[6][i] # -1 if it is padding if aug_idx == -1: continue original_tokens.append(head_doc.get_token(aug_idx).get_latest_token().token) head_doc.add_change_log(aug_idx, new_token=token_placeholder, action=Action.SUBSTITUTE, change_seq=self.parent_change_seq+change_seq) # remove continuous sub-word to_remove_idxes = [] for k in range(aug_idx+1, head_doc.size()): subword_token = head_doc.get_token(k).orig_token.token if subword_token in string.punctuation: break if self.model_type in ['bert', 'distilbert'] and self.model.SUBWORD_PREFIX in subword_token: to_remove_idxes.append(k) elif self.model_type in ['xlnet', 'roberta'] and self.model.SUBWORD_PREFIX not in subword_token: to_remove_idxes.append(k) else: break for k in reversed(to_remove_idxes): head_doc.add_change_log(k, new_token='', action=Action.SUBSTITUTE, change_seq=self.parent_change_seq+change_seq) aug_input_poses.append(j) # some tokenizers handle special charas (e.g. don't can merge after decode) if self.model_type in ['bert', 'distilbert']: ids = self.model.tokenizer.convert_tokens_to_ids(head_doc.get_augmented_tokens()) masked_text = self.model.tokenizer.decode(ids).strip() elif self.model_type in ['xlnet', 'roberta']: masked_text = self.model.tokenizer.convert_tokens_to_string(head_doc.get_augmented_tokens()).strip() masked_texts.append(masked_text) if not len(masked_texts): continue outputs = self.model.predict(masked_texts, target_words=original_tokens, n=2) # Update doc for aug_input_pos, output, masked_text in zip(aug_input_poses, outputs, masked_texts): split_result = split_results[aug_input_pos] head_doc = split_result[5] aug_idx = split_result[6][i] # augment position in text # TODO: Alternative method better than dropout candidate = '' if len(output) == 0: # TODO: no result? pass elif len(output) == 1: candidate = output[0] elif len(output) > 1: candidate = self.sample(output, 1)[0] # if self.model_type in ['xlnet', 'roberta']: # candidate = self.model.SUBWORD_PREFIX + candidate # Adding prefix for space # In XLNet, it can be the first word of sentence which does not come with sapce. E.g. Zombine (ID:29110) if self.model_type in ['xlnet', 'roberta']: if candidate != '' and not candidate.startswith(self.model.SUBWORD_PREFIX): candidate = self.model.SUBWORD_PREFIX + candidate head_doc.update_change_log(aug_idx, token=candidate, action=Action.SUBSTITUTE, change_seq=self.parent_change_seq+change_seq) # Early stop if number of token exceed max number if head_doc.size() > self.max_num_token: for j in range(i+1, max_aug_size): split_results[aug_input_pos][6][j] = -1 augmented_texts = [] for split_result in split_results: tail_text, head_doc = split_result[1], split_result[5] head_tokens = head_doc.get_augmented_tokens() # if self.model_type in ['xlnet', 'roberta']: # # xlent and roberta tokens include prefix (e.g. ▁ or Ġ') # head_tokens = [self.model.SUBWORD_PREFIX + t if self.model.SUBWORD_PREFIX not in t and i != 0 else t for i, t in enumerate(head_tokens)] ids = self.model.tokenizer.convert_tokens_to_ids(head_tokens) augmented_text = self.model.tokenizer.decode(ids) if tail_text is not None: augmented_text += ' ' + tail_text augmented_texts.append(augmented_text) if isinstance(data, list): return augmented_texts else: return augmented_texts[0]
def insert(self, data): change_seq = 0 doc = Doc(data, self.tokenizer(data)) aug_idxes = self._get_random_aug_idxes(doc.get_original_tokens()) if aug_idxes is None or len(aug_idxes) == 0: if self.include_detail: return data, [] return data aug_idxes.sort(reverse=True) for aug_idx in aug_idxes: original_token = doc.get_token(aug_idx).orig_token.token candidate_tokens = self.model.predict(original_token, top_k=self.top_k) new_token = self.sample(candidate_tokens, 1)[0] if aug_idx == 0: new_token = new_token.capitalize() change_seq += 1 doc.add_token(aug_idx, token=new_token, action=Action.INSERT, change_seq=self.parent_change_seq + change_seq) if self.get_word_case(doc.get_token(0).get_latest_token().token) == 'capitalize': change_token = doc.get_token(1).get_latest_token().token.lower() doc.add_change_log(1, new_token=change_token, action=Action.INSERT, change_seq=self.parent_change_seq + change_seq) if self.include_detail: return self.reverse_tokenizer(doc.get_augmented_tokens()), doc.get_change_logs() else: return self.reverse_tokenizer(doc.get_augmented_tokens())
def split(self, data): change_seq = 0 doc = Doc(data, self.tokenizer(data)) aug_idxes = self._get_aug_idxes(doc.get_original_tokens()) aug_idxes.sort(reverse=True) if aug_idxes is None or len(aug_idxes) == 0: if self.include_detail: return data, [] return data for aug_idx in aug_idxes: target_token = doc.get_token(aug_idx).get_latest_token().token separate_pos = self.sample(len(target_token), 1) prev_token = target_token[:separate_pos] next_token = target_token[separate_pos:] change_seq += 1 doc.add_change_log(aug_idx, new_token=next_token, action=Action.SPLIT, change_seq=self.parent_change_seq + change_seq) doc.add_token(aug_idx, token=prev_token, action=Action.SPLIT, change_seq=self.parent_change_seq + change_seq) if self.include_detail: return self.reverse_tokenizer( doc.get_augmented_tokens()), doc.get_change_logs() else: return self.reverse_tokenizer(doc.get_augmented_tokens())