def actor_bound(self, phi_lb, phi_ub, beta=1.0, eps=None, norm=np.inf, upper=True, lower=True, phi = None, center = None): if self.use_loss_fusion: # Use loss fusion (not typically enabled) assert center is not None ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=phi_lb, x_U=phi_ub) x = BoundedTensor(phi, ptb) val = self.fc_action(x, center.detach()) ilb, iub = self.fc_action.compute_bounds(IBP=True, method=None) if beta > 1e-10: clb, cub = self.fc_action.compute_bounds(IBP=False, method="backward", bound_lower=False, bound_upper=True) ub = cub * beta + iub * (1.0 - beta) return ub else: return iub else: assert center is None # Invoke auto_LiRPA for convex relaxation. ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=phi_lb, x_U=phi_ub) x = BoundedTensor(phi, ptb) if self.use_full_backward: clb, cub = self.fc_action.compute_bounds(x=(x,), IBP=False, method="backward") return cub, clb else: ilb, iub = self.fc_action.compute_bounds(x=(x,), IBP=True, method=None) if beta > 1e-10: clb, cub = self.fc_action.compute_bounds(IBP=False, method="backward") ub = cub * beta + iub * (1.0 - beta) lb = clb * beta + ilb * (1.0 - beta) return ub, lb else: return iub, ilb
def build_perturbation(): if args.ptb == "lp_norm": return PerturbationLpNorm(norm=np.inf, eps=args.eps) elif args.ptb == "synonym": return PerturbationSynonym(budget=args.budget) else: raise NotImplementedError
def get_logits_lower_bound(model, state, state_ub, state_lb, eps, C, beta): ptb = PerturbationLpNorm(norm=np.inf, eps=eps, x_L=state_lb, x_U=state_ub) bnd_state = BoundedTensor(state, ptb) pred = model.features(bnd_state, method_opt="forward") logits_ilb, _ = model.features.compute_bounds(C=C, IBP=True, method=None) if beta < 1e-5: logits_lb = logits_ilb else: logits_clb, _ = model.features.compute_bounds(IBP=False, C=C, method="backward", bound_upper=False) logits_lb = beta * logits_clb + (1-beta) * logits_ilb return logits_lb
def critic_bound(self, phi_lb, phi_ub, a_lb, a_ub, beta=1.0, eps=None, phi=None, action=None, norm=np.inf, upper=True, lower=True): x_L = torch.cat([phi_lb, a_lb], dim=1) x_U = torch.cat([phi_ub, a_ub], dim=1) ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=x_L, x_U=x_U) x = BoundedTensor(torch.cat([phi, action], dim=1), ptb) ilb, iub = self.fc_critic.compute_bounds(x=(x,), IBP=True, method=None) if beta > 1e-10: clb, cub = self.fc_critic.compute_bounds(IBP=False, method="backward") ub = cub * beta + iub * (1.0 - beta) lb = clb * beta + ilb * (1.0 - beta) return ub, lb else: return iub, ilb
class mynet(nn.Module): def __init__(self): super(mynet, self).__init__() self.output = nn.sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 3)) def forward(self, input): return self.features(input) raw_model = mynet() bound_model = BoundedModule(raw_model, input_vec) num_actions = 3 batchsize = 5 label = torch.tensor([0, 2, 1, 1, 0]) bnd_state = BoundedTensor(input_vec, PerturbationLpNorm(norm=np.inf, eps=0.1)) c = torch.eye(3).type_as(input_vec)[label].unsqueeze(1) - torch.eye(3).type_as( input_vec).unsqueeze(0) I = (~(label.data.unsqueeze(1) == torch.arange(3).type_as( label.data).unsqueeze(0))) c = (c[I].view(input_vec.size(0), 2, 3)) pred = bound_model(input_vec) basic_bound, _ = bound_model.compute_bounds(IBP=False, method='backward') advance_bound, _ = bound_model.compute_bounds(C=c, IBP=False, method='backward') print(basic_bound.detach().numpy()) print(advance_bound.detach().numpy())
return acc.detach(), acc_robust.detach(), loss.detach() data_train, data_test = load_data() logger.info("Dataset sizes: {}/{}".format(len(data_train), len(data_test))) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) model = LSTM(args).to(args.device) test_batches = get_batches(data_test, args.batch_size) X, y = model.get_input(test_batches[0]) model.core = BoundGeneral(model.core, (X, )) ptb = PerturbationLpNorm(norm=args.norm, eps=args.eps) optimizer = model.build_optimizer() avg_acc, avg_acc_robust, avg_loss = avg = [AverageMeter() for i in range(3)] def train(epoch): model.train() train_batches = get_batches(data_train, args.batch_size) for a in avg: a.reset() eps_inc_per_step = 1.0 / (args.num_epochs_warmup * len(train_batches)) for i, batch in enumerate(train_batches): eps = args.eps * min( eps_inc_per_step * ((epoch - 1) * len(train_batches) + i + 1), 1.0) acc, acc_robust, loss = res = step(model,
def sarsa_steps(self, saps): # Begin advanged logging code assert saps.unrolled loss = torch.nn.SmoothL1Loss() action_std = torch.exp(self.policy_model.log_stdev).detach().requires_grad_(False) # Avoid backprop twice. # We treat all value epochs as one epoch. self.sarsa_eps_scheduler.set_epoch_length(self.params.VAL_EPOCHS * self.params.NUM_MINIBATCHES) self.sarsa_beta_scheduler.set_epoch_length(self.params.VAL_EPOCHS * self.params.NUM_MINIBATCHES) # We count from 1. self.sarsa_eps_scheduler.step_epoch() self.sarsa_beta_scheduler.step_epoch() # saps contains state->action->reward and not_done. for i in range(self.params.VAL_EPOCHS): # Create minibatches with shuffuling state_indices = np.arange(saps.rewards.nelement()) np.random.shuffle(state_indices) splits = np.array_split(state_indices, self.params.NUM_MINIBATCHES) # Minibatch SGD for selected in splits: def sel(*args): return [v[selected] for v in args] self.sarsa_opt.zero_grad() sel_states, sel_actions, sel_rewards, sel_not_dones = sel(saps.states, saps.actions, saps.rewards, saps.not_dones) self.sarsa_eps_scheduler.step_batch() self.sarsa_beta_scheduler.step_batch() inputs = torch.cat((sel_states, sel_actions), dim=1) # action_diff = self.sarsa_eps_scheduler.get_eps() * action_std # inputs_lb = torch.cat((sel_states, sel_actions - action_diff), dim=1).detach().requires_grad_(False) # inputs_ub = torch.cat((sel_states, sel_actions + action_diff), dim=1).detach().requires_grad_(False) # bounded_inputs = BoundedTensor(inputs, ptb=PerturbationLpNorm(norm=np.inf, eps=None, x_L=inputs_lb, x_U=inputs_ub)) bounded_inputs = BoundedTensor(inputs, ptb=PerturbationLpNorm(norm=np.inf, eps=self.sarsa_eps_scheduler.get_eps())) q = self.relaxed_sarsa_model(bounded_inputs).squeeze(-1) q_old = q[:-1] q_next = q[1:] * self.GAMMA * sel_not_dones[:-1] + sel_rewards[:-1] q_next = q_next.detach() # q_loss = (q_old - q_next).pow(2).sum(dim=-1).mean() q_loss = loss(q_old, q_next) # Compute the robustness regularization. if self.sarsa_eps_scheduler.get_eps() > 0 and self.params.SARSA_REG > 0: beta = self.sarsa_beta_scheduler.get_eps() ilb, iub = self.relaxed_sarsa_model.compute_bounds(IBP=True, method=None) if beta < 1: clb, cub = self.relaxed_sarsa_model.compute_bounds(IBP=False, method='backward') lb = beta * ilb + (1 - beta) * clb ub = beta * iub + (1 - beta) * cub else: lb = ilb ub = iub # Output dimension is 1. Remove the extra dimension and keep only the batch dimension. lb = lb.squeeze(-1) ub = ub.squeeze(-1) diff = torch.max(ub - q, q - lb) reg_loss = self.params.SARSA_REG * (diff * diff).mean() sarsa_loss = q_loss + reg_loss reg_loss = reg_loss.item() else: reg_loss = 0.0 sarsa_loss = q_loss sarsa_loss.backward() self.sarsa_opt.step() print(f'q_loss={q_loss.item():.6g}, reg_loss={reg_loss:.6g}, sarsa_loss={sarsa_loss.item():.6g}') if self.ANNEAL_LR: self.sarsa_scheduler.step() # print('value:', self.val_model(saps.states).mean().item()) return q_loss, q.mean()