def main(args):
    config = load_config(args.config)
    logger.info('config: {}'.format(json.dumps(config)))
    set_seed(args.seed or config['seed'])
    model_ori, checkpoint, epoch, best = prepare_model(args, logger, config)
    logger.info('Model structure: \n {}'.format(str(model_ori)))

    custom_ops = {}
    bound_config = config['bound_params']
    batch_size = (args.batch_size or config['batch_size'])
    test_batch_size = args.test_batch_size or batch_size
    dummy_input, train_data, test_data = load_data(args,
                                                   config['data'],
                                                   batch_size,
                                                   test_batch_size,
                                                   aug=not args.no_data_aug)
    lf = args.loss_fusion and args.bound_type == 'CROWN-IBP'
    bound_opts = bound_config['bound_opts']

    model_ori.train()
    model = BoundedModule(model_ori,
                          dummy_input,
                          bound_opts=bound_opts,
                          custom_ops=custom_ops,
                          device=args.device)
    model_ori.to(args.device)

    if checkpoint is None:
        if args.manual_init:
            manual_init(args, model_ori, model, train_data)
        if args.kaiming_init:
            kaiming_init(model_ori)

    if lf:
        model_loss = BoundedModule(
            CrossEntropyWrapper(model_ori),
            (dummy_input.cuda(), torch.zeros(1, dtype=torch.long).cuda()),
            bound_opts=get_bound_opts_lf(bound_opts),
            device=args.device)
        params = list(model_loss.parameters())
    else:
        model_loss = model
        params = list(model_ori.parameters())
    logger.info('Parameter shapes: {}'.format([p.shape for p in params]))
    if args.multi_gpu:
        raise NotImplementedError('Multi-GPU is not supported yet')
        model = BoundDataParallel(model)
        model_loss = BoundDataParallel(model_loss)

    opt = get_optimizer(args, params, checkpoint)
    max_eps = args.eps or bound_config['eps']
    eps_scheduler = get_eps_scheduler(args, max_eps, train_data)
    lr_scheduler = get_lr_scheduler(args, opt)
    if epoch > 0 and not args.plot:
        # skip epochs
        eps_scheduler.train()
        for i in range(epoch):
            # FIXME Can use `last_epoch` argument of lr_scheduler
            lr_scheduler.step()
            eps_scheduler.step_epoch(verbose=False)

    if args.verify:
        logger.info('Inference')
        meter = Train(model,
                      model_ori,
                      10000,
                      test_data,
                      eps_scheduler,
                      None,
                      loss_fusion=False)
        logger.info(meter)
    else:
        timer = 0.0
        for t in range(epoch + 1, args.num_epochs + 1):
            logger.info('Epoch {}, learning rate {}, dir {}'.format(
                t, lr_scheduler.get_last_lr(), args.dir))
            start_time = time.time()
            if lf:
                Train(model_loss,
                      model_ori,
                      t,
                      train_data,
                      eps_scheduler,
                      opt,
                      loss_fusion=True)
            else:
                Train(model, model_ori, t, train_data, eps_scheduler, opt)
            update_state_dict(model_ori, model_loss)
            epoch_time = time.time() - start_time
            timer += epoch_time
            lr_scheduler.step()
            logger.info('Epoch time: {:.4f}, Total time: {:.4f}'.format(
                epoch_time, timer))
            is_best = False
            if t % args.test_interval == 0:
                logger.info('Test without loss fusion')
                with torch.no_grad():
                    meter = Train(model,
                                  model_ori,
                                  t,
                                  test_data,
                                  eps_scheduler,
                                  None,
                                  loss_fusion=False)
                if eps_scheduler.get_eps() == eps_scheduler.get_max_eps():
                    if meter.avg('Rob_Err') < best[1]:
                        is_best, best = True, (meter.avg('Err'),
                                               meter.avg('Rob_Err'), t)
                    logger.info(
                        'Best epoch {}, error {:.4f}, robust error {:.4f}'.
                        format(best[-1], best[0], best[1]))
            save(args,
                 epoch=t,
                 best=best,
                 model=model_ori,
                 opt=opt,
                 is_best=is_best)
Esempio n. 2
0
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    ## Step 1: Initial original model as usual, see model details in models/example_feedforward.py and models/example_resnet.py
    if args.data == 'MNIST':
        model_ori = models.Models[args.model](in_ch=1, in_dim=28)
    else:
        model_ori = models.Models[args.model]()
    epoch = 0
    if args.load:
        checkpoint = torch.load(args.load)
        epoch, state_dict = checkpoint['epoch'], checkpoint['state_dict']
        opt_state = None
        try:
            opt_state = checkpoint['optimizer']
        except KeyError:
            print('no opt_state found')
        for k, v in state_dict.items():
            assert torch.isnan(v).any().cpu().numpy() == 0 and torch.isinf(
                v).any().cpu().numpy() == 0
        model_ori.load_state_dict(state_dict)
        logger.log('Checkpoint loaded: {}'.format(args.load))

    ## Step 2: Prepare dataset as usual
    if args.data == 'MNIST':
        dummy_input = torch.randn(1, 1, 28, 28)
        train_data = datasets.MNIST("./data",
                                    train=True,
                                    download=True,
                                    transform=transforms.ToTensor())
        test_data = datasets.MNIST("./data",
                                   train=False,
                                   download=True,
                                   transform=transforms.ToTensor())
    elif args.data == 'CIFAR':
        dummy_input = torch.randn(1, 3, 32, 32)
        normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                         std=[0.2023, 0.1994, 0.2010])
        train_data = datasets.CIFAR10("./data",
                                      train=True,
                                      download=True,
                                      transform=transforms.Compose([
                                          transforms.RandomHorizontalFlip(),
                                          transforms.RandomCrop(
                                              32, 4, padding_mode='edge'),
                                          transforms.ToTensor(), normalize
                                      ]))
        test_data = datasets.CIFAR10("./data",
                                     train=False,
                                     download=True,
                                     transform=transforms.Compose(
                                         [transforms.ToTensor(), normalize]))

    train_data = torch.utils.data.DataLoader(train_data,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             pin_memory=True,
                                             num_workers=min(
                                                 multiprocessing.cpu_count(),
                                                 4))
    test_data = torch.utils.data.DataLoader(test_data,
                                            batch_size=args.batch_size // 2,
                                            pin_memory=True,
                                            num_workers=min(
                                                multiprocessing.cpu_count(),
                                                4))
    if args.data == 'MNIST':
        train_data.mean = test_data.mean = torch.tensor([0.0])
        train_data.std = test_data.std = torch.tensor([1.0])
    elif args.data == 'CIFAR':
        train_data.mean = test_data.mean = torch.tensor(
            [0.4914, 0.4822, 0.4465])
        train_data.std = test_data.std = torch.tensor([0.2023, 0.1994, 0.2010])

    ## Step 3: wrap model with auto_LiRPA
    # The second parameter dummy_input is for constructing the trace of the computational graph.
    model = BoundedModule(model_ori,
                          dummy_input,
                          bound_opts={'relu': args.bound_opts},
                          device=args.device)
    final_name1 = model.final_name
    model_loss = BoundedModule(CrossEntropyWrapper(model_ori),
                               (dummy_input, torch.zeros(1, dtype=torch.long)),
                               bound_opts={
                                   'relu': args.bound_opts,
                                   'loss_fusion': True
                               },
                               device=args.device)
    # after CrossEntropyWrapper, the final name will change because of one additional input node in CrossEntropyWrapper
    final_name2 = model_loss._modules[final_name1].output_name[0]
    assert type(model._modules[final_name1]) == type(
        model_loss._modules[final_name2])
    if args.no_loss_fusion:
        model_loss = BoundedModule(model_ori,
                                   dummy_input,
                                   bound_opts={'relu': args.bound_opts},
                                   device=args.device)
        final_name2 = None
    model_loss = BoundDataParallel(model_loss)

    macs, params = profile(model_ori, (dummy_input.cuda(), ))
    logger.log('macs: {}, params: {}'.format(macs, params))

    ## Step 4 prepare optimizer, epsilon scheduler and learning rate scheduler
    opt = optim.Adam(model_loss.parameters(), lr=args.lr)
    norm = float(args.norm)
    lr_scheduler = optim.lr_scheduler.MultiStepLR(
        opt, milestones=args.lr_decay_milestones, gamma=0.1)
    eps_scheduler = eval(args.scheduler_name)(args.eps, args.scheduler_opts)
    logger.log(str(model_ori))

    # skip epochs
    if epoch > 0:
        epoch_length = int(
            (len(train_data.dataset) + train_data.batch_size - 1) /
            train_data.batch_size)
        eps_scheduler.set_epoch_length(epoch_length)
        eps_scheduler.train()
        for i in range(epoch):
            lr_scheduler.step()
            eps_scheduler.step_epoch(verbose=True)
            for j in range(epoch_length):
                eps_scheduler.step_batch()
        logger.log('resume from eps={:.12f}'.format(eps_scheduler.get_eps()))

    if args.load:
        if opt_state:
            opt.load_state_dict(opt_state)
            logger.log('resume opt_state')

    ## Step 5: start training
    if args.verify:
        eps_scheduler = FixedScheduler(args.eps)
        with torch.no_grad():
            Train(model,
                  1,
                  test_data,
                  eps_scheduler,
                  norm,
                  False,
                  None,
                  'IBP',
                  loss_fusion=False,
                  final_node_name=None)
    else:
        timer = 0.0
        best_acc = 1e10
        # with torch.autograd.detect_anomaly():
        for t in range(epoch + 1, args.num_epochs + 1):
            logger.log("Epoch {}, learning rate {}".format(
                t, lr_scheduler.get_last_lr()))
            start_time = time.time()
            Train(model_loss,
                  t,
                  train_data,
                  eps_scheduler,
                  norm,
                  True,
                  opt,
                  args.bound_type,
                  loss_fusion=not args.no_loss_fusion)
            lr_scheduler.step()
            epoch_time = time.time() - start_time
            timer += epoch_time
            logger.log('Epoch time: {:.4f}, Total time: {:.4f}'.format(
                epoch_time, timer))

            logger.log("Evaluating...")
            torch.cuda.empty_cache()

            # remove 'model.' in state_dict for CrossEntropyWrapper
            state_dict_loss = model_loss.state_dict()
            state_dict = {}
            if not args.no_loss_fusion:
                for name in state_dict_loss:
                    assert (name.startswith('model.'))
                    state_dict[name[6:]] = state_dict_loss[name]
            else:
                state_dict = state_dict_loss

            with torch.no_grad():
                if t > int(eps_scheduler.params['start']) + int(
                        eps_scheduler.params['length']):
                    m = Train(model_loss,
                              t,
                              test_data,
                              FixedScheduler(8. / 255),
                              norm,
                              False,
                              None,
                              'IBP',
                              loss_fusion=False,
                              final_node_name=final_name2)
                else:
                    m = Train(model_loss,
                              t,
                              test_data,
                              eps_scheduler,
                              norm,
                              False,
                              None,
                              'IBP',
                              loss_fusion=False,
                              final_node_name=final_name2)

            save_dict = {
                'state_dict': state_dict,
                'epoch': t,
                'optimizer': opt.state_dict()
            }
            if t < int(eps_scheduler.params['start']):
                torch.save(save_dict, 'saved_models/natural_' + exp_name)
            elif t > int(eps_scheduler.params['start']) + int(
                    eps_scheduler.params['length']):
                current_acc = m.avg('Verified_Err')
                if current_acc < best_acc:
                    best_acc = current_acc
                    torch.save(
                        save_dict, 'saved_models/' + exp_name + '_best_' +
                        str(best_acc)[:6])
                else:
                    torch.save(save_dict, 'saved_models/' + exp_name)
            else:
                torch.save(save_dict, 'saved_models/' + exp_name)
            torch.cuda.empty_cache()
Esempio n. 3
0
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    ## Step 1: Initial original model as usual, see model details in models/example_feedforward.py and models/example_resnet.py
    model_ori = models.Models[args.model]()
    epoch = 0
    if args.load:
        checkpoint = torch.load(args.load)
        epoch, state_dict = checkpoint['epoch'], checkpoint['state_dict']
        opt_state = None
        try:
            opt_state = checkpoint['optimizer']
        except KeyError:
            print('no opt_state found')
        for k, v in state_dict.items():
            assert torch.isnan(v).any().cpu().numpy() == 0 and torch.isinf(
                v).any().cpu().numpy() == 0
        model_ori.load_state_dict(state_dict)
        logger.log('Checkpoint loaded: {}'.format(args.load))

    ## Step 2: Prepare dataset as usual
    dummy_input = torch.randn(1, 3, 56, 56)
    normalize = transforms.Normalize(mean=[0.4802, 0.4481, 0.3975],
                                     std=[0.2302, 0.2265, 0.2262])
    train_data = datasets.ImageFolder(args.data_dir + '/train',
                                      transform=transforms.Compose([
                                          transforms.RandomHorizontalFlip(),
                                          transforms.RandomCrop(
                                              56, padding_mode='edge'),
                                          transforms.ToTensor(),
                                          normalize,
                                      ]))
    test_data = datasets.ImageFolder(
        args.data_dir + '/val',
        transform=transforms.Compose([
            # transforms.RandomResizedCrop(64, scale=(0.875, 0.875), ratio=(1., 1.)),
            transforms.CenterCrop(56),
            transforms.ToTensor(),
            normalize
        ]))

    train_data = torch.utils.data.DataLoader(train_data,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             pin_memory=True,
                                             num_workers=min(
                                                 multiprocessing.cpu_count(),
                                                 4))
    test_data = torch.utils.data.DataLoader(test_data,
                                            batch_size=args.batch_size // 5,
                                            pin_memory=True,
                                            num_workers=min(
                                                multiprocessing.cpu_count(),
                                                4))
    train_data.mean = test_data.mean = torch.tensor([0.4802, 0.4481, 0.3975])
    train_data.std = test_data.std = torch.tensor([0.2302, 0.2265, 0.2262])

    ## Step 3: wrap model with auto_LiRPA
    # The second parameter dummy_input is for constructing the trace of the computational graph.
    model = BoundedModule(model_ori,
                          dummy_input,
                          bound_opts={'relu': args.bound_opts},
                          device=args.device)
    model_loss = BoundedModule(CrossEntropyWrapper(model_ori),
                               (dummy_input, torch.zeros(1, dtype=torch.long)),
                               bound_opts={
                                   'relu': args.bound_opts,
                                   'loss_fusion': True
                               },
                               device=args.device)
    model_loss = BoundDataParallel(model_loss)

    ## Step 4 prepare optimizer, epsilon scheduler and learning rate scheduler
    opt = optim.Adam(model_loss.parameters(), lr=args.lr)
    norm = float(args.norm)
    lr_scheduler = optim.lr_scheduler.MultiStepLR(
        opt, milestones=args.lr_decay_milestones, gamma=0.1)
    eps_scheduler = eval(args.scheduler_name)(args.eps, args.scheduler_opts)
    logger.log(str(model_ori))

    if args.load:
        if opt_state:
            opt.load_state_dict(opt_state)
            logger.log('resume opt_state')

    # skip epochs
    if epoch > 0:
        epoch_length = int(
            (len(train_data.dataset) + train_data.batch_size - 1) /
            train_data.batch_size)
        eps_scheduler.set_epoch_length(epoch_length)
        eps_scheduler.train()
        for i in range(epoch):
            lr_scheduler.step()
            eps_scheduler.step_epoch(verbose=True)
            for j in range(epoch_length):
                eps_scheduler.step_batch()
        logger.log('resume from eps={:.12f}'.format(eps_scheduler.get_eps()))

    ## Step 5: start training
    if args.verify:
        eps_scheduler = FixedScheduler(args.eps)
        with torch.no_grad():
            Train(model,
                  1,
                  test_data,
                  eps_scheduler,
                  norm,
                  False,
                  None,
                  'IBP',
                  loss_fusion=False,
                  final_node_name=None)
    else:
        timer = 0.0
        best_err = 1e10
        # with torch.autograd.detect_anomaly():
        for t in range(epoch + 1, args.num_epochs + 1):
            logger.log("Epoch {}, learning rate {}".format(
                t, lr_scheduler.get_last_lr()))
            start_time = time.time()
            Train(model_loss,
                  t,
                  train_data,
                  eps_scheduler,
                  norm,
                  True,
                  opt,
                  args.bound_type,
                  loss_fusion=True)
            lr_scheduler.step()
            epoch_time = time.time() - start_time
            timer += epoch_time
            logger.log('Epoch time: {:.4f}, Total time: {:.4f}'.format(
                epoch_time, timer))

            logger.log("Evaluating...")
            torch.cuda.empty_cache()

            # remove 'model.' in state_dict
            state_dict_loss = model_loss.state_dict()
            state_dict = {}
            for name in state_dict_loss:
                assert (name.startswith('model.'))
                state_dict[name[6:]] = state_dict_loss[name]

            with torch.no_grad():
                if int(eps_scheduler.params['start']) + int(
                        eps_scheduler.params['length']) > t >= int(
                            eps_scheduler.params['start']):
                    m = Train(model_loss,
                              t,
                              test_data,
                              eps_scheduler,
                              norm,
                              False,
                              None,
                              args.bound_type,
                              loss_fusion=True)
                else:
                    model_ori.load_state_dict(state_dict)
                    model = BoundedModule(model_ori,
                                          dummy_input,
                                          bound_opts={'relu': args.bound_opts},
                                          device=args.device)
                    model = BoundDataParallel(model)
                    m = Train(model,
                              t,
                              test_data,
                              eps_scheduler,
                              norm,
                              False,
                              None,
                              'IBP',
                              loss_fusion=False)
                    del model

            save_dict = {
                'state_dict': state_dict,
                'epoch': t,
                'optimizer': opt.state_dict()
            }
            if t < int(eps_scheduler.params['start']):
                torch.save(save_dict, 'saved_models/natural_' + exp_name)
            elif t > int(eps_scheduler.params['start']) + int(
                    eps_scheduler.params['length']):
                current_err = m.avg('Verified_Err')
                if current_err < best_err:
                    best_err = current_err
                    torch.save(
                        save_dict, 'saved_models/' + exp_name + '_best_' +
                        str(best_err)[:6])
            else:
                torch.save(save_dict, 'saved_models/' + exp_name)
            torch.cuda.empty_cache()
Esempio n. 4
0
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    ## Load the model with BoundedParameter for weight perturbation.
    model_ori = models.Models['mlp_3layer_weight_perturb']()

    epoch = 0
    ## Load a checkpoint, if requested.
    if args.load:
        checkpoint = torch.load(args.load)
        epoch, state_dict = checkpoint['epoch'], checkpoint['state_dict']
        opt_state = None
        try:
            opt_state = checkpoint['optimizer']
        except KeyError:
            print('no opt_state found')
        for k, v in state_dict.items():
            assert torch.isnan(v).any().cpu().numpy() == 0 and torch.isinf(v).any().cpu().numpy() == 0
        model_ori.load_state_dict(state_dict)
        logger.log('Checkpoint loaded: {}'.format(args.load))

    ## Step 2: Prepare dataset as usual
    dummy_input = torch.randn(1, 1, 28, 28)
    train_data,  test_data = mnist_loaders(datasets.MNIST, batch_size=args.batch_size, ratio=args.ratio)
    train_data.mean = test_data.mean = torch.tensor([0.0])
    train_data.std = test_data.std = torch.tensor([1.0])

    ## Step 3: wrap model with auto_LiRPA
    # The second parameter dummy_input is for constructing the trace of the computational graph.
    model = BoundedModule(model_ori, dummy_input, bound_opts={'relu':args.bound_opts}, device=args.device)
    final_name1 = model.final_name
    model_loss = BoundedModule(CrossEntropyWrapper(model_ori), (dummy_input, torch.zeros(1, dtype=torch.long)),
                               bound_opts= { 'relu': args.bound_opts, 'loss_fusion': True }, device=args.device)

    # after CrossEntropyWrapper, the final name will change because of one more input node in CrossEntropyWrapper
    final_name2 = model_loss._modules[final_name1].output_name[0]
    assert type(model._modules[final_name1]) == type(model_loss._modules[final_name2])
    if args.multigpu:
        model_loss = BoundDataParallel(model_loss)
    model_loss.ptb = model.ptb = model_ori.ptb # Perturbation on the parameters

    ## Step 4 prepare optimizer, epsilon scheduler and learning rate scheduler
    if args.opt == 'ADAM':
        opt = optim.Adam(model_loss.parameters(), lr=args.lr, weight_decay=0.01)
    elif args.opt == 'SGD':
        opt = optim.SGD(model_loss.parameters(), lr=args.lr, weight_decay=0.01)

    norm = float(args.norm)
    lr_scheduler = optim.lr_scheduler.MultiStepLR(opt, milestones=args.lr_decay_milestones, gamma=0.1)
    eps_scheduler = eval(args.scheduler_name)(args.eps, args.scheduler_opts)
    logger.log(str(model_ori))

    # Skip epochs if we continue training from a checkpoint.
    if epoch > 0:
        epoch_length = int((len(train_data.dataset) + train_data.batch_size - 1) / train_data.batch_size)
        eps_scheduler.set_epoch_length(epoch_length)
        eps_scheduler.train()
        for i in range(epoch):
            lr_scheduler.step()
            eps_scheduler.step_epoch(verbose=True)
            for j in range(epoch_length):
                eps_scheduler.step_batch()
        logger.log('resume from eps={:.12f}'.format(eps_scheduler.get_eps()))

    if args.load:
        if opt_state:
            opt.load_state_dict(opt_state)
            logger.log('resume opt_state')

    ## Step 5: start training.
    if args.verify:
        eps_scheduler = FixedScheduler(args.eps)
        with torch.no_grad():
            Train(model, 1, test_data, eps_scheduler, norm, False, None, 'CROWN-IBP', loss_fusion=False, final_node_name=None)
    else:
        timer = 0.0
        best_loss = 1e10
        # Main training loop
        for t in range(epoch + 1, args.num_epochs+1):
            logger.log("Epoch {}, learning rate {}".format(t, lr_scheduler.get_last_lr()))
            start_time = time.time()

            # Training one epoch
            Train(model_loss, t, train_data, eps_scheduler, norm, True, opt, args.bound_type, loss_fusion=True)
            lr_scheduler.step()
            epoch_time = time.time() - start_time
            timer += epoch_time
            logger.log('Epoch time: {:.4f}, Total time: {:.4f}'.format(epoch_time, timer))

            logger.log("Evaluating...")
            torch.cuda.empty_cache()

            # remove 'model.' in state_dict (hack for saving models so far...)
            state_dict_loss = model_loss.state_dict()
            state_dict = {}
            for name in state_dict_loss:
                assert (name.startswith('model.'))
                state_dict[name[6:]] = state_dict_loss[name]

            # Test one epoch.
            with torch.no_grad():
                m = Train(model_loss, t, test_data, eps_scheduler, norm, False, None, args.bound_type,
                            loss_fusion=False, final_node_name=final_name2)

            # Save checkpoints.
            save_dict = {'state_dict': state_dict, 'epoch': t, 'optimizer': opt.state_dict()}
            if not os.path.exists('saved_models'):
                os.mkdir('saved_models')
            if t < int(eps_scheduler.params['start']):
                torch.save(save_dict, 'saved_models/natural_' + exp_name)
            elif t > int(eps_scheduler.params['start']) + int(eps_scheduler.params['length']):
                current_loss = m.avg('Loss')
                if current_loss < best_loss:
                    best_loss = current_loss
                    torch.save(save_dict, 'saved_models/' + exp_name + '_best_' + str(best_loss)[:6])
                else:
                    torch.save(save_dict, 'saved_models/' + exp_name)
            else:
                torch.save(save_dict, 'saved_models/' + exp_name)
            torch.cuda.empty_cache()