def learn_slopes(relu_params, bounds, args, n_layers, net, inputs, targets, abs_inputs, i, j): for param in relu_params: param.data = torch.ones(param.size()).to(param.device) relu_opt = optim.Adam(relu_params, lr=0.03, weight_decay=0) ret_verified = False for it in range(args.num_iters): relu_opt.zero_grad() if i is None and j is None: abs_loss, abs_ok = get_adv_loss(device, args.test_eps, args.layer_idx, net, bounds, inputs, targets, args, detach=False) else: abs_out = net(abs_inputs) abs_loss = -abs_out.get_min_diff(i, j) relu_opt.zero_grad() abs_loss.backward() relu_opt.step() for param in relu_params: if param.grad is not None: param.data = torch.clamp(param.data, 0, 1) if ret_verified: break with torch.no_grad(): abs_out = net(abs_inputs) if i is None and j is None: _, verified_corr = abs_out.verify(targets) if verified_corr: ret_verified = True else: abs_loss = -abs_out.get_min_diff(i, j) if abs_loss < 0: ret_verified = True relu_rnk, relu_priority = {}, {} for lidx, layer in enumerate(net.blocks): if isinstance(layer, ReLU): for param in layer.parameters(): relu_priority[lidx] = [] if param.grad is None: for i in range(param.size()[0]): relu_priority[lidx].append(0) _, sorted_ids = torch.sort(param.abs().view(-1), descending=True) else: g_abs = param.grad.abs().view(-1) for i in range(g_abs.size()[0]): relu_priority[lidx].append(int(g_abs[i].item()*1000)) _, sorted_ids = torch.sort(param.grad.abs().view(-1), descending=True) sorted_ids = sorted_ids.cpu().numpy() relu_rnk[lidx] = sorted_ids net.zero_grad() return relu_rnk, ret_verified, relu_priority
def main(): args = get_args() ver_logdir = args.load_model[:-3] + '_ver' if not os.path.exists(ver_logdir): os.makedirs(ver_logdir) num_train, _, test_loader, input_size, input_channel, n_class = get_loaders( args) net = get_network(device, args, input_size, input_channel, n_class) print(net) args.test_domains = [] # with torch.no_grad(): # test(device, 0, args, net, test_loader, layers=[-1, args.layer_idx]) args.test_batch = 1 num_train, _, test_loader, input_size, input_channel, n_class = get_loaders( args) latent_idx = args.layer_idx if args.latent_idx is None else args.latent_idx img_file = open(args.unverified_imgs_file, 'w') with torch.no_grad(): tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, tot_tests = 0, 0, 0, 0, 0 for test_idx, (inputs, targets) in enumerate(test_loader): if test_idx < args.start_idx or test_idx >= args.end_idx: continue tot_tests += 1 test_file = os.path.join(ver_logdir, '{}.p'.format(test_idx)) test_data = pickle.load(open(test_file, 'rb')) if ( not args.no_load) and os.path.isfile(test_file) else {} print('Verify test_idx =', test_idx) net.reset_bounds() inputs, targets = inputs.to(device), targets.to(device) abs_inputs = get_inputs(args.test_domain, inputs, args.test_eps, device, dtype=dtype) nat_out = net(inputs) nat_ok = targets.eq(nat_out.max(dim=1)[1]).item() tot_nat_ok += float(nat_ok) test_data['ok'] = nat_ok if not nat_ok: report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, test_idx, tot_tests, test_data) continue for _ in range(args.attack_restarts): with torch.enable_grad(): pgd_loss, pgd_ok = get_adv_loss(device, args.test_eps, -1, net, None, inputs, targets, args.test_att_n_steps, args.test_att_step_size) if not pgd_ok: break if pgd_ok: test_data['pgd_ok'] = 1 tot_pgd_ok += 1 else: test_data['pgd_ok'] = 0 report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, test_idx, tot_tests, test_data) continue if 'verified' in test_data and test_data['verified']: tot_verified_corr += 1 tot_attack_ok += 1 report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, test_idx, tot_tests, test_data) continue if args.no_milp: report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, test_idx, tot_tests, test_data) continue zono_inputs = get_inputs('zono_iter', inputs, args.test_eps, device, dtype=dtype) bounds = compute_bounds(net, device, len(net.blocks) - 1, args, zono_inputs) relu_params = reset_params(args, net, dtype) with torch.enable_grad(): learn_slopes(device, relu_params, bounds, args, len(net.blocks), net, inputs, targets, abs_inputs, None, None) bounds = compute_bounds(net, device, len(net.blocks) - 1, args, zono_inputs) for _ in range(args.attack_restarts): with torch.enable_grad(): latent_loss, latent_ok = get_adv_loss( device, args.test_eps, latent_idx, net, bounds, inputs, targets, args.test_att_n_steps, args.test_att_step_size) # print('-> ', latent_idx, latent_loss, latent_ok) if not latent_ok: break if latent_ok: tot_attack_ok += 1 zono_out = net(zono_inputs) verified, verified_corr = zono_out.verify(targets) test_data['verified'] = int(verified_corr.item()) if verified_corr: tot_verified_corr += 1 report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, test_idx, tot_tests, test_data) continue loss_after = net(abs_inputs).ce_loss(targets) if args.refine_lidx is not None: bounds = compute_bounds(net, device, len(net.blocks) - 1, args, abs_inputs) for lidx in range(0, args.layer_idx + 2): net.blocks[lidx].bounds = bounds[lidx] print('loss before refine: ', net(abs_inputs).ce_loss(targets)) refine_dim = bounds[args.refine_lidx + 1][0].shape[2] pbar = tqdm(total=refine_dim * refine_dim, dynamic_ncols=True) for refine_i in range(refine_dim): for refine_j in range(refine_dim): refine(args, bounds, net, refine_i, refine_j, abs_inputs, input_size) pbar.update(1) pbar.close() loss_after = net(abs_inputs).ce_loss(targets) print('loss after refine: ', loss_after) if loss_after < args.loss_threshold: if args.refine_opt is not None: with torch.enable_grad(): learn_bounds(net, bounds, relu_params, zono_inputs, args.refine_opt) if verify_test(args, net, inputs, targets, abs_inputs, bounds, test_data, test_idx): tot_verified_corr += 1 test_data['verified'] = True report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, test_idx, tot_tests, test_data) img_file.close()
def main(): parser = argparse.ArgumentParser(description='Perform greedy layerwise training.') parser.add_argument('--prune_p', default=None, type=float, help='percentage of weights to prune in each layer') parser.add_argument('--dataset', default='cifar10', help='dataset to use') parser.add_argument('--net', required=True, type=str, help='network to use') parser.add_argument('--load_model', type=str, help='model to load') parser.add_argument('--layer_idx', default=1, type=int, help='layer index of flattened vector') parser.add_argument('--n_valid', default=1000, type=int, help='number of test samples') parser.add_argument('--n_train', default=None, type=int, help='number of training samples to use') parser.add_argument('--train_batch', default=1, type=int, help='batch size for training') parser.add_argument('--test_batch', default=128, type=int, help='batch size for testing') parser.add_argument('--test_domain', default='zono', type=str, help='domain to test with') parser.add_argument('--test_eps', default=None, type=float, help='epsilon to verify') parser.add_argument('--debug', action='store_true', help='debug mode') parser.add_argument('--no_milp', action='store_true', help='no MILP mode') parser.add_argument('--no_load', action='store_true', help='verify from scratch') parser.add_argument('--no_smart', action='store_true', help='bla') parser.add_argument('--milp_timeout', default=10, type=int, help='timeout for MILP') parser.add_argument('--eval_train', action='store_true', help='evaluate on training set') parser.add_argument('--test_idx', default=None, type=int, help='specific index to test') parser.add_argument('--start_idx', default=0, type=int, help='specific index to start') parser.add_argument('--end_idx', default=1000, type=int, help='specific index to end') parser.add_argument('--max_binary', default=None, type=int, help='number of neurons to encode as binary variable in MILP (per layer)') parser.add_argument('--num_iters', default=50, type=int, help='number of iterations to find slopes') parser.add_argument('--max_refine_triples', default=0, type=int, help='number of triples to refine') parser.add_argument('--refine_lidx', default=None, type=int, help='layer to refine') parser.add_argument('--save_models', action='store_true', help='whether to only store models') parser.add_argument('--refine_milp', default=0, type=int, help='number of neurons to refine using MILP') parser.add_argument('--obj_threshold', default=None, type=float, help='threshold to consider for MILP verification') parser.add_argument('--attack_type', default='pgd', type=str, help='attack') parser.add_argument('--attack_n_steps', default=10, type=int, help='number of steps for the attack') parser.add_argument('--attack_step_size', default=0.25, type=float, help='step size for the attack (relative to epsilon)') parser.add_argument('--layers', required=False, default=None, type=int, nargs='+', help='layer indices for training') args = parser.parse_args() ver_logdir = args.load_model[:-3] + '_ver' if not os.path.exists(ver_logdir): os.makedirs(ver_logdir) grb_modelsdir = args.load_model[:-3] + '_grb' if not os.path.exists(grb_modelsdir): os.makedirs(grb_modelsdir) num_train, _, test_loader, input_size, input_channel = get_loaders(args) net = get_network(device, args, input_size, input_channel) n_layers = len(net.blocks) # net.to_double() args.test_domains = ['box'] with torch.no_grad(): test(device, 0, args, net, test_loader) args.test_batch = 1 num_train, _, test_loader, input_size, input_channel = get_loaders(args) num_relu = 0 for lidx in range(args.layer_idx+1, n_layers): print(net.blocks[lidx]) if isinstance(net.blocks[lidx], ReLU): num_relu += 1 with torch.no_grad(): tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, tot_tests = 0, 0, 0, 0, 0 for test_idx, (inputs, targets) in enumerate(test_loader): if test_idx < args.start_idx or test_idx >= args.end_idx or test_idx >= args.n_valid: continue if args.test_idx is not None and test_idx != args.test_idx: continue tot_tests += 1 test_file = os.path.join(ver_logdir, '{}.p'.format(test_idx)) test_data = pickle.load(open(test_file, 'rb')) if (not args.no_load) and os.path.isfile(test_file) else {} print('Verify test_idx =', test_idx) for lidx in range(n_layers): net.blocks[lidx].bounds = None inputs, targets = inputs.to(device), targets.to(device) abs_inputs = get_inputs(args.test_domain, inputs, args.test_eps, device, dtype=dtype) nat_out = net(inputs) nat_ok = targets.eq(nat_out.max(dim=1)[1]).item() tot_nat_ok += float(nat_ok) test_data['ok'] = nat_ok if not nat_ok: report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, test_idx, tot_tests, test_data) continue with torch.enable_grad(): pgd_loss, pgd_ok = get_adv_loss(device, args.test_eps, -1, net, None, inputs, targets, args) if pgd_ok: test_data['pgd_ok'] = 1 tot_pgd_ok += 1 else: test_data['pgd_ok'] = 0 report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, test_idx, tot_tests, test_data) continue if 'verified' in test_data and test_data['verified']: tot_verified_corr += 1 tot_attack_ok += 1 report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, test_idx, tot_tests, test_data) continue relu_params = reset_params(args, net, dtype) bounds = compute_bounds(net, device, args.layer_idx, args, abs_inputs) if args.test_domain == 'zono_iter': with torch.enable_grad(): learn_slopes(relu_params, bounds, args, n_layers, net, inputs, targets, abs_inputs, None, None) with torch.enable_grad(): abs_loss, abs_ok = get_adv_loss(device, args.test_eps, args.layer_idx, net, bounds, inputs, targets, args) refined_triples = [] if args.refine_lidx is not None: bounds = compute_bounds(net, device, args.layer_idx+1, args, abs_inputs) for lidx in range(0, args.layer_idx+2): net.blocks[lidx].bounds = bounds[lidx] print('loss before refine: ', abs_loss) refine_dim = bounds[args.refine_lidx+1][0].shape[2] pbar = tqdm(total=refine_dim*refine_dim, dynamic_ncols=True) for refine_i in range(refine_dim): for refine_j in range(refine_dim): # refine(args, bounds, net, 0, 15, abs_inputs, input_size) refine(args, bounds, net, refine_i, refine_j, abs_inputs, input_size) pbar.update(1) pbar.close() with torch.enable_grad(): abs_loss, abs_ok = get_adv_loss(device, args.test_eps, args.layer_idx, net, bounds, inputs, targets, args) print('loss after refine: ', abs_loss) if abs_ok: tot_attack_ok += 1 abs_out = net(abs_inputs) verified, verified_corr = abs_out.verify(targets) test_data['verified'] = int(verified_corr.item()) print('abs_loss: ', abs_loss.item(), '\tabs_ok: ', abs_ok.item(), '\tverified_corr: ', verified_corr.item()) if verified_corr: tot_verified_corr += 1 report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, test_idx, tot_tests, test_data) continue if args.no_milp or (not abs_ok): report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, test_idx, tot_tests, test_data) continue if verify_test(args, net, num_relu, inputs, targets, abs_inputs, bounds, refined_triples, test_data, grb_modelsdir, test_idx): tot_verified_corr += 1 test_data['verified'] = True report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, test_idx, tot_tests, test_data)