def attack(self, example, ground_truth_output): """Attack a single example. Args: example (:obj:`str`, :obj:`OrderedDict[str, str]` or :class:`~textattack.shared.AttackedText`): Example to attack. It can be a single string or an `OrderedDict` where keys represent the input fields (e.g. "premise", "hypothesis") and the values are the actual input textx. Also accepts :class:`~textattack.shared.AttackedText` that wraps around the input. ground_truth_output(:obj:`int`, :obj:`float` or :obj:`str`): Ground truth output of `example`. For classification tasks, it should be an integer representing the ground truth label. For regression tasks (e.g. STS), it should be the target value. For seq2seq tasks (e.g. translation), it should be the target string. Returns: :class:`~textattack.attack_results.AttackResult` that represents the result of the attack. """ assert isinstance( example, (str, OrderedDict, AttackedText) ), "`example` must either be `str`, `collections.OrderedDict`, `textattack.shared.AttackedText`." if isinstance(example, (str, OrderedDict)): example = AttackedText(example) assert isinstance( ground_truth_output, (int, str)), "`ground_truth_output` must either be `str` or `int`." goal_function_result, _ = self.goal_function.init_attack_example( example, ground_truth_output) if goal_function_result.goal_status == GoalFunctionResultStatus.SKIPPED: return SkippedAttackResult(goal_function_result) else: result = self._attack(goal_function_result) return result
def augment(self, text): """ Returns all possible augmentations of ``text`` according to ``self.transformation``. """ attacked_text = AttackedText(text) original_text = attacked_text all_transformed_texts = set() num_words_to_swap = int(self.pct_words_to_swap * len(attacked_text.words)) for _ in range(self.transformations_per_example): index_order = list(range(len(attacked_text.words))) random.shuffle(index_order) current_text = attacked_text words_swapped = 0 for i in index_order: transformed_texts = self.transformation( current_text, self.pre_transformation_constraints, [i]) # Get rid of transformations we already have transformed_texts = [ t for t in transformed_texts if t not in all_transformed_texts ] # Filter out transformations that don't match the constraints. transformed_texts = self._filter_transformations( transformed_texts, current_text, original_text) if not len(transformed_texts): continue current_text = random.choice(transformed_texts) words_swapped += 1 if words_swapped == num_words_to_swap: break all_transformed_texts.add(current_text) return sorted([at.printable_text() for at in all_transformed_texts])
def _get_examples_from_dataset(self, dataset, indices=None): """ Gets examples from a dataset and tokenizes them. Args: dataset: An iterable of (text, ground_truth_output) pairs indices: An iterable of indices of the dataset that we want to attack. If None, attack all samples in dataset. Returns: results (Iterable[GoalFunctionResult]): an iterable of GoalFunctionResults of the original examples """ indices = indices if indices else deque(range(len(dataset))) if not isinstance(indices, deque): indices = deque(indices) if not indices: return yield while indices: i = indices.popleft() try: text, ground_truth_output = dataset[i] try: # get label names from dataset, if possible label_names = dataset.label_names except AttributeError: label_names = None attacked_text = AttackedText( text, attack_attrs={"label_names": label_names} ) self.goal_function.num_queries = 0 goal_function_result, _ = self.goal_function.get_result( attacked_text, ground_truth_output ) if goal_function_result.succeeded: # Store the true output on the goal function so that the # SkippedAttackResult has the correct output, not the incorrect. goal_function_result.output = ground_truth_output yield goal_function_result except IndexError: raise IndexError( f"Out of bounds access of dataset. Size of data is {len(dataset)} but tried to access index {i}" )
def _get_examples_from_dataset(self, dataset, indices=None): """Gets examples from a dataset and tokenizes them. Args: dataset: An iterable of (text_input, ground_truth_output) pairs indices: An iterable of indices of the dataset that we want to attack. If None, attack all samples in dataset. Returns: results (Iterable[GoalFunctionResult]): an iterable of GoalFunctionResults of the original examples """ if indices is None: indices = range(len(dataset)) if not isinstance(indices, deque): indices = deque(sorted(indices)) if not indices: return yield while indices: i = indices.popleft() try: text_input, ground_truth_output = dataset[i] except IndexError: utils.logger.warn( f"Dataset has {len(dataset)} samples but tried to access index {i}. Ending attack early." ) break try: # get label names from dataset, if possible label_names = dataset.label_names except AttributeError: label_names = None attacked_text = AttackedText( text_input, attack_attrs={"label_names": label_names}) goal_function_result, _ = self.goal_function.init_attack_example( attacked_text, ground_truth_output) yield goal_function_result
def augment(self, text): """Returns all possible augmentations of ``text`` according to ``self.transformation``.""" attacked_text = AttackedText(text) original_text = attacked_text all_transformed_texts = set() num_words_to_swap = max( int(self.pct_words_to_swap * len(attacked_text.words)), 1) for _ in range(self.transformations_per_example): current_text = attacked_text words_swapped = len(current_text.attack_attrs["modified_indices"]) while words_swapped < num_words_to_swap: transformed_texts = self.transformation( current_text, self.pre_transformation_constraints) # Get rid of transformations we already have transformed_texts = [ t for t in transformed_texts if t not in all_transformed_texts ] # Filter out transformations that don't match the constraints. transformed_texts = self._filter_transformations( transformed_texts, current_text, original_text) # if there's no more transformed texts after filter, terminate if not len(transformed_texts): break current_text = random.choice(transformed_texts) # update words_swapped based on modified indices words_swapped = max( len(current_text.attack_attrs["modified_indices"]), words_swapped + 1, ) all_transformed_texts.add(current_text) return sorted([at.printable_text() for at in all_transformed_texts])
def main(params): # Loading data dataset, num_labels = load_data(params) dataset = dataset["train"] text_key = 'text' if params.dataset == "dbpedia14": text_key = 'content' print(f"Loaded dataset {params.dataset}, that has {len(dataset)} rows") # Load model and tokenizer from HuggingFace model_class = transformers.AutoModelForSequenceClassification model = model_class.from_pretrained(params.model, num_labels=num_labels).cuda() if params.ckpt != None: state_dict = torch.load(params.ckpt) model.load_state_dict(state_dict) tokenizer = textattack.models.tokenizers.AutoTokenizer(params.model) model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper( model, tokenizer, batch_size=params.batch_size) # Create radioactive directions and modify classification layer to use those if params.radioactive: torch.manual_seed(0) radioactive_directions = torch.randn(num_labels, 768) radioactive_directions /= torch.norm(radioactive_directions, dim=1, keepdim=True) print(radioactive_directions) model.classifier.weight.data = radioactive_directions.cuda() model.classifier.bias.data = torch.zeros(num_labels).cuda() start_index = params.chunk_id * params.chunk_size end_index = start_index + params.chunk_size if params.target_dir is not None: target_file = join(params.target_dir, f"{params.chunk_id}.csv") f = open(target_file, "w") f = csv.writer(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_NONNUMERIC) # Creating attack print(f"Building {params.attack} attack") if params.attack == "custom": current_label = -1 if params.targeted: current_label = dataset[start_index]['label'] assert all([ dataset[i]['label'] == current_label for i in range(start_index, end_index) ]) attack = build_attack(model_wrapper, current_label) elif params.attack == "bae": print(f"Building BAE method with threshold={params.bae_threshold:.2f}") attack = build_baegarg2019(model_wrapper, threshold_cosine=params.bae_threshold, query_budget=params.query_budget) elif params.attack == "bert-attack": assert params.query_budget is None attack = BERTAttackLi2020.build(model_wrapper) elif params.attack == "clare": assert params.query_budget is None attack = CLARE2020.build(model_wrapper) # Launching attack begin_time = time.time() samples = [ (dataset[i][text_key], attack.goal_function.get_output(AttackedText(dataset[i][text_key]))) for i in range(start_index, end_index) ] results = list(attack.attack_dataset(samples)) # Storing attacked text bert_scorer = BERTScorer(model_type="bert-base-uncased", idf=False) n_success = 0 similarities = [] queries = [] use = USE() for i_result, result in enumerate(results): print("") print(50 * "*") print("") text = dataset[start_index + i_result][text_key] ptext = result.perturbed_text() i_data = start_index + i_result if params.target_dir is not None: if params.dataset == 'dbpedia14': f.writerow([ dataset[i_data]['label'] + 1, dataset[i_data]['title'], ptext ]) else: f.writerow([dataset[i_data]['label'] + 1, ptext]) print("True label ", dataset[i_data]['label']) print(f"CLEAN TEXT\n {text}") print(f"ADV TEXT\n {ptext}") if type(result) not in [SuccessfulAttackResult, FailedAttackResult]: print("WARNING: Attack neither succeeded nor failed...") print(result.goal_function_result_str()) precision, recall, f1 = [ r.item() for r in bert_scorer.score([ptext], [text]) ] print( f"Bert scores: precision {precision:.2f}, recall: {recall:.2f}, f1: {f1:.2f}" ) initial_logits = model_wrapper([text]) final_logits = model_wrapper([ptext]) print("Initial logits", initial_logits) print("Final logits", final_logits) print("Logits difference", final_logits - initial_logits) # Statistics n_success += 1 if type(result) is SuccessfulAttackResult else 0 queries.append(result.num_queries) similarities.append(use.compute_sim([text], [ptext])) print("Processing all samples took %.2f" % (time.time() - begin_time)) print(f"Total success: {n_success}/{len(results)}") logs = { "success_rate": n_success / len(results), "avg_queries": sum(queries) / len(queries), "queries": queries, "avg_similarity": sum(similarities) / len(similarities), "similarities": similarities, } print("__logs:" + json.dumps(logs)) if params.target_dir is not None: f.close()