def main(args): torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) random.seed(args.seed) np.random.seed(args.seed) ## Step 1: Initial original model as usual, see model details in models/example_feedforward.py and models/example_resnet.py if args.data == 'MNIST': model_ori = models.Models[args.model](in_ch=1, in_dim=28) else: model_ori = models.Models[args.model](in_ch=3, in_dim=32) if args.load: state_dict = torch.load(args.load)['state_dict'] model_ori.load_state_dict(state_dict) ## Step 2: Prepare dataset as usual if args.data == 'MNIST': dummy_input = torch.randn(1, 1, 28, 28) train_data = datasets.MNIST("./data", train=True, download=True, transform=transforms.ToTensor()) test_data = datasets.MNIST("./data", train=False, download=True, transform=transforms.ToTensor()) elif args.data == 'CIFAR': dummy_input = torch.randn(1, 3, 32, 32) normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) train_data = datasets.CIFAR10("./data", train=True, download=True, transform=transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, 4), transforms.ToTensor(), normalize ])) test_data = datasets.CIFAR10("./data", train=False, download=True, transform=transforms.Compose( [transforms.ToTensor(), normalize])) train_data = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=min( multiprocessing.cpu_count(), 4)) test_data = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, pin_memory=True, num_workers=min( multiprocessing.cpu_count(), 4)) if args.data == 'MNIST': train_data.mean = test_data.mean = torch.tensor([0.0]) train_data.std = test_data.std = torch.tensor([1.0]) elif args.data == 'CIFAR': train_data.mean = test_data.mean = torch.tensor( [0.4914, 0.4822, 0.4465]) train_data.std = test_data.std = torch.tensor([0.2023, 0.1994, 0.2010]) ## Step 3: wrap model with auto_LiRPA # The second parameter dummy_input is for constructing the trace of the computational graph. model = BoundedModule(model_ori, dummy_input, bound_opts={'relu': args.bound_opts}, device=args.device) ## Step 4 prepare optimizer, epsilon scheduler and learning rate scheduler opt = optim.Adam(model.parameters(), lr=args.lr) norm = float(args.norm) lr_scheduler = optim.lr_scheduler.StepLR(opt, step_size=10, gamma=0.5) eps_scheduler = eval(args.scheduler_name)(args.eps, args.scheduler_opts) print("Model structure: \n", str(model_ori)) ## Step 5: start training if args.verify: eps_scheduler = FixedScheduler(args.eps) with torch.no_grad(): Train(model, 1, test_data, eps_scheduler, norm, False, None, args.bound_type) else: timer = 0.0 for t in range(1, args.num_epochs + 1): if eps_scheduler.reached_max_eps(): # Only decay learning rate after reaching the maximum eps lr_scheduler.step() print("Epoch {}, learning rate {}".format(t, lr_scheduler.get_lr())) start_time = time.time() Train(model, t, train_data, eps_scheduler, norm, True, opt, args.bound_type) epoch_time = time.time() - start_time timer += epoch_time print('Epoch time: {:.4f}, Total time: {:.4f}'.format( epoch_time, timer)) print("Evaluating...") with torch.no_grad(): Train(model, t, test_data, eps_scheduler, norm, False, None, args.bound_type) torch.save({ 'state_dict': model.state_dict(), 'epoch': t }, args.model)
def main(args): torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) random.seed(args.seed) np.random.seed(args.seed) ## Step 1: Initial original model as usual, see model details in models/example_feedforward.py and models/example_resnet.py model_ori = models.Models[args.model]() epoch = 0 if args.load: checkpoint = torch.load(args.load) epoch, state_dict = checkpoint['epoch'], checkpoint['state_dict'] opt_state = None try: opt_state = checkpoint['optimizer'] except KeyError: print('no opt_state found') for k, v in state_dict.items(): assert torch.isnan(v).any().cpu().numpy() == 0 and torch.isinf( v).any().cpu().numpy() == 0 model_ori.load_state_dict(state_dict) logger.log('Checkpoint loaded: {}'.format(args.load)) ## Step 2: Prepare dataset as usual dummy_input = torch.randn(1, 3, 56, 56) normalize = transforms.Normalize(mean=[0.4802, 0.4481, 0.3975], std=[0.2302, 0.2265, 0.2262]) train_data = datasets.ImageFolder(args.data_dir + '/train', transform=transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop( 56, padding_mode='edge'), transforms.ToTensor(), normalize, ])) test_data = datasets.ImageFolder( args.data_dir + '/val', transform=transforms.Compose([ # transforms.RandomResizedCrop(64, scale=(0.875, 0.875), ratio=(1., 1.)), transforms.CenterCrop(56), transforms.ToTensor(), normalize ])) train_data = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=min( multiprocessing.cpu_count(), 4)) test_data = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size // 5, pin_memory=True, num_workers=min( multiprocessing.cpu_count(), 4)) train_data.mean = test_data.mean = torch.tensor([0.4802, 0.4481, 0.3975]) train_data.std = test_data.std = torch.tensor([0.2302, 0.2265, 0.2262]) ## Step 3: wrap model with auto_LiRPA # The second parameter dummy_input is for constructing the trace of the computational graph. model = BoundedModule(model_ori, dummy_input, bound_opts={'relu': args.bound_opts}, device=args.device) model_loss = BoundedModule(CrossEntropyWrapper(model_ori), (dummy_input, torch.zeros(1, dtype=torch.long)), bound_opts={ 'relu': args.bound_opts, 'loss_fusion': True }, device=args.device) model_loss = BoundDataParallel(model_loss) ## Step 4 prepare optimizer, epsilon scheduler and learning rate scheduler opt = optim.Adam(model_loss.parameters(), lr=args.lr) norm = float(args.norm) lr_scheduler = optim.lr_scheduler.MultiStepLR( opt, milestones=args.lr_decay_milestones, gamma=0.1) eps_scheduler = eval(args.scheduler_name)(args.eps, args.scheduler_opts) logger.log(str(model_ori)) if args.load: if opt_state: opt.load_state_dict(opt_state) logger.log('resume opt_state') # skip epochs if epoch > 0: epoch_length = int( (len(train_data.dataset) + train_data.batch_size - 1) / train_data.batch_size) eps_scheduler.set_epoch_length(epoch_length) eps_scheduler.train() for i in range(epoch): lr_scheduler.step() eps_scheduler.step_epoch(verbose=True) for j in range(epoch_length): eps_scheduler.step_batch() logger.log('resume from eps={:.12f}'.format(eps_scheduler.get_eps())) ## Step 5: start training if args.verify: eps_scheduler = FixedScheduler(args.eps) with torch.no_grad(): Train(model, 1, test_data, eps_scheduler, norm, False, None, 'IBP', loss_fusion=False, final_node_name=None) else: timer = 0.0 best_err = 1e10 # with torch.autograd.detect_anomaly(): for t in range(epoch + 1, args.num_epochs + 1): logger.log("Epoch {}, learning rate {}".format( t, lr_scheduler.get_last_lr())) start_time = time.time() Train(model_loss, t, train_data, eps_scheduler, norm, True, opt, args.bound_type, loss_fusion=True) lr_scheduler.step() epoch_time = time.time() - start_time timer += epoch_time logger.log('Epoch time: {:.4f}, Total time: {:.4f}'.format( epoch_time, timer)) logger.log("Evaluating...") torch.cuda.empty_cache() # remove 'model.' in state_dict state_dict_loss = model_loss.state_dict() state_dict = {} for name in state_dict_loss: assert (name.startswith('model.')) state_dict[name[6:]] = state_dict_loss[name] with torch.no_grad(): if int(eps_scheduler.params['start']) + int( eps_scheduler.params['length']) > t >= int( eps_scheduler.params['start']): m = Train(model_loss, t, test_data, eps_scheduler, norm, False, None, args.bound_type, loss_fusion=True) else: model_ori.load_state_dict(state_dict) model = BoundedModule(model_ori, dummy_input, bound_opts={'relu': args.bound_opts}, device=args.device) model = BoundDataParallel(model) m = Train(model, t, test_data, eps_scheduler, norm, False, None, 'IBP', loss_fusion=False) del model save_dict = { 'state_dict': state_dict, 'epoch': t, 'optimizer': opt.state_dict() } if t < int(eps_scheduler.params['start']): torch.save(save_dict, 'saved_models/natural_' + exp_name) elif t > int(eps_scheduler.params['start']) + int( eps_scheduler.params['length']): current_err = m.avg('Verified_Err') if current_err < best_err: best_err = current_err torch.save( save_dict, 'saved_models/' + exp_name + '_best_' + str(best_err)[:6]) else: torch.save(save_dict, 'saved_models/' + exp_name) torch.cuda.empty_cache()
def main(args): torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) random.seed(args.seed) np.random.seed(args.seed) ## Step 1: Initial original model as usual, see model details in models/example_feedforward.py and models/example_resnet.py if args.data == 'MNIST': model_ori = models.Models[args.model](in_ch=1, in_dim=28) else: model_ori = models.Models[args.model]() epoch = 0 if args.load: checkpoint = torch.load(args.load) epoch, state_dict = checkpoint['epoch'], checkpoint['state_dict'] opt_state = None try: opt_state = checkpoint['optimizer'] except KeyError: print('no opt_state found') for k, v in state_dict.items(): assert torch.isnan(v).any().cpu().numpy() == 0 and torch.isinf( v).any().cpu().numpy() == 0 model_ori.load_state_dict(state_dict) logger.log('Checkpoint loaded: {}'.format(args.load)) ## Step 2: Prepare dataset as usual if args.data == 'MNIST': dummy_input = torch.randn(1, 1, 28, 28) train_data = datasets.MNIST("./data", train=True, download=True, transform=transforms.ToTensor()) test_data = datasets.MNIST("./data", train=False, download=True, transform=transforms.ToTensor()) elif args.data == 'CIFAR': dummy_input = torch.randn(1, 3, 32, 32) normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) train_data = datasets.CIFAR10("./data", train=True, download=True, transform=transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop( 32, 4, padding_mode='edge'), transforms.ToTensor(), normalize ])) test_data = datasets.CIFAR10("./data", train=False, download=True, transform=transforms.Compose( [transforms.ToTensor(), normalize])) train_data = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=min( multiprocessing.cpu_count(), 4)) test_data = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size // 2, pin_memory=True, num_workers=min( multiprocessing.cpu_count(), 4)) if args.data == 'MNIST': train_data.mean = test_data.mean = torch.tensor([0.0]) train_data.std = test_data.std = torch.tensor([1.0]) elif args.data == 'CIFAR': train_data.mean = test_data.mean = torch.tensor( [0.4914, 0.4822, 0.4465]) train_data.std = test_data.std = torch.tensor([0.2023, 0.1994, 0.2010]) ## Step 3: wrap model with auto_LiRPA # The second parameter dummy_input is for constructing the trace of the computational graph. model = BoundedModule(model_ori, dummy_input, bound_opts={'relu': args.bound_opts}, device=args.device) final_name1 = model.final_name model_loss = BoundedModule(CrossEntropyWrapper(model_ori), (dummy_input, torch.zeros(1, dtype=torch.long)), bound_opts={ 'relu': args.bound_opts, 'loss_fusion': True }, device=args.device) # after CrossEntropyWrapper, the final name will change because of one additional input node in CrossEntropyWrapper final_name2 = model_loss._modules[final_name1].output_name[0] assert type(model._modules[final_name1]) == type( model_loss._modules[final_name2]) if args.no_loss_fusion: model_loss = BoundedModule(model_ori, dummy_input, bound_opts={'relu': args.bound_opts}, device=args.device) final_name2 = None model_loss = BoundDataParallel(model_loss) macs, params = profile(model_ori, (dummy_input.cuda(), )) logger.log('macs: {}, params: {}'.format(macs, params)) ## Step 4 prepare optimizer, epsilon scheduler and learning rate scheduler opt = optim.Adam(model_loss.parameters(), lr=args.lr) norm = float(args.norm) lr_scheduler = optim.lr_scheduler.MultiStepLR( opt, milestones=args.lr_decay_milestones, gamma=0.1) eps_scheduler = eval(args.scheduler_name)(args.eps, args.scheduler_opts) logger.log(str(model_ori)) # skip epochs if epoch > 0: epoch_length = int( (len(train_data.dataset) + train_data.batch_size - 1) / train_data.batch_size) eps_scheduler.set_epoch_length(epoch_length) eps_scheduler.train() for i in range(epoch): lr_scheduler.step() eps_scheduler.step_epoch(verbose=True) for j in range(epoch_length): eps_scheduler.step_batch() logger.log('resume from eps={:.12f}'.format(eps_scheduler.get_eps())) if args.load: if opt_state: opt.load_state_dict(opt_state) logger.log('resume opt_state') ## Step 5: start training if args.verify: eps_scheduler = FixedScheduler(args.eps) with torch.no_grad(): Train(model, 1, test_data, eps_scheduler, norm, False, None, 'IBP', loss_fusion=False, final_node_name=None) else: timer = 0.0 best_acc = 1e10 # with torch.autograd.detect_anomaly(): for t in range(epoch + 1, args.num_epochs + 1): logger.log("Epoch {}, learning rate {}".format( t, lr_scheduler.get_last_lr())) start_time = time.time() Train(model_loss, t, train_data, eps_scheduler, norm, True, opt, args.bound_type, loss_fusion=not args.no_loss_fusion) lr_scheduler.step() epoch_time = time.time() - start_time timer += epoch_time logger.log('Epoch time: {:.4f}, Total time: {:.4f}'.format( epoch_time, timer)) logger.log("Evaluating...") torch.cuda.empty_cache() # remove 'model.' in state_dict for CrossEntropyWrapper state_dict_loss = model_loss.state_dict() state_dict = {} if not args.no_loss_fusion: for name in state_dict_loss: assert (name.startswith('model.')) state_dict[name[6:]] = state_dict_loss[name] else: state_dict = state_dict_loss with torch.no_grad(): if t > int(eps_scheduler.params['start']) + int( eps_scheduler.params['length']): m = Train(model_loss, t, test_data, FixedScheduler(8. / 255), norm, False, None, 'IBP', loss_fusion=False, final_node_name=final_name2) else: m = Train(model_loss, t, test_data, eps_scheduler, norm, False, None, 'IBP', loss_fusion=False, final_node_name=final_name2) save_dict = { 'state_dict': state_dict, 'epoch': t, 'optimizer': opt.state_dict() } if t < int(eps_scheduler.params['start']): torch.save(save_dict, 'saved_models/natural_' + exp_name) elif t > int(eps_scheduler.params['start']) + int( eps_scheduler.params['length']): current_acc = m.avg('Verified_Err') if current_acc < best_acc: best_acc = current_acc torch.save( save_dict, 'saved_models/' + exp_name + '_best_' + str(best_acc)[:6]) else: torch.save(save_dict, 'saved_models/' + exp_name) else: torch.save(save_dict, 'saved_models/' + exp_name) torch.cuda.empty_cache()
def main(args): torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) random.seed(args.seed) np.random.seed(args.seed) ## Load the model with BoundedParameter for weight perturbation. model_ori = models.Models['mlp_3layer_weight_perturb']() epoch = 0 ## Load a checkpoint, if requested. if args.load: checkpoint = torch.load(args.load) epoch, state_dict = checkpoint['epoch'], checkpoint['state_dict'] opt_state = None try: opt_state = checkpoint['optimizer'] except KeyError: print('no opt_state found') for k, v in state_dict.items(): assert torch.isnan(v).any().cpu().numpy() == 0 and torch.isinf(v).any().cpu().numpy() == 0 model_ori.load_state_dict(state_dict) logger.log('Checkpoint loaded: {}'.format(args.load)) ## Step 2: Prepare dataset as usual dummy_input = torch.randn(1, 1, 28, 28) train_data, test_data = mnist_loaders(datasets.MNIST, batch_size=args.batch_size, ratio=args.ratio) train_data.mean = test_data.mean = torch.tensor([0.0]) train_data.std = test_data.std = torch.tensor([1.0]) ## Step 3: wrap model with auto_LiRPA # The second parameter dummy_input is for constructing the trace of the computational graph. model = BoundedModule(model_ori, dummy_input, bound_opts={'relu':args.bound_opts}, device=args.device) final_name1 = model.final_name model_loss = BoundedModule(CrossEntropyWrapper(model_ori), (dummy_input, torch.zeros(1, dtype=torch.long)), bound_opts= { 'relu': args.bound_opts, 'loss_fusion': True }, device=args.device) # after CrossEntropyWrapper, the final name will change because of one more input node in CrossEntropyWrapper final_name2 = model_loss._modules[final_name1].output_name[0] assert type(model._modules[final_name1]) == type(model_loss._modules[final_name2]) if args.multigpu: model_loss = BoundDataParallel(model_loss) model_loss.ptb = model.ptb = model_ori.ptb # Perturbation on the parameters ## Step 4 prepare optimizer, epsilon scheduler and learning rate scheduler if args.opt == 'ADAM': opt = optim.Adam(model_loss.parameters(), lr=args.lr, weight_decay=0.01) elif args.opt == 'SGD': opt = optim.SGD(model_loss.parameters(), lr=args.lr, weight_decay=0.01) norm = float(args.norm) lr_scheduler = optim.lr_scheduler.MultiStepLR(opt, milestones=args.lr_decay_milestones, gamma=0.1) eps_scheduler = eval(args.scheduler_name)(args.eps, args.scheduler_opts) logger.log(str(model_ori)) # Skip epochs if we continue training from a checkpoint. if epoch > 0: epoch_length = int((len(train_data.dataset) + train_data.batch_size - 1) / train_data.batch_size) eps_scheduler.set_epoch_length(epoch_length) eps_scheduler.train() for i in range(epoch): lr_scheduler.step() eps_scheduler.step_epoch(verbose=True) for j in range(epoch_length): eps_scheduler.step_batch() logger.log('resume from eps={:.12f}'.format(eps_scheduler.get_eps())) if args.load: if opt_state: opt.load_state_dict(opt_state) logger.log('resume opt_state') ## Step 5: start training. if args.verify: eps_scheduler = FixedScheduler(args.eps) with torch.no_grad(): Train(model, 1, test_data, eps_scheduler, norm, False, None, 'CROWN-IBP', loss_fusion=False, final_node_name=None) else: timer = 0.0 best_loss = 1e10 # Main training loop for t in range(epoch + 1, args.num_epochs+1): logger.log("Epoch {}, learning rate {}".format(t, lr_scheduler.get_last_lr())) start_time = time.time() # Training one epoch Train(model_loss, t, train_data, eps_scheduler, norm, True, opt, args.bound_type, loss_fusion=True) lr_scheduler.step() epoch_time = time.time() - start_time timer += epoch_time logger.log('Epoch time: {:.4f}, Total time: {:.4f}'.format(epoch_time, timer)) logger.log("Evaluating...") torch.cuda.empty_cache() # remove 'model.' in state_dict (hack for saving models so far...) state_dict_loss = model_loss.state_dict() state_dict = {} for name in state_dict_loss: assert (name.startswith('model.')) state_dict[name[6:]] = state_dict_loss[name] # Test one epoch. with torch.no_grad(): m = Train(model_loss, t, test_data, eps_scheduler, norm, False, None, args.bound_type, loss_fusion=False, final_node_name=final_name2) # Save checkpoints. save_dict = {'state_dict': state_dict, 'epoch': t, 'optimizer': opt.state_dict()} if not os.path.exists('saved_models'): os.mkdir('saved_models') if t < int(eps_scheduler.params['start']): torch.save(save_dict, 'saved_models/natural_' + exp_name) elif t > int(eps_scheduler.params['start']) + int(eps_scheduler.params['length']): current_loss = m.avg('Loss') if current_loss < best_loss: best_loss = current_loss torch.save(save_dict, 'saved_models/' + exp_name + '_best_' + str(best_loss)[:6]) else: torch.save(save_dict, 'saved_models/' + exp_name) else: torch.save(save_dict, 'saved_models/' + exp_name) torch.cuda.empty_cache()
def main(args): torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) random.seed(args.seed) np.random.seed(args.seed) ## Step 1: Initial original model as usual, see model details in models/example_feedforward.py and models/example_resnet.py model_ori = PointNet( number_points=args.num_points, num_classes=40, pool_function=args.pooling ) if args.load: state_dict = torch.load(args.load) model_ori.load_state_dict(state_dict) print(state_dict) ## Step 2: Prepare dataset as usual train_data = datasets.modelnet40(num_points=args.num_points, split='train', rotate='z') test_data = datasets.modelnet40(num_points=args.num_points, split='test', rotate='none') train_data = DataLoader( dataset=train_data, batch_size=args.batch_size, shuffle=True, num_workers=4 ) test_data = DataLoader( dataset=test_data, batch_size=args.batch_size, shuffle=False, num_workers=4 ) dummy_input = torch.randn(2, args.num_points, 3) ## Step 3: wrap model with auto_LiRPA # The second parameter dummy_input is for constructing the trace of the computational graph. model = BoundedModule(model_ori, dummy_input, bound_opts={'relu': args.bound_opts, 'conv_mode': args.conv_mode}, device=args.device) ## Step 4 prepare optimizer, epsilon scheduler and learning rate scheduler opt = optim.Adam(model.parameters(), lr=args.lr) norm = float(args.norm) lr_scheduler = optim.lr_scheduler.StepLR(opt, step_size=10, gamma=0.5) eps_scheduler = eval(args.scheduler_name)(args.eps, args.scheduler_opts) print("Model structure: \n", str(model_ori)) ## Step 5: start training if args.verify: eps_scheduler = FixedScheduler(args.eps) with torch.no_grad(): Train(model, 1, test_data, eps_scheduler, norm, False, None, args.bound_type) else: timer = 0.0 for t in range(1, args.num_epochs + 1): if eps_scheduler.reached_max_eps(): # Only decay learning rate after reaching the maximum eps lr_scheduler.step() print("Epoch {}, learning rate {}".format(t, lr_scheduler.get_lr())) start_time = time.time() Train(model, t, train_data, eps_scheduler, norm, True, opt, args.bound_type) epoch_time = time.time() - start_time timer += epoch_time print('Epoch time: {:.4f}, Total time: {:.4f}'.format(epoch_time, timer)) print("Evaluating...") with torch.no_grad(): Train(model, t, test_data, eps_scheduler, norm, False, None, args.bound_type) torch.save(model.state_dict(), args.save_model if args.save_model != "" else args.model)
class RobustDeterministicActorCriticNet(nn.Module, BaseNet): def __init__(self, state_dim, action_dim, actor_network, critic_network, mini_batch_size, actor_opt_fn, critic_opt_fn, robust_params=None): super(RobustDeterministicActorCriticNet, self).__init__() if robust_params is None: robust_params = {} self.use_loss_fusion = robust_params.get('use_loss_fusion', False) # Use loss fusion to reduce complexity for convex relaxation. Default is False. self.use_full_backward = robust_params.get('use_full_backward', False) if self.use_loss_fusion: # Use auto_LiRPA to compute the L2 norm directly. self.fc_action = model_mlp_any_with_loss(state_dim, actor_network, action_dim) modules = self.fc_action._modules # Auto LiRPA wrapper self.fc_action = BoundedModule( self.fc_action, (torch.empty(size=(1, state_dim)), torch.empty(size=(1, action_dim))), device=Config.DEVICE) # self.fc_action._modules = modules for n in self.fc_action.nodes: # Find the tanh neuron in computational graph if isinstance(n, BoundTanh): self.fc_action_after_tanh = n self.fc_action_pre_tanh = n.inputs[0] break else: # Fully connected layer with [state_dim, 400, 300, action_dim] neurons and ReLU activation function self.fc_action = model_mlp_any(state_dim, actor_network, action_dim) # auto_lirpa wrapper self.fc_action = BoundedModule( self.fc_action, (torch.empty(size=(1, state_dim)), ), device=Config.DEVICE) # Fully connected layer with [state_dim + action_dim, 400, 300, 1] self.fc_critic = model_mlp_any(state_dim + action_dim, critic_network, 1) # auto_lirpa wrapper self.fc_critic = BoundedModule( self.fc_critic, (torch.empty(size=(1, state_dim + action_dim)), ), device=Config.DEVICE) self.actor_params = self.fc_action.parameters() self.critic_params = self.fc_critic.parameters() self.actor_opt = actor_opt_fn(self.actor_params) self.critic_opt = critic_opt_fn(self.critic_params) self.to(Config.DEVICE) # Create identity specification matrices self.actor_identity = torch.eye(action_dim).repeat(mini_batch_size,1,1).to(Config.DEVICE) self.critic_identity = torch.eye(1).repeat(mini_batch_size,1,1).to(Config.DEVICE) self.action_dim = action_dim self.state_dim = state_dim def forward(self, obs): phi = self.feature(obs) action = self.actor(phi) return action def feature(self, obs): # Not used, originally this is a feature extraction network return tensor(obs) def actor(self, phi): if self.use_loss_fusion: self.fc_action(phi, torch.zeros(size=phi.size()[:1] + (self.action_dim,), device=Config.DEVICE)) return self.fc_action_after_tanh.forward_value else: return torch.tanh(self.fc_action(phi, method_opt="forward")) # Obtain element-wise lower and upper bounds for actor network through convex relaxations. def actor_bound(self, phi_lb, phi_ub, beta=1.0, eps=None, norm=np.inf, upper=True, lower=True, phi = None, center = None): if self.use_loss_fusion: # Use loss fusion (not typically enabled) assert center is not None ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=phi_lb, x_U=phi_ub) x = BoundedTensor(phi, ptb) val = self.fc_action(x, center.detach()) ilb, iub = self.fc_action.compute_bounds(IBP=True, method=None) if beta > 1e-10: clb, cub = self.fc_action.compute_bounds(IBP=False, method="backward", bound_lower=False, bound_upper=True) ub = cub * beta + iub * (1.0 - beta) return ub else: return iub else: assert center is None # Invoke auto_LiRPA for convex relaxation. ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=phi_lb, x_U=phi_ub) x = BoundedTensor(phi, ptb) if self.use_full_backward: clb, cub = self.fc_action.compute_bounds(x=(x,), IBP=False, method="backward") return cub, clb else: ilb, iub = self.fc_action.compute_bounds(x=(x,), IBP=True, method=None) if beta > 1e-10: clb, cub = self.fc_action.compute_bounds(IBP=False, method="backward") ub = cub * beta + iub * (1.0 - beta) lb = clb * beta + ilb * (1.0 - beta) return ub, lb else: return iub, ilb def critic(self, phi, a): return self.fc_critic(torch.cat([phi, a], dim=1), method_opt="forward") # Obtain element-wise lower and upper bounds for critic network through convex relaxations. def critic_bound(self, phi_lb, phi_ub, a_lb, a_ub, beta=1.0, eps=None, phi=None, action=None, norm=np.inf, upper=True, lower=True): x_L = torch.cat([phi_lb, a_lb], dim=1) x_U = torch.cat([phi_ub, a_ub], dim=1) ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=x_L, x_U=x_U) x = BoundedTensor(torch.cat([phi, action], dim=1), ptb) ilb, iub = self.fc_critic.compute_bounds(x=(x,), IBP=True, method=None) if beta > 1e-10: clb, cub = self.fc_critic.compute_bounds(IBP=False, method="backward") ub = cub * beta + iub * (1.0 - beta) lb = clb * beta + ilb * (1.0 - beta) return ub, lb else: return iub, ilb def load_state_dict(self, state_dict, strict=True): action_dict = OrderedDict() critic_dict = OrderedDict() for k in state_dict.keys(): if 'action' in k: pos = k.find('.') + 1 action_dict[k[pos:]] = state_dict[k] if 'critic' in k: pos = k.find('.') + 1 critic_dict[k[pos:]] = state_dict[k] # loading actor and critic networks separtely. this is requried for auto lirpa. self.fc_action.load_state_dict(action_dict) self.fc_critic.load_state_dict(critic_dict) def state_dict(self): # save actor and critic networks separtely. this is requried for auto lirpa. action_state_dict = self.fc_action.state_dict() critic_state_dict = self.fc_critic.state_dict() network_state_dict = OrderedDict() for k,v in action_state_dict.items(): network_state_dict["fc_action."+k] = v for k,v in critic_state_dict.items(): network_state_dict["fc_critic."+k] = v return network_state_dict