コード例 #1
0
def learn_slopes(relu_params, bounds, args, n_layers, net, inputs, targets, abs_inputs, i, j):
    for param in relu_params:
        param.data = torch.ones(param.size()).to(param.device)
    relu_opt = optim.Adam(relu_params, lr=0.03, weight_decay=0)
    ret_verified = False

    for it in range(args.num_iters):
        relu_opt.zero_grad()
        if i is None and j is None:
            abs_loss, abs_ok = get_adv_loss(device, args.test_eps, args.layer_idx, net, bounds, inputs, targets, args, detach=False)
        else:
            abs_out = net(abs_inputs)
            abs_loss = -abs_out.get_min_diff(i, j)
        relu_opt.zero_grad()
        abs_loss.backward()
        relu_opt.step()
        for param in relu_params:
            if param.grad is not None:
                param.data = torch.clamp(param.data, 0, 1)
        if ret_verified:
            break

    with torch.no_grad():
        abs_out = net(abs_inputs)
        if i is None and j is None:
            _, verified_corr = abs_out.verify(targets)
            if verified_corr:
                ret_verified = True
        else:
            abs_loss = -abs_out.get_min_diff(i, j)
            if abs_loss < 0:
                ret_verified = True

    relu_rnk, relu_priority = {}, {}
    for lidx, layer in enumerate(net.blocks):
        if isinstance(layer, ReLU):
            for param in layer.parameters():
                relu_priority[lidx] = []

                if param.grad is None:
                    for i in range(param.size()[0]):
                        relu_priority[lidx].append(0)
                    _, sorted_ids = torch.sort(param.abs().view(-1), descending=True)
                else:
                    g_abs = param.grad.abs().view(-1)
                    for i in range(g_abs.size()[0]):
                        relu_priority[lidx].append(int(g_abs[i].item()*1000))
                    _, sorted_ids = torch.sort(param.grad.abs().view(-1), descending=True)
                sorted_ids = sorted_ids.cpu().numpy()
                relu_rnk[lidx] = sorted_ids
    net.zero_grad()
    return relu_rnk, ret_verified, relu_priority
コード例 #2
0
ファイル: verify.py プロジェクト: sungyoon-lee/colt
def main():
    args = get_args()

    ver_logdir = args.load_model[:-3] + '_ver'
    if not os.path.exists(ver_logdir):
        os.makedirs(ver_logdir)

    num_train, _, test_loader, input_size, input_channel, n_class = get_loaders(
        args)
    net = get_network(device, args, input_size, input_channel, n_class)
    print(net)

    args.test_domains = []
    # with torch.no_grad():
    #     test(device, 0, args, net, test_loader, layers=[-1, args.layer_idx])
    args.test_batch = 1
    num_train, _, test_loader, input_size, input_channel, n_class = get_loaders(
        args)
    latent_idx = args.layer_idx if args.latent_idx is None else args.latent_idx
    img_file = open(args.unverified_imgs_file, 'w')

    with torch.no_grad():
        tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, tot_tests = 0, 0, 0, 0, 0
        for test_idx, (inputs, targets) in enumerate(test_loader):
            if test_idx < args.start_idx or test_idx >= args.end_idx:
                continue
            tot_tests += 1
            test_file = os.path.join(ver_logdir, '{}.p'.format(test_idx))
            test_data = pickle.load(open(test_file, 'rb')) if (
                not args.no_load) and os.path.isfile(test_file) else {}
            print('Verify test_idx =', test_idx)

            net.reset_bounds()

            inputs, targets = inputs.to(device), targets.to(device)
            abs_inputs = get_inputs(args.test_domain,
                                    inputs,
                                    args.test_eps,
                                    device,
                                    dtype=dtype)
            nat_out = net(inputs)
            nat_ok = targets.eq(nat_out.max(dim=1)[1]).item()
            tot_nat_ok += float(nat_ok)
            test_data['ok'] = nat_ok
            if not nat_ok:
                report(ver_logdir, tot_verified_corr, tot_nat_ok,
                       tot_attack_ok, tot_pgd_ok, test_idx, tot_tests,
                       test_data)
                continue

            for _ in range(args.attack_restarts):
                with torch.enable_grad():
                    pgd_loss, pgd_ok = get_adv_loss(device, args.test_eps, -1,
                                                    net, None, inputs, targets,
                                                    args.test_att_n_steps,
                                                    args.test_att_step_size)
                    if not pgd_ok:
                        break

            if pgd_ok:
                test_data['pgd_ok'] = 1
                tot_pgd_ok += 1
            else:
                test_data['pgd_ok'] = 0
                report(ver_logdir, tot_verified_corr, tot_nat_ok,
                       tot_attack_ok, tot_pgd_ok, test_idx, tot_tests,
                       test_data)
                continue

            if 'verified' in test_data and test_data['verified']:
                tot_verified_corr += 1
                tot_attack_ok += 1
                report(ver_logdir, tot_verified_corr, tot_nat_ok,
                       tot_attack_ok, tot_pgd_ok, test_idx, tot_tests,
                       test_data)
                continue
            if args.no_milp:
                report(ver_logdir, tot_verified_corr, tot_nat_ok,
                       tot_attack_ok, tot_pgd_ok, test_idx, tot_tests,
                       test_data)
                continue

            zono_inputs = get_inputs('zono_iter',
                                     inputs,
                                     args.test_eps,
                                     device,
                                     dtype=dtype)
            bounds = compute_bounds(net, device,
                                    len(net.blocks) - 1, args, zono_inputs)
            relu_params = reset_params(args, net, dtype)
            with torch.enable_grad():
                learn_slopes(device, relu_params, bounds, args,
                             len(net.blocks), net, inputs, targets, abs_inputs,
                             None, None)
            bounds = compute_bounds(net, device,
                                    len(net.blocks) - 1, args, zono_inputs)

            for _ in range(args.attack_restarts):
                with torch.enable_grad():
                    latent_loss, latent_ok = get_adv_loss(
                        device, args.test_eps, latent_idx, net, bounds, inputs,
                        targets, args.test_att_n_steps,
                        args.test_att_step_size)
                    # print('-> ', latent_idx, latent_loss, latent_ok)
                    if not latent_ok:
                        break

            if latent_ok:
                tot_attack_ok += 1
            zono_out = net(zono_inputs)
            verified, verified_corr = zono_out.verify(targets)
            test_data['verified'] = int(verified_corr.item())
            if verified_corr:
                tot_verified_corr += 1
                report(ver_logdir, tot_verified_corr, tot_nat_ok,
                       tot_attack_ok, tot_pgd_ok, test_idx, tot_tests,
                       test_data)
                continue

            loss_after = net(abs_inputs).ce_loss(targets)
            if args.refine_lidx is not None:
                bounds = compute_bounds(net, device,
                                        len(net.blocks) - 1, args, abs_inputs)
                for lidx in range(0, args.layer_idx + 2):
                    net.blocks[lidx].bounds = bounds[lidx]

                print('loss before refine: ', net(abs_inputs).ce_loss(targets))
                refine_dim = bounds[args.refine_lidx + 1][0].shape[2]
                pbar = tqdm(total=refine_dim * refine_dim, dynamic_ncols=True)
                for refine_i in range(refine_dim):
                    for refine_j in range(refine_dim):
                        refine(args, bounds, net, refine_i, refine_j,
                               abs_inputs, input_size)
                        pbar.update(1)
                pbar.close()
                loss_after = net(abs_inputs).ce_loss(targets)
                print('loss after refine: ', loss_after)

            if loss_after < args.loss_threshold:
                if args.refine_opt is not None:
                    with torch.enable_grad():
                        learn_bounds(net, bounds, relu_params, zono_inputs,
                                     args.refine_opt)
                if verify_test(args, net, inputs, targets, abs_inputs, bounds,
                               test_data, test_idx):
                    tot_verified_corr += 1
                    test_data['verified'] = True
            report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok,
                   tot_pgd_ok, test_idx, tot_tests, test_data)
    img_file.close()
コード例 #3
0
def main():
    parser = argparse.ArgumentParser(description='Perform greedy layerwise training.')
    parser.add_argument('--prune_p', default=None, type=float, help='percentage of weights to prune in each layer')
    parser.add_argument('--dataset', default='cifar10', help='dataset to use')
    parser.add_argument('--net', required=True, type=str, help='network to use')
    parser.add_argument('--load_model', type=str, help='model to load')
    parser.add_argument('--layer_idx', default=1, type=int, help='layer index of flattened vector')
    parser.add_argument('--n_valid', default=1000, type=int, help='number of test samples')
    parser.add_argument('--n_train', default=None, type=int, help='number of training samples to use')
    parser.add_argument('--train_batch', default=1, type=int, help='batch size for training')
    parser.add_argument('--test_batch', default=128, type=int, help='batch size for testing')
    parser.add_argument('--test_domain', default='zono', type=str, help='domain to test with')
    parser.add_argument('--test_eps', default=None, type=float, help='epsilon to verify')
    parser.add_argument('--debug', action='store_true', help='debug mode')
    parser.add_argument('--no_milp', action='store_true', help='no MILP mode')
    parser.add_argument('--no_load', action='store_true', help='verify from scratch')
    parser.add_argument('--no_smart', action='store_true', help='bla')
    parser.add_argument('--milp_timeout', default=10, type=int, help='timeout for MILP')
    parser.add_argument('--eval_train', action='store_true', help='evaluate on training set')
    parser.add_argument('--test_idx', default=None, type=int, help='specific index to test')
    parser.add_argument('--start_idx', default=0, type=int, help='specific index to start')
    parser.add_argument('--end_idx', default=1000, type=int, help='specific index to end')
    parser.add_argument('--max_binary', default=None, type=int, help='number of neurons to encode as binary variable in MILP (per layer)')
    parser.add_argument('--num_iters', default=50, type=int, help='number of iterations to find slopes')
    parser.add_argument('--max_refine_triples', default=0, type=int, help='number of triples to refine')
    parser.add_argument('--refine_lidx', default=None, type=int, help='layer to refine')
    parser.add_argument('--save_models', action='store_true', help='whether to only store models')
    parser.add_argument('--refine_milp', default=0, type=int, help='number of neurons to refine using MILP')
    parser.add_argument('--obj_threshold', default=None, type=float, help='threshold to consider for MILP verification')
    parser.add_argument('--attack_type', default='pgd', type=str, help='attack')
    parser.add_argument('--attack_n_steps', default=10, type=int, help='number of steps for the attack')
    parser.add_argument('--attack_step_size', default=0.25, type=float, help='step size for the attack (relative to epsilon)')
    parser.add_argument('--layers', required=False, default=None, type=int, nargs='+', help='layer indices for training')
    args = parser.parse_args()

    ver_logdir = args.load_model[:-3] + '_ver'
    if not os.path.exists(ver_logdir):
        os.makedirs(ver_logdir)
    grb_modelsdir = args.load_model[:-3] + '_grb'
    if not os.path.exists(grb_modelsdir):
        os.makedirs(grb_modelsdir)

    num_train, _, test_loader, input_size, input_channel = get_loaders(args)
    net = get_network(device, args, input_size, input_channel)
    n_layers = len(net.blocks)
    
    # net.to_double()

    args.test_domains = ['box']
    with torch.no_grad():
        test(device, 0, args, net, test_loader)

    args.test_batch = 1
    num_train, _, test_loader, input_size, input_channel = get_loaders(args)

    num_relu = 0
    for lidx in range(args.layer_idx+1, n_layers):
        print(net.blocks[lidx])
        if isinstance(net.blocks[lidx], ReLU):
            num_relu += 1

    with torch.no_grad():
        tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, tot_tests = 0, 0, 0, 0, 0
        for test_idx, (inputs, targets) in enumerate(test_loader):
            if test_idx < args.start_idx or test_idx >= args.end_idx or test_idx >= args.n_valid:
                continue
            if args.test_idx is not None and test_idx != args.test_idx:
                continue
            tot_tests += 1
            test_file = os.path.join(ver_logdir, '{}.p'.format(test_idx))
            test_data = pickle.load(open(test_file, 'rb')) if (not args.no_load) and os.path.isfile(test_file) else {}
            print('Verify test_idx =', test_idx)

            for lidx in range(n_layers):
                net.blocks[lidx].bounds = None

            inputs, targets = inputs.to(device), targets.to(device)
            abs_inputs = get_inputs(args.test_domain, inputs, args.test_eps, device, dtype=dtype)
            nat_out = net(inputs)
            nat_ok = targets.eq(nat_out.max(dim=1)[1]).item()
            tot_nat_ok += float(nat_ok)
            test_data['ok'] = nat_ok
            if not nat_ok:
                report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, test_idx, tot_tests, test_data)
                continue

            with torch.enable_grad():
                pgd_loss, pgd_ok = get_adv_loss(device, args.test_eps, -1, net, None, inputs, targets, args)
            if pgd_ok:
                test_data['pgd_ok'] = 1
                tot_pgd_ok += 1
            else:
                test_data['pgd_ok'] = 0
                report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, test_idx, tot_tests, test_data)
                continue
            if 'verified' in test_data and test_data['verified']:
                tot_verified_corr += 1
                tot_attack_ok += 1
                report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, test_idx, tot_tests, test_data)
                continue

            relu_params = reset_params(args, net, dtype)

            bounds = compute_bounds(net, device, args.layer_idx, args, abs_inputs)
            if args.test_domain == 'zono_iter':
                with torch.enable_grad():
                    learn_slopes(relu_params, bounds, args, n_layers, net, inputs, targets, abs_inputs, None, None)

            with torch.enable_grad():
                abs_loss, abs_ok = get_adv_loss(device, args.test_eps, args.layer_idx, net, bounds, inputs, targets, args)

            refined_triples = []
            if args.refine_lidx is not None:
                bounds = compute_bounds(net, device, args.layer_idx+1, args, abs_inputs)
                for lidx in range(0, args.layer_idx+2):
                    net.blocks[lidx].bounds = bounds[lidx]
                print('loss before refine: ', abs_loss)
                refine_dim = bounds[args.refine_lidx+1][0].shape[2]
                pbar = tqdm(total=refine_dim*refine_dim, dynamic_ncols=True)
                for refine_i in range(refine_dim):
                    for refine_j in range(refine_dim):
                        # refine(args, bounds, net, 0, 15, abs_inputs, input_size)
                        refine(args, bounds, net, refine_i, refine_j, abs_inputs, input_size)
                        pbar.update(1)
                pbar.close()
                with torch.enable_grad():
                    abs_loss, abs_ok = get_adv_loss(device, args.test_eps, args.layer_idx, net, bounds, inputs, targets, args)
                print('loss after refine: ', abs_loss)

            if abs_ok:
                tot_attack_ok += 1
            abs_out = net(abs_inputs)
            verified, verified_corr = abs_out.verify(targets)
            test_data['verified'] = int(verified_corr.item())
            print('abs_loss: ', abs_loss.item(), '\tabs_ok: ', abs_ok.item(), '\tverified_corr: ', verified_corr.item())
            if verified_corr:
                tot_verified_corr += 1
                report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, test_idx, tot_tests, test_data)
                continue
            if args.no_milp or (not abs_ok):
                report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, test_idx, tot_tests, test_data)
                continue

            if verify_test(args, net, num_relu, inputs, targets, abs_inputs, bounds, refined_triples, test_data, grb_modelsdir, test_idx):
                tot_verified_corr += 1
                test_data['verified'] = True
            report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, test_idx, tot_tests, test_data)