def perturb_batch(self, instances: List[InputInstance]) -> List[InputInstance]: result_instances = [] for instance in instances: perturb_sentences = self.augmentor.get_perturbed_batch( instance.perturbable_sentence().lower()) tmp_instances = [] for sentence in perturb_sentences: tmp_instances.append( InputInstance.from_instance_and_perturb_sentence( instance, sentence)) result_instances.extend(tmp_instances) return result_instances
def augmentation(self, args: ClassifierArgs, **kwargs): self.loading_model_from_file( args.saving_dir, args.build_saving_file_name(description='best')) self.model.eval() train_instances, _ = self.build_data_loader(args, 'train', tokenizer=False) train_dataset_len = len(train_instances.data) print('Training Set: {} sentences. '.format(train_dataset_len)) # delete instance whose length is smaller than 3 train_instances_deleted = [ instance for instance in train_instances.data if instance.length() >= 3 ] dataset_to_aug = np.random.choice(train_instances_deleted, size=(int(train_dataset_len * 0.5), ), replace=False) dataset_to_write = np.random.choice(train_instances.data, size=(int(train_dataset_len * 0.5), ), replace=False).tolist() attacker = self.build_attacker(args) attacker_log_manager = AttackLogManager() dataset = CustomTextAttackDataset.from_instances( args.dataset_name, dataset_to_aug, self.data_reader.get_labels()) results_iterable = attacker.attack_dataset(dataset) aug_instances = [] for result, instance in tqdm(zip(results_iterable, dataset_to_aug), total=len(dataset)): try: adv_sentence = result.perturbed_text() aug_instances.append( InputInstance.from_instance_and_perturb_sentence( instance, adv_sentence)) except: print('one error happend, delete one instance') dataset_to_write.extend(aug_instances) self.data_reader.saving_instances(dataset_to_write, args.dataset_dir, 'aug_{}'.format(args.attack_method)) print('Writing {} Sentence. '.format(len(dataset_to_write))) attacker_log_manager.enable_stdout() attacker_log_manager.log_summary()
def mask_instance(instance: InputInstance, rate: float, token: str, nums: int = 1, return_indexes: bool = False, forbidden_indexes: List[int] = None, random_probs: List[float] = None) -> List[InputInstance]: sentence = instance.perturbable_sentence() results = mask_sentence(sentence, rate, token, nums, return_indexes, forbidden_indexes, random_probs) if return_indexes: mask_sentences_list = results[0] else: mask_sentences_list = results tmp_instances = [ InputInstance.from_instance_and_perturb_sentence(instance, sent) for sent in mask_sentences_list ] if return_indexes: return tmp_instances, results[1] else: return tmp_instances