def _bae_replacement_words(self, current_text, indices_to_modify): """Get replacement words for the word we want to replace using BAE method. Args: current_text (AttackedText): Text we want to get replacements for. index (int): index of word we want to replace """ masked_texts = [] for index in indices_to_modify: masked_texts.append( current_text.replace_word_at_index( index, self._lm_tokenizer.mask_token).text) i = 0 # 2-D list where for each index to modify we have a list of replacement words replacement_words = [] while i < len(masked_texts): inputs = self._encode_text(masked_texts[i:i + self.batch_size]) ids = inputs["input_ids"].tolist() with torch.no_grad(): preds = self._language_model(**inputs)[0] for j in range(len(ids)): try: # Need try-except b/c mask-token located past max_length might be truncated by tokenizer masked_index = ids[j].index( self._lm_tokenizer.mask_token_id) except ValueError: replacement_words.append([]) continue mask_token_logits = preds[j, masked_index] mask_token_probs = torch.softmax(mask_token_logits, dim=0) ranked_indices = torch.argsort(mask_token_probs) top_words = [] for _id in ranked_indices: _id = _id.item() token = self._lm_tokenizer.convert_ids_to_tokens(_id) if utils.check_if_subword( token, self._language_model.config.model_type, (masked_index == 1), ): word = utils.strip_BPE_artifacts( token, self._language_model.config.model_type) if (mask_token_probs[_id] >= self.min_confidence and utils.is_one_word(word) and not utils.check_if_punctuations(word)): top_words.append(token) if (len(top_words) >= self.max_candidates or mask_token_probs[_id] < self.min_confidence): break replacement_words.append(top_words) i += self.batch_size return replacement_words
def _get_transformations(self, current_text, indices_to_modify): words = current_text.words sentence = Sentence(" ".join(words)) # in-place POS tagging self._flair_pos_tagger.predict(sentence) word_list, pos_list = zip_flair_result(sentence) assert len(words) == len( word_list ), "Part-of-speech tagger returned incorrect number of tags" transformed_texts = [] for i in indices_to_modify: word_to_replace = words[i] word_to_replace_pos = pos_list[i][:2] # get the root POS replacement_words = self._get_replacement_words( word_to_replace, word_to_replace_pos) transformed_texts_idx = [] for r in replacement_words: if r != word_to_replace and utils.is_one_word(r): transformed_texts_idx.append( current_text.replace_word_at_index(i, r)) transformed_texts.extend(transformed_texts_idx) return transformed_texts
def _bae_replacement_words(self, current_text, index): """Get replacement words for the word we want to replace using BAE method. Args: current_text (AttackedText): Text we want to get replacements for. index (int): index of word we want to replace """ masked_attacked_text = current_text.replace_word_at_index( index, self._lm_tokenizer.mask_token) inputs = self._encode_text(masked_attacked_text.text) ids = inputs["input_ids"].tolist()[0] try: # Need try-except b/c mask-token located past max_length might be truncated by tokenizer masked_index = ids.index(self._lm_tokenizer.mask_token_id) except ValueError: return [] with torch.no_grad(): preds = self._language_model(**inputs)[0] mask_token_probs = preds[0, masked_index] topk = torch.topk(mask_token_probs, self.max_candidates) top_ids = topk[1].tolist() replacement_words = [] for id in top_ids: token = self._lm_tokenizer.convert_ids_to_tokens(id) if utils.is_one_word(token) and not check_if_subword(token): replacement_words.append(token) return replacement_words
def _get_transformations(self, current_text, indices_to_modify): transformed_texts = [] for i in indices_to_modify: word_to_replace = current_text.words[i] word_to_replace_pos = current_text.pos_of_word_index(i) replacement_words = self._get_replacement_words( word_to_replace, word_to_replace_pos) transformed_texts_idx = [] for r in replacement_words: if r != word_to_replace and utils.is_one_word(r): transformed_texts_idx.append( current_text.replace_word_at_index(i, r)) transformed_texts.extend(transformed_texts_idx) return transformed_texts
def _bert_attack_replacement_words( self, current_text, index, id_preds, masked_lm_logits, ): """Get replacement words for the word we want to replace using BERT- Attack method. Args: current_text (AttackedText): Text we want to get replacements for. index (int): index of word we want to replace id_preds (torch.Tensor): N x K tensor of top-K ids for each token-position predicted by the masked language model. N is equivalent to `self.max_length`. masked_lm_logits (torch.Tensor): N x V tensor of the raw logits outputted by the masked language model. N is equivlaent to `self.max_length` and V is dictionary size of masked language model. """ # We need to find which BPE tokens belong to the word we want to replace masked_text = current_text.replace_word_at_index( index, self._lm_tokenizer.mask_token) current_inputs = self._encode_text(masked_text.text) current_ids = current_inputs["input_ids"].tolist()[0] word_tokens = self._lm_tokenizer.encode(current_text.words[index], add_special_tokens=False) try: # Need try-except b/c mask-token located past max_length might be truncated by tokenizer masked_index = current_ids.index(self._lm_tokenizer.mask_token_id) except ValueError: return [] # List of indices of tokens that are part of the target word target_ids_pos = list( range(masked_index, min(masked_index + len(word_tokens), self.max_length))) if not len(target_ids_pos): return [] elif len(target_ids_pos) == 1: # Word to replace is tokenized as a single word top_preds = id_preds[target_ids_pos[0]].tolist() replacement_words = [] for id in top_preds: token = self._lm_tokenizer.convert_ids_to_tokens(id) if utils.is_one_word(token) and not utils.check_if_subword( token, self._language_model.config.model_type, index == 0): replacement_words.append(token) return replacement_words else: # Word to replace is tokenized as multiple sub-words top_preds = [id_preds[i] for i in target_ids_pos] products = itertools.product(*top_preds) combination_results = [] # Original BERT-Attack implement uses cross-entropy loss cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction="none") target_ids_pos_tensor = torch.tensor(target_ids_pos) word_tensor = torch.zeros(len(target_ids_pos), dtype=torch.long) for bpe_tokens in products: for i in range(len(bpe_tokens)): word_tensor[i] = bpe_tokens[i] logits = torch.index_select(masked_lm_logits, 0, target_ids_pos_tensor) loss = cross_entropy_loss(logits, word_tensor) perplexity = torch.exp(torch.mean(loss, dim=0)).item() word = "".join( self._lm_tokenizer.convert_ids_to_tokens( word_tensor)).replace("##", "") if utils.is_one_word(word): combination_results.append((word, perplexity)) # Sort to get top-K results sorted(combination_results, key=lambda x: x[1]) top_replacements = [ x[0] for x in combination_results[:self.max_candidates] ] return top_replacements