def eval_one_epoch(net, batch_generator, DEVICE=torch.device('cuda:0'), AttackMethod=None): net.eval() pbar = tqdm(batch_generator) clean_accuracy = AvgMeter() adv_accuracy = AvgMeter() pbar.set_description('Evaluating') for (data, label) in pbar: data = data.to(DEVICE) label = label.to(DEVICE) with torch.no_grad(): pred = net(data) acc = torch_accuracy(pred, label, (1, )) clean_accuracy.update(acc[0].item()) if AttackMethod is not None: adv_inp = AttackMethod.attack(net, data, label) with torch.no_grad(): pred = net(adv_inp) acc = torch_accuracy(pred, label, (1, )) adv_accuracy.update(acc[0].item()) pbar_dic = OrderedDict() pbar_dic['CleanAcc'] = '{:.2f}'.format(clean_accuracy.mean) pbar_dic['AdvAcc'] = '{:.2f}'.format(adv_accuracy.mean) pbar.set_postfix(pbar_dic) adv_acc = adv_accuracy.mean if AttackMethod is not None else 0 return clean_accuracy.mean, adv_acc
def train_one_epoch(net, batch_generator, optimizer, criterion, DEVICE=torch.device('cuda:0'), descrip_str='Training', AttackMethod=None, adv_coef=1.0): ''' :param attack_freq: Frequencies of training with adversarial examples. -1 indicates natural training :param AttackMethod: the attack method, None represents natural training :return: None #(clean_acc, adv_acc) ''' net.train() pbar = tqdm(batch_generator) advacc = -1 advloss = -1 cleanacc = -1 cleanloss = -1 pbar.set_description(descrip_str) for i, (data, label) in enumerate(pbar): data = data.to(DEVICE) label = label.to(DEVICE) optimizer.zero_grad() pbar_dic = OrderedDict() TotalLoss = 0 if AttackMethod is not None: adv_inp = AttackMethod.attack(net, data, label) optimizer.zero_grad() pred = net(adv_inp) loss = criterion(pred, label) acc = torch_accuracy(pred, label, (1, )) advacc = acc[0].item() advloss = loss.item() TotalLoss = TotalLoss + loss * adv_coef pred = net(data) loss = criterion(pred, label) TotalLoss = TotalLoss + loss TotalLoss.backward() #param = next(net.parameters()) #grad_mean = torch.mean(param.grad) optimizer.step() acc = torch_accuracy(pred, label, (1, )) cleanacc = acc[0].item() cleanloss = loss.item() #pbar_dic['grad'] = '{}'.format(grad_mean) pbar_dic['Acc'] = '{:.2f}'.format(cleanacc) pbar_dic['loss'] = '{:.2f}'.format(cleanloss) pbar_dic['AdvAcc'] = '{:.2f}'.format(advacc) pbar_dic['Advloss'] = '{:.2f}'.format(advloss) pbar.set_postfix(pbar_dic)
def train_one_step(net, data, label, optimizer, criterion, param_name_to_merge_matrix, param_name_to_decay_matrix): pred = net(data) loss = criterion(pred, label) loss.backward() #TODO note: C-SGD works here for name, param in net.named_parameters(): name = name.replace('module.', '') if name in param_name_to_merge_matrix: p_dim = param.dim() p_size = param.size() if p_dim == 4: param_mat = param.reshape(p_size[0], -1) g_mat = param.grad.reshape(p_size[0], -1) elif p_dim == 1: param_mat = param.reshape(p_size[0], 1) g_mat = param.grad.reshape(p_size[0], 1) else: assert p_dim == 2 param_mat = param g_mat = param.grad csgd_gradient = param_name_to_merge_matrix[name].matmul( g_mat) + param_name_to_decay_matrix[name].matmul(param_mat) param.grad.copy_(csgd_gradient.reshape(p_size)) optimizer.step() optimizer.zero_grad() acc, acc5 = torch_accuracy(pred, label, (1, 5)) return acc, acc5, loss
def train_one_step(net, data, label, optimizer, criterion, if_accum_grad=False, gradient_mask_tensor=None, lasso_keyword_to_strength=None): pred = net(data) loss = criterion(pred, label) if lasso_keyword_to_strength is not None: assert len(lasso_keyword_to_strength) == 1 #TODO for lasso_key, lasso_strength in lasso_keyword_to_strength.items(): for name, param in net.named_parameters(): if lasso_key in name: if param.ndimension() == 1: loss += lasso_strength * param.abs().sum() # print('lasso on vec ', name) else: assert param.ndimension() == 4 loss += lasso_strength * ( (param**2).sum(dim=(1, 2, 3)).sqrt().sum()) # print('lasso on tensor ', name) loss.backward() if not if_accum_grad: if gradient_mask_tensor is not None: for name, param in net.named_parameters(): if name in gradient_mask_tensor: param.grad = param.grad * gradient_mask_tensor[name] optimizer.step() optimizer.zero_grad() acc, acc5 = torch_accuracy(pred, label, (1, 5)) return acc, acc5, loss
def train_one_step(net, data, label, optimizer, criterion, nonzero_ratio): pred = net(data) loss = criterion(pred, label) loss.backward() to_concat_g = [] to_concat_v = [] for name, param in net.named_parameters(): if param.dim() in [2, 4]: to_concat_g.append(param.grad.data.view(-1)) to_concat_v.append(param.data.view(-1)) all_g = torch.cat(to_concat_g) all_v = torch.cat(to_concat_v) metric = torch.abs(all_g * all_v) num_params = all_v.size(0) nz = int(nonzero_ratio * num_params) top_values, _ = torch.topk(metric, nz) thresh = top_values[-1] for name, param in net.named_parameters(): if param.dim() in [2, 4]: mask = (torch.abs(param.data * param.grad.data) >= thresh).type( torch.cuda.FloatTensor) param.grad.data = mask * param.grad.data optimizer.step() optimizer.zero_grad() acc, acc5 = torch_accuracy(pred, label, (1, 5)) return acc, acc5, loss
def my_eval_one_epoch(net, batch_generator, DEVICE=torch.device('cuda:0'), AttackMethod=None): net.eval() pbar = tqdm(batch_generator) clean_accuracy = AvgMeter() adv_accuracy = AvgMeter() correct_indices = None natural_indices = None pbar.set_description('Evaluating') for (data, label) in pbar: data = data.to(DEVICE) label = label.to(DEVICE) with torch.no_grad(): pred = net(data) predictions = np.argmax(pred.cpu().numpy(), axis=1) correct_labels = label.cpu().numpy() if natural_indices is None: natural_indices = np.where(predictions == correct_labels)[0] else: natural_indices = np.append( natural_indices, np.where(predictions == correct_labels)[0]) acc = torch_accuracy(pred, label, (1, )) clean_accuracy.update(acc[0].item()) if AttackMethod is not None: adv_inp = AttackMethod.attack(net, data, label) with torch.no_grad(): pred = net(adv_inp) predictions = np.argmax(pred.cpu().numpy(), axis=1) correct_labels = label.cpu().numpy() if correct_indices is None: correct_indices = np.where( predictions == correct_labels)[0] else: correct_indices = np.append( correct_indices, np.where(predictions == correct_labels)[0]) acc = my_torch_accuracy(pred, label, (1, )) adv_accuracy.update(acc[0].item()) pbar_dic = OrderedDict() pbar_dic['CleanAcc'] = '{:.2f}'.format(clean_accuracy.mean) pbar_dic['AdvAcc'] = '{:.2f}'.format(adv_accuracy.mean) pbar.set_postfix(pbar_dic) adv_acc = adv_accuracy.mean if AttackMethod is not None else 0 print('Natural Samples', natural_indices.shape) print('Adversarial Samples', correct_indices.shape) return clean_accuracy.mean, adv_acc
def run_eval(ds_val, max_iters, net, criterion, discrip_str, dataset_name): pbar = tqdm(range(max_iters)) top1 = AvgMeter() top5 = AvgMeter() losses = AvgMeter() pbar.set_description('Validation' + discrip_str) with torch.no_grad(): for i in pbar: start_time = time.time() data, label = load_cuda_data(ds_val, dataset_name=dataset_name) data_time = time.time() - start_time pred = net(data) loss = criterion(pred, label) acc, acc5 = torch_accuracy(pred, label, (1, 5)) top1.update(acc.item()) top5.update(acc5.item()) losses.update(loss.item()) pbar_dic = OrderedDict() pbar_dic['data-time'] = '{:.2f}'.format(data_time) pbar_dic['top1'] = '{:.5f}'.format(top1.mean) pbar_dic['top5'] = '{:.5f}'.format(top5.mean) pbar_dic['loss'] = '{:.5f}'.format(losses.mean) pbar.set_postfix(pbar_dic) metric_dic = { 'top1': torch.tensor(top1.mean), 'top5': torch.tensor(top5.mean), 'loss': torch.tensor(losses.mean) } reduced_metirc_dic = reduce_loss_dict(metric_dic) # reduced_metirc_dic = my_reduce_dic(metric_dic) return reduced_metirc_dic
def train_one_step(net, data, label, optimizer, criterion, if_accum_grad = False, gradient_mask_tensor = None): pred = net(data) loss = criterion(pred, label) loss.backward() if not if_accum_grad: if gradient_mask_tensor is not None: for name, param in net.named_parameters(): if name in gradient_mask_tensor: param.grad = param.grad * gradient_mask_tensor[name] optimizer.step() optimizer.zero_grad() acc, acc5 = torch_accuracy(pred, label, (1,5)) return acc, acc5, loss
def train_one_step(net, data, label, optimizer, criterion, if_accum_grad=False): pred = net(data) loss = criterion(pred, label) loss.backward() if not if_accum_grad: optimizer.step() optimizer.zero_grad() acc, acc5 = torch_accuracy(pred, label, (1, 5)) return acc, acc5, loss
def train_one_epoch(net, batch_generator, optimizer, criterion, DEVICE=torch.device('cuda:0'), descrip_str='Training', **args): ''' :param net: xxx :return: None #(clean_acc, adv_acc) ''' net.train() pbar = tqdm(batch_generator) cleanacc = -1 cleanloss = -1 pbar.set_description(descrip_str) for i, (data, label) in enumerate(pbar): data = data.to(DEVICE) label = label.to(DEVICE) #print('data shape', data.shape, label.shape) optimizer.zero_grad() pbar_dic = OrderedDict() TotalLoss = 0 pred = net(data) loss = criterion(pred, label) #TotalLoss = TotalLoss + loss loss.backward() optimizer.step() acc = torch_accuracy(pred, label, (1, )) cleanacc = acc[0].item() cleanloss = loss.item() #pbar_dic['grad'] = '{}'.format(grad_mean) pbar_dic['Acc'] = '{:.2f}'.format(cleanacc) pbar_dic['loss'] = '{:.2f}'.format(cleanloss) pbar.set_postfix(pbar_dic)
def train_one_step(net: torch.nn.Module, data, label, optimizer, criterion, if_accum_grad=False, gradient_mask_tensor=None, lasso_keyword_to_strength=None): pred = net(data) loss = criterion(pred, label) for m in net.modules(): if isinstance(m, RFBlock): w_left = m.left_alter.weight.permute([2, 3, 1, 0]) w_right = m.right_alter.weight.permute(2, 3, 0, 1) loss += w_left.matmul(w_right).abs().mean() * m.alpha loss -= (w_left.abs().mean() + w_right.abs().mean()) * m.alpha # left = m.save_left.permute([0, 2, 3, 1]) # right = m.save_right.permute([0, 2, 3, 1]) # loss += torch.nn.functional.cosine_similarity(left, right, dim=-1).abs().mean() * m.alpha if lasso_keyword_to_strength is not None: assert len(lasso_keyword_to_strength) == 1 #TODO for lasso_key, lasso_strength in lasso_keyword_to_strength.items(): for name, param in net.named_parameters(): if lasso_key in name: if param.ndimension() == 1: loss += lasso_strength * param.abs().sum() # print('lasso on vec ', name) else: assert param.ndimension() == 4 loss += lasso_strength * ( (param**2).sum(dim=(1, 2, 3)).sqrt().sum()) # print('lasso on tensor ', name) loss.backward() if not if_accum_grad: if gradient_mask_tensor is not None: for name, param in net.named_parameters(): if name in gradient_mask_tensor: param.grad = param.grad * gradient_mask_tensor[name] optimizer.step() optimizer.zero_grad() acc, acc5 = torch_accuracy(pred, label, (1, 5)) return acc, acc5, loss
def run_eval(val_data, max_iters, net, criterion, discrip_str, dataset_name): pbar = tqdm( range(max_iters) ) #Tqdm 是一个快速,可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意的迭代器 tqdm(iterator)。 top1 = AvgMeter() #实例化 top5 = AvgMeter() losses = AvgMeter() pbar.set_description('Validation' + discrip_str) #设置进度条左边显示的信息 total_net_time = 0 with torch.no_grad(): for iter_idx, i in enumerate(pbar): start_time = time.time() data, label = load_cuda_data(val_data, dataset_name=dataset_name) data_time = time.time() - start_time net_time_start = time.time() pred = net(data) net_time_end = time.time() if iter_idx >= SPEED_TEST_SAMPLE_IGNORE_RATIO * max_iters: total_net_time += net_time_end - net_time_start loss = criterion(pred, label) acc, acc5 = torch_accuracy(pred, label, (1, 5)) top1.update(acc.item()) top5.update(acc5.item()) losses.update(loss.item()) pbar_dic = OrderedDict() pbar_dic['data-time'] = '{:.2f}'.format(data_time) pbar_dic['top1'] = '{:.5f}'.format(top1.mean) pbar_dic['top5'] = '{:.5f}'.format(top5.mean) pbar_dic['loss'] = '{:.5f}'.format(losses.mean) pbar.set_postfix(pbar_dic) #设置进度条右边显示的信息 metric_dic = { 'top1': torch.tensor(top1.mean), 'top5': torch.tensor(top5.mean), 'loss': torch.tensor(losses.mean) } # reduced_metirc_dic = reduce_loss_dict(metric_dic) reduced_metirc_dic = metric_dic #TODO note this return reduced_metirc_dic, total_net_time #{top1,top5,loss},网络运行时间
def train_one_step(compactor_mask_dict, resrep_config:ResRepConfig, net, data, label, optimizer, criterion, if_accum_grad = False, gradient_mask_tensor = None): pred = net(data) loss = criterion(pred, label) loss.backward() for compactor_param, mask in compactor_mask_dict.items(): compactor_param.grad.data = mask * compactor_param.grad.data lasso_grad = compactor_param.data * ((compactor_param.data ** 2).sum(dim=(1, 2, 3), keepdim=True) ** (-0.5)) compactor_param.grad.data.add_(resrep_config.lasso_strength, lasso_grad) if not if_accum_grad: if gradient_mask_tensor is not None: for name, param in net.named_parameters(): if name in gradient_mask_tensor: param.grad = param.grad * gradient_mask_tensor[name] optimizer.step() optimizer.zero_grad() acc, acc5 = torch_accuracy(pred, label, (1,5)) return acc, acc5, loss
def eval_one_epoch(net, batch_generator, DEVICE=torch.device('cuda:0')): net.eval() pbar = tqdm(batch_generator) clean_accuracy = AvgMeter() pbar.set_description('Evaluating') for (data, label) in pbar: data = data.to(DEVICE) label = label.to(DEVICE) with torch.no_grad(): pred = net(data) acc = torch_accuracy(pred, label, (1, )) clean_accuracy.update(acc[0].item()) pbar_dic = OrderedDict() pbar_dic['CleanAcc'] = '{:.2f}'.format(clean_accuracy.mean) pbar.set_postfix(pbar_dic) return clean_accuracy.mean
def run_eval(ds_val, max_iters, net, criterion, discrip_str, dataset_name): pbar = tqdm(range(max_iters)) top1 = AvgMeter() top5 = AvgMeter() losses = AvgMeter() pbar.set_description('Validation' + discrip_str) total_net_time = 0 with torch.no_grad(): for iter_idx, i in enumerate(pbar): start_time = time.time() data, label = load_cuda_data(ds_val, dataset_name=dataset_name) data_time = time.time() - start_time net_time_start = time.time() pred = net(data) net_time_end = time.time() if iter_idx >= SPEED_TEST_SAMPLE_IGNORE_RATIO * max_iters: total_net_time += net_time_end - net_time_start loss = criterion(pred, label) acc, acc5 = torch_accuracy(pred, label, (1, 5)) top1.update(acc.item()) top5.update(acc5.item()) losses.update(loss.item()) pbar_dic = OrderedDict() pbar_dic['data-time'] = '{:.2f}'.format(data_time) pbar_dic['top1'] = '{:.5f}'.format(top1.mean) pbar_dic['top5'] = '{:.5f}'.format(top5.mean) pbar_dic['loss'] = '{:.5f}'.format(losses.mean) pbar.set_postfix(pbar_dic) metric_dic = {'top1':torch.tensor(top1.mean), 'top5':torch.tensor(top5.mean), 'loss':torch.tensor(losses.mean)} reduced_metirc_dic = reduce_loss_dict(metric_dic) return reduced_metirc_dic, total_net_time
def train_one_epoch(net, batch_generator, optimizer, criterion, LayerOneTrainner, K, DEVICE=torch.device('cuda:0'),descrip_str='Training'): ''' :param attack_freq: Frequencies of training with adversarial examples. -1 indicates natural training :param AttackMethod: the attack method, None represents natural training :return: None #(clean_acc, adv_acc) ''' net.train() pbar = tqdm(batch_generator) yofoacc = -1 # cleanacc = -1 # cleanloss = -1 pbar.set_description(descrip_str) trades_criterion = torch.nn.KLDivLoss(size_average=False) for i, (data, label) in enumerate(pbar): data = data.to(DEVICE) label = label.to(DEVICE) net.eval() # eta = torch.FloatTensor(*data.shape).uniform_(-config.eps, config.eps) # eta = eta.to(label.device) eta = 0.001 * torch.randn(data.shape).to(DEVICE) eta.requires_grad_() # optimizer.zero_grad() # LayerOneTrainner.param_optimizer.zero_grad() raw_soft_label = F.softmax(net(data), dim=1).detach() for j in range(K): pbar_dic = OrderedDict() TotalLoss = 0 pred = net(data + eta.detach()) with torch.enable_grad(): loss = trades_criterion(F.log_softmax(pred, dim = 1), raw_soft_label)#raw_soft_label.detach()) # loss = criterion(pred, label) # TotalLoss = TotalLoss + loss # wgrad = net.conv1.weight.grad # TotalLoss.backward() # net.conv1.weight.grad = wgrad p = -1.0 * torch.autograd.grad(loss, [net.layer_one_out, ])[0] yofo_inp, eta = LayerOneTrainner.step(data, p, eta) with torch.no_grad(): if j == K - 1: yofo_pred = net(yofo_inp) yofo_loss = criterion(yofo_pred, label) yofoacc = torch_accuracy(yofo_pred, label, (1,))[0].item() net.train() optimizer.zero_grad() LayerOneTrainner.param_optimizer.zero_grad() raw_pred = net(data) acc = torch_accuracy(raw_pred, label, (1,)) clean_acc = acc[0].item() clean_loss = criterion(raw_pred, label) adv_pred = net(torch.clamp(data + eta.detach(), 0.0, 1.0)) kl_loss = trades_criterion(F.log_softmax(adv_pred, dim=1), F.softmax(raw_pred, dim=1)) / data.shape[0] loss = clean_loss + kl_loss loss.backward() # if j == 0: # acc = torch_accuracy(pred, label, (1,)) # cleanacc = acc[0].item() # cleanloss = loss.item() # if j == K - 1: # yofo_pred = net(yofo_inp) # yofoacc = torch_accuracy(yofo_pred, label, (1,))[0].item() optimizer.step() LayerOneTrainner.param_optimizer.step() optimizer.zero_grad() LayerOneTrainner.param_optimizer.zero_grad() pbar_dic = OrderedDict() pbar_dic['Acc'] = '{:.2f}'.format(clean_acc) pbar_dic['cleanloss'] = '{:.3f}'.format(clean_loss.item()) pbar_dic['loss'] = '{:.2f}'.format(loss) pbar_dic['YofoAcc'] = '{:.2f}'.format(yofoacc) pbar_dic['Yofoloss'] = '{:.3f}'.format(yofo_loss.item()) pbar.set_postfix(pbar_dic) return clean_acc, yofoacc
def train_one_epoch(net, batch_generator, optimizer, criterion, LayerOneTrainner, K, DEVICE=torch.device('cuda:0'), descrip_str='Training'): ''' :param attack_freq: Frequencies of training with adversarial examples. -1 indicates natural training :param AttackMethod: the attack method, None represents natural training :return: None #(clean_acc, adv_acc) ''' net.train() pbar = tqdm(batch_generator) yofoacc = -1 cleanacc = -1 cleanloss = -1 pbar.set_description(descrip_str) for i, (data, label) in enumerate(pbar): data = data.to(DEVICE) label = label.to(DEVICE) eta = torch.FloatTensor(*data.shape).uniform_(-config.eps, config.eps) eta = eta.to(label.device) eta.requires_grad_() optimizer.zero_grad() LayerOneTrainner.param_optimizer.zero_grad() for j in range(K): #optimizer.zero_grad() pbar_dic = OrderedDict() TotalLoss = 0 pred = net(data + eta.detach()) loss = criterion(pred, label) TotalLoss = TotalLoss + loss # wgrad = net.conv1.weight.grad #bgrad = net.conv1.bias.grad TotalLoss.backward() # net.conv1.weight.grad = wgrad #net.conv1.bias.grad = bgrad #param = next(net.parameters()) #grad_mean = torch.mean(param.grad) #optimizer.step() #optimizer.zero_grad() p = -1.0 * net.layer_one_out.grad yofo_inp, eta = LayerOneTrainner.step(data, p, eta) with torch.no_grad(): if j == 0: acc = torch_accuracy(pred, label, (1, )) cleanacc = acc[0].item() cleanloss = loss.item() if j == K - 1: yofo_pred = net(yofo_inp) yofoacc = torch_accuracy(yofo_pred, label, (1, ))[0].item() #pbar_dic['grad'] = '{}'.format(grad_mean) optimizer.step() LayerOneTrainner.param_optimizer.step() optimizer.zero_grad() LayerOneTrainner.param_optimizer.zero_grad() pbar_dic['Acc'] = '{:.2f}'.format(cleanacc) pbar_dic['loss'] = '{:.2f}'.format(cleanloss) pbar_dic['YofoAcc'] = '{:.2f}'.format(yofoacc) pbar.set_postfix(pbar_dic) return cleanacc, yofoacc
def trades_loss(model, x_natural, y, optimizer, device, step_size=0.003, epsilon=0.031, perturb_steps=10, beta=1.0, distance='l_inf'): # define KL-loss criterion_kl = nn.KLDivLoss(size_average=False) model.eval() batch_size = len(x_natural) # generate adversarial example x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach().to(device) if distance == 'l_inf': # logits_natural = model(x_natural).detach() for _ in range(perturb_steps): x_adv.requires_grad_() with torch.enable_grad(): loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1), F.softmax(model(x_natural), dim=1)) # loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1), # F.softmax(logits_natural, dim=1)) grad = torch.autograd.grad(loss_kl, [x_adv])[0] x_adv = x_adv.detach() + step_size * torch.sign(grad.detach()) x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon) x_adv = torch.clamp(x_adv, 0.0, 1.0) elif distance == 'l_2': for _ in range(perturb_steps): x_adv.requires_grad_() with torch.enable_grad(): loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1), F.softmax(model(x_natural), dim=1)) grad = torch.autograd.grad(loss_kl, [x_adv])[0] for idx_batch in range(batch_size): grad_idx = grad[idx_batch] grad_idx_norm = l2_norm(grad_idx) grad_idx /= (grad_idx_norm + 1e-8) x_adv[idx_batch] = x_adv[idx_batch].detach() + step_size * grad_idx eta_x_adv = x_adv[idx_batch] - x_natural[idx_batch] norm_eta = l2_norm(eta_x_adv) if norm_eta > epsilon: eta_x_adv = eta_x_adv * epsilon / l2_norm(eta_x_adv) x_adv[idx_batch] = x_natural[idx_batch] + eta_x_adv x_adv = torch.clamp(x_adv, 0.0, 1.0) else: x_adv = torch.clamp(x_adv, 0.0, 1.0) model.train() x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False) # zero gradient optimizer.zero_grad() # calculate robust loss logits = model(x_natural) adv_logits = model(x_adv) loss_natural = F.cross_entropy(logits, y) loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(adv_logits, dim=1), F.softmax(logits, dim=1)) loss = loss_natural + beta * loss_robust cleanacc = torch_accuracy(logits, y, (1,))[0].item() tradesacc = torch_accuracy(adv_logits, y, (1,))[0].item() return loss, loss_natural.item(), loss_robust.item(), cleanacc, tradesacc
def train_one_epoch(net, batch_generator, optimizer, criterion, DEVICE=torch.device('cuda:0'), descrip_str='Training', AttackMethod=None, alpha=1): ''' :param AttackMethod: the attack method, None represents natural training :param alpha: weight coeffcient for mig loss :return: None #(clean_acc, adv_acc) ''' #assert callable(AttackMethod) net.train() pbar = tqdm(batch_generator) advacc = -1 advloss = -1 cleanacc = -1 cleanloss = -1 criterion_kl = torch.nn.KLDivLoss(size_average=False).to(DEVICE) pbar.set_description(descrip_str) for i, (data, label) in enumerate(pbar): data = data.to(DEVICE) label = label.to(DEVICE) optimizer.zero_grad() pbar_dic = OrderedDict() adv_inp = AttackMethod.attack(net, data, label) optimizer.zero_grad() pred1 = net(adv_inp) pred2 = net(data) loss_robust = criterion_kl(F.log_softmax(pred1, dim=1), F.softmax(pred2, dim=1)) loss_natural = criterion(pred2, label) TotalLoss = loss_natural + alpha * loss_robust TotalLoss.backward() acc = torch_accuracy(pred1, label, (1, )) advacc = acc[0].item() advloss = loss_robust.item() acc = torch_accuracy(pred2, label, (1, )) cleanacc = acc[0].item() cleanloss = loss_natural.item() param = next(net.parameters()) grad_mean = torch.mean(param.grad) optimizer.step() pbar_dic['grad'] = '{}'.format(grad_mean) pbar_dic['cleanAcc'] = '{:.2f}'.format(cleanacc) pbar_dic['cleanloss'] = '{:.2f}'.format(cleanloss) pbar_dic['AdvAcc'] = '{:.2f}'.format(advacc) pbar_dic['Robloss'] = '{:.2f}'.format(advloss) pbar.set_postfix(pbar_dic)
def train_one_epoch(net, batch_generator, optimizer, criterion, LayerOneTrainner, K, DEVICE=torch.device('cuda:0'),descrip_str='Training'): net.train() pbar = tqdm(batch_generator) yofoacc = -1 pbar.set_description(descrip_str) trades_criterion = torch.nn.KLDivLoss(size_average=False) #.to(DEVICE) for i, (data, label) in enumerate(pbar): data = data.to(DEVICE) label = label.to(DEVICE) net.eval() eta = 0.001 * torch.randn(data.shape).cuda().detach().to(DEVICE) eta.requires_grad_() raw_soft_label = F.softmax(net(data), dim=1).detach() for j in range(K): pred = net(data + eta.detach()) with torch.enable_grad(): loss = trades_criterion(F.log_softmax(pred, dim = 1), raw_soft_label)#raw_soft_label.detach()) p = -1.0 * torch.autograd.grad(loss, [net.layer_one_out, ])[0] yofo_inp, eta = LayerOneTrainner.step(data, p, eta) with torch.no_grad(): if j == K - 1: yofo_pred = net(yofo_inp) yofo_loss = criterion(yofo_pred, label) yofoacc = torch_accuracy(yofo_pred, label, (1,))[0].item() net.train() optimizer.zero_grad() LayerOneTrainner.param_optimizer.zero_grad() raw_pred = net(data) acc = torch_accuracy(raw_pred, label, (1,)) clean_acc = acc[0].item() clean_loss = criterion(raw_pred, label) adv_pred = net(torch.clamp(data + eta.detach(), 0.0, 1.0)) kl_loss = trades_criterion(F.log_softmax(adv_pred, dim=1), F.softmax(raw_pred, dim=1)) / data.shape[0] loss = clean_loss + kl_loss loss.backward() optimizer.step() LayerOneTrainner.param_optimizer.step() optimizer.zero_grad() LayerOneTrainner.param_optimizer.zero_grad() pbar_dic = OrderedDict() pbar_dic['Acc'] = '{:.2f}'.format(clean_acc) pbar_dic['cleanloss'] = '{:.3f}'.format(clean_loss.item()) pbar_dic['klloss'] = '{:.3f}'.format(kl_loss.item()) pbar_dic['YofoAcc'] = '{:.2f}'.format(yofoacc) pbar_dic['Yofoloss'] = '{:.3f}'.format(yofo_loss.item()) pbar.set_postfix(pbar_dic) return clean_acc, yofoacc