def Train(model, t, loader, eps_scheduler, norm, train, opt, bound_type, method='robust'): num_class = 10 meter = MultiAverageMeter() if train: model.train() eps_scheduler.train() eps_scheduler.step_epoch() eps_scheduler.set_epoch_length( int((len(loader.dataset) + loader.batch_size - 1) / loader.batch_size)) else: model.eval() eps_scheduler.eval() for i, (data, labels) in enumerate(loader): start = time.time() eps_scheduler.step_batch() eps = eps_scheduler.get_eps() # For small eps just use natural training, no need to compute LiRPA bounds batch_method = method if eps < 1e-20: batch_method = "natural" if train: opt.zero_grad() # generate specifications c = torch.eye(num_class).type_as(data)[labels].unsqueeze( 1) - torch.eye(num_class).type_as(data).unsqueeze(0) # remove specifications to self I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as( labels.data).unsqueeze(0))) c = (c[I].view(data.size(0), num_class - 1, num_class)) # bound input for Linf norm used only if norm == np.inf: data_max = torch.reshape((1. - loader.mean) / loader.std, (1, -1, 1, 1)) data_min = torch.reshape((0. - loader.mean) / loader.std, (1, -1, 1, 1)) data_ub = torch.min(data + (eps / loader.std).view(1, -1, 1, 1), data_max) data_lb = torch.max(data - (eps / loader.std).view(1, -1, 1, 1), data_min) else: data_ub = data_lb = data if list(model.parameters())[0].is_cuda: data, labels, c = data.cuda(), labels.cuda(), c.cuda() data_lb, data_ub = data_lb.cuda(), data_ub.cuda() # Specify Lp norm perturbation. # When using Linf perturbation, we manually set element-wise bound x_L and x_U. eps is not used for Linf norm. if norm > 0: ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=data_lb, x_U=data_ub) elif norm == 0: ptb = PerturbationL0Norm(eps=eps_scheduler.get_max_eps(), ratio=eps_scheduler.get_eps() / eps_scheduler.get_max_eps()) x = BoundedTensor(data, ptb) output = model(x) regular_ce = CrossEntropyLoss()( output, labels) # regular CrossEntropyLoss used for warming up meter.update('CE', regular_ce.item(), x.size(0)) meter.update( 'Err', torch.sum( torch.argmax(output, dim=1) != labels).cpu().detach().numpy() / x.size(0), x.size(0)) if batch_method == "robust": if bound_type == "IBP": lb, ub = model.compute_bounds(IBP=True, C=c, method=None) elif bound_type == "CROWN": lb, ub = model.compute_bounds(IBP=False, C=c, method="backward", bound_upper=False) elif bound_type == "CROWN-IBP": # lb, ub = model.compute_bounds(ptb=ptb, IBP=True, x=data, C=c, method="backward") # pure IBP bound # we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020) factor = (eps_scheduler.get_max_eps() - eps) / eps_scheduler.get_max_eps() ilb, iub = model.compute_bounds(IBP=True, C=c, method=None) if factor < 1e-5: lb = ilb else: clb, cub = model.compute_bounds(IBP=False, C=c, method="backward", bound_upper=False) lb = clb * factor + ilb * (1 - factor) elif bound_type == "CROWN-FAST": # model.compute_bounds(IBP=True, C=c, method=None) lb, ub = model.compute_bounds(IBP=True, C=c, method=None) lb, ub = model.compute_bounds(IBP=False, C=c, method="backward", bound_upper=False) # Pad zero at the beginning for each example, and use fake label "0" for all examples lb_padded = torch.cat((torch.zeros( size=(lb.size(0), 1), dtype=lb.dtype, device=lb.device), lb), dim=1) fake_labels = torch.zeros(size=(lb.size(0), ), dtype=torch.int64, device=lb.device) robust_ce = CrossEntropyLoss()(-lb_padded, fake_labels) if batch_method == "robust": loss = robust_ce elif batch_method == "natural": loss = regular_ce if train: loss.backward() eps_scheduler.update_loss(loss.item() - regular_ce.item()) opt.step() meter.update('Loss', loss.item(), data.size(0)) if batch_method != "natural": meter.update('Robust_CE', robust_ce.item(), data.size(0)) # For an example, if lower bounds of margins is >0 for all classes, the output is verifiably correct. # If any margin is < 0 this example is counted as an error meter.update('Verified_Err', torch.sum((lb < 0).any(dim=1)).item() / data.size(0), data.size(0)) meter.update('Time', time.time() - start) if i % 50 == 0 and train: print('[{:2d}:{:4d}]: eps={:.8f} {}'.format(t, i, eps, meter)) print('[{:2d}:{:4d}]: eps={:.8f} {}'.format(t, i, eps, meter))
def step(model, ptb, batch, eps=1.0, train=False): model_bound = model.model_from_embeddings if train: model.train() model_bound.train() grad = torch.enable_grad() if args.loss_fusion: model_loss.train() else: model.eval() model_bound.eval() grad = torch.no_grad() if args.auto_test: grad = torch.enable_grad() with grad: ptb.set_eps(eps) ptb.set_train(train) embeddings_unbounded, mask, tokens, labels = model.get_input(batch) aux = (tokens, batch) if args.robust and eps > 1e-9: embeddings = BoundedTensor(embeddings_unbounded, ptb) else: embeddings = embeddings_unbounded.detach().requires_grad_(True) robust = args.robust and eps > 1e-6 if train and robust and args.loss_fusion: # loss_fusion loss if args.method == 'IBP+backward_train': lb, ub = model_loss.compute_bounds(x=(labels, embeddings, mask), aux=aux, C=None, method='IBP+backward', bound_lower=False) else: raise NotImplementedError loss_robust = torch.log(ub).mean() loss = acc = acc_robust = -1 # unknown else: # regular loss logits = model_bound(embeddings, mask) loss = CrossEntropyLoss()(logits, labels) acc = (torch.argmax(logits, dim=1) == labels).float().mean() if robust: num_class = args.num_classes c = torch.eye(num_class).type_as(embeddings)[labels].unsqueeze(1) - \ torch.eye(num_class).type_as(embeddings).unsqueeze(0) I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as( labels.data).unsqueeze(0))) c = (c[I].view(embeddings.size(0), num_class - 1, num_class)) if args.method in [ 'IBP', 'IBP+backward', 'forward', 'forward+backward' ]: lb, ub = model_bound.compute_bounds(aux=aux, C=c, method=args.method, bound_upper=False) elif args.method == 'IBP+backward_train': # CROWN-IBP if 1 - eps > 1e-4: lb, ub = model_bound.compute_bounds( aux=aux, C=c, method='IBP+backward', bound_upper=False) ilb, iub = model_bound.compute_bounds(aux=aux, C=c, method='IBP', reuse_ibp=True) lb = eps * ilb + (1 - eps) * lb else: lb, ub = model_bound.compute_bounds(aux=aux, C=c, method='IBP') else: raise NotImplementedError lb_padded = torch.cat((torch.zeros(size=(lb.size(0), 1), dtype=lb.dtype, device=lb.device), lb), dim=1) fake_labels = torch.zeros(size=(lb.size(0), ), dtype=torch.int64, device=lb.device) loss_robust = robust_ce = CrossEntropyLoss()(-lb_padded, fake_labels) acc_robust = 1 - torch.mean((lb < 0).any(dim=1).float()) else: acc_robust, loss_robust = acc, loss if train or args.auto_test: loss_robust.backward() grad_embed = torch.autograd.grad(embeddings_unbounded, model.word_embeddings.weight, grad_outputs=embeddings.grad)[0] if model.word_embeddings.weight.grad is None: model.word_embeddings.weight.grad = grad_embed else: model.word_embeddings.weight.grad += grad_embed if args.auto_test: with open('res_test.pkl', 'wb') as file: pickle.dump((float(acc), float(loss), float(acc_robust), float(loss_robust), grad_embed.detach().numpy()), file) return acc, loss, acc_robust, loss_robust
def Train(model, t, loader, eps_scheduler, norm, train, opt, bound_type, method='robust', loss_fusion=True, final_node_name=None): num_class = 200 meter = MultiAverageMeter() if train: model.train() eps_scheduler.train() eps_scheduler.step_epoch() eps_scheduler.set_epoch_length( int((len(loader.dataset) + loader.batch_size - 1) / loader.batch_size)) else: model.eval() eps_scheduler.eval() exp_module = get_exp_module(model) def get_bound_loss(x=None, c=None): if loss_fusion: bound_lower, bound_upper = False, True else: bound_lower, bound_upper = True, False if bound_type == 'IBP': lb, ub = model(method_opt="compute_bounds", x=x, IBP=True, C=c, method=None, final_node_name=final_node_name, no_replicas=True) elif bound_type == 'CROWN': lb, ub = model(method_opt="compute_bounds", x=x, IBP=False, C=c, method='backward', bound_lower=bound_lower, bound_upper=bound_upper) elif bound_type == 'CROWN-IBP': # lb, ub = model.compute_bounds(ptb=ptb, IBP=True, x=data, C=c, method='backward') # pure IBP bound # we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020) factor = (eps_scheduler.get_max_eps() - eps_scheduler.get_eps()) / eps_scheduler.get_max_eps() ilb, iub = model(method_opt="compute_bounds", x=x, IBP=True, C=c, method=None, final_node_name=final_node_name, no_replicas=True) if factor < 1e-50: lb, ub = ilb, iub else: clb, cub = model(method_opt="compute_bounds", IBP=False, C=c, method='backward', bound_lower=bound_lower, bound_upper=bound_upper, final_node_name=final_node_name, no_replicas=True) if loss_fusion: ub = cub * factor + iub * (1 - factor) else: lb = clb * factor + ilb * (1 - factor) if loss_fusion: if isinstance(model, BoundDataParallel): max_input = model(get_property=True, node_class=BoundExp, att_name='max_input') else: max_input = exp_module.max_input return None, torch.mean(torch.log(ub) + max_input) else: # Pad zero at the beginning for each example, and use fake label '0' for all examples lb_padded = torch.cat((torch.zeros( size=(lb.size(0), 1), dtype=lb.dtype, device=lb.device), lb), dim=1) fake_labels = torch.zeros(size=(lb.size(0), ), dtype=torch.int64, device=lb.device) robust_ce = CrossEntropyLoss()(-lb_padded, fake_labels) return lb, robust_ce for i, (data, labels) in enumerate(loader): start = time.time() eps_scheduler.step_batch() eps = eps_scheduler.get_eps() # For small eps just use natural training, no need to compute LiRPA bounds batch_method = method if eps < 1e-50: batch_method = "natural" if train: opt.zero_grad() # bound input for Linf norm used only if norm == np.inf: data_max = torch.reshape((1. - loader.mean) / loader.std, (1, -1, 1, 1)) data_min = torch.reshape((0. - loader.mean) / loader.std, (1, -1, 1, 1)) data_ub = torch.min(data + (eps / loader.std).view(1, -1, 1, 1), data_max) data_lb = torch.max(data - (eps / loader.std).view(1, -1, 1, 1), data_min) else: data_ub = data_lb = data if list(model.parameters())[0].is_cuda: data, labels = data.cuda(), labels.cuda() data_lb, data_ub = data_lb.cuda(), data_ub.cuda() ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=data_lb, x_U=data_ub) x = BoundedTensor(data, ptb) if loss_fusion: if batch_method == 'natural' or not train: output = model(x, labels) regular_ce = torch.mean(torch.log(output)) else: model(x, labels) regular_ce = torch.tensor(0., device=data.device) meter.update('CE', regular_ce.item(), x.size(0)) x = (x, labels) c = None else: c = torch.eye(num_class).type_as(data)[labels].unsqueeze( 1) - torch.eye(num_class).type_as(data).unsqueeze(0) # remove specifications to self I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as( labels.data).unsqueeze(0))) c = (c[I].view(data.size(0), num_class - 1, num_class)) output = model(x, final_node_name=final_node_name) regular_ce = CrossEntropyLoss()( output, labels) # regular CrossEntropyLoss used for warming up meter.update('CE', regular_ce.item(), x.size(0)) meter.update( 'Err', torch.sum(torch.argmax(output, dim=1) != labels).item() / x.size(0), x.size(0)) x = (x, ) if batch_method == 'robust': # print(data.sum()) lb, robust_ce = get_bound_loss(x=x, c=c) loss = robust_ce elif batch_method == 'natural': loss = regular_ce if train: loss.backward() if args.clip_grad_norm: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=args.clip_grad_norm) meter.update('grad_norm', grad_norm) if isinstance(eps_scheduler, AdaptiveScheduler): eps_scheduler.update_loss(loss.item() - regular_ce.item()) opt.step() meter.update('Loss', loss.item(), data.size(0)) if batch_method != 'natural': meter.update('Robust_CE', robust_ce.item(), data.size(0)) if not loss_fusion: # For an example, if lower bounds of margins is >0 for all classes, the output is verifiably correct. # If any margin is < 0 this example is counted as an error meter.update( 'Verified_Err', torch.sum((lb < 0).any(dim=1)).item() / data.size(0), data.size(0)) meter.update('Time', time.time() - start) if (i + 1) % 250 == 0 and train: logger.log('[{:2d}:{:4d}]: eps={:.12f} {}'.format( t, i + 1, eps, meter)) logger.log('[{:2d}:{:4d}]: eps={:.12f} {}'.format(t, i + 1, eps, meter)) return meter