Exemple #1
0
def train(**kwargs):
    args = DefaultConfig()
    args.parse(kwargs)
    vocab = Vocab()
    loss_functions = transformer_celoss
    score_functions = rouge_func
    model = getattr(Models, args.model_name)(vocab, args)
    train_loader = get_loaders('train', args.batch_size, 12)
    dev_loader = get_loaders('val', args.batch_size, 12)
    trainer = ScheduledTrainerTrans(args, model, loss_functions, score_functions, train_loader, dev_loader)
    if args.resume is not None:
        trainer.init_trainner(resume_from=args.resume)
    else:
        trainer.init_trainner()
    trainer.train()
Exemple #2
0
def train_re(**kwargs):
    args = DefaultConfig()
    args.parse(kwargs)
    vocab = Vocab()
    loss_functions = transformer_celoss
    score_functions = rouge_func
    model = getattr(Models, args.model_name)(vocab, args)
    train_loader = get_loaders('train', args.batch_size, 12)
    dev_loader = get_loaders('val', args.batch_size, 12)
    trainer = ScheduledTrainerTrans(args, model, loss_functions, score_functions, train_loader, dev_loader)
    trainer.init_trainner(resume_from=args.resume)
    # try:
    #     trainer.model.vgg_feature.requires_grad = True
    #     trainer.model.vgg_input.requires_grad = True
    #
    # except:
    #     trainer.model.module.vgg_feature.requires_grad = True
    #     trainer.model.module.vgg_input.requires_grad = True
    # trainer.optim.param_groups[0]['lr'] = 3e-5
    trainer.train()
Exemple #3
0
def run(args=None):
    device = 'cuda' if torch.cuda.is_available() and (not args.no_cuda) else 'cpu'
    num_train, train_loader, test_loader, input_size, input_channel, n_class = get_loaders(args)

    lossFn = nn.CrossEntropyLoss(reduction='none')
    def evalFn(x): return torch.max(x, dim=1)[1]

    ## initialize SpecNet
    dTNet = MyDeepTrunkNet.get_deepTrunk_net(args, device, lossFn, evalFn, input_size, input_channel, n_class)

    ## setup logging and checkpointing
    timestamp = int(time.time())
    model_signature = '%s/%s/%d/%s_%.5f/%d' % (args.dataset, args.exp_name, args.exp_id, args.net, args.train_eps, timestamp)
    model_dir = args.root_dir + 'models_new/%s' % (model_signature)
    args.model_dir = model_dir


    print("Saving model to: %s" % model_dir)
    count_vars(args, dTNet)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    tb_writer = SummaryWriter(model_dir)
    stats = Statistics(len(train_loader), tb_writer, model_dir)
    args_file = os.path.join(model_dir, 'args.json')
    with open(args_file, 'w') as fou:
        json.dump(vars(args), fou, indent=4)
    write_config(args, os.path.join(model_dir, 'run_config.txt'))


    ## main part depending on training mode
    if 'train' in args.train_mode:
        epoch = train_deepTrunk(dTNet, args, device, stats, train_loader, test_loader)
        if args.cert:
            with torch.no_grad():
                cert_deepTrunk_net(dTNet, args, device, test_loader if args.test_set == "test" else train_loader,
                                   stats, log_ind=True, break_on_failure=False, epoch=epoch)
    elif args.train_mode == 'test':
        with torch.no_grad():
            test_deepTrunk_net(dTNet, args, device, test_loader if args.test_set == "test" else train_loader, stats,
                               log_ind=True)
    elif args.train_mode == "cert":
        with torch.no_grad():
            cert_deepTrunk_net(dTNet, args, device, test_loader if args.test_set == "test" else train_loader, stats,
                               log_ind=True, break_on_failure=False)
    else:
        assert False, 'Unknown mode: {}!'.format(args.train_mode)

    exit(0)
Exemple #4
0
def main():
    args = get_args()

    ver_logdir = args.load_model[:-3] + '_ver'
    if not os.path.exists(ver_logdir):
        os.makedirs(ver_logdir)

    num_train, _, test_loader, input_size, input_channel, n_class = get_loaders(
        args)
    net = get_network(device, args, input_size, input_channel, n_class)
    print(net)

    args.test_domains = []
    # with torch.no_grad():
    #     test(device, 0, args, net, test_loader, layers=[-1, args.layer_idx])
    args.test_batch = 1
    num_train, _, test_loader, input_size, input_channel, n_class = get_loaders(
        args)
    latent_idx = args.layer_idx if args.latent_idx is None else args.latent_idx
    img_file = open(args.unverified_imgs_file, 'w')

    with torch.no_grad():
        tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, tot_tests = 0, 0, 0, 0, 0
        for test_idx, (inputs, targets) in enumerate(test_loader):
            if test_idx < args.start_idx or test_idx >= args.end_idx:
                continue
            tot_tests += 1
            test_file = os.path.join(ver_logdir, '{}.p'.format(test_idx))
            test_data = pickle.load(open(test_file, 'rb')) if (
                not args.no_load) and os.path.isfile(test_file) else {}
            print('Verify test_idx =', test_idx)

            net.reset_bounds()

            inputs, targets = inputs.to(device), targets.to(device)
            abs_inputs = get_inputs(args.test_domain,
                                    inputs,
                                    args.test_eps,
                                    device,
                                    dtype=dtype)
            nat_out = net(inputs)
            nat_ok = targets.eq(nat_out.max(dim=1)[1]).item()
            tot_nat_ok += float(nat_ok)
            test_data['ok'] = nat_ok
            if not nat_ok:
                report(ver_logdir, tot_verified_corr, tot_nat_ok,
                       tot_attack_ok, tot_pgd_ok, test_idx, tot_tests,
                       test_data)
                continue

            for _ in range(args.attack_restarts):
                with torch.enable_grad():
                    pgd_loss, pgd_ok = get_adv_loss(device, args.test_eps, -1,
                                                    net, None, inputs, targets,
                                                    args.test_att_n_steps,
                                                    args.test_att_step_size)
                    if not pgd_ok:
                        break

            if pgd_ok:
                test_data['pgd_ok'] = 1
                tot_pgd_ok += 1
            else:
                test_data['pgd_ok'] = 0
                report(ver_logdir, tot_verified_corr, tot_nat_ok,
                       tot_attack_ok, tot_pgd_ok, test_idx, tot_tests,
                       test_data)
                continue

            if 'verified' in test_data and test_data['verified']:
                tot_verified_corr += 1
                tot_attack_ok += 1
                report(ver_logdir, tot_verified_corr, tot_nat_ok,
                       tot_attack_ok, tot_pgd_ok, test_idx, tot_tests,
                       test_data)
                continue
            if args.no_milp:
                report(ver_logdir, tot_verified_corr, tot_nat_ok,
                       tot_attack_ok, tot_pgd_ok, test_idx, tot_tests,
                       test_data)
                continue

            zono_inputs = get_inputs('zono_iter',
                                     inputs,
                                     args.test_eps,
                                     device,
                                     dtype=dtype)
            bounds = compute_bounds(net, device,
                                    len(net.blocks) - 1, args, zono_inputs)
            relu_params = reset_params(args, net, dtype)
            with torch.enable_grad():
                learn_slopes(device, relu_params, bounds, args,
                             len(net.blocks), net, inputs, targets, abs_inputs,
                             None, None)
            bounds = compute_bounds(net, device,
                                    len(net.blocks) - 1, args, zono_inputs)

            for _ in range(args.attack_restarts):
                with torch.enable_grad():
                    latent_loss, latent_ok = get_adv_loss(
                        device, args.test_eps, latent_idx, net, bounds, inputs,
                        targets, args.test_att_n_steps,
                        args.test_att_step_size)
                    # print('-> ', latent_idx, latent_loss, latent_ok)
                    if not latent_ok:
                        break

            if latent_ok:
                tot_attack_ok += 1
            zono_out = net(zono_inputs)
            verified, verified_corr = zono_out.verify(targets)
            test_data['verified'] = int(verified_corr.item())
            if verified_corr:
                tot_verified_corr += 1
                report(ver_logdir, tot_verified_corr, tot_nat_ok,
                       tot_attack_ok, tot_pgd_ok, test_idx, tot_tests,
                       test_data)
                continue

            loss_after = net(abs_inputs).ce_loss(targets)
            if args.refine_lidx is not None:
                bounds = compute_bounds(net, device,
                                        len(net.blocks) - 1, args, abs_inputs)
                for lidx in range(0, args.layer_idx + 2):
                    net.blocks[lidx].bounds = bounds[lidx]

                print('loss before refine: ', net(abs_inputs).ce_loss(targets))
                refine_dim = bounds[args.refine_lidx + 1][0].shape[2]
                pbar = tqdm(total=refine_dim * refine_dim, dynamic_ncols=True)
                for refine_i in range(refine_dim):
                    for refine_j in range(refine_dim):
                        refine(args, bounds, net, refine_i, refine_j,
                               abs_inputs, input_size)
                        pbar.update(1)
                pbar.close()
                loss_after = net(abs_inputs).ce_loss(targets)
                print('loss after refine: ', loss_after)

            if loss_after < args.loss_threshold:
                if args.refine_opt is not None:
                    with torch.enable_grad():
                        learn_bounds(net, bounds, relu_params, zono_inputs,
                                     args.refine_opt)
                if verify_test(args, net, inputs, targets, abs_inputs, bounds,
                               test_data, test_idx):
                    tot_verified_corr += 1
                    test_data['verified'] = True
            report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok,
                   tot_pgd_ok, test_idx, tot_tests, test_data)
    img_file.close()
    def forward(self, position_feature):
        # inputs [B, max_lenth]
        positions_encoded = self.position_encoding(position_feature)
        return positions_encoded


if __name__ == '__main__':
    import ipdb
    from loaders import get_loaders
    from configs_transformer import DefaultConfig
    from tqdm import tqdm
    from vocabulary import Vocab
    args = DefaultConfig
    args.batch_size = 2
    loader = get_loaders('val', args.batch_size, 2)
    vocab = Vocab()

    for i in tqdm(loader):
        feature, caption, lenth = [j for j in i]
        batch_size, c, h, w = feature.size()
        _, n, l = caption.size()
        feature = feature.unsqueeze(1).expand(
            (batch_size, n, c, h, w)).contiguous().view(-1, c, h, w)
        caption = caption.long()
        caption = caption.view(-1, l)

        model = VGGTransformerNew1(vocab, args)
        output_log_prob, output_token = model(feature, caption)
        token = model.greedy_search(feature)
        loss = output_log_prob.sum()
Exemple #6
0
def run(args):
    device = 'cuda' if torch.cuda.is_available() and (
        not args.no_cuda) else 'cpu'

    num_train, train_loader, test_loader, input_size, input_channel, n_class = get_loaders(
        args)
    net = get_network(device, args, input_size, input_channel, n_class)
    print(net)
    n_params = 0
    for param_name, param_value in net.named_parameters():
        if 'deepz_lambda' not in param_name:
            n_params += param_value.numel()
            param_value.requires_grad_(True)
        else:
            param_value.data = torch.ones(param_value.size()).to(device)
            param_value.requires_grad_(False)
    print('Number of parameters: ', n_params)

    n_epochs = args.n_epochs
    if args.train_mode == 'train':
        timestamp = int(time.time())
        model_dir = args.root_dir + 'models_new/%s/%s/%d/%s_%.5f/%d' % (
            args.dataset, args.exp_name, args.exp_id, args.net, args.train_eps,
            timestamp)
        print('Saving model to:', model_dir)
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        args_file = os.path.join(model_dir, 'args.json')
        with open(args_file, 'w') as fou:
            json.dump(vars(args), fou, indent=4)
        writer = None

        epoch = 0
        relu_stable = args.relu_stable
        lr = args.lr
        for j in range(len(args.layers) - 1):
            if args.opt == 'adam':
                opt = optim.Adam(net.parameters(), lr=lr, weight_decay=0)
            else:
                opt = optim.SGD(net.parameters(),
                                lr=lr,
                                momentum=0.9,
                                weight_decay=0)

            if args.lr_sched == 'step_lr':
                lr_scheduler = optim.lr_scheduler.StepLR(
                    opt, step_size=args.lr_step, gamma=args.lr_factor)
            else:
                lr_scheduler = optim.lr_scheduler.OneCycleLR(
                    opt,
                    div_factor=10000,
                    max_lr=lr,
                    pct_start=args.pct_start,
                    steps_per_epoch=len(train_loader),
                    epochs=n_epochs)

            eps = args.eps_factor**(len(args.layers) - 2 - j) * (
                args.start_eps_factor * args.train_eps)
            kappa_sched = Scheduler(0.0, 1.0, num_train * args.mix_epochs, 0)
            eps_sched = Scheduler(0 if args.anneal else eps, eps,
                                  num_train * args.mix_epochs, 0)
            prev_layer_idx, curr_layer_idx = args.layers[j], args.layers[j + 1]
            next_layer_idx = args.layers[j + 2] if j + 2 < len(
                args.layers) else None
            print(
                'new train phase: eps={}, lr={}, prev_layer={}, curr_layer={}, next_layer={}'
                .format(eps, lr, prev_layer_idx, curr_layer_idx,
                        next_layer_idx))
            layer_dir = '{}/{}'.format(model_dir, curr_layer_idx)
            if not os.path.exists(layer_dir):
                os.makedirs(layer_dir)
            for curr_epoch in range(n_epochs):
                train(device, writer, epoch, args, prev_layer_idx,
                      curr_layer_idx, next_layer_idx, net, eps_sched,
                      kappa_sched, opt, train_loader, lr_scheduler,
                      relu_stable)
                if curr_epoch >= args.mix_epochs and isinstance(
                        lr_scheduler, optim.lr_scheduler.StepLR):
                    lr_scheduler.step()
                if (epoch + 1) % args.test_freq == 0:
                    torch.save(
                        net.state_dict(),
                        os.path.join(layer_dir, 'net_%d.pt' % (epoch + 1)))
                    with torch.no_grad():
                        valid_nat_loss, valid_nat_acc, valid_robust_loss, valid_robust_acc = test(
                            device, epoch, args, net, test_loader,
                            [curr_layer_idx])
                epoch += 1
            relu_stable = None if relu_stable is None else relu_stable * args.relu_stable_factor
            n_epochs -= args.n_epochs_reduce
            lr = lr * args.lr_layer_dec
    elif args.train_mode == 'print':
        print('printing network to:', args.out_net_file)
        dummy_input = torch.randn(1,
                                  input_channel,
                                  input_size,
                                  input_size,
                                  device='cuda')
        net.skip_norm = True
        torch.onnx.export(net, dummy_input, args.out_net_file, verbose=True)
    elif args.train_mode == 'test':
        with torch.no_grad():
            test(device, 0, args, net, test_loader, args.layers)
    else:
        assert False, 'Unknown mode: {}!'.format(args.train_mode)
    return valid_nat_loss, valid_nat_acc, valid_robust_loss, valid_robust_acc
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, 1)


def go_through_data(net, data_loader, device):

    net.eval()
    with torch.no_grad():
        for (idx, (x, t)) in enumerate(data_loader):
            x = net.forward(x.to(device))
            t = t.to(device)
    return 1


device = "cuda"
train_loader, _ = get_loaders(batch_size_train, batch_size_test)
net = LeNet().to(device)
net.load_state_dict(torch.load("models/exp1.pth"))

t0 = time.time()
for i in range(20):
    go_through_data(net, train_loader, device)

total_time = time.time() - t0
print(total_time, total_time / (i + 1))
# 46.204389810562134 9.240877962112426
Exemple #8
0
def train(ifold, data, **config):
    """Train the networks for one fold.

    :param ifold: id of a sample

    :type ifold: tensor
    
    :param data: Paus data
    
    :type data: tensors
    
    :param Ntrain: train size
    
    :type Ntrain: int

    :param config: This parameter has several conf of the train.

    :type config: dictionary

    :return: networks

    :rtype:

    """
    
    #: data
    flux, flux_err, fmes, vinv, isnan, zbin, ref_id = data
    
    #: Loading the pretrained network.
    enc, dec, net_pz = utils.get_nets(config['use_mdn'], config['pretrain'])
    
    #: Indices of the selected sources with flow information and 
    inds = inds_all[config['catnr']][:len(flux)]
    
    #: Train and test Samples (sample_Kfold, mask, data, train set size)
    train_dl, test_dl, _ = loaders.get_loaders(ifold, inds, data, config['Ntrain'])

    #: Model, samples, config
    K = (enc, dec, net_pz, train_dl, test_dl, config['use_mdn'], config['alpha'], config['Nexp'], \
         config['keep_last'])

    #: Networks parameters
    params = chain(enc.parameters(), dec.parameters(), net_pz.parameters()) ### CAMBIO IMPORTANTE
   
    #:wd
    wd = 1e-4
    
    #: trainer_alpha.train with the better hyperparameters, optimizer=params, lrm wd), N?, K=Model, samples, config
    if True: 
        optimizer = optim.Adam(params, lr=1e-3, weight_decay=wd)
        trainer_alpha.train(optimizer, 100, *K)          

    print('main train function...')
    optimizer = optim.Adam(params, lr=1e-4, weight_decay=wd)
    trainer_alpha.train(optimizer, 200, *K)
    
    optimizer = optim.Adam(params, lr=1e-5, weight_decay=wd)
    trainer_alpha.train(optimizer, 200, *K)

    optimizer = optim.Adam(params, lr=1e-6, weight_decay=wd)
    trainer_alpha.train(optimizer, 200, *K)
    
    return enc, dec, net_pz
Exemple #9
0
def pz_fold(ifold, inds, data, **config):
    """Estimate the photo-z for one fold.
       to predict photometric redshifts receives both the 
       encoded latent variables and the original input flux 
       ratios. 

    :param ifold: id of a sample

    :type ifold: tensor
    
    :param inds: Indices of the selected sources with flux information an
    
    :type inds: tensor
    
    :param data: Paus data
    
    :type data: tensors

    :param config: This parameter has several conf of the train.

    :type config: dictionary

    :return: Dataframe with photo and spec redshifts and id

    :rtype: DataFrame

    """

    #: data
    flux, flux_err, fmes, vinv, isnan, zbin, ref_id = data
    
    # Loading the networks...
    net_base_path = config['out_fmt'].format(ifold=ifold, net='{}')
    enc, dec, net_pz = utils.get_nets(str(net_base_path), config['use_mdn'])
    enc.eval(), dec.eval(), net_pz.eval()

    # Loading test data
    _, test_dl, zbin_test = loaders.get_loaders(ifold, inds, data, config['Ntrain'])

    assert isinstance(inds, torch.Tensor)
 
    # OK, this needs some improvement...
    L = []
    for Bflux, Bfmes, Bvinv, Bisnan, Bzbin in test_dl:
        Bcoadd, touse = trainer_sexp.get_coadd(Bflux, Bfmes, Bvinv, Bisnan, alpha=1)
        assert touse.all()
            
        # Testing training augmentation.            
        feat = enc(Bcoadd)
        Binput = torch.cat([Bcoadd, feat], 1)
        pred = net_pz(Binput)
        
        zb_part = 0.001*pred.argmax(1).type(torch.float)
        L.append(zb_part)

    zb_fold = torch.cat(L).detach().cpu().numpy()
    zs_fold = 0.001*zbin_test.type(torch.float)

    
    #refid_fold = ref_id[1*(inds == ifold).type(torch.bool)]
    refid_fold = ref_id[inds == ifold]
    print(zb_fold.shape, zs_fold.shape, refid_fold.shape)
    D = {'zs': zs_fold, 'zb': zb_fold, 'ref_id': refid_fold}
    
    part = pd.DataFrame(D)
    part['ifold'] = ifold

    return part
Exemple #10
0
def run(args=None):
    device = 'cuda' if torch.cuda.is_available() and (
        not args.no_cuda) else 'cpu'
    num_train, train_loader, test_loader, input_size, input_channel, n_class = get_loaders(
        args)

    lossFn = nn.CrossEntropyLoss(reduction='none')
    evalFn = lambda x: torch.max(x, dim=1)[1]

    net = get_net(device,
                  args.dataset,
                  args.net,
                  input_size,
                  input_channel,
                  n_class,
                  load_model=args.load_model,
                  net_dim=args.cert_net_dim
                  )  #, feature_extract=args.core_feature_extract)

    timestamp = int(time.time())
    model_signature = '%s/%s/%d/%s_%.5f/%d' % (args.dataset, args.exp_name,
                                               args.exp_id, args.net,
                                               args.train_eps, timestamp)
    model_dir = args.root_dir + 'models_new/%s' % (model_signature)
    args.model_dir = model_dir
    count_vars(args, net)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    if isinstance(net, UpscaleNet):
        relaxed_net = None
        relu_ids = None
    else:
        relaxed_net = RelaxedNetwork(net.blocks, args.n_rand_proj).to(device)
        relu_ids = relaxed_net.get_relu_ids()

    if "nat" in args.train_mode:
        cnet = CombinedNetwork(net,
                               relaxed_net,
                               lossFn=lossFn,
                               evalFn=evalFn,
                               device=device,
                               no_r_net=True).to(device)
    else:
        dummy_input = torch.rand((1, ) + net.dims[0],
                                 device=device,
                                 dtype=torch.float32)
        cnet = CombinedNetwork(net,
                               relaxed_net,
                               lossFn=lossFn,
                               evalFn=evalFn,
                               device=device,
                               dummy_input=dummy_input).to(device)

    n_epochs, test_nat_loss, test_nat_acc, test_adv_loss, test_adv_acc = args.n_epochs, None, None, None, None

    if 'train' in args.train_mode:
        tb_writer = SummaryWriter(model_dir)
        stats = Statistics(len(train_loader), tb_writer, model_dir)
        args_file = os.path.join(model_dir, 'args.json')
        with open(args_file, 'w') as fou:
            json.dump(vars(args), fou, indent=4)
        write_config(args, os.path.join(model_dir, 'run_config.txt'))

        eps = 0
        epoch = 0
        lr = args.lr
        n_epochs = args.n_epochs

        if "COLT" in args.train_mode:
            relu_stable = args.relu_stable
            # if args.layers is None:
            #     args.layers = [-2, -1] + relu_ids
            layers = get_layers(args.train_mode,
                                cnet,
                                n_attack_layers=args.n_attack_layers,
                                protected_layers=args.protected_layers)
        elif "adv" in args.train_mode:
            relu_stable = None
            layers = [-1, -1]
            args.mix = False
        elif "natural" in args.train_mode:
            relu_stable = None
            layers = [-2, -2]
            args.nat_factor = 1
            args.mix = False
        elif "diffAI" in args.train_mode:
            relu_stable = None
            layers = [-2, -2]
        else:
            assert False, "Unknown train mode %s" % args.train_mode

        print('Saving model to:', model_dir)
        print('Training layers: ', layers)

        for j in range(len(layers) - 1):
            opt, lr_scheduler = get_opt(cnet.net,
                                        args.opt,
                                        lr,
                                        args.lr_step,
                                        args.lr_factor,
                                        args.n_epochs,
                                        train_loader,
                                        args.lr_sched,
                                        fixup="fixup" in args.net)

            curr_layer_idx = layers[j + 1]
            eps_old = eps
            eps = get_scaled_eps(args, layers, relu_ids, curr_layer_idx, j)

            kappa_sched = Scheduler(0.0 if args.mix else 1.0, 1.0,
                                    num_train * args.mix_epochs, 0)
            beta_sched = Scheduler(
                args.beta_start if args.mix else args.beta_end, args.beta_end,
                args.train_batch * len(train_loader) * args.mix_epochs, 0)
            eps_sched = Scheduler(eps_old if args.anneal else eps, eps,
                                  num_train * args.anneal_epochs, 0)

            layer_dir = '{}/{}'.format(model_dir, curr_layer_idx)
            if not os.path.exists(layer_dir):
                os.makedirs(layer_dir)

            print('\nnew train phase: eps={:.5f}, lr={:.2e}, curr_layer={}\n'.
                  format(eps, lr, curr_layer_idx))

            for curr_epoch in range(n_epochs):
                train(device,
                      epoch,
                      args,
                      j + 1,
                      layers,
                      cnet,
                      eps_sched,
                      kappa_sched,
                      opt,
                      train_loader,
                      lr_scheduler,
                      relu_ids,
                      stats,
                      relu_stable,
                      relu_stable_protected=args.relu_stable_protected,
                      beta_sched=beta_sched)

                if isinstance(lr_scheduler, optim.lr_scheduler.StepLR
                              ) and curr_epoch >= args.mix_epochs:
                    lr_scheduler.step()

                if (epoch + 1) % args.test_freq == 0:
                    with torch.no_grad():
                        test_nat_loss, test_nat_acc, test_adv_loss, test_adv_acc = test(
                            device,
                            args,
                            cnet,
                            test_loader if args.test_set == "test" else
                            train_loader, [curr_layer_idx],
                            stats=stats,
                            log_ind=(epoch + 1) % n_epochs == 0)

                if (epoch + 1) % args.test_freq == 0 or (epoch +
                                                         1) % n_epochs == 0:
                    torch.save(
                        net.state_dict(),
                        os.path.join(layer_dir, 'net_%d.pt' % (epoch + 1)))
                    torch.save(
                        opt.state_dict(),
                        os.path.join(layer_dir, 'opt_%d.pt' % (epoch + 1)))

                stats.update_tb(epoch)
                epoch += 1
            relu_stable = None if relu_stable is None else relu_stable * args.relu_stable_layer_dec
            lr = lr * args.lr_layer_dec
        if args.cert:
            with torch.no_grad():
                diffAI_cert(
                    device,
                    args,
                    cnet,
                    test_loader if args.test_set == "test" else train_loader,
                    stats=stats,
                    log_ind=True,
                    epoch=epoch,
                    domains=args.cert_domain)
    elif args.train_mode == 'print':
        print('printing network to:', args.out_net_file)
        dummy_input = torch.randn(1,
                                  input_channel,
                                  input_size,
                                  input_size,
                                  device='cuda')
        net.skip_norm = True
        torch.onnx.export(net, dummy_input, args.out_net_file, verbose=True)
    elif args.train_mode == 'test':
        with torch.no_grad():
            test(device,
                 args,
                 cnet,
                 test_loader if args.test_set == "test" else train_loader,
                 [-1],
                 log_ind=True)
    elif args.train_mode == "cert":
        tb_writer = SummaryWriter(model_dir)
        stats = Statistics(len(train_loader), tb_writer, model_dir)
        args_file = os.path.join(model_dir, 'args.json')
        with open(args_file, 'w') as fou:
            json.dump(vars(args), fou, indent=4)
        write_config(args, os.path.join(model_dir, 'run_config.txt'))
        print('Saving results to:', model_dir)
        with torch.no_grad():
            diffAI_cert(
                device,
                args,
                cnet,
                test_loader if args.test_set == "test" else train_loader,
                stats=stats,
                log_ind=True,
                domains=args.cert_domain)
        exit(0)
    else:
        assert False, 'Unknown mode: {}!'.format(args.train_mode)

    return test_nat_loss, test_nat_acc, test_adv_loss, test_adv_acc
Exemple #11
0
def main():
    parser = argparse.ArgumentParser(description='Perform greedy layerwise training.')
    parser.add_argument('--prune_p', default=None, type=float, help='percentage of weights to prune in each layer')
    parser.add_argument('--dataset', default='cifar10', help='dataset to use')
    parser.add_argument('--net', required=True, type=str, help='network to use')
    parser.add_argument('--load_model', type=str, help='model to load')
    parser.add_argument('--layer_idx', default=1, type=int, help='layer index of flattened vector')
    parser.add_argument('--n_valid', default=1000, type=int, help='number of test samples')
    parser.add_argument('--n_train', default=None, type=int, help='number of training samples to use')
    parser.add_argument('--train_batch', default=1, type=int, help='batch size for training')
    parser.add_argument('--test_batch', default=128, type=int, help='batch size for testing')
    parser.add_argument('--test_domain', default='zono', type=str, help='domain to test with')
    parser.add_argument('--test_eps', default=None, type=float, help='epsilon to verify')
    parser.add_argument('--debug', action='store_true', help='debug mode')
    parser.add_argument('--no_milp', action='store_true', help='no MILP mode')
    parser.add_argument('--no_load', action='store_true', help='verify from scratch')
    parser.add_argument('--no_smart', action='store_true', help='bla')
    parser.add_argument('--milp_timeout', default=10, type=int, help='timeout for MILP')
    parser.add_argument('--eval_train', action='store_true', help='evaluate on training set')
    parser.add_argument('--test_idx', default=None, type=int, help='specific index to test')
    parser.add_argument('--start_idx', default=0, type=int, help='specific index to start')
    parser.add_argument('--end_idx', default=1000, type=int, help='specific index to end')
    parser.add_argument('--max_binary', default=None, type=int, help='number of neurons to encode as binary variable in MILP (per layer)')
    parser.add_argument('--num_iters', default=50, type=int, help='number of iterations to find slopes')
    parser.add_argument('--max_refine_triples', default=0, type=int, help='number of triples to refine')
    parser.add_argument('--refine_lidx', default=None, type=int, help='layer to refine')
    parser.add_argument('--save_models', action='store_true', help='whether to only store models')
    parser.add_argument('--refine_milp', default=0, type=int, help='number of neurons to refine using MILP')
    parser.add_argument('--obj_threshold', default=None, type=float, help='threshold to consider for MILP verification')
    parser.add_argument('--attack_type', default='pgd', type=str, help='attack')
    parser.add_argument('--attack_n_steps', default=10, type=int, help='number of steps for the attack')
    parser.add_argument('--attack_step_size', default=0.25, type=float, help='step size for the attack (relative to epsilon)')
    parser.add_argument('--layers', required=False, default=None, type=int, nargs='+', help='layer indices for training')
    args = parser.parse_args()

    ver_logdir = args.load_model[:-3] + '_ver'
    if not os.path.exists(ver_logdir):
        os.makedirs(ver_logdir)
    grb_modelsdir = args.load_model[:-3] + '_grb'
    if not os.path.exists(grb_modelsdir):
        os.makedirs(grb_modelsdir)

    num_train, _, test_loader, input_size, input_channel = get_loaders(args)
    net = get_network(device, args, input_size, input_channel)
    n_layers = len(net.blocks)
    
    # net.to_double()

    args.test_domains = ['box']
    with torch.no_grad():
        test(device, 0, args, net, test_loader)

    args.test_batch = 1
    num_train, _, test_loader, input_size, input_channel = get_loaders(args)

    num_relu = 0
    for lidx in range(args.layer_idx+1, n_layers):
        print(net.blocks[lidx])
        if isinstance(net.blocks[lidx], ReLU):
            num_relu += 1

    with torch.no_grad():
        tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, tot_tests = 0, 0, 0, 0, 0
        for test_idx, (inputs, targets) in enumerate(test_loader):
            if test_idx < args.start_idx or test_idx >= args.end_idx or test_idx >= args.n_valid:
                continue
            if args.test_idx is not None and test_idx != args.test_idx:
                continue
            tot_tests += 1
            test_file = os.path.join(ver_logdir, '{}.p'.format(test_idx))
            test_data = pickle.load(open(test_file, 'rb')) if (not args.no_load) and os.path.isfile(test_file) else {}
            print('Verify test_idx =', test_idx)

            for lidx in range(n_layers):
                net.blocks[lidx].bounds = None

            inputs, targets = inputs.to(device), targets.to(device)
            abs_inputs = get_inputs(args.test_domain, inputs, args.test_eps, device, dtype=dtype)
            nat_out = net(inputs)
            nat_ok = targets.eq(nat_out.max(dim=1)[1]).item()
            tot_nat_ok += float(nat_ok)
            test_data['ok'] = nat_ok
            if not nat_ok:
                report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, test_idx, tot_tests, test_data)
                continue

            with torch.enable_grad():
                pgd_loss, pgd_ok = get_adv_loss(device, args.test_eps, -1, net, None, inputs, targets, args)
            if pgd_ok:
                test_data['pgd_ok'] = 1
                tot_pgd_ok += 1
            else:
                test_data['pgd_ok'] = 0
                report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, test_idx, tot_tests, test_data)
                continue
            if 'verified' in test_data and test_data['verified']:
                tot_verified_corr += 1
                tot_attack_ok += 1
                report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, test_idx, tot_tests, test_data)
                continue

            relu_params = reset_params(args, net, dtype)

            bounds = compute_bounds(net, device, args.layer_idx, args, abs_inputs)
            if args.test_domain == 'zono_iter':
                with torch.enable_grad():
                    learn_slopes(relu_params, bounds, args, n_layers, net, inputs, targets, abs_inputs, None, None)

            with torch.enable_grad():
                abs_loss, abs_ok = get_adv_loss(device, args.test_eps, args.layer_idx, net, bounds, inputs, targets, args)

            refined_triples = []
            if args.refine_lidx is not None:
                bounds = compute_bounds(net, device, args.layer_idx+1, args, abs_inputs)
                for lidx in range(0, args.layer_idx+2):
                    net.blocks[lidx].bounds = bounds[lidx]
                print('loss before refine: ', abs_loss)
                refine_dim = bounds[args.refine_lidx+1][0].shape[2]
                pbar = tqdm(total=refine_dim*refine_dim, dynamic_ncols=True)
                for refine_i in range(refine_dim):
                    for refine_j in range(refine_dim):
                        # refine(args, bounds, net, 0, 15, abs_inputs, input_size)
                        refine(args, bounds, net, refine_i, refine_j, abs_inputs, input_size)
                        pbar.update(1)
                pbar.close()
                with torch.enable_grad():
                    abs_loss, abs_ok = get_adv_loss(device, args.test_eps, args.layer_idx, net, bounds, inputs, targets, args)
                print('loss after refine: ', abs_loss)

            if abs_ok:
                tot_attack_ok += 1
            abs_out = net(abs_inputs)
            verified, verified_corr = abs_out.verify(targets)
            test_data['verified'] = int(verified_corr.item())
            print('abs_loss: ', abs_loss.item(), '\tabs_ok: ', abs_ok.item(), '\tverified_corr: ', verified_corr.item())
            if verified_corr:
                tot_verified_corr += 1
                report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, test_idx, tot_tests, test_data)
                continue
            if args.no_milp or (not abs_ok):
                report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, test_idx, tot_tests, test_data)
                continue

            if verify_test(args, net, num_relu, inputs, targets, abs_inputs, bounds, refined_triples, test_data, grb_modelsdir, test_idx):
                tot_verified_corr += 1
                test_data['verified'] = True
            report(ver_logdir, tot_verified_corr, tot_nat_ok, tot_attack_ok, tot_pgd_ok, test_idx, tot_tests, test_data)