def get_acc_img(self):
     start_time = time.time()
     images = self.adv_data[self.remain_idx]
     labels = self.labels[self.remain_idx]
     nimages = images.shape[0]
     relative_idx = np.arange(nimages)
     attack = NoAttack()
     adv_img = attack(self.g_net_model, images[relative_idx],
                      labels[relative_idx], self.epsilon)
     is_adv, mask = misclassification_criterion(self.f_net,
                                                adv_img,
                                                labels,
                                                detector=True)
     is_adv = is_adv | ~mask
     remove_list = []
     for i in range(is_adv.shape[0]):
         if is_adv[i]:
             remove_list.append(i)
     self.adv_data[self.remain_idx[remove_list]] = adv_img[remove_list]
     if not self.is_random:
         # if random, also attack on the misclassified samples
         self.remain_idx = np.delete(self.remain_idx, remove_list, axis=0)
     relative_idx = np.delete(relative_idx, remove_list, axis=0)
     newdata = (images[relative_idx].clone(), labels[relative_idx].clone())
     self.timer += (time.time() - start_time)
     self.n_clean = len(newdata[0])
     return newdata
def get_robust_img(attack_list, f_net, g_net_model, data, epsilon, norm):
    if not isinstance(attack_list, list):
        attack_list = [attack_list]
    images = data[0]
    labels = data[1]
    nimages = images.shape[0]
    remain_idx = np.arange(nimages)
    for attack in attack_list:
        adv_img, _ = attack(g_net_model, images[remain_idx], labels[remain_idx], epsilon)
        is_adv = misclassification_criterion(f_net, adv_img, labels)
        remove_list = []
        for i in range(is_adv.shape[0]):
            if is_adv[i]:
                remove_list.append(i)
        remain_idx = np.delete(remain_idx, remove_list, axis=0)
        if len(remain_idx) == 0:
            break
    newdata = (images[remain_idx], labels[remain_idx])
    return newdata
 def get_robust_img(self, attack_list):
     start_time = time.time()
     self.acc_attack += attack_list
     images = self.raw_data[self.remain_idx]
     labels = self.labels[self.remain_idx]
     nimages = images.shape[0]
     relative_idx = np.arange(nimages)
     for attack in attack_list:
         adv_img = attack(self.g_net_model, images[relative_idx], labels[relative_idx], self.epsilon)
         is_adv = misclassification_criterion(self.f_net, adv_img, labels)
         remove_list = []
         for i in range(is_adv.shape[0]):
             if is_adv[i]:
                 remove_list.append(i)
         self.adv_data[self.remain_idx] = adv_img
         self.remain_idx = np.delete(self.remain_idx, remove_list, axis=0)
         relative_idx = np.delete(relative_idx, remove_list, axis=0)
     newdata = (images[relative_idx].clone(), labels[relative_idx].clone())
     self.timer += (time.time() - start_time)
     return newdata