def evaluate(self, pooled_outputs, targets, topk=(1,5)): results = {} # Classification evaluation res_acc = None if 'sketchcls' in self.args.task_types: res_acc = accuracy(pooled_outputs[0]['sketchcls'], targets[0], topk=topk) if 'sketchclsinput' in self.args.task_types: res_acc = accuracy(pooled_outputs[0]['sketchclsinput'], targets[0], topk=topk) if res_acc is not None: for k in topk: results['accuracy_{}'.format(k)] = res_acc[k][0] # Retrieval Evaluation if 'sketchretrieval' in self.args.task_types: retrieval_evaluations = self.retrieval_evaluation(pooled_outputs[0]['sketchretrieval'], targets[0], topk=(1,)) results = {**results, **retrieval_evaluations} return results
def train(net, dataloader, epoch, opt, criterion): net.train() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() end = time.time() for i, (imgs, cls_ids) in enumerate(dataloader): data_time.update(time.time() - end) imgs, cls_ids = imgs.to(device), cls_ids.to(device) opt.zero_grad() masks = generate_mask() masks = masks.to(device) if np.random.rand() < mask_rate: pred = net(imgs * (1 - masks)) else: pred = net(imgs) loss = criterion(pred, cls_ids) loss.backward() opt.step() #measure prec1, prec5 = accuracy(pred, cls_ids, topk=(1, 5)) losses.update(loss.item(), imgs.size(0)) top1.update(prec1[0], imgs.size(0)) top5.update(prec5[0], imgs.size(0)) batch_time.update(time.time() - end) end = time.time() if i % print_freq == 0: logger.info('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( epoch, i, len(dataloader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5))
def validate(net, dataloader, criterion): batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() net.eval() with torch.no_grad(): end = time.time() for i, (imgs, cls_ids) in enumerate(dataloader): imgs, cls_ids = imgs.to(device), cls_ids.to(device) masks = generate_mask() masks = masks.to(device) pred = net(imgs * (1 - masks)) loss = criterion(pred, cls_ids) #measure prec1, prec5 = accuracy(pred, cls_ids, topk=(1, 5)) losses.update(loss.item(), imgs.size(0)) top1.update(prec1[0], imgs.size(0)) top5.update(prec5[0], imgs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % print_freq == 0: logger.info( 'Test: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( i, len(dataloader), batch_time=batch_time, loss=losses, top1=top1, top5=top5)) return top1.avg, top5.avg
if args.optim.lower() in ['sgd',]: optim = torch.optim.SGD(model.parameters(), lr = args.lr, momentum = 0.9, weight_decay = 1e-4) elif args.optim.lower() in ['adam',]: optim = torch.optim.Adam(model.parameters(), lr = args.lr, betas = (0.9, 0.99), weight_decay = 1e-4) else: raise ValueError('Unrecognized Optimizer: %s' % args.optim.lower()) data_loader = load_pkl(args.data, batch_size = args.batch_size) setup_config = {kwarg: value for kwarg, value in args._get_kwargs()} for idx in range(args.total_iters): data_batch, label_batch = next(data_loader) data_batch = data_batch.cuda(device) if use_gpu else data_batch label_batch = label_batch.cuda(device) if use_gpu else label_batch logits = model(data_batch) loss = criterion(logits, label_batch) acc = accuracy(logits.data, label_batch) optim.zero_grad() loss.backward() optim.step() sys.stdout.write('iter %d: accuracy = %.2f%%\r' % (idx, acc * 100.)) pickle.dump(setup_config, open(os.path.join(args.out_folder, '%s.pkl' % args.model_name), 'wb')) torch.save(model.state_dict(), os.path.join(args.out_folder, '%s.ckpt' % args.model_name))
def certify_pgd(model, data_loader, out_file, eps, norm, bound_est, device, tosave, pixel_range=None, **tricks): ''' >>> Accuracy under adversarial attack, providing the upper bound of certified accuracy ''' use_gpu = device != torch.device('cpu') and torch.cuda.is_available() attacker = PGM(step_size=eps / 10., threshold=eps, iter_num=20, order=norm, pixel_range=pixel_range) optimizer = torch.optim.SGD(model.parameters(), lr=1.) model.eval() acc_calculator = AverageCalculator() robust_acc_calculator = AverageCalculator() normal_success_bits = [] robust_success_bits = [] for idx, (data_batch, label_batch) in enumerate(data_loader, 0): sys.stdout.write('Batch Index = %d\r' % idx) if 'batch_num' in tricks and idx >= tricks['batch_num'] and tricks[ 'batch_num'] > 0: print('The test process stops after %d batches' % tricks['batch_num']) break data_batch = data_batch.cuda(device) if use_gpu else data_batch label_batch = label_batch.cuda(device) if use_gpu else label_batch logits = model(data_batch) acc = accuracy(logits.data, label_batch) acc_calculator.update(acc.item(), data_batch.size(0)) normal_success_bits_this_batch = (torch.argmax( logits, dim=1) == label_batch).float().data.cpu().numpy() normal_success_bits += list(normal_success_bits_this_batch) data_batch = attacker.attack(model, optimizer, data_batch, label_batch) logits = model(data_batch) acc = accuracy(logits.data, label_batch) robust_acc_calculator.update(acc.item(), data_batch.size(0)) robust_success_bits_this_batch = (torch.argmax( logits, dim=1) == label_batch).float().data.cpu().numpy() robust_success_bits += list(robust_success_bits_this_batch) print('') acc_this_epoch = acc_calculator.average robust_acc_this_epoch = robust_acc_calculator.average tosave['normal_success_bits'] = normal_success_bits tosave['robust_success_bits'] = robust_success_bits print('>>>>> The results of PGD <<<<<') print('Average Accuracy: %.2f%%' % (acc_this_epoch * 100.)) print('Robust Accuracy: %.2f%%' % (robust_acc_this_epoch * 100.)) if out_file != None: pickle.dump(tosave, open(out_file, 'wb')) return tosave
def certify_kw(model, data_loader, out_file, eps, norm, bound_est, device, tosave, pixel_range=None, **tricks): ''' >>> Certification function using Kolter-Wong's framework ''' use_gpu = device != torch.device('cpu') and torch.cuda.is_available() if pixel_range == None: bounded_input = False else: assert pixel_range[0] == 0. and pixel_range[ 1] == 1., 'pixel_range %s is not supported' % pixel_range bounded_input = True model.eval() seq_model = model.model2sequence() acc_calculator = AverageCalculator() success_bits = [] for idx, (data_batch, label_batch) in enumerate(data_loader, 0): sys.stdout.write('Batch Index = %d\r' % idx) if 'batch_num' in tricks and idx >= tricks['batch_num'] and tricks[ 'batch_num'] > 0: print('The certification process stops after %d batches' % tricks['batch_num']) break data_batch = data_batch.cuda(device) if use_gpu else data_batch label_batch = label_batch.cuda(device) if use_gpu else label_batch logits = model(data_batch) acc = accuracy(logits.data, label_batch) acc_calculator.update(acc.item(), data_batch.size(0)) norm_type = {2.: 'l2', np.inf: 'l1'}[norm] _, robust_err = robust_loss( seq_model, eps, data_batch, label_batch, norm_type=norm_type, bounded_input=bounded_input, size_average=False) # of shape [batch_size] success_bits_this_batch = (robust_err.data.float() < 1e-8).data.cpu().numpy() success_bits += list(success_bits_this_batch) print('') acc_this_epoch = acc_calculator.average tosave['guaranteed_distances'] = [eps * flag for flag in success_bits] print('>>>>> The results of KW <<<<<') print('Average Accuracy: %.2f%%' % (acc_this_epoch * 100.)) print('Average Certified Distances: %.4f' % (eps * np.mean(success_bits))) print('Certified Bounds over %.4f: %.2f%%' % (eps, np.mean(success_bits) * 100.)) if out_file != None: pickle.dump(tosave, open(out_file, 'wb')) return tosave
def certify_per(model, data_loader, out_file, eps, norm, bound_est, device, tosave, pixel_range=None, **tricks): ''' >>> Certification function using per ''' use_gpu = device != torch.device('cpu') and torch.cuda.is_available() model.eval() guaranteed_distance_list = [] acc_calculator = AverageCalculator() for idx, (data_batch, label_batch) in enumerate(data_loader, 0): sys.stdout.write('Batch Index = %d\r' % idx) if 'batch_num' in tricks and idx >= tricks['batch_num'] and tricks[ 'batch_num'] > 0: print('The certification process stops after %d batches' % tricks['batch_num']) break data_batch = data_batch.cuda(device) if use_gpu else data_batch label_batch = label_batch.cuda(device) if use_gpu else label_batch logits = model(data_batch) criterion = nn.CrossEntropyLoss( ) if use_gpu == False else nn.CrossEntropyLoss().cuda() total_loss, safe_distances = per(model=model, bound_est=bound_est, x=data_batch, T=1, c=label_batch, norm=norm, alpha=1., gamma=0., eps=eps, at=False, pixel_range=pixel_range, criterion=criterion, is_certify=0) for is_certify in [1, 2, 3]: _, safe_distances_this_mode = per(model=model, bound_est=bound_est, x=data_batch, T=1, c=label_batch, norm=norm, alpha=1., gamma=0., eps=eps, at=False, pixel_range=pixel_range, criterion=criterion, is_certify=is_certify) safe_distances = torch.max(safe_distances, safe_distances_this_mode) acc = accuracy(logits.data, label_batch) acc_calculator.update(acc.item(), data_batch.size(0)) guaranteed_distance_list += list(safe_distances.data.cpu().numpy()) print('') acc_this_epoch = acc_calculator.average tosave['guaranteed_distances'] = guaranteed_distance_list success_bits = [ 1. if d > eps - 1e-6 else 0. for d in guaranteed_distance_list ] print('>>>>> The results of PEC <<<<<') print('Average Accuracy: %.2f%%' % (acc_this_epoch * 100.)) print('Average Certified Distances: %.4f' % np.mean(guaranteed_distance_list)) print('Certified Bounds over %.4f: %.2f%%' % (eps, np.mean(success_bits) * 100.)) if out_file != None: pickle.dump(tosave, open(out_file, 'wb')) return tosave
def certify_ibp(model, data_loader, out_file, eps, norm, bound_est, device, tosave, pixel_range=None, **tricks): ''' >>> Certification function using IBP/CROWN-IBP ''' use_gpu = device != torch.device('cpu') and torch.cuda.is_available() model.eval() success_bits_ibp = [] success_bits_crown = [] acc_calculator = AverageCalculator() for idx, (data_batch, label_batch) in enumerate(data_loader, 0): sys.stdout.write('Batch Index = %d\r' % idx) if 'batch_num' in tricks and idx >= tricks['batch_num'] and tricks[ 'batch_num'] > 0: print('The certification process stops after %d batches' % tricks['batch_num']) break data_batch = data_batch.cuda(device) if use_gpu else data_batch label_batch = label_batch.cuda(device) if use_gpu else label_batch logits = model(data_batch) acc = accuracy(logits.data, label_batch) acc_calculator.update(acc.item(), data_batch.size(0)) results_ibp_this_batch, results_crown_this_batch = calc_ibp_certify( model=model, data_batch=data_batch, label_batch=label_batch, perturb_norm=norm, perturb_eps=eps, pixel_range=pixel_range) results_ibp_this_batch = results_ibp_this_batch.data.cpu().numpy() results_crown_this_batch = results_crown_this_batch.data.cpu().numpy() success_bits_ibp += list(results_ibp_this_batch) success_bits_crown += list(results_crown_this_batch) print('') acc_this_epoch = acc_calculator.average tosave['guaranteed_distances_ibp'] = [ eps * flag for flag in success_bits_ibp ] tosave['guaranteed_distances_crownibp'] = [ eps * flag for flag in success_bits_crown ] print('>>> The results of IBP/CROWN-IBP <<<') print('Average Accuracy: %.2f%%' % (acc_this_epoch * 100.)) print('Average Certified Distances by IBP: %.4f' % (eps * np.mean(success_bits_ibp))) print('Certified Bounds over %.4f: %.2f%%' % (eps, np.mean(success_bits_ibp) * 100.)) print('Average Certified Distances by CROWN-IBP: %.4f' % (eps * np.mean(success_bits_crown))) print('Certified Bounds over %.4f: %.2f%%' % (eps, np.mean(success_bits_crown) * 100.)) if out_file != None: pickle.dump(tosave, open(out_file, 'wb')) return tosave
def certify_crown(model, data_loader, out_file, eps, norm, bound_est, device, tosave, pixel_range=None, **tricks): ''' >>> Certification function using Fast-Lin / CROWN ''' use_gpu = device != torch.device('cpu') and torch.cuda.is_available() model.eval() success_bits = [] acc_calculator = AverageCalculator() for idx, (data_batch, label_batch) in enumerate(data_loader, 0): sys.stdout.write('Batch Index = %d\r' % idx) if 'batch_num' in tricks and idx >= tricks['batch_num'] and tricks[ 'batch_num'] > 0: print('The certification process stops after %d batches' % tricks['batch_num']) break data_batch = data_batch.cuda(device) if use_gpu else data_batch label_batch = label_batch.cuda(device) if use_gpu else label_batch logits = model(data_batch) acc = accuracy(logits.data, label_batch) acc_calculator.update(acc.item(), data_batch.size(0)) l, u = crown(model=model, bound_est=bound_est, x=data_batch, c=label_batch, norm=norm, eps=eps, pixel_range=pixel_range) margin, _ = torch.min(l, dim=1) success_bits_this_batch = (margin > -1e-8).data.cpu().numpy() success_bits += list(success_bits_this_batch) print('') acc_this_epoch = acc_calculator.average tosave['guaranteed_distances'] = [eps * flag for flag in success_bits] print('>>> The results of Fast-Lin / CROWN <<<<') print('Average Accuracy: %.2f%%' % (acc_this_epoch * 100.)) print('Average Certified Distances: %.4f' % (eps * np.mean(success_bits))) print('Certified Bounds over %.4f: %.2f%%' % (eps, np.mean(success_bits) * 100.)) if out_file != None: pickle.dump(tosave, open(out_file, 'wb')) return tosave
def accuracy_evaluation(self, prediction, target, topk=(1, 5)): results = {} res_acc = accuracy(prediction, target, topk=topk) for k in topk: results['accuracy_{}'.format(k)] = res_acc[k][0] return results
def search_eps(model, data_loader, min_idx, max_idx, min_eps, max_eps, precision_eps, certify_mode, bound_est, norm, pixel_range, device, **tricks): ''' >>> Find the optimal eps ''' use_gpu = device != torch.device('cpu') and torch.cuda.is_available() tosave = {} for idx, (data_batch, label_batch) in enumerate(data_loader, 0): sys.stdout.write('Processing batch: %d\r' % idx) sys.stdout.flush() if (min_idx != None and idx < min_idx) or (max_idx != None and idx >= max_idx): continue data_batch = data_batch.cuda(device) if use_gpu else data_batch label_batch = label_batch.cuda(device) if use_gpu else label_batch logits = model(data_batch) acc = accuracy(logits.data, label_batch) batch_size = data_batch.shape[0] assert batch_size == 1, 'The batch size should be 1 in this case, but %d found' % batch_size call_time = 0 # number of attempt to calculate certified bounds # eps_this_batch = certify_this_batch(model = model, data_batch = data_batch, label_batch = label_batch, # certify_mode = certify_mode, eps = max_eps, bound_est = bound_est, norm = norm, pixel_range = pixel_range) # call_time += 1 # if eps_this_batch[0] > max_eps - 1e-8: # tosave[idx] = {'eps': max_eps, 'call_time': 1, 'acc': 1 if acc.item() > 0.5 else 0} # continue # up_eps = max_eps # low_eps = eps_this_batch[0] # eps_this_batch = certify_this_batch(model = model, data_batch = data_batch, label_batch = label_batch, # certify_mode = certify_mode, eps = min_eps, bound_est = bound_est, norm = norm, pixel_range = pixel_range) # call_time += 1 # assert eps_this_batch[0] > (min_eps - 1e-8), 'data cannot be certified when eps = %1.2e' % min_eps # low_eps = max(low_eps, eps_this_batch[0]) up_eps = max_eps low_eps = min_eps while up_eps - low_eps > precision_eps: attempt_eps = (up_eps + low_eps) / 2. eps_this_batch = certify_this_batch(model=model, data_batch=data_batch, label_batch=label_batch, certify_mode=certify_mode, eps=attempt_eps, bound_est=bound_est, norm=norm, pixel_range=pixel_range) call_time += 1 low_eps = max(low_eps, eps_this_batch[0]) if eps_this_batch[0] < attempt_eps - 1e-6: up_eps = attempt_eps tosave[idx] = { 'eps': (low_eps + up_eps) / 2., 'call_time': call_time, 'acc': 1 if acc.item() > 0.5 else 0 } return tosave
def train_test_ibp(model, train_loader, test_loader, attacker, epoch_num, optimizer, out_folder, model_name, alpha_list, beta_list, eps_list, norm, device, criterion, tosave, pixel_range=None, **tricks): ''' >>> General training function using IBP ''' acc_calculator = AverageCalculator() loss_calculator = AverageCalculator() use_gpu = device != torch.device('cpu') and torch.cuda.is_available() for epoch_idx in range(epoch_num): alpha = alpha_list[epoch_idx] beta = beta_list[epoch_idx] eps = eps_list[epoch_idx] print('alpha = %1.2e, beta = %1.2e, eps = %1.2e' % (alpha, beta, eps)) acc_calculator.reset() loss_calculator.reset() model.train() for idx, (data_batch, label_batch) in enumerate(train_loader, 0): sys.stdout.write('Batch_idx = %d\r' % idx) if 'lr_func' in tricks and tricks['lr_func'] != None: lr_func = tricks['lr_func'] local_idx = epoch_idx if 'train_batch_per_epoch' not in tricks else epoch_idx + float( idx) / tricks['train_batch_per_epoch'] local_lr = lr_func(local_idx) for param_group in optimizer.param_groups: param_group['lr'] = local_lr data_batch = data_batch.cuda(device) if use_gpu else data_batch label_batch = label_batch.cuda(device) if use_gpu else label_batch if attacker != None: attack_num = data_batch.shape[0] // 2 data_batch_attack = attacker.attack(model, optimizer, data_batch[:attack_num], label_batch[:attack_num], criterion) data_batch = torch.cat( [data_batch_attack, data_batch[attack_num:]], dim=0) logits = model(data_batch) if 'bound_calc_per_batch' in tricks and tricks[ 'bound_calc_per_batch'] != None: loss = calc_ibp_loss(model, data_batch, label_batch, norm, eps, pixel_range, alpha, beta, criterion, tricks['bound_calc_per_batch']) else: loss = calc_ibp_loss(model, data_batch, label_batch, norm, eps, pixel_range, alpha, beta, criterion) loss.backward() for group in optimizer.param_groups: for p in group['params']: if p.grad is None: continue p.grad.data = torch.clamp(p.grad.data, min=-0.5, max=0.5) optimizer.step() optimizer.zero_grad() acc = accuracy(logits.data, label_batch) acc_calculator.update(acc.item(), data_batch.size(0)) loss_calculator.update(loss.item(), data_batch.size(0)) loss_this_epoch = loss_calculator.average acc_this_epoch = acc_calculator.average print('Train loss/acc after epoch %d: %.4f/%.2f%%' % (epoch_idx, loss_this_epoch, acc_this_epoch * 100.)) tosave['train_loss'][epoch_idx] = loss_this_epoch tosave['train_acc'][epoch_idx] = acc_this_epoch acc_calculator.reset() loss_calculator.reset() model.eval() for idx, (data_batch, label_batch) in enumerate(test_loader, 0): data_batch = data_batch.cuda(device) if use_gpu else data_batch label_batch = label_batch.cuda(device) if use_gpu else label_batch logits = model(data_batch) if 'bound_calc_per_batch' in tricks and tricks[ 'bound_calc_per_batch'] != None: loss = calc_ibp_loss(model, data_batch, label_batch, norm, eps, pixel_range, alpha, beta, criterion, tricks['bound_calc_per_batch']) else: loss = calc_ibp_loss(model, data_batch, label_batch, norm, eps, pixel_range, alpha, beta, criterion) acc = accuracy(logits.data, label_batch) acc_calculator.update(acc.item(), data_batch.size(0)) loss_calculator.update(loss.item(), data_batch.size(0)) loss_this_epoch = loss_calculator.average acc_this_epoch = acc_calculator.average print('Test loss/acc after epoch %d: %.4f/%.2f%%' % (epoch_idx, loss_this_epoch, acc_this_epoch * 100.)) tosave['test_loss'][epoch_idx] = loss_this_epoch tosave['test_acc'][epoch_idx] = acc_this_epoch pickle.dump( tosave, open(os.path.join(out_folder, '%s.pkl' % model_name), 'wb')) torch.save(model.state_dict(), os.path.join(out_folder, '%s.ckpt' % model_name)) pickle.dump(tosave, open(os.path.join(out_folder, '%s.pkl' % model_name), 'wb')) torch.save(model.state_dict(), os.path.join(out_folder, '%s.ckpt' % model_name))
def train_test(model, train_loader, test_loader, attacker, epoch_num, optimizer, out_folder, model_name, alpha_list, eps_list, gamma_list, bound_est, T, norm, device, criterion, tosave, at_per=False, pixel_range=None, update_freq=1, **tricks): ''' >>> General training function without validation set ''' acc_calculator = AverageCalculator() loss_calculator = AverageCalculator() safe_distance_calculator = AverageCalculator() use_gpu = device != torch.device('cpu') and torch.cuda.is_available() global_batch_idx = 0 # use to update the parameter for epoch_idx in range(epoch_num): alpha = alpha_list[epoch_idx] eps = eps_list[epoch_idx] gamma = gamma_list[epoch_idx] acc_calculator.reset() loss_calculator.reset() safe_distance_calculator.reset() model.train() for idx, (data_batch, label_batch) in enumerate(train_loader, 0): sys.stdout.write('Batch_idx = %d\r' % idx) if 'lr_func' in tricks and tricks['lr_func'] != None: lr_func = tricks['lr_func'] local_idx = epoch_idx if 'train_batch_per_epoch' not in tricks else epoch_idx + float( idx) / tricks['train_batch_per_epoch'] local_lr = lr_func(local_idx) for param_group in optimizer.param_groups: param_group['lr'] = local_lr data_batch = data_batch.cuda(device) if use_gpu else data_batch label_batch = label_batch.cuda(device) if use_gpu else label_batch # Play adversarial attack if attacker != None: attack_num = data_batch.shape[0] // 2 data_batch_attack = attacker.attack(model, optimizer, data_batch[:attack_num], label_batch[:attack_num], criterion) data_batch = torch.cat( [data_batch_attack, data_batch[attack_num:]], dim=0) logits = model(data_batch) if gamma > 1e-6: # Turn on per if 'bound_calc_per_batch' in tricks and tricks[ 'bound_calc_per_batch'] != None: ins_num = tricks['bound_calc_per_batch'] if 'regularize_mode' not in tricks or tricks[ 'regularize_mode'] == 'per': loss, safe_distances = per( model=model, bound_est=bound_est, x=data_batch, T=T, c=label_batch, norm=norm, alpha=alpha, gamma=gamma, eps=eps, at=at_per, pixel_range=pixel_range, criterion=criterion, bound_calc_per_batch=ins_num) elif tricks['regularize_mode'] == 'kw': loss = crown_loss(model=model, bound_est=bound_est, x=data_batch, c=label_batch, norm=norm, gamma=gamma, eps=eps, at=at_per, pixel_range=pixel_range, criterion=criterion, bound_calc_per_batch=ins_num) safe_distances = torch.zeros_like(label_batch).float() else: raise ValueError('Invalid regularizer_mode') else: if 'regularize_mode' not in tricks or tricks[ 'regularize_mode'] == 'per': loss, safe_distances = per(model=model, bound_est=bound_est, x=data_batch, T=T, c=label_batch, norm=norm, alpha=alpha, gamma=gamma, eps=eps, at=at_per, pixel_range=pixel_range, criterion=criterion) elif tricks['regularize_mode'] == 'kw': loss = crown_loss(model=model, bound_est=bound_est, x=data_batch, c=label_batch, norm=norm, gamma=gamma, eps=eps, at=at_per, pixel_range=pixel_range, criterion=criterion) safe_distances = torch.zeros_like(label_batch).float() else: raise ValueError('Invalid regularizer_mode') else: loss = criterion(logits, label_batch) safe_distances = torch.zeros_like(label_batch).float() global_batch_idx += 1 loss.backward(retain_graph=True) if global_batch_idx % update_freq == 0: for group in optimizer.param_groups: for p in group['params']: if p.grad is None: continue p.grad.data = p.grad.data / update_freq p.grad.data = torch.clamp(p.grad.data, min=-0.5, max=0.5) # Gradient clipping optimizer.step() optimizer.zero_grad() acc = accuracy(logits.data, label_batch) acc_calculator.update(acc.item(), data_batch.size(0)) loss_calculator.update(loss.item(), data_batch.size(0)) safe_distance_calculator.update(safe_distances.mean().item(), data_batch.size(0)) loss_this_epoch = loss_calculator.average acc_this_epoch = acc_calculator.average safe_distance_this_epoch = safe_distance_calculator.average print('Train loss/acc after epoch %d: %.4f/%.2f%%' % (epoch_idx, loss_this_epoch, acc_this_epoch * 100.)) print('Train safe distance: %.4f' % safe_distance_this_epoch) tosave['train_loss'][epoch_idx] = loss_this_epoch tosave['train_acc'][epoch_idx] = acc_this_epoch tosave['train_safe_distance'][epoch_idx] = safe_distance_this_epoch acc_calculator.reset() loss_calculator.reset() safe_distance_calculator.reset() model.eval() guaranteed_distance_list = [] for idx, (data_batch, label_batch) in enumerate(test_loader, 0): data_batch = data_batch.cuda(device) if use_gpu else data_batch label_batch = label_batch.cuda(device) if use_gpu else label_batch logits = model(data_batch) if gamma > 1e-6: # Turn on per if 'bound_calc_per_batch' in tricks and tricks[ 'bound_calc_per_batch'] != None: ins_num = tricks['bound_calc_per_batch'] if 'regularize_mode' not in tricks or tricks[ 'regularize_mode'] == 'per': loss, safe_distances = per( model=model, bound_est=bound_est, x=data_batch, T=T, c=label_batch, norm=norm, alpha=alpha, gamma=gamma, eps=eps, at=at_per, pixel_range=pixel_range, criterion=criterion, bound_calc_per_batch=ins_num) elif tricks['regularize_mode'] == 'kw': loss = crown_loss(model=model, bound_est=bound_est, x=data_batch, c=label_batch, norm=norm, gamma=gamma, eps=eps, at=at_per, pixel_range=pixel_range, criterion=criterion, bound_calc_per_batch=ins_num) safe_distances = torch.zeros_like(label_batch).float() else: raise ValueError('Invalid regularizer_mode') else: if 'regularize_mode' not in tricks or tricks[ 'regularize_mode'] == 'per': loss, safe_distances = per(model=model, bound_est=bound_est, x=data_batch, T=T, c=label_batch, norm=norm, alpha=alpha, gamma=gamma, eps=eps, at=at_per, pixel_range=pixel_range, criterion=criterion) elif tricks['regularize_mode'] == 'kw': loss = crown_loss(model=model, bound_est=bound_est, x=data_batch, c=label_batch, norm=norm, gamma=gamma, eps=eps, at=at_per, pixel_range=pixel_range, criterion=criterion) safe_distances = torch.zeros_like(label_batch).float() else: raise ValueError('Invalid regularizer_mode') else: loss = criterion(logits, label_batch) safe_distances = torch.zeros_like(label_batch).float() acc = accuracy(logits.data, label_batch) acc_calculator.update(acc.item(), data_batch.size(0)) loss_calculator.update(loss.item(), data_batch.size(0)) safe_distance_calculator.update(safe_distances.mean().item(), data_batch.size(0)) guaranteed_distance_list += list(safe_distances.data.cpu().numpy()) loss_this_epoch = loss_calculator.average acc_this_epoch = acc_calculator.average safe_distance_this_epoch = safe_distance_calculator.average print('Test loss/acc after epoch %d: %.4f/%.2f%%' % (epoch_idx, loss_this_epoch, acc_this_epoch * 100.)) print('Test safe distance: %.4f' % safe_distance_this_epoch) tosave['test_loss'][epoch_idx] = loss_this_epoch tosave['test_acc'][epoch_idx] = acc_this_epoch tosave['test_safe_distance'][epoch_idx] = safe_distance_this_epoch tosave['guaranteed_distances'] = guaranteed_distance_list pickle.dump( tosave, open(os.path.join(out_folder, '%s.pkl' % model_name), 'wb')) torch.save(model.state_dict(), os.path.join(out_folder, '%s.ckpt' % model_name)) pickle.dump(tosave, open(os.path.join(out_folder, '%s.pkl' % model_name), 'wb')) torch.save(model.state_dict(), os.path.join(out_folder, '%s.ckpt' % model_name))
def train_test(setup_config, model, train_loader, test_loader, epoch_num, optimizer, lr_func, output_folder, model_name, device_ids, criterion=nn.CrossEntropyLoss(), **tricks): ''' >>> general training function without validation set ''' tosave = { 'model_summary': str(model), 'setup_config': setup_config, 'train_loss': {}, 'train_acc': {}, 'test_loss': {}, 'test_acc': {} } device = torch.device('cuda:0' if not device_ids in ['cpu'] and torch.cuda.is_available() else 'cpu') if not device_ids in ['cpu']: criterion = criterion.cuda(device) idx_list = [ idx for idx, (data_batch, label_batch) in enumerate(train_loader, 0) ] batches_per_epoch = len(idx_list) acc_calculator = AverageCalculator() loss_calculator = AverageCalculator() for epoch_idx in range(epoch_num): print('Epoch %d: lr = %1.2e' % (epoch_idx, lr_func(epoch_idx))) acc_calculator.reset() loss_calculator.reset() model.train() # Switch to Train Mode for idx, (data_batch, label_batch) in enumerate(train_loader, 0): epoch_value = epoch_idx + float(idx) / float(batches_per_epoch) for param_group in optimizer.param_groups: param_group['lr'] = lr_func(epoch_value) if not device_ids in ['cpu']: # Use of GPU data_batch = Variable(data_batch).cuda(device) label_batch = Variable(label_batch.cuda(device, async=True)) else: data_batch = Variable(data_batch) label_batch = Variable(label_batch) logits = model(data_batch) loss = criterion(logits, label_batch) optimizer.zero_grad() loss.backward() optimizer.step() acc = accuracy(logits.data, label_batch) acc_calculator.update(acc.item(), data_batch.size(0)) loss_calculator.update(loss.item(), data_batch.size(0)) if 'ema' in tricks and tricks['ema'] != None: tricks['ema'].update_model(model=model) print('Training loss after epoch %d: %.4f' % (epoch_idx + 1, loss_calculator.average)) tosave['train_acc'][epoch_idx] = acc_calculator.average tosave['train_loss'][epoch_idx] = loss_calculator.average acc_calculator.reset() loss_calculator.reset() model.eval() # Switch to Evaluation Mode for idx, (data_batch, label_batch) in enumerate(test_loader, 0): if not device_ids in ['cpu']: # Use of GPU data_batch = Variable(data_batch).cuda(device) label_batch = Variable(label_batch.cuda(device, async=True)) else: data_batch = Variable(data_batch) label_batch = Variable(label_batch) logits = model(data_batch) loss = criterion(logits, label_batch) acc = accuracy(logits.data, label_batch) acc_calculator.update(acc.item(), data_batch.size(0)) loss_calculator.update(loss.item(), data_batch.size(0)) print('Test loss after epoch %d: %.4f, accuracy = %.2f%%' % (epoch_idx + 1, loss_calculator.average, acc_calculator.average * 100.)) tosave['test_acc'][epoch_idx] = acc_calculator.average tosave['test_loss'][epoch_idx] = loss_calculator.average if 'snapshots' in tricks and (epoch_idx + 1) in tricks['snapshots']: print('snapshot saved in epoch %d' % (epoch_idx + 1)) torch.save( model.state_dict(), os.path.join(output_folder, '%s_%d.ckpt' % (model_name, epoch_idx + 1))) pickle.dump( tosave, open(os.path.join(output_folder, '%s.pkl' % model_name), 'wb')) torch.save(model.state_dict(), os.path.join(output_folder, '%s.ckpt' % model_name)) return tosave