def salience_val_one_epoch(net_f, net, optimizer, batch_generator, criterion, SalenceGenerator, AttackMethod, clock, attack_freq=5, use_adv=True, DEVICE=torch.device('cuda:0')): ''' Using net_f to generate salience maps which is used to train a new clsiffier :param net_f: feature network :param net: :param optimizer: :param batch_generator: :param criterion: :param SalenceGenerator: :param clock: :return: ''' clean_acc = AvgMeter() adv_acc = AvgMeter() net_f.eval() net.eval() #clock.tock() pbar = tqdm(batch_generator) start_time = time.time() for (data, label) in pbar: #clock.tick() data = data.to(DEVICE) label = label.to(DEVICE) if clock.minibatch % (attack_freq + 1) == 1 and use_adv: net.eval() adv_inp = AttackMethod.attack(net_f, data, label) salience_data = SalenceGenerator(net_f, adv_inp, label) pred = net(salience_data) loss = criterion(pred, label) acc = torch_accuracy(pred, label, (1, )) adv_acc.update(acc[0].item()) else: salience_data = SalenceGenerator(net_f, data, label) pred = net(salience_data) loss = criterion(pred, label) acc = torch_accuracy(pred, label, (1, )) clean_acc.update(acc[0].item()) pbar.set_description("Validation Epoch: {}".format(clock.epoch)) pbar.set_postfix({ 'clean_acc': clean_acc.mean, 'adv_acc': adv_acc.mean }) return {'clean_acc': clean_acc.mean, 'adv_acc': adv_acc.mean}
def val_one_epoch(net, batch_generator, CrossEntropyCriterion, clock, tv_loss_weight=1.0): ''' :param net: network :param optimizer: :param batch_generator: pytorch dataloader or other generator :param CrossEntropyCriterion: Used for calculating CrossEntropy loss :param clock: TrainClock from utils :return: ''' Acc = AvgMeter() CrossLoss = AvgMeter() GradTvLoss = AvgMeter() net.eval() pbar = tqdm(batch_generator) for (data, label) in pbar: data = data.cuda() label = label.cuda() data.requires_grad = True pred = net(data) cross_entropy_loss = CrossEntropyCriterion(pred, label) grad_map = torch.autograd.grad(cross_entropy_loss, data, create_graph=True, only_inputs=False)[0] grad_tv_loss = TvLoss(grad_map) #loss = torch.add(cross_entropy_loss, grad_tv_loss * tv_loss_weight) #optimizer.zero_grad() acc = torch_accuracy(pred, label, topk=(1, ))[0].item() Acc.update(acc) CrossLoss.update(cross_entropy_loss.item()) GradTvLoss.update(grad_tv_loss.item()) pbar.set_description("Validation Epoch: {}".format(clock.epoch)) pbar.set_postfix({ 'Acc': '{:.3f}'.format(Acc.mean), 'CrossLoss': "{:.2f}".format(CrossLoss.mean), 'GradTvLoss': '{:.3f}'.format(GradTvLoss.mean) }) return Acc.mean, CrossLoss.mean, GradTvLoss.mean
def get_batch_accuracy(self, net, inp, label): adv_inp = self.attack(net, inp, label) pred = net(adv_inp) accuracy = torch_accuracy(pred, label, (1, ))[0].item() return accuracy
def eval_one_epoch(net, batch_generator, device=torch.device('cuda:0'), val_attack=None, logger=None): # logger.info('test start') net.eval() pbar = tqdm(batch_generator) clean_accuracy = AvgMeter() adv_accuracy = AvgMeter() pbar.set_description('Evaluating') for (data, label) in pbar: data = data.to(device) label = label.to(device) with torch.no_grad(): pred = net(data) acc = torch_accuracy(pred, label, (1, )) clean_accuracy.update(acc[0].item()) if val_attack is not None: adv_inp = val_attack.forward(data, label) with torch.no_grad(): pred = net(adv_inp) acc = torch_accuracy(pred, label, (1, )) adv_accuracy.update(acc[0].item()) pbar_dic = OrderedDict() pbar_dic['CleanAcc'] = '{:.2f}'.format(clean_accuracy.mean) pbar_dic['AdvAcc'] = '{:.2f}'.format(adv_accuracy.mean) pbar.set_postfix(pbar_dic) adv_acc = adv_accuracy.mean if val_attack is not None else 0 if logger is None: pass else: logger.info( f'standard acc: {clean_accuracy.mean:.3f}%, robustness acc: {adv_accuracy.mean:.3f}%' ) return clean_accuracy.mean, adv_acc
def adversarial_val(net, batch_generator, criterion, AttackMethod, clock, attack_freq=1, DEVICE=torch.device('cuda:0')): """ val both on clean data and adversarial examples. :param net: :param batch_generator: :param criterion: :param AttackMethod: the attack method :param clock: clock object from my_snip.clock import TrainClock :param attack_freq: Frequencies of training with adversarial examples :return: """ training_losses = AvgMeter() training_accs = AvgMeter() clean_losses = AvgMeter() clean_accs = AvgMeter() defense_losses = AvgMeter() defense_accs = AvgMeter() names = ['loss', 'acc', 'clean_loss', 'clean_acc', 'adv_loss', 'adv_acc'] clean_batch_times = AvgMeter() ad_batch_times = AvgMeter() net.eval() pbar = tqdm(batch_generator) start_time = time.time() i = 0 for (data, label) in pbar: i += 1 data = data.to(DEVICE) label = label.to(DEVICE) data_time = time.time() - start_time if i % (attack_freq + 1) == 1: adv_inp = AttackMethod.attack(net, data, label) net.eval() with torch.no_grad(): pred = net(adv_inp) loss = criterion(pred, label) defense_losses.update(loss.item()) acc = torch_accuracy(pred, label, (1, )) defense_accs.update(acc[0].item()) batch_time = time.time() - start_time ad_batch_times.update(batch_time) else: with torch.no_grad(): pred = net(data) loss = criterion(pred, label) acc = torch_accuracy(pred, label, (1, )) clean_losses.update(loss.item()) clean_accs.update(acc[0].item()) batch_time = time.time() - start_time clean_batch_times.update(batch_time) training_losses.update(loss.item()) training_accs.update(acc[0].item()) pbar.set_description("Validation Epoch: {}".format(clock.epoch)) values = [ training_losses.mean, training_accs.mean, clean_losses.mean, clean_accs.mean, defense_losses.mean, defense_accs.mean ] pbar_dic = OrderedDict() for n, v in zip(names, values): if n not in ['acc', 'clean_acc', 'adv_acc']: continue pbar_dic[n] = v pbar.set_postfix(pbar_dic) pbar_dic['clean_time'] = "{:.2f}".format(clean_batch_times.mean) pbar_dic['ad_time'] = "{:.2f}".format(ad_batch_times.mean) ''' pbar.set_postfix( loss='{:.2f}'.format(training_losses.mean), acc='{:.2f}'.format(training_accs.mean), clean_losses='{:.2f}'.format(clean_losses.mean), clean_acc="{:.2f}".format(clean_accs.mean), defense_losses="{:.2f}".format(defense_losses.mean), defense_accs="{:.2f}".format(defense_accs.mean), clean_time="{:.2f}".format(clean_batch_times.mean), ad_time="{:.2f}".format(ad_batch_times.mean) ) ''' start_time = time.time() values = [ training_losses.mean, training_accs.mean, clean_losses.mean, clean_accs.mean, defense_losses.mean, defense_accs.mean ] dic = {n: v for n, v in zip(names, values)} dic = OrderedDict(dic) return dic
def adversairal_train_one_epoch(net, optimizer, batch_generator, criterion, AttackMethod, clock, attack_freq=1, use_adv=True, DEVICE=torch.device('cuda:0'), act_loss_coef=0): """ adversarial training. :param net: :param optimizer: :param batch_generator: :param criterion: :param AttackMethod: the attack method :param clock: clock object from my_snip.clock import TrainClock :param attack_freq: Frequencies of training with adversarial examples :return: """ training_losses = AvgMeter() training_accs = AvgMeter() clean_losses = AvgMeter() clean_accs = AvgMeter() defense_losses = AvgMeter() defense_accs = AvgMeter() names = ['loss', 'acc', 'clean_loss', 'clean_acc', 'adv_loss', 'adv_acc'] clean_batch_times = AvgMeter() ad_batch_times = AvgMeter() net.train() clock.tock() pbar = tqdm(batch_generator) start_time = time.time() for (data, label) in pbar: clock.tick() data = data.to(DEVICE) label = label.to(DEVICE) data_time = time.time() - start_time if clock.minibatch % (attack_freq + 1) == 1 and use_adv: adv_inp = AttackMethod.attack(net, data, label) #print("ADV ", torch.mean(adv_inp).item()) net.train() optimizer.zero_grad() pred = net(adv_inp) loss = criterion(pred, label) loss.backward() optimizer.step() defense_losses.update(loss.item()) acc = torch_accuracy(pred, label, (1, )) defense_accs.update(acc[0].item()) batch_time = time.time() - start_time ad_batch_times.update(batch_time) else: optimizer.zero_grad() #print("Clean ", torch.mean(data).item()) pred, act_loss = net(data, if_return_activation=True) loss = criterion(pred, label) + act_loss_coef * act_loss print(loss, act_loss) loss.backward() optimizer.step() acc = torch_accuracy(pred, label, (1, )) clean_losses.update(loss.item()) clean_accs.update(acc[0].item()) batch_time = time.time() - start_time clean_batch_times.update(batch_time) training_losses.update(loss.item()) training_accs.update(acc[0].item()) pbar.set_description("Training Epoch: {}".format(clock.epoch)) values = [ training_losses.mean, training_accs.mean, clean_losses.mean, clean_accs.mean, defense_losses.mean, defense_accs.mean ] pbar_dic = OrderedDict() for n, v in zip(names, values): if n not in ['acc', 'clean_acc', 'adv_acc']: continue pbar_dic[n] = v pbar_dic['clean_time'] = "{:.1f}".format(clean_batch_times.mean) pbar_dic['ad_time'] = "{:.1f}".format(ad_batch_times.mean) pbar.set_postfix(pbar_dic) ''' pbar.set_postfix( loss = '{:.2f}'.format(training_losses.mean), acc = '{:.2f}'.format(training_accs.mean), clean_losses = '{:.2f}'.format(clean_losses.mean), clean_acc = "{:.2f}".format(clean_accs.mean), defense_losses = "{:.2f}".format(defense_losses.mean), defense_accs = "{:.2f}".format(defense_accs.mean), clean_time = "{:.2f}".format(clean_batch_times.mean), ad_time = "{:.2f}".format(ad_batch_times.mean) )''' start_time = time.time() #names = ['loss', 'acc', 'clean_loss', 'clean_acc', 'adv_loss', 'adv_acc'] values = [ training_losses.mean, training_accs.mean, clean_losses.mean, clean_accs.mean, defense_losses.mean, defense_accs.mean ] dic = {n: v for n, v in zip(names, values)} dic = OrderedDict(dic) return dic
def train_one_epoch(net, batch_generator, optimizer, criterion, device=torch.device('cuda:0'), descrip_str='Training', attack_method=None, adv_coef=1.0, logger=None): # 设置为训练模式 net.train() # 初始化进度条,以及acc、loss pbar = tqdm(batch_generator) adv_acc = -1 adv_loss = -1 clean_acc = -1 clean_loss = -1 pbar.set_description(descrip_str) # 分批次取数据进行训练 for i, (data, label) in enumerate(pbar): data = data.to(device) label = label.to(device) optimizer.zero_grad() pbar_dic = OrderedDict() # 对当前batch生成对抗样本 if attack_method is not None: adv_inp = attack_method.forward(data, label) optimizer.zero_grad() net.train() # 前传 pred = net(adv_inp) loss = criterion(pred, label) acc = torch_accuracy(pred, label, (1, )) adv_acc = acc[0].item() adv_loss = loss.item() # 反传 (loss * adv_coef).backward() pred = net(data) loss = criterion(pred, label) # TotalLoss = TotalLoss + loss loss.backward() # TotalLoss.backward() # param = next(net.parameters()) # grad_mean = torch.mean(param.grad) optimizer.step() acc = torch_accuracy(pred, label, (1, )) clean_acc = acc[0].item() clean_loss = loss.item() pbar_dic['Acc'] = '{:.2f}'.format(clean_acc) pbar_dic['loss'] = '{:.2f}'.format(clean_loss) pbar_dic['Adv Acc'] = '{:.2f}'.format(adv_acc) pbar_dic['Adv loss'] = '{:.2f}'.format(adv_loss) pbar.set_postfix(pbar_dic) if logger is None: pass else: logger.info( f'standard loss: {clean_loss:.3f}, Adv loss: {adv_loss:.3f}') logger.info( f'standard acc: {clean_acc:.3f}%, robustness acc: {adv_acc:.3f}%' )