Ejemplo n.º 1
0
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    ## Step 1: Initial original model as usual, see model details in models/example_feedforward.py and models/example_resnet.py
    if args.data == 'MNIST':
        model_ori = models.Models[args.model](in_ch=1, in_dim=28)
    else:
        model_ori = models.Models[args.model](in_ch=3, in_dim=32)
    if args.load:
        state_dict = torch.load(args.load)['state_dict']
        model_ori.load_state_dict(state_dict)

    ## Step 2: Prepare dataset as usual
    if args.data == 'MNIST':
        dummy_input = torch.randn(1, 1, 28, 28)
        train_data = datasets.MNIST("./data",
                                    train=True,
                                    download=True,
                                    transform=transforms.ToTensor())
        test_data = datasets.MNIST("./data",
                                   train=False,
                                   download=True,
                                   transform=transforms.ToTensor())
    elif args.data == 'CIFAR':
        dummy_input = torch.randn(1, 3, 32, 32)
        normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                         std=[0.2023, 0.1994, 0.2010])
        train_data = datasets.CIFAR10("./data",
                                      train=True,
                                      download=True,
                                      transform=transforms.Compose([
                                          transforms.RandomHorizontalFlip(),
                                          transforms.RandomCrop(32, 4),
                                          transforms.ToTensor(), normalize
                                      ]))
        test_data = datasets.CIFAR10("./data",
                                     train=False,
                                     download=True,
                                     transform=transforms.Compose(
                                         [transforms.ToTensor(), normalize]))

    train_data = torch.utils.data.DataLoader(train_data,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             pin_memory=True,
                                             num_workers=min(
                                                 multiprocessing.cpu_count(),
                                                 4))
    test_data = torch.utils.data.DataLoader(test_data,
                                            batch_size=args.batch_size,
                                            pin_memory=True,
                                            num_workers=min(
                                                multiprocessing.cpu_count(),
                                                4))
    if args.data == 'MNIST':
        train_data.mean = test_data.mean = torch.tensor([0.0])
        train_data.std = test_data.std = torch.tensor([1.0])
    elif args.data == 'CIFAR':
        train_data.mean = test_data.mean = torch.tensor(
            [0.4914, 0.4822, 0.4465])
        train_data.std = test_data.std = torch.tensor([0.2023, 0.1994, 0.2010])

    ## Step 3: wrap model with auto_LiRPA
    # The second parameter dummy_input is for constructing the trace of the computational graph.
    model = BoundedModule(model_ori,
                          dummy_input,
                          bound_opts={'relu': args.bound_opts},
                          device=args.device)

    ## Step 4 prepare optimizer, epsilon scheduler and learning rate scheduler
    opt = optim.Adam(model.parameters(), lr=args.lr)
    norm = float(args.norm)
    lr_scheduler = optim.lr_scheduler.StepLR(opt, step_size=10, gamma=0.5)
    eps_scheduler = eval(args.scheduler_name)(args.eps, args.scheduler_opts)
    print("Model structure: \n", str(model_ori))

    ## Step 5: start training
    if args.verify:
        eps_scheduler = FixedScheduler(args.eps)
        with torch.no_grad():
            Train(model, 1, test_data, eps_scheduler, norm, False, None,
                  args.bound_type)
    else:
        timer = 0.0
        for t in range(1, args.num_epochs + 1):
            if eps_scheduler.reached_max_eps():
                # Only decay learning rate after reaching the maximum eps
                lr_scheduler.step()
            print("Epoch {}, learning rate {}".format(t,
                                                      lr_scheduler.get_lr()))
            start_time = time.time()
            Train(model, t, train_data, eps_scheduler, norm, True, opt,
                  args.bound_type)
            epoch_time = time.time() - start_time
            timer += epoch_time
            print('Epoch time: {:.4f}, Total time: {:.4f}'.format(
                epoch_time, timer))
            print("Evaluating...")
            with torch.no_grad():
                Train(model, t, test_data, eps_scheduler, norm, False, None,
                      args.bound_type)
            torch.save({
                'state_dict': model.state_dict(),
                'epoch': t
            }, args.model)
Ejemplo n.º 2
0
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    ## Step 1: Initial original model as usual, see model details in models/example_feedforward.py and models/example_resnet.py
    model_ori = models.Models[args.model]()
    epoch = 0
    if args.load:
        checkpoint = torch.load(args.load)
        epoch, state_dict = checkpoint['epoch'], checkpoint['state_dict']
        opt_state = None
        try:
            opt_state = checkpoint['optimizer']
        except KeyError:
            print('no opt_state found')
        for k, v in state_dict.items():
            assert torch.isnan(v).any().cpu().numpy() == 0 and torch.isinf(
                v).any().cpu().numpy() == 0
        model_ori.load_state_dict(state_dict)
        logger.log('Checkpoint loaded: {}'.format(args.load))

    ## Step 2: Prepare dataset as usual
    dummy_input = torch.randn(1, 3, 56, 56)
    normalize = transforms.Normalize(mean=[0.4802, 0.4481, 0.3975],
                                     std=[0.2302, 0.2265, 0.2262])
    train_data = datasets.ImageFolder(args.data_dir + '/train',
                                      transform=transforms.Compose([
                                          transforms.RandomHorizontalFlip(),
                                          transforms.RandomCrop(
                                              56, padding_mode='edge'),
                                          transforms.ToTensor(),
                                          normalize,
                                      ]))
    test_data = datasets.ImageFolder(
        args.data_dir + '/val',
        transform=transforms.Compose([
            # transforms.RandomResizedCrop(64, scale=(0.875, 0.875), ratio=(1., 1.)),
            transforms.CenterCrop(56),
            transforms.ToTensor(),
            normalize
        ]))

    train_data = torch.utils.data.DataLoader(train_data,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             pin_memory=True,
                                             num_workers=min(
                                                 multiprocessing.cpu_count(),
                                                 4))
    test_data = torch.utils.data.DataLoader(test_data,
                                            batch_size=args.batch_size // 5,
                                            pin_memory=True,
                                            num_workers=min(
                                                multiprocessing.cpu_count(),
                                                4))
    train_data.mean = test_data.mean = torch.tensor([0.4802, 0.4481, 0.3975])
    train_data.std = test_data.std = torch.tensor([0.2302, 0.2265, 0.2262])

    ## Step 3: wrap model with auto_LiRPA
    # The second parameter dummy_input is for constructing the trace of the computational graph.
    model = BoundedModule(model_ori,
                          dummy_input,
                          bound_opts={'relu': args.bound_opts},
                          device=args.device)
    model_loss = BoundedModule(CrossEntropyWrapper(model_ori),
                               (dummy_input, torch.zeros(1, dtype=torch.long)),
                               bound_opts={
                                   'relu': args.bound_opts,
                                   'loss_fusion': True
                               },
                               device=args.device)
    model_loss = BoundDataParallel(model_loss)

    ## Step 4 prepare optimizer, epsilon scheduler and learning rate scheduler
    opt = optim.Adam(model_loss.parameters(), lr=args.lr)
    norm = float(args.norm)
    lr_scheduler = optim.lr_scheduler.MultiStepLR(
        opt, milestones=args.lr_decay_milestones, gamma=0.1)
    eps_scheduler = eval(args.scheduler_name)(args.eps, args.scheduler_opts)
    logger.log(str(model_ori))

    if args.load:
        if opt_state:
            opt.load_state_dict(opt_state)
            logger.log('resume opt_state')

    # skip epochs
    if epoch > 0:
        epoch_length = int(
            (len(train_data.dataset) + train_data.batch_size - 1) /
            train_data.batch_size)
        eps_scheduler.set_epoch_length(epoch_length)
        eps_scheduler.train()
        for i in range(epoch):
            lr_scheduler.step()
            eps_scheduler.step_epoch(verbose=True)
            for j in range(epoch_length):
                eps_scheduler.step_batch()
        logger.log('resume from eps={:.12f}'.format(eps_scheduler.get_eps()))

    ## Step 5: start training
    if args.verify:
        eps_scheduler = FixedScheduler(args.eps)
        with torch.no_grad():
            Train(model,
                  1,
                  test_data,
                  eps_scheduler,
                  norm,
                  False,
                  None,
                  'IBP',
                  loss_fusion=False,
                  final_node_name=None)
    else:
        timer = 0.0
        best_err = 1e10
        # with torch.autograd.detect_anomaly():
        for t in range(epoch + 1, args.num_epochs + 1):
            logger.log("Epoch {}, learning rate {}".format(
                t, lr_scheduler.get_last_lr()))
            start_time = time.time()
            Train(model_loss,
                  t,
                  train_data,
                  eps_scheduler,
                  norm,
                  True,
                  opt,
                  args.bound_type,
                  loss_fusion=True)
            lr_scheduler.step()
            epoch_time = time.time() - start_time
            timer += epoch_time
            logger.log('Epoch time: {:.4f}, Total time: {:.4f}'.format(
                epoch_time, timer))

            logger.log("Evaluating...")
            torch.cuda.empty_cache()

            # remove 'model.' in state_dict
            state_dict_loss = model_loss.state_dict()
            state_dict = {}
            for name in state_dict_loss:
                assert (name.startswith('model.'))
                state_dict[name[6:]] = state_dict_loss[name]

            with torch.no_grad():
                if int(eps_scheduler.params['start']) + int(
                        eps_scheduler.params['length']) > t >= int(
                            eps_scheduler.params['start']):
                    m = Train(model_loss,
                              t,
                              test_data,
                              eps_scheduler,
                              norm,
                              False,
                              None,
                              args.bound_type,
                              loss_fusion=True)
                else:
                    model_ori.load_state_dict(state_dict)
                    model = BoundedModule(model_ori,
                                          dummy_input,
                                          bound_opts={'relu': args.bound_opts},
                                          device=args.device)
                    model = BoundDataParallel(model)
                    m = Train(model,
                              t,
                              test_data,
                              eps_scheduler,
                              norm,
                              False,
                              None,
                              'IBP',
                              loss_fusion=False)
                    del model

            save_dict = {
                'state_dict': state_dict,
                'epoch': t,
                'optimizer': opt.state_dict()
            }
            if t < int(eps_scheduler.params['start']):
                torch.save(save_dict, 'saved_models/natural_' + exp_name)
            elif t > int(eps_scheduler.params['start']) + int(
                    eps_scheduler.params['length']):
                current_err = m.avg('Verified_Err')
                if current_err < best_err:
                    best_err = current_err
                    torch.save(
                        save_dict, 'saved_models/' + exp_name + '_best_' +
                        str(best_err)[:6])
            else:
                torch.save(save_dict, 'saved_models/' + exp_name)
            torch.cuda.empty_cache()
Ejemplo n.º 3
0
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    ## Step 1: Initial original model as usual, see model details in models/example_feedforward.py and models/example_resnet.py
    if args.data == 'MNIST':
        model_ori = models.Models[args.model](in_ch=1, in_dim=28)
    else:
        model_ori = models.Models[args.model]()
    epoch = 0
    if args.load:
        checkpoint = torch.load(args.load)
        epoch, state_dict = checkpoint['epoch'], checkpoint['state_dict']
        opt_state = None
        try:
            opt_state = checkpoint['optimizer']
        except KeyError:
            print('no opt_state found')
        for k, v in state_dict.items():
            assert torch.isnan(v).any().cpu().numpy() == 0 and torch.isinf(
                v).any().cpu().numpy() == 0
        model_ori.load_state_dict(state_dict)
        logger.log('Checkpoint loaded: {}'.format(args.load))

    ## Step 2: Prepare dataset as usual
    if args.data == 'MNIST':
        dummy_input = torch.randn(1, 1, 28, 28)
        train_data = datasets.MNIST("./data",
                                    train=True,
                                    download=True,
                                    transform=transforms.ToTensor())
        test_data = datasets.MNIST("./data",
                                   train=False,
                                   download=True,
                                   transform=transforms.ToTensor())
    elif args.data == 'CIFAR':
        dummy_input = torch.randn(1, 3, 32, 32)
        normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                         std=[0.2023, 0.1994, 0.2010])
        train_data = datasets.CIFAR10("./data",
                                      train=True,
                                      download=True,
                                      transform=transforms.Compose([
                                          transforms.RandomHorizontalFlip(),
                                          transforms.RandomCrop(
                                              32, 4, padding_mode='edge'),
                                          transforms.ToTensor(), normalize
                                      ]))
        test_data = datasets.CIFAR10("./data",
                                     train=False,
                                     download=True,
                                     transform=transforms.Compose(
                                         [transforms.ToTensor(), normalize]))

    train_data = torch.utils.data.DataLoader(train_data,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             pin_memory=True,
                                             num_workers=min(
                                                 multiprocessing.cpu_count(),
                                                 4))
    test_data = torch.utils.data.DataLoader(test_data,
                                            batch_size=args.batch_size // 2,
                                            pin_memory=True,
                                            num_workers=min(
                                                multiprocessing.cpu_count(),
                                                4))
    if args.data == 'MNIST':
        train_data.mean = test_data.mean = torch.tensor([0.0])
        train_data.std = test_data.std = torch.tensor([1.0])
    elif args.data == 'CIFAR':
        train_data.mean = test_data.mean = torch.tensor(
            [0.4914, 0.4822, 0.4465])
        train_data.std = test_data.std = torch.tensor([0.2023, 0.1994, 0.2010])

    ## Step 3: wrap model with auto_LiRPA
    # The second parameter dummy_input is for constructing the trace of the computational graph.
    model = BoundedModule(model_ori,
                          dummy_input,
                          bound_opts={'relu': args.bound_opts},
                          device=args.device)
    final_name1 = model.final_name
    model_loss = BoundedModule(CrossEntropyWrapper(model_ori),
                               (dummy_input, torch.zeros(1, dtype=torch.long)),
                               bound_opts={
                                   'relu': args.bound_opts,
                                   'loss_fusion': True
                               },
                               device=args.device)
    # after CrossEntropyWrapper, the final name will change because of one additional input node in CrossEntropyWrapper
    final_name2 = model_loss._modules[final_name1].output_name[0]
    assert type(model._modules[final_name1]) == type(
        model_loss._modules[final_name2])
    if args.no_loss_fusion:
        model_loss = BoundedModule(model_ori,
                                   dummy_input,
                                   bound_opts={'relu': args.bound_opts},
                                   device=args.device)
        final_name2 = None
    model_loss = BoundDataParallel(model_loss)

    macs, params = profile(model_ori, (dummy_input.cuda(), ))
    logger.log('macs: {}, params: {}'.format(macs, params))

    ## Step 4 prepare optimizer, epsilon scheduler and learning rate scheduler
    opt = optim.Adam(model_loss.parameters(), lr=args.lr)
    norm = float(args.norm)
    lr_scheduler = optim.lr_scheduler.MultiStepLR(
        opt, milestones=args.lr_decay_milestones, gamma=0.1)
    eps_scheduler = eval(args.scheduler_name)(args.eps, args.scheduler_opts)
    logger.log(str(model_ori))

    # skip epochs
    if epoch > 0:
        epoch_length = int(
            (len(train_data.dataset) + train_data.batch_size - 1) /
            train_data.batch_size)
        eps_scheduler.set_epoch_length(epoch_length)
        eps_scheduler.train()
        for i in range(epoch):
            lr_scheduler.step()
            eps_scheduler.step_epoch(verbose=True)
            for j in range(epoch_length):
                eps_scheduler.step_batch()
        logger.log('resume from eps={:.12f}'.format(eps_scheduler.get_eps()))

    if args.load:
        if opt_state:
            opt.load_state_dict(opt_state)
            logger.log('resume opt_state')

    ## Step 5: start training
    if args.verify:
        eps_scheduler = FixedScheduler(args.eps)
        with torch.no_grad():
            Train(model,
                  1,
                  test_data,
                  eps_scheduler,
                  norm,
                  False,
                  None,
                  'IBP',
                  loss_fusion=False,
                  final_node_name=None)
    else:
        timer = 0.0
        best_acc = 1e10
        # with torch.autograd.detect_anomaly():
        for t in range(epoch + 1, args.num_epochs + 1):
            logger.log("Epoch {}, learning rate {}".format(
                t, lr_scheduler.get_last_lr()))
            start_time = time.time()
            Train(model_loss,
                  t,
                  train_data,
                  eps_scheduler,
                  norm,
                  True,
                  opt,
                  args.bound_type,
                  loss_fusion=not args.no_loss_fusion)
            lr_scheduler.step()
            epoch_time = time.time() - start_time
            timer += epoch_time
            logger.log('Epoch time: {:.4f}, Total time: {:.4f}'.format(
                epoch_time, timer))

            logger.log("Evaluating...")
            torch.cuda.empty_cache()

            # remove 'model.' in state_dict for CrossEntropyWrapper
            state_dict_loss = model_loss.state_dict()
            state_dict = {}
            if not args.no_loss_fusion:
                for name in state_dict_loss:
                    assert (name.startswith('model.'))
                    state_dict[name[6:]] = state_dict_loss[name]
            else:
                state_dict = state_dict_loss

            with torch.no_grad():
                if t > int(eps_scheduler.params['start']) + int(
                        eps_scheduler.params['length']):
                    m = Train(model_loss,
                              t,
                              test_data,
                              FixedScheduler(8. / 255),
                              norm,
                              False,
                              None,
                              'IBP',
                              loss_fusion=False,
                              final_node_name=final_name2)
                else:
                    m = Train(model_loss,
                              t,
                              test_data,
                              eps_scheduler,
                              norm,
                              False,
                              None,
                              'IBP',
                              loss_fusion=False,
                              final_node_name=final_name2)

            save_dict = {
                'state_dict': state_dict,
                'epoch': t,
                'optimizer': opt.state_dict()
            }
            if t < int(eps_scheduler.params['start']):
                torch.save(save_dict, 'saved_models/natural_' + exp_name)
            elif t > int(eps_scheduler.params['start']) + int(
                    eps_scheduler.params['length']):
                current_acc = m.avg('Verified_Err')
                if current_acc < best_acc:
                    best_acc = current_acc
                    torch.save(
                        save_dict, 'saved_models/' + exp_name + '_best_' +
                        str(best_acc)[:6])
                else:
                    torch.save(save_dict, 'saved_models/' + exp_name)
            else:
                torch.save(save_dict, 'saved_models/' + exp_name)
            torch.cuda.empty_cache()
Ejemplo n.º 4
0
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    ## Load the model with BoundedParameter for weight perturbation.
    model_ori = models.Models['mlp_3layer_weight_perturb']()

    epoch = 0
    ## Load a checkpoint, if requested.
    if args.load:
        checkpoint = torch.load(args.load)
        epoch, state_dict = checkpoint['epoch'], checkpoint['state_dict']
        opt_state = None
        try:
            opt_state = checkpoint['optimizer']
        except KeyError:
            print('no opt_state found')
        for k, v in state_dict.items():
            assert torch.isnan(v).any().cpu().numpy() == 0 and torch.isinf(v).any().cpu().numpy() == 0
        model_ori.load_state_dict(state_dict)
        logger.log('Checkpoint loaded: {}'.format(args.load))

    ## Step 2: Prepare dataset as usual
    dummy_input = torch.randn(1, 1, 28, 28)
    train_data,  test_data = mnist_loaders(datasets.MNIST, batch_size=args.batch_size, ratio=args.ratio)
    train_data.mean = test_data.mean = torch.tensor([0.0])
    train_data.std = test_data.std = torch.tensor([1.0])

    ## Step 3: wrap model with auto_LiRPA
    # The second parameter dummy_input is for constructing the trace of the computational graph.
    model = BoundedModule(model_ori, dummy_input, bound_opts={'relu':args.bound_opts}, device=args.device)
    final_name1 = model.final_name
    model_loss = BoundedModule(CrossEntropyWrapper(model_ori), (dummy_input, torch.zeros(1, dtype=torch.long)),
                               bound_opts= { 'relu': args.bound_opts, 'loss_fusion': True }, device=args.device)

    # after CrossEntropyWrapper, the final name will change because of one more input node in CrossEntropyWrapper
    final_name2 = model_loss._modules[final_name1].output_name[0]
    assert type(model._modules[final_name1]) == type(model_loss._modules[final_name2])
    if args.multigpu:
        model_loss = BoundDataParallel(model_loss)
    model_loss.ptb = model.ptb = model_ori.ptb # Perturbation on the parameters

    ## Step 4 prepare optimizer, epsilon scheduler and learning rate scheduler
    if args.opt == 'ADAM':
        opt = optim.Adam(model_loss.parameters(), lr=args.lr, weight_decay=0.01)
    elif args.opt == 'SGD':
        opt = optim.SGD(model_loss.parameters(), lr=args.lr, weight_decay=0.01)

    norm = float(args.norm)
    lr_scheduler = optim.lr_scheduler.MultiStepLR(opt, milestones=args.lr_decay_milestones, gamma=0.1)
    eps_scheduler = eval(args.scheduler_name)(args.eps, args.scheduler_opts)
    logger.log(str(model_ori))

    # Skip epochs if we continue training from a checkpoint.
    if epoch > 0:
        epoch_length = int((len(train_data.dataset) + train_data.batch_size - 1) / train_data.batch_size)
        eps_scheduler.set_epoch_length(epoch_length)
        eps_scheduler.train()
        for i in range(epoch):
            lr_scheduler.step()
            eps_scheduler.step_epoch(verbose=True)
            for j in range(epoch_length):
                eps_scheduler.step_batch()
        logger.log('resume from eps={:.12f}'.format(eps_scheduler.get_eps()))

    if args.load:
        if opt_state:
            opt.load_state_dict(opt_state)
            logger.log('resume opt_state')

    ## Step 5: start training.
    if args.verify:
        eps_scheduler = FixedScheduler(args.eps)
        with torch.no_grad():
            Train(model, 1, test_data, eps_scheduler, norm, False, None, 'CROWN-IBP', loss_fusion=False, final_node_name=None)
    else:
        timer = 0.0
        best_loss = 1e10
        # Main training loop
        for t in range(epoch + 1, args.num_epochs+1):
            logger.log("Epoch {}, learning rate {}".format(t, lr_scheduler.get_last_lr()))
            start_time = time.time()

            # Training one epoch
            Train(model_loss, t, train_data, eps_scheduler, norm, True, opt, args.bound_type, loss_fusion=True)
            lr_scheduler.step()
            epoch_time = time.time() - start_time
            timer += epoch_time
            logger.log('Epoch time: {:.4f}, Total time: {:.4f}'.format(epoch_time, timer))

            logger.log("Evaluating...")
            torch.cuda.empty_cache()

            # remove 'model.' in state_dict (hack for saving models so far...)
            state_dict_loss = model_loss.state_dict()
            state_dict = {}
            for name in state_dict_loss:
                assert (name.startswith('model.'))
                state_dict[name[6:]] = state_dict_loss[name]

            # Test one epoch.
            with torch.no_grad():
                m = Train(model_loss, t, test_data, eps_scheduler, norm, False, None, args.bound_type,
                            loss_fusion=False, final_node_name=final_name2)

            # Save checkpoints.
            save_dict = {'state_dict': state_dict, 'epoch': t, 'optimizer': opt.state_dict()}
            if not os.path.exists('saved_models'):
                os.mkdir('saved_models')
            if t < int(eps_scheduler.params['start']):
                torch.save(save_dict, 'saved_models/natural_' + exp_name)
            elif t > int(eps_scheduler.params['start']) + int(eps_scheduler.params['length']):
                current_loss = m.avg('Loss')
                if current_loss < best_loss:
                    best_loss = current_loss
                    torch.save(save_dict, 'saved_models/' + exp_name + '_best_' + str(best_loss)[:6])
                else:
                    torch.save(save_dict, 'saved_models/' + exp_name)
            else:
                torch.save(save_dict, 'saved_models/' + exp_name)
            torch.cuda.empty_cache()
Ejemplo n.º 5
0
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    ## Step 1: Initial original model as usual, see model details in models/example_feedforward.py and models/example_resnet.py
    model_ori = PointNet(
        number_points=args.num_points,
        num_classes=40,
        pool_function=args.pooling
    )

    if args.load:
        state_dict = torch.load(args.load)
        model_ori.load_state_dict(state_dict)
        print(state_dict)

    ## Step 2: Prepare dataset as usual

    train_data = datasets.modelnet40(num_points=args.num_points, split='train', rotate='z')
    test_data = datasets.modelnet40(num_points=args.num_points, split='test', rotate='none')

    train_data = DataLoader(
        dataset=train_data,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=4
    )
    test_data = DataLoader(
        dataset=test_data,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=4
    )
    dummy_input = torch.randn(2, args.num_points, 3)

    ## Step 3: wrap model with auto_LiRPA
    # The second parameter dummy_input is for constructing the trace of the computational graph.
    model = BoundedModule(model_ori, dummy_input, bound_opts={'relu': args.bound_opts, 'conv_mode': args.conv_mode}, device=args.device)

    ## Step 4 prepare optimizer, epsilon scheduler and learning rate scheduler
    opt = optim.Adam(model.parameters(), lr=args.lr)
    norm = float(args.norm)
    lr_scheduler = optim.lr_scheduler.StepLR(opt, step_size=10, gamma=0.5)
    eps_scheduler = eval(args.scheduler_name)(args.eps, args.scheduler_opts)
    print("Model structure: \n", str(model_ori))

    ## Step 5: start training
    if args.verify:
        eps_scheduler = FixedScheduler(args.eps)
        with torch.no_grad():
            Train(model, 1, test_data, eps_scheduler, norm, False, None, args.bound_type)
    else:
        timer = 0.0
        for t in range(1, args.num_epochs + 1):
            if eps_scheduler.reached_max_eps():
                # Only decay learning rate after reaching the maximum eps
                lr_scheduler.step()
            print("Epoch {}, learning rate {}".format(t, lr_scheduler.get_lr()))
            start_time = time.time()
            Train(model, t, train_data, eps_scheduler, norm, True, opt, args.bound_type)
            epoch_time = time.time() - start_time
            timer += epoch_time
            print('Epoch time: {:.4f}, Total time: {:.4f}'.format(epoch_time, timer))
            print("Evaluating...")
            with torch.no_grad():
                Train(model, t, test_data, eps_scheduler, norm, False, None, args.bound_type)
            torch.save(model.state_dict(), args.save_model if args.save_model != "" else args.model)
Ejemplo n.º 6
0
class RobustDeterministicActorCriticNet(nn.Module, BaseNet):
    def __init__(self,
                 state_dim,
                 action_dim,
                 actor_network,
                 critic_network,
                 mini_batch_size,
                 actor_opt_fn,
                 critic_opt_fn,
                 robust_params=None):
        super(RobustDeterministicActorCriticNet, self).__init__()

        if robust_params is None:
            robust_params = {}
        self.use_loss_fusion = robust_params.get('use_loss_fusion', False) # Use loss fusion to reduce complexity for convex relaxation. Default is False.
        self.use_full_backward = robust_params.get('use_full_backward', False)
        if self.use_loss_fusion:
            # Use auto_LiRPA to compute the L2 norm directly.
            self.fc_action = model_mlp_any_with_loss(state_dim, actor_network, action_dim)
            modules = self.fc_action._modules
            # Auto LiRPA wrapper
            self.fc_action = BoundedModule(
                    self.fc_action, (torch.empty(size=(1, state_dim)), torch.empty(size=(1, action_dim))), device=Config.DEVICE)
            # self.fc_action._modules = modules
            for n in self.fc_action.nodes:
                # Find the tanh neuron in computational graph
                if isinstance(n, BoundTanh):
                    self.fc_action_after_tanh = n
                    self.fc_action_pre_tanh = n.inputs[0]
                    break
        else:
            # Fully connected layer with [state_dim, 400, 300, action_dim] neurons and ReLU activation function
            self.fc_action = model_mlp_any(state_dim, actor_network, action_dim)
            # auto_lirpa wrapper
            self.fc_action = BoundedModule(
                    self.fc_action, (torch.empty(size=(1, state_dim)), ), device=Config.DEVICE)

        # Fully connected layer with [state_dim + action_dim, 400, 300, 1]
        self.fc_critic = model_mlp_any(state_dim + action_dim, critic_network, 1)
        # auto_lirpa wrapper
        self.fc_critic = BoundedModule(
                self.fc_critic, (torch.empty(size=(1, state_dim + action_dim)), ), device=Config.DEVICE)

        self.actor_params = self.fc_action.parameters()
        self.critic_params = self.fc_critic.parameters()

        self.actor_opt = actor_opt_fn(self.actor_params)
        self.critic_opt = critic_opt_fn(self.critic_params)
        self.to(Config.DEVICE)
        # Create identity specification matrices
        self.actor_identity = torch.eye(action_dim).repeat(mini_batch_size,1,1).to(Config.DEVICE)
        self.critic_identity = torch.eye(1).repeat(mini_batch_size,1,1).to(Config.DEVICE)
        self.action_dim = action_dim
        self.state_dim = state_dim

    def forward(self, obs):
        phi = self.feature(obs)
        action = self.actor(phi)
        return action

    def feature(self, obs):
        # Not used, originally this is a feature extraction network
        return tensor(obs)

    def actor(self, phi):
        if self.use_loss_fusion:
            self.fc_action(phi, torch.zeros(size=phi.size()[:1] + (self.action_dim,), device=Config.DEVICE))
            return self.fc_action_after_tanh.forward_value
        else:
            return torch.tanh(self.fc_action(phi, method_opt="forward"))

    # Obtain element-wise lower and upper bounds for actor network through convex relaxations.
    def actor_bound(self, phi_lb, phi_ub, beta=1.0, eps=None, norm=np.inf, upper=True, lower=True, phi = None, center = None):
        if self.use_loss_fusion: # Use loss fusion (not typically enabled)
            assert center is not None
            ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=phi_lb, x_U=phi_ub)
            x = BoundedTensor(phi, ptb)
            val = self.fc_action(x, center.detach())
            ilb, iub = self.fc_action.compute_bounds(IBP=True, method=None)
            if beta > 1e-10:
                clb, cub = self.fc_action.compute_bounds(IBP=False, method="backward", bound_lower=False, bound_upper=True)
                ub = cub * beta + iub * (1.0 - beta)
                return ub
            else:
                return iub
        else:
            assert center is None
            # Invoke auto_LiRPA for convex relaxation.
            ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=phi_lb, x_U=phi_ub)
            x = BoundedTensor(phi, ptb)
            if self.use_full_backward:
                clb, cub = self.fc_action.compute_bounds(x=(x,), IBP=False, method="backward")
                return cub, clb
            else:
                ilb, iub = self.fc_action.compute_bounds(x=(x,), IBP=True, method=None)
                if beta > 1e-10:
                    clb, cub = self.fc_action.compute_bounds(IBP=False, method="backward")
                    ub = cub * beta + iub * (1.0 - beta)
                    lb = clb * beta + ilb * (1.0 - beta)
                    return ub, lb
                else:
                    return iub, ilb


    def critic(self, phi, a):
        return self.fc_critic(torch.cat([phi, a], dim=1), method_opt="forward")

    # Obtain element-wise lower and upper bounds for critic network through convex relaxations.
    def critic_bound(self, phi_lb, phi_ub, a_lb, a_ub, beta=1.0, eps=None, phi=None, action=None, norm=np.inf, upper=True, lower=True):
        x_L = torch.cat([phi_lb, a_lb], dim=1)
        x_U = torch.cat([phi_ub, a_ub], dim=1)
        ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=x_L, x_U=x_U)
        x = BoundedTensor(torch.cat([phi, action], dim=1), ptb)
        ilb, iub = self.fc_critic.compute_bounds(x=(x,), IBP=True, method=None)
        if beta > 1e-10:
            clb, cub = self.fc_critic.compute_bounds(IBP=False, method="backward")
            ub = cub * beta + iub * (1.0 - beta)
            lb = clb * beta + ilb * (1.0 - beta)
            return ub, lb
        else:
            return iub, ilb
        
    def load_state_dict(self, state_dict, strict=True):
        action_dict = OrderedDict()
        critic_dict = OrderedDict()
        for k in state_dict.keys():
            if 'action' in k:
                pos = k.find('.') + 1
                action_dict[k[pos:]] = state_dict[k]
            if 'critic' in k:
                pos = k.find('.') + 1
                critic_dict[k[pos:]] = state_dict[k]
        # loading actor and critic networks separtely. this is requried for auto lirpa.
        self.fc_action.load_state_dict(action_dict)
        self.fc_critic.load_state_dict(critic_dict)

    def state_dict(self):
        # save actor and critic networks separtely. this is requried for auto lirpa.
        action_state_dict = self.fc_action.state_dict()
        critic_state_dict = self.fc_critic.state_dict()
        network_state_dict = OrderedDict()
        for k,v in action_state_dict.items():
            network_state_dict["fc_action."+k] = v
        for k,v in critic_state_dict.items():
            network_state_dict["fc_critic."+k] = v
        return network_state_dict