def __init__(self, model, norm='Linf', eps=.3, seed=None, verbose=True, attacks_to_run=['apgd-ce', 'apgd-dlr', 'fab', 'square'], plus=False, is_tf_model=False, device='cuda', log_path=None): self.model = model self.norm = norm assert norm in ['Linf', 'L2'] self.epsilon = eps self.seed = seed self.verbose = verbose if plus: attacks_to_run.extend(['apgd-t', 'fab-t']) self.attacks_to_run = attacks_to_run self.plus = plus self.is_tf_model = is_tf_model self.device = device self.logger = utils.Logger(log_path) if not self.is_tf_model: from autopgd_pt import APGDAttack self.apgd = APGDAttack(self.model, n_restarts=5, n_iter=100, verbose=False, eps=self.epsilon, norm=self.norm, eot_iter=1, rho=.75, seed=self.seed, device=self.device) from fab_pt import FABAttack self.fab = FABAttack(self.model, n_restarts=5, n_iter=100, eps=self.epsilon, seed=self.seed, norm=self.norm, verbose=False, device=self.device) from square import SquareAttack self.square = SquareAttack(self.model, p_init=.8, n_queries=5000, eps=self.epsilon, norm=self.norm, early_stop=True, n_restarts=1, seed=self.seed, verbose=False, device=self.device) from autopgd_pt import APGDAttack_targeted self.apgd_targeted = APGDAttack_targeted(self.model, n_restarts=1, n_iter=100, verbose=False, eps=self.epsilon, norm=self.norm, eot_iter=1, rho=.75, seed=self.seed, device=self.device) else: from autopgd_tf import APGDAttack self.apgd = APGDAttack(self.model, n_restarts=5, n_iter=100, verbose=False, eps=self.epsilon, norm=self.norm, eot_iter=1, rho=.75, seed=self.seed, device=self.device) from fab_tf import FABAttack self.fab = FABAttack(self.model, n_restarts=5, n_iter=100, eps=self.epsilon, seed=self.seed, norm=self.norm, verbose=False, device=self.device) from square import SquareAttack self.square = SquareAttack(self.model.predict, p_init=.8, n_queries=5000, eps=self.epsilon, norm=self.norm, early_stop=True, n_restarts=1, seed=self.seed, verbose=False, device=self.device) from autopgd_tf import APGDAttack_targeted self.apgd_targeted = APGDAttack_targeted(self.model, n_restarts=1, n_iter=100, verbose=False, eps=self.epsilon, norm=self.norm, eot_iter=1, rho=.75, seed=self.seed, device=self.device)
class AutoAttack(): def __init__(self, model, norm='Linf', eps=.3, seed=None, verbose=True, attacks_to_run=['apgd-ce', 'apgd-dlr', 'fab', 'square'], plus=False, is_tf_model=False, device='cuda', log_path=None): self.model = model self.norm = norm assert norm in ['Linf', 'L2'] self.epsilon = eps self.seed = seed self.verbose = verbose if plus: attacks_to_run.extend(['apgd-t', 'fab-t']) self.attacks_to_run = attacks_to_run self.plus = plus self.is_tf_model = is_tf_model self.device = device self.logger = utils.Logger(log_path) if not self.is_tf_model: from autopgd_pt import APGDAttack self.apgd = APGDAttack(self.model, n_restarts=5, n_iter=100, verbose=False, eps=self.epsilon, norm=self.norm, eot_iter=1, rho=.75, seed=self.seed, device=self.device) from fab_pt import FABAttack self.fab = FABAttack(self.model, n_restarts=5, n_iter=100, eps=self.epsilon, seed=self.seed, norm=self.norm, verbose=False, device=self.device) from square import SquareAttack self.square = SquareAttack(self.model, p_init=.8, n_queries=5000, eps=self.epsilon, norm=self.norm, early_stop=True, n_restarts=1, seed=self.seed, verbose=False, device=self.device) from autopgd_pt import APGDAttack_targeted self.apgd_targeted = APGDAttack_targeted(self.model, n_restarts=1, n_iter=100, verbose=False, eps=self.epsilon, norm=self.norm, eot_iter=1, rho=.75, seed=self.seed, device=self.device) else: from autopgd_tf import APGDAttack self.apgd = APGDAttack(self.model, n_restarts=5, n_iter=100, verbose=False, eps=self.epsilon, norm=self.norm, eot_iter=1, rho=.75, seed=self.seed, device=self.device) from fab_tf import FABAttack self.fab = FABAttack(self.model, n_restarts=5, n_iter=100, eps=self.epsilon, seed=self.seed, norm=self.norm, verbose=False, device=self.device) from square import SquareAttack self.square = SquareAttack(self.model.predict, p_init=.8, n_queries=5000, eps=self.epsilon, norm=self.norm, early_stop=True, n_restarts=1, seed=self.seed, verbose=False, device=self.device) from autopgd_tf import APGDAttack_targeted self.apgd_targeted = APGDAttack_targeted(self.model, n_restarts=1, n_iter=100, verbose=False, eps=self.epsilon, norm=self.norm, eot_iter=1, rho=.75, seed=self.seed, device=self.device) def get_logits(self, x): if not self.is_tf_model: return self.model(x) else: return self.model.predict(x) def get_seed(self): return time.time() if self.seed is None else self.seed def run_standard_evaluation(self, x_orig, y_orig, bs=250): # update attacks list if plus activated or after initialization if self.plus: if not 'apgd-t' in self.attacks_to_run: self.attacks_to_run.extend(['apgd-t']) if not 'fab-t' in self.attacks_to_run: self.attacks_to_run.extend(['fab-t']) with torch.no_grad(): # calculate accuracy n_batches = int(np.ceil(x_orig.shape[0] / bs)) robust_flags = torch.zeros(x_orig.shape[0], dtype=torch.bool, device=x_orig.device) for batch_idx in range(n_batches): start_idx = batch_idx * bs end_idx = min( (batch_idx + 1) * bs, x_orig.shape[0]) x = x_orig[start_idx:end_idx, :].clone().to(self.device) y = y_orig[start_idx:end_idx].clone().to(self.device) output = self.get_logits(x) correct_batch = y.eq(output.max(dim=1)[1]) robust_flags[start_idx:end_idx] = correct_batch.detach().to(robust_flags.device) robust_accuracy = torch.sum(robust_flags).item() / x_orig.shape[0] if self.verbose: self.logger.log('initial accuracy: {:.2%}'.format(robust_accuracy)) x_adv = x_orig.clone().detach() startt = time.time() for attack in self.attacks_to_run: # item() is super important as pytorch int division uses floor rounding num_robust = torch.sum(robust_flags).item() if num_robust == 0: break n_batches = int(np.ceil(num_robust / bs)) robust_lin_idcs = torch.nonzero(robust_flags, as_tuple=False) if num_robust > 1: robust_lin_idcs.squeeze_() for batch_idx in range(n_batches): start_idx = batch_idx * bs end_idx = min((batch_idx + 1) * bs, num_robust) batch_datapoint_idcs = robust_lin_idcs[start_idx:end_idx] if len(batch_datapoint_idcs.shape) > 1: batch_datapoint_idcs.squeeze_(-1) x = x_orig[batch_datapoint_idcs, :].clone().to(self.device) y = y_orig[batch_datapoint_idcs].clone().to(self.device) # make sure that x is a 4d tensor even if there is only a single datapoint left if len(x.shape) == 3: x.unsqueeze_(dim=0) # run attack if attack == 'apgd-ce': # apgd on cross-entropy loss self.apgd.loss = 'ce' self.apgd.seed = self.get_seed() _, adv_curr = self.apgd.perturb(x, y, cheap=True) elif attack == 'apgd-dlr': # apgd on dlr loss self.apgd.loss = 'dlr' self.apgd.seed = self.get_seed() _, adv_curr = self.apgd.perturb(x, y, cheap=True) elif attack == 'fab': # fab self.fab.targeted = False self.fab.seed = self.get_seed() adv_curr = self.fab.perturb(x, y) elif attack == 'square': # square self.square.seed = self.get_seed() _, adv_curr = self.square.perturb(x, y) elif attack == 'apgd-t': # targeted apgd self.apgd_targeted.seed = self.get_seed() _, adv_curr = self.apgd_targeted.perturb(x, y, cheap=True) elif attack == 'fab-t': # fab targeted self.fab.targeted = True self.fab.n_restarts = 1 self.fab.seed = self.get_seed() adv_curr = self.fab.perturb(x, y) else: raise ValueError('Attack not supported') output = self.get_logits(adv_curr) false_batch = ~y.eq(output.max(dim=1)[1]).to(robust_flags.device) non_robust_lin_idcs = batch_datapoint_idcs[false_batch] robust_flags[non_robust_lin_idcs] = False x_adv[non_robust_lin_idcs] = adv_curr[false_batch].detach().to(x_adv.device) if self.verbose: num_non_robust_batch = torch.sum(false_batch) self.logger.log('{} - {}/{} - {} out of {} successfully perturbed'.format( attack, batch_idx + 1, n_batches, num_non_robust_batch, x.shape[0])) robust_accuracy = torch.sum(robust_flags).item() / x_orig.shape[0] if self.verbose: print('robust accuracy after {}: {:.2%} (total time {:.1f} s)'.format( attack.upper(), robust_accuracy, time.time() - startt)) # final check if self.verbose: if self.norm == 'Linf': res = (x_adv - x_orig).abs().view(x_orig.shape[0], -1).max(1)[0] elif self.norm == 'L2': res = ((x_adv - x_orig) ** 2).view(x_orig.shape[0], -1).sum(-1).sqrt() self.logger.log('max {} perturbation: {:.5f}, nan in tensor: {}, max: {:.5f}, min: {:.5f}'.format( self.norm, res.max(), (x_adv != x_adv).sum(), x_adv.max(), x_adv.min())) self.logger.log('robust accuracy: {:.2%}'.format(robust_accuracy)) return x_adv def clean_accuracy(self, x_orig, y_orig, bs=250): n_batches = x_orig.shape[0] // bs acc = 0. for counter in range(n_batches): x = x_orig[counter * bs:min((counter + 1) * bs, x_orig.shape[0])].clone().to(self.device) y = y_orig[counter * bs:min((counter + 1) * bs, x_orig.shape[0])].clone().to(self.device) output = self.get_logits(x) acc += (output.max(1)[1] == y).float().sum() if self.verbose: print('clean accuracy: {:.2%}'.format(acc / x_orig.shape[0])) return acc.item() / x_orig.shape[0] def run_standard_evaluation_individual(self, x_orig, y_orig, bs=250): # update attacks list if plus activated after initialization if self.plus: if not 'apgd-t' in self.attacks_to_run: self.attacks_to_run.extend(['apgd-t']) if not 'fab-t' in self.attacks_to_run: self.attacks_to_run.extend(['fab-t']) l_attacks = self.attacks_to_run adv = {} self.plus = False verbose_indiv = self.verbose self.verbose = False for c in l_attacks: startt = time.time() self.attacks_to_run = [c] adv[c] = self.run_standard_evaluation(x_orig, y_orig, bs=bs) if verbose_indiv: acc_indiv = self.clean_accuracy(adv[c], y_orig, bs=bs) space = '\t \t' if c == 'fab' else '\t' self.logger.log('robust accuracy by {} {} {:.2%} \t (time attack: {:.1f} s)'.format( c.upper(), space, acc_indiv, time.time() - startt)) return adv def cheap(self): self.apgd.n_restarts = 1 self.fab.n_restarts = 1 self.apgd_targeted.n_restarts = 1 self.square.n_queries = 1000 self.plus = False
class Attack(): def __init__(self, model, eot_iter, norm='Linf', eps=.3, restarts=5, seed=None, verbose=True, attacks_to_run=['apgd-ce', 'apgd-dlr', 'fab', 'square', 'MM'], plus=False, is_tf_model=False, device='cuda'): # self.model = model self.norm = norm self.eot_iter = eot_iter assert norm in ['Linf', 'L2'] self.epsilon = eps self.restarts = restarts self.seed = seed self.verbose = verbose if plus: attacks_to_run.extend(['apgd-t', 'fab-t']) self.attacks_to_run = attacks_to_run self.plus = plus self.is_tf_model = is_tf_model self.device = device from autopgd_pt import APGDAttack self.apgd = APGDAttack(self.model, n_restarts=self.restarts, n_iter=100, verbose=False, eps=self.epsilon, norm=self.norm, eot_iter=1, rho=.75, seed=self.seed, device=self.device) from fab_pt import FABAttack self.fab = FABAttack(self.model, n_restarts=self.restarts, n_iter=100, eps=self.epsilon, seed=self.seed, norm=self.norm, verbose=False, device=self.device) from square import SquareAttack self.square = SquareAttack(self.model, p_init=.8, n_queries=5000, eps=self.epsilon, norm=self.norm, n_restarts=1, seed=self.seed, verbose=False, device=self.device, resc_schedule=False) from autopgd_pt import APGDAttack_targeted self.apgd_targeted = APGDAttack_targeted(self.model, n_restarts=1, n_iter=100, verbose=False, eps=self.epsilon, norm=self.norm, eot_iter=1, rho=.75, seed=self.seed, device=self.device) def get_logits(self, x): return self.model(x) def _pgd_whitebox(self, X, y, epsilon=0.03137254, num_steps=100, step_size=0.007843137, eot_iter=1): out = self.model(X) err = (out.data.max(1)[1] != y.data).float().sum() X_pgd = Variable(X.data, requires_grad=True) random_noise = torch.FloatTensor(*X_pgd.shape).uniform_( -epsilon, epsilon).cuda() X_pgd = Variable(X_pgd.data + random_noise, requires_grad=True) for _ in range(num_steps): summer_grad = torch.zeros(X_pgd.shape).cuda() for j in range(eot_iter): opt = optim.SGD([X_pgd], lr=1e-3) opt.zero_grad() with torch.enable_grad(): loss = nn.CrossEntropyLoss()(self.model(X_pgd), y) loss.backward() summer_grad = summer_grad + X_pgd.grad.data eta = step_size * ((summer_grad).sign()) X_pgd = Variable(X_pgd.data + eta, requires_grad=True) eta = torch.clamp(X_pgd.data - X.data, -epsilon, epsilon) X_pgd = Variable(X.data + eta, requires_grad=True) X_pgd = Variable(torch.clamp(X_pgd, 0, 1.0), requires_grad=True) return X_pgd def get_seed(self): return time.time() if self.seed is None else self.seed def run_standard_evaluation(self, x_orig, y_orig, lister_attack, threshold, bs=1000, type_of_thresholder="original", RR=10): with torch.no_grad(): # calculate accuracy n_batches = int(np.ceil(x_orig.shape[0] / bs)) robust_flags = torch.zeros(x_orig.shape[0], dtype=torch.bool, device=x_orig.device) #print("Robust_flag",len(robust_flags)) for batch_idx in range(n_batches): start_idx = batch_idx * bs end_idx = min((batch_idx + 1) * bs, x_orig.shape[0]) x = x_orig[start_idx:end_idx, :].clone().to(self.device) y = y_orig[start_idx:end_idx].clone().to(self.device) correct_batch = y.eq(y) robust_flags[start_idx:end_idx] = correct_batch.detach().to( robust_flags.device) #print("Robust_flag",len(robust_flags)) x_adv = x_orig.clone().detach() startt = time.time() attack = lister_attack[0] index_array_accepted_false_0 = [] index_array_accepted_true_0 = [] index_array_accepted_false = [] index_array_accepted_true = [] index_array_rejected = [] num_robust = torch.sum(robust_flags).item() n_batches = int(np.ceil(num_robust / bs)) robust_lin_idcs = torch.nonzero(robust_flags, as_tuple=False) if num_robust > 1: robust_lin_idcs.squeeze_() noise_mat = torch.Tensor(np.load('./mat.npy')).cuda() #print("Number of batches are:",n_batches) for batch_idx in range(n_batches): start_idx = batch_idx * bs end_idx = min((batch_idx + 1) * bs, num_robust) batch_datapoint_idcs = robust_lin_idcs[start_idx:end_idx] if len(batch_datapoint_idcs.shape) > 1: batch_datapoint_idcs.squeeze_(-1) x = x_orig[batch_datapoint_idcs, :].clone().to(self.device) y = y_orig[batch_datapoint_idcs].clone().to(self.device) # make sure that x is a 4d tensor even if there is only a single datapoint left if len(x.shape) == 3: x.unsqueeze_(dim=0) # run attack if attack == 'apgd_ce': # apgd on cross-entropy loss print("running apgd_ce") self.apgd.loss = 'ce' self.apgd.seed = self.get_seed() _, adv_curr = self.apgd.perturb(x, y, cheap=True) elif attack == 'apgd_dlr': print("running apgd_dlr") # apgd on dlr loss self.apgd.loss = 'dlr' self.apgd.seed = self.get_seed() _, adv_curr = self.apgd.perturb(x, y, cheap=True) elif attack == 'fab': print("running fab") # fab self.fab.targeted = False self.fab.seed = self.get_seed() adv_curr = self.fab.perturb(x, y) elif attack == 'square': # square print("running square") self.square.seed = self.get_seed() adv_curr = self.square.perturb(x, y) elif attack == 'clean': # square print("running clean") adv_curr = x elif attack == 'Gama_pgd': print("running Gama_pgd") with torch.enable_grad(): adv_curr = GAMA_PGD(self.model, x, y, eps=self.epsilon, eps_iter=2 * self.epsilon, bounds=np.array([[0, 1], [0, 1], [0, 1]]), steps=100, w_reg=50, lin=25, SCHED=[60, 85], drop=10) adv_curr = Variable(adv_curr).cuda() elif attack == 'Gama_MT': #Max margin attack print("running GAMA_MT") with torch.enable_grad(): out_clean = self.model(x) #print(out_clean) topk = torch.topk(out_clean, 5 + 1)[1] #print(topk) adv_curr = GAMA_MT(self.model, x, y, eps=self.epsilon, eps_iter=2 * self.epsilon, bounds=np.array([[0, 1], [0, 1], [0, 1]]), steps=100, w_reg=50, lin=25, SCHED=[60, 85], drop=10, rr=RR + 1, new_targ=topk[range(len(y)), RR + 1]) elif attack == 'MM': #Max margin attack print("running MM") with torch.enable_grad(): adv_curr = MT(self.model, x, y, eps=self.epsilon, eps_iter=2 * self.epsilon, bounds=np.array([[0, 1], [0, 1], [0, 1]]), steps=100, w_reg=0, lin=0, SCHED=[50, 75], drop=10, multi_tar=10) elif attack == 'MT': print("running MT") # apgd on dlr loss #self.MT.loss = 'MT' #self.apgd.seed = self.get_seed() with torch.enable_grad(): adv_curr = MT(self.model, x, y, eps=self.epsilon, eps_iter=2 * self.epsilon, bounds=np.array([[0, 1], [0, 1], [0, 1]]), steps=100, w_reg=0, lin=0, SCHED=[50, 75], drop=10, multi_tar=RR + 1) elif attack == 'pgd': #pgd print("running pgd") adv_curr = self._pgd_whitebox(x, y, eot_iter=self.eot_iter, epsilon=self.epsilon, step_size=self.epsilon / 4) else: raise ValueError('Attack not supported') maxa = np.zeros(len(adv_curr)) output = np.zeros(len(adv_curr)) lst_ind = np.zeros([len(adv_curr), 100]) for i in range(100): noise = noise_mat[i] answer = self.model(adv_curr, noi=noise, noi_sample=0) ans = answer.max(dim=1)[1] for j in range(len(ans)): lst_ind[j][ans[j]] += 1 if lst_ind[j][ans[j]] > maxa[j]: maxa[j] = lst_ind[j][ans[j]] output[j] = ans[j] y_label = y.detach().cpu().long().numpy() for i in range(len(ans)): # Finding the 0% correct and incorrect samples if output[i] == y_label[i]: index_array_accepted_true_0.append( (batch_idx * bs + i)) else: index_array_accepted_false_0.append( (batch_idx * bs + i)) #Running the original rejection if type_of_thresholder == "original": if maxa[i] <= int(threshold): index_array_rejected.append((batch_idx * bs + i)) else: if output[i] == y_label[i]: index_array_accepted_true.append( (batch_idx * bs + i)) else: index_array_accepted_false.append( (batch_idx * bs + i)) false_batch = ~y.eq(torch.Tensor(output).cuda()).to( robust_flags.device) non_robust_lin_idcs = batch_datapoint_idcs[false_batch] robust_flags[non_robust_lin_idcs] = False x_adv[non_robust_lin_idcs] = adv_curr[false_batch].detach().to( x_adv.device) if self.verbose: num_non_robust_batch = torch.sum(false_batch) print('{} - {}/{} - {} out of {} successfully perturbed'. format(attack, batch_idx + 1, n_batches, num_non_robust_batch, x.shape[0])) return np.array(index_array_accepted_true_0), np.array( index_array_accepted_false_0 ), np.array(index_array_rejected), np.array( index_array_accepted_true), np.array(index_array_accepted_false) def cheap(self): self.apgd.n_restarts = 1 self.fab.n_restarts = 1 self.apgd_targeted.n_restarts = 1 self.square.n_queries = 1000 self.square.resc_schedule = True self.plus = False
def __init__(self, model, eot_iter, norm='Linf', eps=.3, restarts=5, seed=None, verbose=True, attacks_to_run=['apgd-ce', 'apgd-dlr', 'fab', 'square', 'MM'], plus=False, is_tf_model=False, device='cuda'): # self.model = model self.norm = norm self.eot_iter = eot_iter assert norm in ['Linf', 'L2'] self.epsilon = eps self.restarts = restarts self.seed = seed self.verbose = verbose if plus: attacks_to_run.extend(['apgd-t', 'fab-t']) self.attacks_to_run = attacks_to_run self.plus = plus self.is_tf_model = is_tf_model self.device = device from autopgd_pt import APGDAttack self.apgd = APGDAttack(self.model, n_restarts=self.restarts, n_iter=100, verbose=False, eps=self.epsilon, norm=self.norm, eot_iter=1, rho=.75, seed=self.seed, device=self.device) from fab_pt import FABAttack self.fab = FABAttack(self.model, n_restarts=self.restarts, n_iter=100, eps=self.epsilon, seed=self.seed, norm=self.norm, verbose=False, device=self.device) from square import SquareAttack self.square = SquareAttack(self.model, p_init=.8, n_queries=5000, eps=self.epsilon, norm=self.norm, n_restarts=1, seed=self.seed, verbose=False, device=self.device, resc_schedule=False) from autopgd_pt import APGDAttack_targeted self.apgd_targeted = APGDAttack_targeted(self.model, n_restarts=1, n_iter=100, verbose=False, eps=self.epsilon, norm=self.norm, eot_iter=1, rho=.75, seed=self.seed, device=self.device)
def __init__(self, model, norm='Linf', eps=.3, seed=None, verbose=True, attacks_to_run=[], version='standard', is_tf_model=False, device='cuda', log_path=None): self.model = model self.norm = norm assert norm in ['Linf', 'L2'] self.epsilon = eps self.seed = seed self.verbose = verbose self.attacks_to_run = attacks_to_run self.version = version self.is_tf_model = is_tf_model self.device = device self.logger = Logger(log_path) if not self.is_tf_model: from autopgd_pt import APGDAttack self.apgd = APGDAttack(self.model, n_restarts=5, n_iter=100, verbose=False, eps=self.epsilon, norm=self.norm, eot_iter=1, rho=.75, seed=self.seed, device=self.device) from fab_pt import FABAttack self.fab = FABAttack(self.model, n_restarts=5, n_iter=100, eps=self.epsilon, seed=self.seed, norm=self.norm, verbose=False, device=self.device) from square import SquareAttack self.square = SquareAttack(self.model, p_init=.8, n_queries=5000, eps=self.epsilon, norm=self.norm, n_restarts=1, seed=self.seed, verbose=False, device=self.device, resc_schedule=False) from autopgd_pt import APGDAttack_targeted self.apgd_targeted = APGDAttack_targeted(self.model, n_restarts=1, n_iter=100, verbose=False, eps=self.epsilon, norm=self.norm, eot_iter=1, rho=.75, seed=self.seed, device=self.device) else: from autopgd_tf import APGDAttack self.apgd = APGDAttack(self.model, n_restarts=5, n_iter=100, verbose=False, eps=self.epsilon, norm=self.norm, eot_iter=1, rho=.75, seed=self.seed, device=self.device) from fab_tf import FABAttack self.fab = FABAttack(self.model, n_restarts=5, n_iter=100, eps=self.epsilon, seed=self.seed, norm=self.norm, verbose=False, device=self.device) from square import SquareAttack self.square = SquareAttack(self.model.predict, p_init=.8, n_queries=5000, eps=self.epsilon, norm=self.norm, n_restarts=1, seed=self.seed, verbose=False, device=self.device, resc_schedule=False) from autopgd_tf import APGDAttack_targeted self.apgd_targeted = APGDAttack_targeted(self.model, n_restarts=1, n_iter=100, verbose=False, eps=self.epsilon, norm=self.norm, eot_iter=1, rho=.75, seed=self.seed, device=self.device) if version in ['standard', 'plus', 'rand']: self.set_version(version)