Exemple #1
0
 def actor_bound(self, phi_lb, phi_ub, beta=1.0, eps=None, norm=np.inf, upper=True, lower=True, phi = None, center = None):
     if self.use_loss_fusion: # Use loss fusion (not typically enabled)
         assert center is not None
         ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=phi_lb, x_U=phi_ub)
         x = BoundedTensor(phi, ptb)
         val = self.fc_action(x, center.detach())
         ilb, iub = self.fc_action.compute_bounds(IBP=True, method=None)
         if beta > 1e-10:
             clb, cub = self.fc_action.compute_bounds(IBP=False, method="backward", bound_lower=False, bound_upper=True)
             ub = cub * beta + iub * (1.0 - beta)
             return ub
         else:
             return iub
     else:
         assert center is None
         # Invoke auto_LiRPA for convex relaxation.
         ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=phi_lb, x_U=phi_ub)
         x = BoundedTensor(phi, ptb)
         if self.use_full_backward:
             clb, cub = self.fc_action.compute_bounds(x=(x,), IBP=False, method="backward")
             return cub, clb
         else:
             ilb, iub = self.fc_action.compute_bounds(x=(x,), IBP=True, method=None)
             if beta > 1e-10:
                 clb, cub = self.fc_action.compute_bounds(IBP=False, method="backward")
                 ub = cub * beta + iub * (1.0 - beta)
                 lb = clb * beta + ilb * (1.0 - beta)
                 return ub, lb
             else:
                 return iub, ilb
Exemple #2
0
def build_perturbation():
    if args.ptb == "lp_norm":
        return PerturbationLpNorm(norm=np.inf, eps=args.eps)
    elif args.ptb == "synonym":
        return PerturbationSynonym(budget=args.budget)
    else:
        raise NotImplementedError
Exemple #3
0
def get_logits_lower_bound(model, state, state_ub, state_lb, eps, C, beta):
    ptb = PerturbationLpNorm(norm=np.inf, eps=eps, x_L=state_lb, x_U=state_ub)
    bnd_state = BoundedTensor(state, ptb)
    pred = model.features(bnd_state, method_opt="forward")
    logits_ilb, _ = model.features.compute_bounds(C=C, IBP=True, method=None)
    if beta < 1e-5:
        logits_lb = logits_ilb
    else:
        logits_clb, _ = model.features.compute_bounds(IBP=False, C=C, method="backward", bound_upper=False)
        logits_lb = beta * logits_clb + (1-beta) * logits_ilb
    return logits_lb
Exemple #4
0
 def critic_bound(self, phi_lb, phi_ub, a_lb, a_ub, beta=1.0, eps=None, phi=None, action=None, norm=np.inf, upper=True, lower=True):
     x_L = torch.cat([phi_lb, a_lb], dim=1)
     x_U = torch.cat([phi_ub, a_ub], dim=1)
     ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=x_L, x_U=x_U)
     x = BoundedTensor(torch.cat([phi, action], dim=1), ptb)
     ilb, iub = self.fc_critic.compute_bounds(x=(x,), IBP=True, method=None)
     if beta > 1e-10:
         clb, cub = self.fc_critic.compute_bounds(IBP=False, method="backward")
         ub = cub * beta + iub * (1.0 - beta)
         lb = clb * beta + ilb * (1.0 - beta)
         return ub, lb
     else:
         return iub, ilb
Exemple #5
0
class mynet(nn.Module):
    def __init__(self):
        super(mynet, self).__init__()
        self.output = nn.sequential(nn.Linear(5, 10), nn.ReLU(),
                                    nn.Linear(10, 3))

    def forward(self, input):
        return self.features(input)


raw_model = mynet()
bound_model = BoundedModule(raw_model, input_vec)
num_actions = 3
batchsize = 5
label = torch.tensor([0, 2, 1, 1, 0])
bnd_state = BoundedTensor(input_vec, PerturbationLpNorm(norm=np.inf, eps=0.1))

c = torch.eye(3).type_as(input_vec)[label].unsqueeze(1) - torch.eye(3).type_as(
    input_vec).unsqueeze(0)
I = (~(label.data.unsqueeze(1) == torch.arange(3).type_as(
    label.data).unsqueeze(0)))
c = (c[I].view(input_vec.size(0), 2, 3))

pred = bound_model(input_vec)
basic_bound, _ = bound_model.compute_bounds(IBP=False, method='backward')
advance_bound, _ = bound_model.compute_bounds(C=c,
                                              IBP=False,
                                              method='backward')
print(basic_bound.detach().numpy())
print(advance_bound.detach().numpy())
Exemple #6
0
    return acc.detach(), acc_robust.detach(), loss.detach()


data_train, data_test = load_data()
logger.info("Dataset sizes: {}/{}".format(len(data_train), len(data_test)))

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

model = LSTM(args).to(args.device)
test_batches = get_batches(data_test, args.batch_size)
X, y = model.get_input(test_batches[0])
model.core = BoundGeneral(model.core, (X, ))
ptb = PerturbationLpNorm(norm=args.norm, eps=args.eps)
optimizer = model.build_optimizer()

avg_acc, avg_acc_robust, avg_loss = avg = [AverageMeter() for i in range(3)]


def train(epoch):
    model.train()
    train_batches = get_batches(data_train, args.batch_size)
    for a in avg:
        a.reset()
    eps_inc_per_step = 1.0 / (args.num_epochs_warmup * len(train_batches))
    for i, batch in enumerate(train_batches):
        eps = args.eps * min(
            eps_inc_per_step * ((epoch - 1) * len(train_batches) + i + 1), 1.0)
        acc, acc_robust, loss = res = step(model,
Exemple #7
0
    def sarsa_steps(self, saps):
        # Begin advanged logging code
        assert saps.unrolled
        loss = torch.nn.SmoothL1Loss()
        action_std = torch.exp(self.policy_model.log_stdev).detach().requires_grad_(False)  # Avoid backprop twice.
        # We treat all value epochs as one epoch.
        self.sarsa_eps_scheduler.set_epoch_length(self.params.VAL_EPOCHS * self.params.NUM_MINIBATCHES)
        self.sarsa_beta_scheduler.set_epoch_length(self.params.VAL_EPOCHS * self.params.NUM_MINIBATCHES)
        # We count from 1.
        self.sarsa_eps_scheduler.step_epoch()
        self.sarsa_beta_scheduler.step_epoch()
        # saps contains state->action->reward and not_done.
        for i in range(self.params.VAL_EPOCHS):
            # Create minibatches with shuffuling
            state_indices = np.arange(saps.rewards.nelement())
            np.random.shuffle(state_indices)
            splits = np.array_split(state_indices, self.params.NUM_MINIBATCHES)

            # Minibatch SGD
            for selected in splits:
                def sel(*args):
                    return [v[selected] for v in args]

                self.sarsa_opt.zero_grad()
                sel_states, sel_actions, sel_rewards, sel_not_dones = sel(saps.states, saps.actions, saps.rewards, saps.not_dones)
                
                self.sarsa_eps_scheduler.step_batch()
                self.sarsa_beta_scheduler.step_batch()
                
                inputs = torch.cat((sel_states, sel_actions), dim=1)
                # action_diff = self.sarsa_eps_scheduler.get_eps() * action_std
                # inputs_lb = torch.cat((sel_states, sel_actions - action_diff), dim=1).detach().requires_grad_(False)
                # inputs_ub = torch.cat((sel_states, sel_actions + action_diff), dim=1).detach().requires_grad_(False)
                # bounded_inputs = BoundedTensor(inputs, ptb=PerturbationLpNorm(norm=np.inf, eps=None, x_L=inputs_lb, x_U=inputs_ub))
                bounded_inputs = BoundedTensor(inputs, ptb=PerturbationLpNorm(norm=np.inf, eps=self.sarsa_eps_scheduler.get_eps()))

                q = self.relaxed_sarsa_model(bounded_inputs).squeeze(-1)
                q_old = q[:-1]
                q_next = q[1:] * self.GAMMA * sel_not_dones[:-1] + sel_rewards[:-1]
                q_next = q_next.detach()
                # q_loss = (q_old - q_next).pow(2).sum(dim=-1).mean()
                q_loss = loss(q_old, q_next)
                # Compute the robustness regularization.
                if self.sarsa_eps_scheduler.get_eps() > 0 and self.params.SARSA_REG > 0:
                    beta = self.sarsa_beta_scheduler.get_eps()
                    ilb, iub = self.relaxed_sarsa_model.compute_bounds(IBP=True, method=None)
                    if beta < 1:
                        clb, cub = self.relaxed_sarsa_model.compute_bounds(IBP=False, method='backward')
                        lb = beta * ilb + (1 - beta) * clb
                        ub = beta * iub + (1 - beta) * cub
                    else:
                        lb = ilb
                        ub = iub
                    # Output dimension is 1. Remove the extra dimension and keep only the batch dimension.
                    lb = lb.squeeze(-1)
                    ub = ub.squeeze(-1)
                    diff = torch.max(ub - q, q - lb)
                    reg_loss = self.params.SARSA_REG * (diff * diff).mean()
                    sarsa_loss = q_loss + reg_loss
                    reg_loss = reg_loss.item()
                else:
                    reg_loss = 0.0
                    sarsa_loss = q_loss
                sarsa_loss.backward()
                self.sarsa_opt.step()
            print(f'q_loss={q_loss.item():.6g}, reg_loss={reg_loss:.6g}, sarsa_loss={sarsa_loss.item():.6g}')

        if self.ANNEAL_LR:
            self.sarsa_scheduler.step()
        # print('value:', self.val_model(saps.states).mean().item())

        return q_loss, q.mean()