def evaluate_natural(args, model, test_loader, verbose=False): model.eval() with torch.no_grad(): meter = MultiAverageMeter() test_loss = test_acc = test_n = 0 def test_step(step, X_batch, y_batch): X, y = X_batch.cuda(), y_batch.cuda() if args.method == 'natural': output = model(X) loss = F.cross_entropy(output, y) else: raise NotImplementedError(args.method) meter.update('test_loss', loss.item(), y.size(0)) meter.update('test_acc', (output.max(1)[1] == y).float().mean(), y.size(0)) if step % args.log_interval == 0 and verbose: logger.info('Eval step {}/{} {}'.format( step, total_test_steps, meter)) for step, (X_batch, y_batch) in enumerate(test_loader): test_step(step, X_batch, y_batch) logger.info('Evaluation {}'.format(meter))
def Train(model, t, loader, eps_scheduler, norm, train, opt, bound_type, method='robust'): num_class = 10 meter = MultiAverageMeter() if train: model.train() eps_scheduler.train() eps_scheduler.step_epoch() eps_scheduler.set_epoch_length( int((len(loader.dataset) + loader.batch_size - 1) / loader.batch_size)) else: model.eval() eps_scheduler.eval() for i, (data, labels) in enumerate(loader): start = time.time() eps_scheduler.step_batch() eps = eps_scheduler.get_eps() # For small eps just use natural training, no need to compute LiRPA bounds batch_method = method if eps < 1e-20: batch_method = "natural" if train: opt.zero_grad() # generate specifications c = torch.eye(num_class).type_as(data)[labels].unsqueeze( 1) - torch.eye(num_class).type_as(data).unsqueeze(0) # remove specifications to self I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as( labels.data).unsqueeze(0))) c = (c[I].view(data.size(0), num_class - 1, num_class)) # bound input for Linf norm used only if norm == np.inf: data_max = torch.reshape((1. - loader.mean) / loader.std, (1, -1, 1, 1)) data_min = torch.reshape((0. - loader.mean) / loader.std, (1, -1, 1, 1)) data_ub = torch.min(data + (eps / loader.std).view(1, -1, 1, 1), data_max) data_lb = torch.max(data - (eps / loader.std).view(1, -1, 1, 1), data_min) else: data_ub = data_lb = data if list(model.parameters())[0].is_cuda: data, labels, c = data.cuda(), labels.cuda(), c.cuda() data_lb, data_ub = data_lb.cuda(), data_ub.cuda() # Specify Lp norm perturbation. # When using Linf perturbation, we manually set element-wise bound x_L and x_U. eps is not used for Linf norm. ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=data_lb, x_U=data_ub) x = BoundedTensor(data, ptb) output = model(x) regular_ce = CrossEntropyLoss()( output, labels) # regular CrossEntropyLoss used for warming up meter.update('CE', regular_ce.item(), x.size(0)) meter.update( 'Err', torch.sum( torch.argmax(output, dim=1) != labels).cpu().detach().numpy() / x.size(0), x.size(0)) if batch_method == "robust": if bound_type == "IBP": lb, ub = model.compute_bounds(IBP=True, C=c, method=None) elif bound_type == "CROWN": lb, ub = model.compute_bounds(IBP=False, C=c, method="backward", bound_upper=False) elif bound_type == "CROWN-IBP": # lb, ub = model.compute_bounds(ptb=ptb, IBP=True, x=data, C=c, method="backward") # pure IBP bound # we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020) factor = (eps_scheduler.get_max_eps() - eps) / eps_scheduler.get_max_eps() ilb, iub = model.compute_bounds(IBP=True, C=c, method=None) if factor < 1e-5: lb = ilb else: clb, cub = model.compute_bounds(IBP=False, C=c, method="backward", bound_upper=False) lb = clb * factor + ilb * (1 - factor) # Pad zero at the beginning for each example, and use fake label "0" for all examples lb_padded = torch.cat((torch.zeros( size=(lb.size(0), 1), dtype=lb.dtype, device=lb.device), lb), dim=1) fake_labels = torch.zeros(size=(lb.size(0), ), dtype=torch.int64, device=lb.device) robust_ce = CrossEntropyLoss()(-lb_padded, fake_labels) if batch_method == "robust": loss = robust_ce elif batch_method == "natural": loss = regular_ce if train: loss.backward() eps_scheduler.update_loss(loss.item() - regular_ce.item()) opt.step() meter.update('Loss', loss.item(), data.size(0)) if batch_method != "natural": meter.update('Robust_CE', robust_ce.item(), data.size(0)) # For an example, if lower bounds of margins is >0 for all classes, the output is verifiably correct. # If any margin is < 0 this example is counted as an error meter.update('Verified_Err', torch.sum((lb < 0).any(dim=1)).item() / data.size(0), data.size(0)) meter.update('Time', time.time() - start) if i % 50 == 0 and train: print('[{:2d}:{:4d}]: eps={:.8f} {}'.format(t, i, eps, meter)) print('[{:2d}:{:4d}]: eps={:.8f} {}'.format(t, i, eps, meter))
def Train(model, model_ori, t, loader, eps_scheduler, opt, loss_fusion=False, valid=False): train = opt is not None meter = MultiAverageMeter() meter_layer = [] data_max, data_min, std = loader.data_max, loader.data_min, loader.std if args.device == 'cuda': data_min, data_max, std = data_min.cuda(), data_max.cuda(), std.cuda() if train: model_ori.train() model.train() eps_scheduler.train() eps_scheduler.step_epoch() else: model_ori.eval() model.eval() eps_scheduler.eval() for i, (data, labels) in enumerate(loader): start = time.time() eps_scheduler.step_batch() eps = eps_scheduler.get_eps() epoch_progress = (i + 1) * 1. / len(loader) if train else 1.0 if train: eps *= args.train_eps_mul if eps < args.min_eps: eps = args.min_eps if args.fix_eps: eps = eps_scheduler.get_max_eps() if args.natural: eps = 0. reg = t <= args.num_reg_epochs # For small eps just use natural training, no need to compute LiRPA bounds batch_method = 'natural' if (eps < 1e-50) else 'robust' robust = batch_method == 'robust' # labels = labels.to(torch.long) if args.device == 'cuda': data, labels = data.cuda().detach().requires_grad_(), labels.cuda() data_batch, labels_batch = data, labels grad_acc = args.grad_acc_steps assert data.shape[0] % grad_acc == 0 bsz = data.shape[0] // grad_acc for k in range(grad_acc): if grad_acc > 1: data, labels = data_batch[bsz * k:bsz * (k + 1)], labels_batch[bsz * k:bsz * (k + 1)] if args.mode == 'cert': regular_ce, robust_loss, regular_err, robust_err = cert( args, model, model_ori, t, epoch_progress, data, labels, eps=eps, data_max=data_max, data_min=data_min, std=std, robust=robust, reg=reg, loss_fusion=loss_fusion, eps_scheduler=eps_scheduler, train=train, meter=meter) elif args.mode == 'adv': method = args.method if train else 'pgd' regular_ce, robust_loss, regular_err, robust_err = adv( args, model, model_ori, t, epoch_progress, data, labels, eps=eps, data_max=data_max, data_min=data_min, std=std, train=train, meter=meter) else: raise NotImplementedError update_meter(meter, regular_ce, robust_loss, regular_err, robust_err, data.size(0)) if reg: loss = compute_reg(args, model, meter, eps, eps_scheduler) elif args.xiao_reg: loss = compute_stab_reg( args, model, meter, eps, eps_scheduler) + compute_L1_reg( args, model_ori, meter, eps, eps_scheduler) elif args.vol_reg: # by colt loss = compute_vol_reg(args, model, meter, eps, eps_scheduler) else: loss = torch.tensor(0.).to(args.device) if robust: loss += robust_loss * args.kappa + robust_loss * (1 - args.kappa) else: loss += regular_ce meter.update('Loss', loss.item(), data.size(0)) if train: loss /= grad_acc loss.backward() if args.check_nan: for p in model.parameters(): if torch.isnan(p.grad).any(): pdb.set_trace() ckpt = { 'model_ori': model_ori, 'args_cert': (t, epoch_progress, data, labels, eps, data_max, data_min, std, robust, reg, loss_fusion, eps_scheduler, train, meter) } torch.save(ckpt, 'nan_ckpt') pdb.set_trace() if train: grad_norm = torch.nn.utils.clip_grad_norm_(model_ori.parameters(), max_norm=args.grad_norm) meter.update('grad_norm', grad_norm) opt.step() opt.zero_grad() meter.update('wnorm', get_weight_norm(model_ori)) meter.update('Time', time.time() - start) if (i + 1) % args.log_interval == 0 and (train or args.eval or args.verify): logger.info('[{:2d}:{:4d}/{:4d}]: eps={:.8f} {}'.format( t, i + 1, len(loader), eps, meter)) if args.debug: print() pdb.set_trace() logger.info('[{:2d}]: eps={:.8f} {}'.format(t, eps, meter)) if batch_method != 'natural': meter.update('eps', eps_scheduler.get_eps()) if t <= args.num_reg_epochs: update_log_reg(writer, meter, t, train, model) update_log_writer(args, writer, meter, t, train, robust=(batch_method != 'natural')) return meter
def Train(model, t, loader, eps_scheduler, norm, train, opt, bound_type, method='robust', loss_fusion=True, final_node_name=None): num_class = 200 meter = MultiAverageMeter() if train: model.train() eps_scheduler.train() eps_scheduler.step_epoch() eps_scheduler.set_epoch_length( int((len(loader.dataset) + loader.batch_size - 1) / loader.batch_size)) else: model.eval() eps_scheduler.eval() exp_module = get_exp_module(model) def get_bound_loss(x=None, c=None): if loss_fusion: bound_lower, bound_upper = False, True else: bound_lower, bound_upper = True, False if bound_type == 'IBP': lb, ub = model(method_opt="compute_bounds", x=x, IBP=True, C=c, method=None, final_node_name=final_node_name, no_replicas=True) elif bound_type == 'CROWN': lb, ub = model(method_opt="compute_bounds", x=x, IBP=False, C=c, method='backward', bound_lower=bound_lower, bound_upper=bound_upper) elif bound_type == 'CROWN-IBP': # lb, ub = model.compute_bounds(ptb=ptb, IBP=True, x=data, C=c, method='backward') # pure IBP bound # we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020) factor = (eps_scheduler.get_max_eps() - eps_scheduler.get_eps()) / eps_scheduler.get_max_eps() ilb, iub = model(method_opt="compute_bounds", x=x, IBP=True, C=c, method=None, final_node_name=final_node_name, no_replicas=True) if factor < 1e-50: lb, ub = ilb, iub else: clb, cub = model(method_opt="compute_bounds", IBP=False, C=c, method='backward', bound_lower=bound_lower, bound_upper=bound_upper, final_node_name=final_node_name, no_replicas=True) if loss_fusion: ub = cub * factor + iub * (1 - factor) else: lb = clb * factor + ilb * (1 - factor) if loss_fusion: if isinstance(model, BoundDataParallel): max_input = model(get_property=True, node_class=BoundExp, att_name='max_input') else: max_input = exp_module.max_input return None, torch.mean(torch.log(ub) + max_input) else: # Pad zero at the beginning for each example, and use fake label '0' for all examples lb_padded = torch.cat((torch.zeros( size=(lb.size(0), 1), dtype=lb.dtype, device=lb.device), lb), dim=1) fake_labels = torch.zeros(size=(lb.size(0), ), dtype=torch.int64, device=lb.device) robust_ce = CrossEntropyLoss()(-lb_padded, fake_labels) return lb, robust_ce for i, (data, labels) in enumerate(loader): start = time.time() eps_scheduler.step_batch() eps = eps_scheduler.get_eps() # For small eps just use natural training, no need to compute LiRPA bounds batch_method = method if eps < 1e-50: batch_method = "natural" if train: opt.zero_grad() # bound input for Linf norm used only if norm == np.inf: data_max = torch.reshape((1. - loader.mean) / loader.std, (1, -1, 1, 1)) data_min = torch.reshape((0. - loader.mean) / loader.std, (1, -1, 1, 1)) data_ub = torch.min(data + (eps / loader.std).view(1, -1, 1, 1), data_max) data_lb = torch.max(data - (eps / loader.std).view(1, -1, 1, 1), data_min) else: data_ub = data_lb = data if list(model.parameters())[0].is_cuda: data, labels = data.cuda(), labels.cuda() data_lb, data_ub = data_lb.cuda(), data_ub.cuda() ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=data_lb, x_U=data_ub) x = BoundedTensor(data, ptb) if loss_fusion: if batch_method == 'natural' or not train: output = model(x, labels) regular_ce = torch.mean(torch.log(output)) else: model(x, labels) regular_ce = torch.tensor(0., device=data.device) meter.update('CE', regular_ce.item(), x.size(0)) x = (x, labels) c = None else: c = torch.eye(num_class).type_as(data)[labels].unsqueeze( 1) - torch.eye(num_class).type_as(data).unsqueeze(0) # remove specifications to self I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as( labels.data).unsqueeze(0))) c = (c[I].view(data.size(0), num_class - 1, num_class)) x = (x, labels) output = model(x, final_node_name=final_node_name) regular_ce = CrossEntropyLoss()( output, labels) # regular CrossEntropyLoss used for warming up meter.update('CE', regular_ce.item(), x[0].size(0)) meter.update( 'Err', torch.sum(torch.argmax(output, dim=1) != labels).item() / x[0].size(0), x[0].size(0)) if batch_method == 'robust': # print(data.sum()) lb, robust_ce = get_bound_loss(x=x, c=c) loss = robust_ce elif batch_method == 'natural': loss = regular_ce if train: loss.backward() if args.clip_grad_norm: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=args.clip_grad_norm) meter.update('grad_norm', grad_norm) if isinstance(eps_scheduler, AdaptiveScheduler): eps_scheduler.update_loss(loss.item() - regular_ce.item()) opt.step() meter.update('Loss', loss.item(), data.size(0)) if batch_method != 'natural': meter.update('Robust_CE', robust_ce.item(), data.size(0)) if not loss_fusion: # For an example, if lower bounds of margins is >0 for all classes, the output is verifiably correct. # If any margin is < 0 this example is counted as an error meter.update( 'Verified_Err', torch.sum((lb < 0).any(dim=1)).item() / data.size(0), data.size(0)) meter.update('Time', time.time() - start) if (i + 1) % 250 == 0 and train: logger.log('[{:2d}:{:4d}]: eps={:.12f} {}'.format( t, i + 1, eps, meter)) logger.log('[{:2d}:{:4d}]: eps={:.12f} {}'.format(t, i + 1, eps, meter)) return meter
def Train(model, t, loader, eps_scheduler, norm, train, opt, bound_type, method='robust', loss_fusion=True, final_node_name=None): meter = MultiAverageMeter() if train: model.train() eps_scheduler.train() eps_scheduler.step_epoch(verbose=False) eps_scheduler.set_epoch_length(int((len(loader.dataset) + loader.batch_size - 1) / loader.batch_size)) else: model.eval() eps_scheduler.eval() # Used for loss-fusion. Get the exp operation in computational graph. exp_module = get_exp_module(model) def get_bound_loss(x=None, c=None): if loss_fusion: # When loss fusion is used, we need the upper bound for the final loss function. bound_lower, bound_upper = False, True else: # When loss fusion is not used, we need the lower bound for the logit layer. bound_lower, bound_upper = True, False if bound_type == 'IBP': lb, ub = model(method_opt="compute_bounds", x=x, C=c, method="IBP", final_node_name=final_node_name, no_replicas=True) elif bound_type == 'CROWN': lb, ub = model(method_opt="compute_bounds", x=x, C=c, method="backward", bound_lower=bound_lower, bound_upper=bound_upper) elif bound_type == 'CROWN-IBP': # we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020) # factor = (eps_scheduler.get_max_eps() - eps_scheduler.get_eps()) / eps_scheduler.get_max_eps() ilb, iub = model(method_opt="compute_bounds", x=x, C=c, method="IBP", final_node_name=final_node_name, no_replicas=True) lb, ub = model(method_opt="compute_bounds", C=c, method="CROWN-IBP", bound_lower=bound_lower, bound_upper=bound_upper, final_node_name=final_node_name, average_A=True, no_replicas=True) if loss_fusion: # When loss fusion is enabled, we need to get the common factor before softmax. if isinstance(model, BoundDataParallel): max_input = model(get_property=True, node_class=BoundExp, att_name='max_input') else: max_input = exp_module.max_input return None, torch.mean(torch.log(ub) + max_input) else: # Pad zero at the beginning for each example, and use fake label '0' for all examples lb_padded = torch.cat((torch.zeros(size=(lb.size(0), 1), dtype=lb.dtype, device=lb.device), lb), dim=1) fake_labels = torch.zeros(size=(lb.size(0),), dtype=torch.int64, device=lb.device) robust_ce = CrossEntropyLoss()(-lb_padded, fake_labels) return lb, robust_ce for i, (data, labels) in enumerate(loader): # For unit test. We only use a small number of batches if args.truncate_data: if i >= args.truncate_data: break start = time.time() eps_scheduler.step_batch() eps = eps_scheduler.get_eps() # For small eps just use natural training, no need to compute LiRPA bounds batch_method = method if eps < 1e-50: batch_method = "natural" if train: opt.zero_grad() if list(model.parameters())[0].is_cuda: data, labels = data.cuda(), labels.cuda() model.ptb.eps = eps x = data if loss_fusion: if batch_method == 'natural' or not train: output = model(x, labels) # , disable_multi_gpu=True regular_ce = torch.mean(torch.log(output)) else: model(x, labels) regular_ce = torch.tensor(0., device=data.device) meter.update('CE', regular_ce.item(), x.size(0)) x = (x, labels) c = None else: # Generate speicification matrix (when loss fusion is not used). c = torch.eye(num_class).type_as(data)[labels].unsqueeze(1) - torch.eye(num_class).type_as(data).unsqueeze( 0) # remove specifications to self. I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(labels.data).unsqueeze(0))) c = (c[I].view(data.size(0), num_class - 1, num_class)) x = (x, labels) output = model(x, final_node_name=final_node_name) regular_ce = CrossEntropyLoss()(output, labels) # regular CrossEntropyLoss used for warming up meter.update('CE', regular_ce.item(), x[0].size(0)) meter.update('Err', torch.sum(torch.argmax(output, dim=1) != labels).item() / x[0].size(0), x[0].size(0)) if batch_method == 'robust': lb, robust_ce = get_bound_loss(x=x, c=c) loss = robust_ce elif batch_method == 'natural': loss = regular_ce if train: loss.backward() if args.clip_grad_norm: grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad_norm) meter.update('grad_norm', grad_norm) if isinstance(eps_scheduler, AdaptiveScheduler): eps_scheduler.update_loss(loss.item() - regular_ce.item()) opt.step() meter.update('Loss', loss.item(), data.size(0)) if batch_method != 'natural': meter.update('Robust_CE', robust_ce.item(), data.size(0)) if not loss_fusion: # For an example, if lower bounds of margins is >0 for all classes, the output is verifiably correct. # If any margin is < 0 this example is counted as an error meter.update('Verified_Err', torch.sum((lb < 0).any(dim=1)).item() / data.size(0), data.size(0)) meter.update('Time', time.time() - start) if (i + 1) % 50 == 0 and train: logger.log('[{:2d}:{:4d}]: eps={:.12f} {}'.format(t, i + 1, eps, meter)) logger.log('[{:2d}:{:4d}]: eps={:.12f} {}'.format(t, i + 1, eps, meter)) return meter
def train_natural(args, model, ds_train, ds_test): if args.optimizer == 'sgd': opt = torch.optim.SGD(model.parameters(), lr=args.base_lr, momentum=0.9) elif args.optimizer == 'adam': opt = torch.optim.Adam(model.parameters(), lr=args.base_lr) else: raise ValueError(args.optimizer) start_epoch_time = time.time() meter = MultiAverageMeter() if args.data_loader == 'torch': train_loader, test_loader = get_loaders(args) else: test_loader = ds_test def train_step(step, X_batch, y_batch, lr_repl): lr = float(lr_repl[0]) opt.param_groups[0]['lr'] = lr model.train() batch_size = math.ceil(args.batch_size / args.accum_steps) for i in range(args.accum_steps): X = X_batch[i*batch_size:(i+1)*batch_size].cuda() y = y_batch[i*batch_size:(i+1)*batch_size].cuda() if args.method == 'natural': output = model(X) loss = F.cross_entropy(output, y) (loss*(X.shape[0]/X_batch.shape[0])).backward() else: raise NotImplementedError(args.method) meter.update('train_loss', loss.item(), y.size(0)) meter.update('train_acc', (output.max(1)[1] == y).float().mean(), y.size(0)) grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) opt.step() opt.zero_grad() if step % args.log_interval == 0: logger.info('Training step {} lr {:.4f} {}'.format( step, lr, meter)) train_loss = 0 train_acc = 0 train_n = 0 if step % args.save_interval == 0: path = os.path.join(args.out_dir, 'checkpoint_{}'.format(step)) torch.save({ 'state_dict': model.state_dict(), 'step': step }, path) logger.info('Checkpoint saved to {}'.format(path)) if step == total_steps: return total_steps = args.epochs * len(train_loader) lr_fn = hyper.create_learning_rate_schedule(total_steps, args.base_lr, 'cosine', args.warmup_steps) lr_iter = hyper.lr_prefetch_iter(lr_fn, 0, total_steps) step = 0 for epoch in range(1, args.epochs+1): logger.info('Epoch {}'.format(epoch)) for (X_batch, y_batch) in train_loader: step += 1 lr_repl = next(lr_iter) train_step(step, X_batch, y_batch, lr_repl) evaluate_natural(args, model, test_loader)
def train(epoch, batches, type): meter = MultiAverageMeter() assert (optimizer is not None) train = type == 'train' if args.robust: eps_scheduler.set_epoch_length(len(batches)) if train: eps_scheduler.train() eps_scheduler.step_epoch() else: eps_scheduler.eval() for i, batch in enumerate(batches): if args.robust: eps_scheduler.step_batch() eps = eps_scheduler.get_eps() else: eps = 0 acc, loss, acc_robust, loss_robust = \ step(model, ptb, batch, eps=eps, train=train) meter.update('acc', acc, len(batch)) meter.update('loss', loss, len(batch)) meter.update('acc_rob', acc_robust, len(batch)) meter.update('loss_rob', loss_robust, len(batch)) if train: if (i + 1) % args.gradient_accumulation_steps == 0 or ( i + 1) == len(batches): scale_gradients(optimizer, i % args.gradient_accumulation_steps + 1, args.grad_clip) optimizer.step() optimizer.zero_grad() if lr_scheduler is not None: lr_scheduler.step() writer.add_scalar('loss_train_{}'.format(epoch), meter.avg('loss'), i + 1) writer.add_scalar('loss_robust_train_{}'.format(epoch), meter.avg('loss_rob'), i + 1) writer.add_scalar('acc_train_{}'.format(epoch), meter.avg('acc'), i + 1) writer.add_scalar('acc_robust_train_{}'.format(epoch), meter.avg('acc_rob'), i + 1) if (i + 1) % args.log_interval == 0 or (i + 1) == len(batches): logger.info('Epoch {}, {} step {}/{}: eps {:.5f}, {}'.format( epoch, type, i + 1, len(batches), eps, meter)) if lr_scheduler is not None: logger.info('lr {}'.format(lr_scheduler.get_lr())) writer.add_scalar('loss/{}'.format(type), meter.avg('loss'), epoch) writer.add_scalar('loss_robust/{}'.format(type), meter.avg('loss_rob'), epoch) writer.add_scalar('acc/{}'.format(type), meter.avg('acc'), epoch) writer.add_scalar('acc_robust/{}'.format(type), meter.avg('acc_rob'), epoch) if train: if args.loss_fusion: state_dict_loss = model_loss.state_dict() state_dict = {} for name in state_dict_loss: assert (name.startswith('model.')) state_dict[name[6:]] = state_dict_loss[name] model_ori.load_state_dict(state_dict) model_bound = BoundedModule(model_ori, (dummy_embeddings, dummy_mask), bound_opts=bound_opts, device=args.device) model.model_from_embeddings = model_bound model.save(epoch) return meter.avg('acc_rob')
def Train(model, t, loader, start_eps, end_eps, max_eps, norm, train, opt, bound_type, method='robust'): num_class = 10 meter = MultiAverageMeter() if train: model.train() else: model.eval() # Pre-generate the array for specifications, will be used latter for scatter sa = np.zeros((num_class, num_class - 1), dtype=np.int32) for i in range(sa.shape[0]): for j in range(sa.shape[1]): if j < i: sa[i][j] = j else: sa[i][j] = j + 1 sa = torch.LongTensor(sa) total = len(loader.dataset) batch_size = loader.batch_size # Increase epsilon batch by batch batch_eps = np.linspace(start_eps, end_eps, (total // batch_size) + 1) # For small eps just use natural training, no need to compute LiRPA bounds if end_eps < 1e-6: method = "natural" for i, (data, labels) in enumerate(loader): start = time.time() eps = batch_eps[i] if train: opt.zero_grad() # generate specifications c = torch.eye(num_class).type_as(data)[labels].unsqueeze( 1) - torch.eye(num_class).type_as(data).unsqueeze(0) # remove specifications to self I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as( labels.data).unsqueeze(0))) c = (c[I].view(data.size(0), num_class - 1, num_class)) # scatter matrix to avoid compute margin to self sa_labels = sa[labels] # storing computed lower bounds after scatter lb_s = torch.zeros(data.size(0), num_class) # bound input for Linf norm used only if norm == np.inf: data_ub = (data + eps).clamp(max=1.0) data_lb = (data - eps).clamp(min=0.0) else: data_ub = data_lb = data if list(model.parameters())[0].is_cuda: data, labels, sa_labels, c, lb_s = data.cuda(), labels.cuda( ), sa_labels.cuda(), c.cuda(), lb_s.cuda() data_lb, data_ub = data_lb.cuda(), data_ub.cuda() output = model(data) regular_ce = CrossEntropyLoss()( output, labels) # regular CrossEntropyLoss used for warming up meter.update('CE', regular_ce.cpu().detach().numpy(), data.size(0)) meter.update( 'Err', torch.sum( torch.argmax(output, dim=1) != labels).cpu().detach().numpy() / data.size(0), data.size(0)) # Specify Lp norm perturbation. # When using Linf perturbation, we manually set element-wise bound x_L and x_U. eps is not used for Linf norm. ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=data_lb, x_U=data_ub) if method == "robust": if bound_type == "IBP": lb, ub = model.compute_bounds(ptb=ptb, IBP=True, x=data, C=c, method=None) elif bound_type == "CROWN": lb, ub = model.compute_bounds(ptb=ptb, IBP=False, x=data, C=c, method="backward") elif bound_type == "CROWN-IBP": # lb, ub = model.compute_bounds(ptb=ptb, IBP=True, x=data, C=c, method="backward") # pure IBP bound # we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020) factor = (max_eps - eps) / max_eps ilb, iub = model.compute_bounds(ptb=ptb, IBP=True, x=data, C=c, method=None) if factor < 1e-5: lb = ilb else: clb, cub = model.compute_bounds(ptb=ptb, IBP=False, x=data, C=c, method="backward") lb = clb * factor + ilb * (1 - factor) # Filling a missing 0 in lb. The margin from class j to itself is always 0 and not computed. lb = lb_s.scatter(1, sa_labels, lb) # Use the robust cross entropy loss objective (Wong & Kolter, 2018) robust_ce = CrossEntropyLoss()(-lb, labels) if method == "robust": loss = robust_ce elif method == "natural": loss = regular_ce if train: loss.backward() opt.step() meter.update('Loss', loss.cpu().detach().numpy(), data.size(0)) if method != "natural": meter.update('Robust_CE', robust_ce.cpu().detach().numpy(), data.size(0)) # For an example, if lower bounds of margins is >0 for all classes, the output is verifiably correct. # If any margin is < 0 this example is counted as an error meter.update( 'Verified_Err', torch.sum( (lb < 0).any(dim=1)).cpu().detach().numpy() / data.size(0), data.size(0)) meter.update('Time', time.time() - start) if i % 50 == 0 and train: print('[{:2d}:{:4d}]: eps={:4f} {}'.format(t, i, eps, meter)) print('[FINAL RESULT] epoch={:2d} eps={:.4f} {}'.format(t, eps, meter))