Ejemplo n.º 1
0
def Train(model,
          t,
          loader,
          eps_scheduler,
          norm,
          train,
          opt,
          bound_type,
          method='robust'):
    num_class = 10
    meter = MultiAverageMeter()
    if train:
        model.train()
        eps_scheduler.train()
        eps_scheduler.step_epoch()
        eps_scheduler.set_epoch_length(
            int((len(loader.dataset) + loader.batch_size - 1) /
                loader.batch_size))
    else:
        model.eval()
        eps_scheduler.eval()

    for i, (data, labels) in enumerate(loader):
        start = time.time()
        eps_scheduler.step_batch()
        eps = eps_scheduler.get_eps()
        # For small eps just use natural training, no need to compute LiRPA bounds
        batch_method = method
        if eps < 1e-20:
            batch_method = "natural"
        if train:
            opt.zero_grad()
        # generate specifications
        c = torch.eye(num_class).type_as(data)[labels].unsqueeze(
            1) - torch.eye(num_class).type_as(data).unsqueeze(0)
        # remove specifications to self
        I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(
            labels.data).unsqueeze(0)))
        c = (c[I].view(data.size(0), num_class - 1, num_class))
        # bound input for Linf norm used only
        if norm == np.inf:
            data_max = torch.reshape((1. - loader.mean) / loader.std,
                                     (1, -1, 1, 1))
            data_min = torch.reshape((0. - loader.mean) / loader.std,
                                     (1, -1, 1, 1))
            data_ub = torch.min(data + (eps / loader.std).view(1, -1, 1, 1),
                                data_max)
            data_lb = torch.max(data - (eps / loader.std).view(1, -1, 1, 1),
                                data_min)
        else:
            data_ub = data_lb = data

        if list(model.parameters())[0].is_cuda:
            data, labels, c = data.cuda(), labels.cuda(), c.cuda()
            data_lb, data_ub = data_lb.cuda(), data_ub.cuda()

        # Specify Lp norm perturbation.
        # When using Linf perturbation, we manually set element-wise bound x_L and x_U. eps is not used for Linf norm.
        if norm > 0:
            ptb = PerturbationLpNorm(norm=norm,
                                     eps=eps,
                                     x_L=data_lb,
                                     x_U=data_ub)
        elif norm == 0:
            ptb = PerturbationL0Norm(eps=eps_scheduler.get_max_eps(),
                                     ratio=eps_scheduler.get_eps() /
                                     eps_scheduler.get_max_eps())
        x = BoundedTensor(data, ptb)

        output = model(x)
        regular_ce = CrossEntropyLoss()(
            output, labels)  # regular CrossEntropyLoss used for warming up
        meter.update('CE', regular_ce.item(), x.size(0))
        meter.update(
            'Err',
            torch.sum(
                torch.argmax(output, dim=1) != labels).cpu().detach().numpy() /
            x.size(0), x.size(0))

        if batch_method == "robust":
            if bound_type == "IBP":
                lb, ub = model.compute_bounds(IBP=True, C=c, method=None)
            elif bound_type == "CROWN":
                lb, ub = model.compute_bounds(IBP=False,
                                              C=c,
                                              method="backward",
                                              bound_upper=False)
            elif bound_type == "CROWN-IBP":
                # lb, ub = model.compute_bounds(ptb=ptb, IBP=True, x=data, C=c, method="backward")  # pure IBP bound
                # we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020)
                factor = (eps_scheduler.get_max_eps() -
                          eps) / eps_scheduler.get_max_eps()
                ilb, iub = model.compute_bounds(IBP=True, C=c, method=None)
                if factor < 1e-5:
                    lb = ilb
                else:
                    clb, cub = model.compute_bounds(IBP=False,
                                                    C=c,
                                                    method="backward",
                                                    bound_upper=False)
                    lb = clb * factor + ilb * (1 - factor)
            elif bound_type == "CROWN-FAST":
                # model.compute_bounds(IBP=True, C=c, method=None)
                lb, ub = model.compute_bounds(IBP=True, C=c, method=None)
                lb, ub = model.compute_bounds(IBP=False,
                                              C=c,
                                              method="backward",
                                              bound_upper=False)

            # Pad zero at the beginning for each example, and use fake label "0" for all examples
            lb_padded = torch.cat((torch.zeros(
                size=(lb.size(0), 1), dtype=lb.dtype, device=lb.device), lb),
                                  dim=1)
            fake_labels = torch.zeros(size=(lb.size(0), ),
                                      dtype=torch.int64,
                                      device=lb.device)
            robust_ce = CrossEntropyLoss()(-lb_padded, fake_labels)
        if batch_method == "robust":
            loss = robust_ce
        elif batch_method == "natural":
            loss = regular_ce
        if train:
            loss.backward()
            eps_scheduler.update_loss(loss.item() - regular_ce.item())
            opt.step()
        meter.update('Loss', loss.item(), data.size(0))
        if batch_method != "natural":
            meter.update('Robust_CE', robust_ce.item(), data.size(0))
            # For an example, if lower bounds of margins is >0 for all classes, the output is verifiably correct.
            # If any margin is < 0 this example is counted as an error
            meter.update('Verified_Err',
                         torch.sum((lb < 0).any(dim=1)).item() / data.size(0),
                         data.size(0))
        meter.update('Time', time.time() - start)
        if i % 50 == 0 and train:
            print('[{:2d}:{:4d}]: eps={:.8f} {}'.format(t, i, eps, meter))
    print('[{:2d}:{:4d}]: eps={:.8f} {}'.format(t, i, eps, meter))
Ejemplo n.º 2
0
def step(model, ptb, batch, eps=1.0, train=False):
    model_bound = model.model_from_embeddings
    if train:
        model.train()
        model_bound.train()
        grad = torch.enable_grad()
        if args.loss_fusion:
            model_loss.train()
    else:
        model.eval()
        model_bound.eval()
        grad = torch.no_grad()
    if args.auto_test:
        grad = torch.enable_grad()

    with grad:
        ptb.set_eps(eps)
        ptb.set_train(train)
        embeddings_unbounded, mask, tokens, labels = model.get_input(batch)
        aux = (tokens, batch)
        if args.robust and eps > 1e-9:
            embeddings = BoundedTensor(embeddings_unbounded, ptb)
        else:
            embeddings = embeddings_unbounded.detach().requires_grad_(True)

        robust = args.robust and eps > 1e-6

        if train and robust and args.loss_fusion:
            # loss_fusion loss
            if args.method == 'IBP+backward_train':
                lb, ub = model_loss.compute_bounds(x=(labels, embeddings,
                                                      mask),
                                                   aux=aux,
                                                   C=None,
                                                   method='IBP+backward',
                                                   bound_lower=False)
            else:
                raise NotImplementedError
            loss_robust = torch.log(ub).mean()
            loss = acc = acc_robust = -1  # unknown
        else:
            # regular loss
            logits = model_bound(embeddings, mask)
            loss = CrossEntropyLoss()(logits, labels)
            acc = (torch.argmax(logits, dim=1) == labels).float().mean()

            if robust:
                num_class = args.num_classes
                c = torch.eye(num_class).type_as(embeddings)[labels].unsqueeze(1) - \
                    torch.eye(num_class).type_as(embeddings).unsqueeze(0)
                I = (~(labels.data.unsqueeze(1)
                       == torch.arange(num_class).type_as(
                           labels.data).unsqueeze(0)))
                c = (c[I].view(embeddings.size(0), num_class - 1, num_class))
                if args.method in [
                        'IBP', 'IBP+backward', 'forward', 'forward+backward'
                ]:
                    lb, ub = model_bound.compute_bounds(aux=aux,
                                                        C=c,
                                                        method=args.method,
                                                        bound_upper=False)
                elif args.method == 'IBP+backward_train':
                    # CROWN-IBP
                    if 1 - eps > 1e-4:
                        lb, ub = model_bound.compute_bounds(
                            aux=aux,
                            C=c,
                            method='IBP+backward',
                            bound_upper=False)
                        ilb, iub = model_bound.compute_bounds(aux=aux,
                                                              C=c,
                                                              method='IBP',
                                                              reuse_ibp=True)
                        lb = eps * ilb + (1 - eps) * lb
                    else:
                        lb, ub = model_bound.compute_bounds(aux=aux,
                                                            C=c,
                                                            method='IBP')
                else:
                    raise NotImplementedError
                lb_padded = torch.cat((torch.zeros(size=(lb.size(0), 1),
                                                   dtype=lb.dtype,
                                                   device=lb.device), lb),
                                      dim=1)
                fake_labels = torch.zeros(size=(lb.size(0), ),
                                          dtype=torch.int64,
                                          device=lb.device)
                loss_robust = robust_ce = CrossEntropyLoss()(-lb_padded,
                                                             fake_labels)
                acc_robust = 1 - torch.mean((lb < 0).any(dim=1).float())
            else:
                acc_robust, loss_robust = acc, loss

    if train or args.auto_test:
        loss_robust.backward()
        grad_embed = torch.autograd.grad(embeddings_unbounded,
                                         model.word_embeddings.weight,
                                         grad_outputs=embeddings.grad)[0]
        if model.word_embeddings.weight.grad is None:
            model.word_embeddings.weight.grad = grad_embed
        else:
            model.word_embeddings.weight.grad += grad_embed

    if args.auto_test:
        with open('res_test.pkl', 'wb') as file:
            pickle.dump((float(acc), float(loss), float(acc_robust),
                         float(loss_robust), grad_embed.detach().numpy()),
                        file)

    return acc, loss, acc_robust, loss_robust
Ejemplo n.º 3
0
def Train(model,
          t,
          loader,
          eps_scheduler,
          norm,
          train,
          opt,
          bound_type,
          method='robust',
          loss_fusion=True,
          final_node_name=None):
    num_class = 200
    meter = MultiAverageMeter()
    if train:
        model.train()
        eps_scheduler.train()
        eps_scheduler.step_epoch()
        eps_scheduler.set_epoch_length(
            int((len(loader.dataset) + loader.batch_size - 1) /
                loader.batch_size))
    else:
        model.eval()
        eps_scheduler.eval()

    exp_module = get_exp_module(model)

    def get_bound_loss(x=None, c=None):
        if loss_fusion:
            bound_lower, bound_upper = False, True
        else:
            bound_lower, bound_upper = True, False

        if bound_type == 'IBP':
            lb, ub = model(method_opt="compute_bounds",
                           x=x,
                           IBP=True,
                           C=c,
                           method=None,
                           final_node_name=final_node_name,
                           no_replicas=True)
        elif bound_type == 'CROWN':
            lb, ub = model(method_opt="compute_bounds",
                           x=x,
                           IBP=False,
                           C=c,
                           method='backward',
                           bound_lower=bound_lower,
                           bound_upper=bound_upper)
        elif bound_type == 'CROWN-IBP':
            # lb, ub = model.compute_bounds(ptb=ptb, IBP=True, x=data, C=c, method='backward')  # pure IBP bound
            # we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020)
            factor = (eps_scheduler.get_max_eps() -
                      eps_scheduler.get_eps()) / eps_scheduler.get_max_eps()
            ilb, iub = model(method_opt="compute_bounds",
                             x=x,
                             IBP=True,
                             C=c,
                             method=None,
                             final_node_name=final_node_name,
                             no_replicas=True)
            if factor < 1e-50:
                lb, ub = ilb, iub
            else:
                clb, cub = model(method_opt="compute_bounds",
                                 IBP=False,
                                 C=c,
                                 method='backward',
                                 bound_lower=bound_lower,
                                 bound_upper=bound_upper,
                                 final_node_name=final_node_name,
                                 no_replicas=True)
                if loss_fusion:
                    ub = cub * factor + iub * (1 - factor)
                else:
                    lb = clb * factor + ilb * (1 - factor)

        if loss_fusion:
            if isinstance(model, BoundDataParallel):
                max_input = model(get_property=True,
                                  node_class=BoundExp,
                                  att_name='max_input')
            else:
                max_input = exp_module.max_input
            return None, torch.mean(torch.log(ub) + max_input)
        else:
            # Pad zero at the beginning for each example, and use fake label '0' for all examples
            lb_padded = torch.cat((torch.zeros(
                size=(lb.size(0), 1), dtype=lb.dtype, device=lb.device), lb),
                                  dim=1)
            fake_labels = torch.zeros(size=(lb.size(0), ),
                                      dtype=torch.int64,
                                      device=lb.device)
            robust_ce = CrossEntropyLoss()(-lb_padded, fake_labels)
            return lb, robust_ce

    for i, (data, labels) in enumerate(loader):
        start = time.time()
        eps_scheduler.step_batch()
        eps = eps_scheduler.get_eps()
        # For small eps just use natural training, no need to compute LiRPA bounds
        batch_method = method
        if eps < 1e-50:
            batch_method = "natural"
        if train:
            opt.zero_grad()
        # bound input for Linf norm used only
        if norm == np.inf:
            data_max = torch.reshape((1. - loader.mean) / loader.std,
                                     (1, -1, 1, 1))
            data_min = torch.reshape((0. - loader.mean) / loader.std,
                                     (1, -1, 1, 1))
            data_ub = torch.min(data + (eps / loader.std).view(1, -1, 1, 1),
                                data_max)
            data_lb = torch.max(data - (eps / loader.std).view(1, -1, 1, 1),
                                data_min)
        else:
            data_ub = data_lb = data

        if list(model.parameters())[0].is_cuda:
            data, labels = data.cuda(), labels.cuda()
            data_lb, data_ub = data_lb.cuda(), data_ub.cuda()

        ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=data_lb, x_U=data_ub)
        x = BoundedTensor(data, ptb)
        if loss_fusion:
            if batch_method == 'natural' or not train:
                output = model(x, labels)
                regular_ce = torch.mean(torch.log(output))
            else:
                model(x, labels)
                regular_ce = torch.tensor(0., device=data.device)
            meter.update('CE', regular_ce.item(), x.size(0))
            x = (x, labels)
            c = None
        else:
            c = torch.eye(num_class).type_as(data)[labels].unsqueeze(
                1) - torch.eye(num_class).type_as(data).unsqueeze(0)
            # remove specifications to self
            I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(
                labels.data).unsqueeze(0)))
            c = (c[I].view(data.size(0), num_class - 1, num_class))
            output = model(x, final_node_name=final_node_name)
            regular_ce = CrossEntropyLoss()(
                output, labels)  # regular CrossEntropyLoss used for warming up
            meter.update('CE', regular_ce.item(), x.size(0))
            meter.update(
                'Err',
                torch.sum(torch.argmax(output, dim=1) != labels).item() /
                x.size(0), x.size(0))
            x = (x, )

        if batch_method == 'robust':
            # print(data.sum())
            lb, robust_ce = get_bound_loss(x=x, c=c)
            loss = robust_ce
        elif batch_method == 'natural':
            loss = regular_ce

        if train:
            loss.backward()

            if args.clip_grad_norm:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), max_norm=args.clip_grad_norm)
                meter.update('grad_norm', grad_norm)

            if isinstance(eps_scheduler, AdaptiveScheduler):
                eps_scheduler.update_loss(loss.item() - regular_ce.item())
            opt.step()
        meter.update('Loss', loss.item(), data.size(0))

        if batch_method != 'natural':
            meter.update('Robust_CE', robust_ce.item(), data.size(0))
            if not loss_fusion:
                # For an example, if lower bounds of margins is >0 for all classes, the output is verifiably correct.
                # If any margin is < 0 this example is counted as an error
                meter.update(
                    'Verified_Err',
                    torch.sum((lb < 0).any(dim=1)).item() / data.size(0),
                    data.size(0))
        meter.update('Time', time.time() - start)

        if (i + 1) % 250 == 0 and train:
            logger.log('[{:2d}:{:4d}]: eps={:.12f} {}'.format(
                t, i + 1, eps, meter))

    logger.log('[{:2d}:{:4d}]: eps={:.12f} {}'.format(t, i + 1, eps, meter))
    return meter