Ejemplo n.º 1
0
    def attack(self, example, ground_truth_output):
        """Attack a single example.

        Args:
            example (:obj:`str`, :obj:`OrderedDict[str, str]` or :class:`~textattack.shared.AttackedText`):
                Example to attack. It can be a single string or an `OrderedDict` where
                keys represent the input fields (e.g. "premise", "hypothesis") and the values are the actual input textx.
                Also accepts :class:`~textattack.shared.AttackedText` that wraps around the input.
            ground_truth_output(:obj:`int`, :obj:`float` or :obj:`str`):
                Ground truth output of `example`.
                For classification tasks, it should be an integer representing the ground truth label.
                For regression tasks (e.g. STS), it should be the target value.
                For seq2seq tasks (e.g. translation), it should be the target string.
        Returns:
            :class:`~textattack.attack_results.AttackResult` that represents the result of the attack.
        """
        assert isinstance(
            example, (str, OrderedDict, AttackedText)
        ), "`example` must either be `str`, `collections.OrderedDict`, `textattack.shared.AttackedText`."
        if isinstance(example, (str, OrderedDict)):
            example = AttackedText(example)

        assert isinstance(
            ground_truth_output,
            (int, str)), "`ground_truth_output` must either be `str` or `int`."
        goal_function_result, _ = self.goal_function.init_attack_example(
            example, ground_truth_output)
        if goal_function_result.goal_status == GoalFunctionResultStatus.SKIPPED:
            return SkippedAttackResult(goal_function_result)
        else:
            result = self._attack(goal_function_result)
            return result
Ejemplo n.º 2
0
 def augment(self, text):
     """ 
     Returns all possible augmentations of ``text`` according to 
     ``self.transformation``.
     """
     attacked_text = AttackedText(text)
     original_text = attacked_text
     all_transformed_texts = set()
     num_words_to_swap = int(self.pct_words_to_swap *
                             len(attacked_text.words))
     for _ in range(self.transformations_per_example):
         index_order = list(range(len(attacked_text.words)))
         random.shuffle(index_order)
         current_text = attacked_text
         words_swapped = 0
         for i in index_order:
             transformed_texts = self.transformation(
                 current_text, self.pre_transformation_constraints, [i])
             # Get rid of transformations we already have
             transformed_texts = [
                 t for t in transformed_texts
                 if t not in all_transformed_texts
             ]
             # Filter out transformations that don't match the constraints.
             transformed_texts = self._filter_transformations(
                 transformed_texts, current_text, original_text)
             if not len(transformed_texts):
                 continue
             current_text = random.choice(transformed_texts)
             words_swapped += 1
             if words_swapped == num_words_to_swap:
                 break
         all_transformed_texts.add(current_text)
     return sorted([at.printable_text() for at in all_transformed_texts])
Ejemplo n.º 3
0
    def _get_examples_from_dataset(self, dataset, indices=None):
        """ 
        Gets examples from a dataset and tokenizes them.

        Args:
            dataset: An iterable of (text, ground_truth_output) pairs
            indices: An iterable of indices of the dataset that we want to attack. If None, attack all samples in dataset.
        
        Returns:
            results (Iterable[GoalFunctionResult]): an iterable of GoalFunctionResults of the original examples
        """
        indices = indices if indices else deque(range(len(dataset)))
        if not isinstance(indices, deque):
            indices = deque(indices)
        if not indices:
            return
            yield

        while indices:
            i = indices.popleft()
            try:
                text, ground_truth_output = dataset[i]
                try:
                    # get label names from dataset, if possible
                    label_names = dataset.label_names
                except AttributeError:
                    label_names = None
                attacked_text = AttackedText(
                    text, attack_attrs={"label_names": label_names}
                )
                self.goal_function.num_queries = 0
                goal_function_result, _ = self.goal_function.get_result(
                    attacked_text, ground_truth_output
                )
                if goal_function_result.succeeded:
                    # Store the true output on the goal function so that the
                    # SkippedAttackResult has the correct output, not the incorrect.
                    goal_function_result.output = ground_truth_output
                yield goal_function_result

            except IndexError:
                raise IndexError(
                    f"Out of bounds access of dataset. Size of data is {len(dataset)} but tried to access index {i}"
                )
Ejemplo n.º 4
0
    def _get_examples_from_dataset(self, dataset, indices=None):
        """Gets examples from a dataset and tokenizes them.

        Args:
            dataset: An iterable of (text_input, ground_truth_output) pairs
            indices: An iterable of indices of the dataset that we want to attack. If None, attack all samples in dataset.

        Returns:
            results (Iterable[GoalFunctionResult]): an iterable of GoalFunctionResults of the original examples
        """
        if indices is None:
            indices = range(len(dataset))

        if not isinstance(indices, deque):
            indices = deque(sorted(indices))

        if not indices:
            return
            yield

        while indices:
            i = indices.popleft()
            try:
                text_input, ground_truth_output = dataset[i]
            except IndexError:
                utils.logger.warn(
                    f"Dataset has {len(dataset)} samples but tried to access index {i}. Ending attack early."
                )
                break

            try:
                # get label names from dataset, if possible
                label_names = dataset.label_names
            except AttributeError:
                label_names = None

            attacked_text = AttackedText(
                text_input, attack_attrs={"label_names": label_names})
            goal_function_result, _ = self.goal_function.init_attack_example(
                attacked_text, ground_truth_output)
            yield goal_function_result
Ejemplo n.º 5
0
    def augment(self, text):
        """Returns all possible augmentations of ``text`` according to
        ``self.transformation``."""
        attacked_text = AttackedText(text)
        original_text = attacked_text
        all_transformed_texts = set()
        num_words_to_swap = max(
            int(self.pct_words_to_swap * len(attacked_text.words)), 1)
        for _ in range(self.transformations_per_example):
            current_text = attacked_text
            words_swapped = len(current_text.attack_attrs["modified_indices"])

            while words_swapped < num_words_to_swap:
                transformed_texts = self.transformation(
                    current_text, self.pre_transformation_constraints)

                # Get rid of transformations we already have
                transformed_texts = [
                    t for t in transformed_texts
                    if t not in all_transformed_texts
                ]

                # Filter out transformations that don't match the constraints.
                transformed_texts = self._filter_transformations(
                    transformed_texts, current_text, original_text)

                # if there's no more transformed texts after filter, terminate
                if not len(transformed_texts):
                    break

                current_text = random.choice(transformed_texts)

                # update words_swapped based on modified indices
                words_swapped = max(
                    len(current_text.attack_attrs["modified_indices"]),
                    words_swapped + 1,
                )
            all_transformed_texts.add(current_text)
        return sorted([at.printable_text() for at in all_transformed_texts])
Ejemplo n.º 6
0
def main(params):
    # Loading data
    dataset, num_labels = load_data(params)
    dataset = dataset["train"]
    text_key = 'text'
    if params.dataset == "dbpedia14":
        text_key = 'content'
    print(f"Loaded dataset {params.dataset}, that has {len(dataset)} rows")

    # Load model and tokenizer from HuggingFace
    model_class = transformers.AutoModelForSequenceClassification
    model = model_class.from_pretrained(params.model,
                                        num_labels=num_labels).cuda()

    if params.ckpt != None:
        state_dict = torch.load(params.ckpt)
        model.load_state_dict(state_dict)
    tokenizer = textattack.models.tokenizers.AutoTokenizer(params.model)
    model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(
        model, tokenizer, batch_size=params.batch_size)

    # Create radioactive directions and modify classification layer to use those
    if params.radioactive:
        torch.manual_seed(0)
        radioactive_directions = torch.randn(num_labels, 768)
        radioactive_directions /= torch.norm(radioactive_directions,
                                             dim=1,
                                             keepdim=True)
        print(radioactive_directions)
        model.classifier.weight.data = radioactive_directions.cuda()
        model.classifier.bias.data = torch.zeros(num_labels).cuda()

    start_index = params.chunk_id * params.chunk_size
    end_index = start_index + params.chunk_size

    if params.target_dir is not None:
        target_file = join(params.target_dir, f"{params.chunk_id}.csv")
        f = open(target_file, "w")
        f = csv.writer(f,
                       delimiter=',',
                       quotechar='"',
                       quoting=csv.QUOTE_NONNUMERIC)

    # Creating attack
    print(f"Building {params.attack} attack")
    if params.attack == "custom":
        current_label = -1
        if params.targeted:
            current_label = dataset[start_index]['label']
            assert all([
                dataset[i]['label'] == current_label
                for i in range(start_index, end_index)
            ])
        attack = build_attack(model_wrapper, current_label)
    elif params.attack == "bae":
        print(f"Building BAE method with threshold={params.bae_threshold:.2f}")
        attack = build_baegarg2019(model_wrapper,
                                   threshold_cosine=params.bae_threshold,
                                   query_budget=params.query_budget)
    elif params.attack == "bert-attack":
        assert params.query_budget is None
        attack = BERTAttackLi2020.build(model_wrapper)
    elif params.attack == "clare":
        assert params.query_budget is None
        attack = CLARE2020.build(model_wrapper)

    # Launching attack
    begin_time = time.time()
    samples = [
        (dataset[i][text_key],
         attack.goal_function.get_output(AttackedText(dataset[i][text_key])))
        for i in range(start_index, end_index)
    ]
    results = list(attack.attack_dataset(samples))

    # Storing attacked text
    bert_scorer = BERTScorer(model_type="bert-base-uncased", idf=False)

    n_success = 0
    similarities = []
    queries = []
    use = USE()

    for i_result, result in enumerate(results):
        print("")
        print(50 * "*")
        print("")
        text = dataset[start_index + i_result][text_key]
        ptext = result.perturbed_text()
        i_data = start_index + i_result
        if params.target_dir is not None:
            if params.dataset == 'dbpedia14':
                f.writerow([
                    dataset[i_data]['label'] + 1, dataset[i_data]['title'],
                    ptext
                ])
            else:
                f.writerow([dataset[i_data]['label'] + 1, ptext])

        print("True label ", dataset[i_data]['label'])
        print(f"CLEAN TEXT\n {text}")
        print(f"ADV TEXT\n {ptext}")

        if type(result) not in [SuccessfulAttackResult, FailedAttackResult]:
            print("WARNING: Attack neither succeeded nor failed...")
        print(result.goal_function_result_str())
        precision, recall, f1 = [
            r.item() for r in bert_scorer.score([ptext], [text])
        ]
        print(
            f"Bert scores: precision {precision:.2f}, recall: {recall:.2f}, f1: {f1:.2f}"
        )
        initial_logits = model_wrapper([text])
        final_logits = model_wrapper([ptext])
        print("Initial logits", initial_logits)
        print("Final logits", final_logits)
        print("Logits difference", final_logits - initial_logits)

        # Statistics
        n_success += 1 if type(result) is SuccessfulAttackResult else 0
        queries.append(result.num_queries)
        similarities.append(use.compute_sim([text], [ptext]))

    print("Processing all samples took %.2f" % (time.time() - begin_time))
    print(f"Total success: {n_success}/{len(results)}")
    logs = {
        "success_rate": n_success / len(results),
        "avg_queries": sum(queries) / len(queries),
        "queries": queries,
        "avg_similarity": sum(similarities) / len(similarities),
        "similarities": similarities,
    }
    print("__logs:" + json.dumps(logs))
    if params.target_dir is not None:
        f.close()