def _bae_replacement_words(self, current_text, indices_to_modify): """Get replacement words for the word we want to replace using BAE method. Args: current_text (AttackedText): Text we want to get replacements for. index (int): index of word we want to replace """ masked_texts = [] for index in indices_to_modify: masked_texts.append( current_text.replace_word_at_index( index, self._lm_tokenizer.mask_token).text) i = 0 # 2-D list where for each index to modify we have a list of replacement words replacement_words = [] while i < len(masked_texts): inputs = self._encode_text(masked_texts[i:i + self.batch_size]) ids = inputs["input_ids"].tolist() with torch.no_grad(): preds = self._language_model(**inputs)[0] for j in range(len(ids)): try: # Need try-except b/c mask-token located past max_length might be truncated by tokenizer masked_index = ids[j].index( self._lm_tokenizer.mask_token_id) except ValueError: replacement_words.append([]) continue mask_token_logits = preds[j, masked_index] mask_token_probs = torch.softmax(mask_token_logits, dim=0) ranked_indices = torch.argsort(mask_token_probs) top_words = [] for _id in ranked_indices: _id = _id.item() token = self._lm_tokenizer.convert_ids_to_tokens(_id) if utils.check_if_subword( token, self._language_model.config.model_type, (masked_index == 1), ): word = utils.strip_BPE_artifacts( token, self._language_model.config.model_type) if (mask_token_probs[_id] >= self.min_confidence and utils.is_one_word(word) and not utils.check_if_punctuations(word)): top_words.append(token) if (len(top_words) >= self.max_candidates or mask_token_probs[_id] < self.min_confidence): break replacement_words.append(top_words) i += self.batch_size return replacement_words
def _bert_attack_replacement_words( self, current_text, index, id_preds, masked_lm_logits, ): """Get replacement words for the word we want to replace using BERT- Attack method. Args: current_text (AttackedText): Text we want to get replacements for. index (int): index of word we want to replace id_preds (torch.Tensor): N x K tensor of top-K ids for each token-position predicted by the masked language model. N is equivalent to `self.max_length`. masked_lm_logits (torch.Tensor): N x V tensor of the raw logits outputted by the masked language model. N is equivlaent to `self.max_length` and V is dictionary size of masked language model. """ # We need to find which BPE tokens belong to the word we want to replace masked_text = current_text.replace_word_at_index( index, self._lm_tokenizer.mask_token) current_inputs = self._encode_text(masked_text.text) current_ids = current_inputs["input_ids"].tolist()[0] word_tokens = self._lm_tokenizer.encode(current_text.words[index], add_special_tokens=False) try: # Need try-except b/c mask-token located past max_length might be truncated by tokenizer masked_index = current_ids.index(self._lm_tokenizer.mask_token_id) except ValueError: return [] # List of indices of tokens that are part of the target word target_ids_pos = list( range(masked_index, min(masked_index + len(word_tokens), self.max_length))) if not len(target_ids_pos): return [] elif len(target_ids_pos) == 1: # Word to replace is tokenized as a single word top_preds = id_preds[target_ids_pos[0]].tolist() replacement_words = [] for id in top_preds: token = self._lm_tokenizer.convert_ids_to_tokens(id) if utils.is_one_word(token) and not utils.check_if_subword( token, self._language_model.config.model_type, index == 0): replacement_words.append(token) return replacement_words else: # Word to replace is tokenized as multiple sub-words top_preds = [id_preds[i] for i in target_ids_pos] products = itertools.product(*top_preds) combination_results = [] # Original BERT-Attack implement uses cross-entropy loss cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction="none") target_ids_pos_tensor = torch.tensor(target_ids_pos) word_tensor = torch.zeros(len(target_ids_pos), dtype=torch.long) for bpe_tokens in products: for i in range(len(bpe_tokens)): word_tensor[i] = bpe_tokens[i] logits = torch.index_select(masked_lm_logits, 0, target_ids_pos_tensor) loss = cross_entropy_loss(logits, word_tensor) perplexity = torch.exp(torch.mean(loss, dim=0)).item() word = "".join( self._lm_tokenizer.convert_ids_to_tokens( word_tensor)).replace("##", "") if utils.is_one_word(word): combination_results.append((word, perplexity)) # Sort to get top-K results sorted(combination_results, key=lambda x: x[1]) top_replacements = [ x[0] for x in combination_results[:self.max_candidates] ] return top_replacements