示例#1
0
文件: safer.py 项目: zjiehang/RanMASK
 def perturb_batch(self,
                   instances: List[InputInstance]) -> List[InputInstance]:
     result_instances = []
     for instance in instances:
         perturb_sentences = self.augmentor.get_perturbed_batch(
             instance.perturbable_sentence().lower())
         tmp_instances = []
         for sentence in perturb_sentences:
             tmp_instances.append(
                 InputInstance.from_instance_and_perturb_sentence(
                     instance, sentence))
         result_instances.extend(tmp_instances)
     return result_instances
示例#2
0
    def augmentation(self, args: ClassifierArgs, **kwargs):
        self.loading_model_from_file(
            args.saving_dir, args.build_saving_file_name(description='best'))
        self.model.eval()

        train_instances, _ = self.build_data_loader(args,
                                                    'train',
                                                    tokenizer=False)
        train_dataset_len = len(train_instances.data)
        print('Training Set: {} sentences. '.format(train_dataset_len))

        # delete instance whose length is smaller than 3
        train_instances_deleted = [
            instance for instance in train_instances.data
            if instance.length() >= 3
        ]
        dataset_to_aug = np.random.choice(train_instances_deleted,
                                          size=(int(train_dataset_len *
                                                    0.5), ),
                                          replace=False)

        dataset_to_write = np.random.choice(train_instances.data,
                                            size=(int(train_dataset_len *
                                                      0.5), ),
                                            replace=False).tolist()
        attacker = self.build_attacker(args)
        attacker_log_manager = AttackLogManager()
        dataset = CustomTextAttackDataset.from_instances(
            args.dataset_name, dataset_to_aug, self.data_reader.get_labels())
        results_iterable = attacker.attack_dataset(dataset)
        aug_instances = []
        for result, instance in tqdm(zip(results_iterable, dataset_to_aug),
                                     total=len(dataset)):
            try:
                adv_sentence = result.perturbed_text()
                aug_instances.append(
                    InputInstance.from_instance_and_perturb_sentence(
                        instance, adv_sentence))
            except:
                print('one error happend, delete one instance')

        dataset_to_write.extend(aug_instances)
        self.data_reader.saving_instances(dataset_to_write, args.dataset_dir,
                                          'aug_{}'.format(args.attack_method))
        print('Writing {} Sentence. '.format(len(dataset_to_write)))
        attacker_log_manager.enable_stdout()
        attacker_log_manager.log_summary()
示例#3
0
def mask_instance(instance: InputInstance,
                  rate: float,
                  token: str,
                  nums: int = 1,
                  return_indexes: bool = False,
                  forbidden_indexes: List[int] = None,
                  random_probs: List[float] = None) -> List[InputInstance]:
    sentence = instance.perturbable_sentence()
    results = mask_sentence(sentence, rate, token, nums, return_indexes,
                            forbidden_indexes, random_probs)
    if return_indexes:
        mask_sentences_list = results[0]
    else:
        mask_sentences_list = results
    tmp_instances = [
        InputInstance.from_instance_and_perturb_sentence(instance, sent)
        for sent in mask_sentences_list
    ]
    if return_indexes:
        return tmp_instances, results[1]
    else:
        return tmp_instances