Exemple #1
0
def evaluate_natural(args, model, test_loader, verbose=False):
    model.eval()

    with torch.no_grad():
        meter = MultiAverageMeter()
        test_loss = test_acc = test_n = 0

        def test_step(step, X_batch, y_batch):
            X, y = X_batch.cuda(), y_batch.cuda()
            if args.method == 'natural':
                output = model(X)
                loss = F.cross_entropy(output, y)
            else:
                raise NotImplementedError(args.method)
            meter.update('test_loss', loss.item(), y.size(0))
            meter.update('test_acc', (output.max(1)[1] == y).float().mean(), y.size(0))
            if step % args.log_interval == 0 and verbose:
                logger.info('Eval step {}/{} {}'.format(
                    step, total_test_steps, meter))               

        for step, (X_batch, y_batch) in enumerate(test_loader):
            test_step(step, X_batch, y_batch)

        logger.info('Evaluation {}'.format(meter))  
def Train(model,
          t,
          loader,
          eps_scheduler,
          norm,
          train,
          opt,
          bound_type,
          method='robust'):
    num_class = 10
    meter = MultiAverageMeter()
    if train:
        model.train()
        eps_scheduler.train()
        eps_scheduler.step_epoch()
        eps_scheduler.set_epoch_length(
            int((len(loader.dataset) + loader.batch_size - 1) /
                loader.batch_size))
    else:
        model.eval()
        eps_scheduler.eval()

    for i, (data, labels) in enumerate(loader):
        start = time.time()
        eps_scheduler.step_batch()
        eps = eps_scheduler.get_eps()
        # For small eps just use natural training, no need to compute LiRPA bounds
        batch_method = method
        if eps < 1e-20:
            batch_method = "natural"
        if train:
            opt.zero_grad()
        # generate specifications
        c = torch.eye(num_class).type_as(data)[labels].unsqueeze(
            1) - torch.eye(num_class).type_as(data).unsqueeze(0)
        # remove specifications to self
        I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(
            labels.data).unsqueeze(0)))
        c = (c[I].view(data.size(0), num_class - 1, num_class))
        # bound input for Linf norm used only
        if norm == np.inf:
            data_max = torch.reshape((1. - loader.mean) / loader.std,
                                     (1, -1, 1, 1))
            data_min = torch.reshape((0. - loader.mean) / loader.std,
                                     (1, -1, 1, 1))
            data_ub = torch.min(data + (eps / loader.std).view(1, -1, 1, 1),
                                data_max)
            data_lb = torch.max(data - (eps / loader.std).view(1, -1, 1, 1),
                                data_min)
        else:
            data_ub = data_lb = data

        if list(model.parameters())[0].is_cuda:
            data, labels, c = data.cuda(), labels.cuda(), c.cuda()
            data_lb, data_ub = data_lb.cuda(), data_ub.cuda()

        # Specify Lp norm perturbation.
        # When using Linf perturbation, we manually set element-wise bound x_L and x_U. eps is not used for Linf norm.
        ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=data_lb, x_U=data_ub)
        x = BoundedTensor(data, ptb)

        output = model(x)
        regular_ce = CrossEntropyLoss()(
            output, labels)  # regular CrossEntropyLoss used for warming up
        meter.update('CE', regular_ce.item(), x.size(0))
        meter.update(
            'Err',
            torch.sum(
                torch.argmax(output, dim=1) != labels).cpu().detach().numpy() /
            x.size(0), x.size(0))

        if batch_method == "robust":
            if bound_type == "IBP":
                lb, ub = model.compute_bounds(IBP=True, C=c, method=None)
            elif bound_type == "CROWN":
                lb, ub = model.compute_bounds(IBP=False,
                                              C=c,
                                              method="backward",
                                              bound_upper=False)
            elif bound_type == "CROWN-IBP":
                # lb, ub = model.compute_bounds(ptb=ptb, IBP=True, x=data, C=c, method="backward")  # pure IBP bound
                # we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020)
                factor = (eps_scheduler.get_max_eps() -
                          eps) / eps_scheduler.get_max_eps()
                ilb, iub = model.compute_bounds(IBP=True, C=c, method=None)
                if factor < 1e-5:
                    lb = ilb
                else:
                    clb, cub = model.compute_bounds(IBP=False,
                                                    C=c,
                                                    method="backward",
                                                    bound_upper=False)
                    lb = clb * factor + ilb * (1 - factor)

            # Pad zero at the beginning for each example, and use fake label "0" for all examples
            lb_padded = torch.cat((torch.zeros(
                size=(lb.size(0), 1), dtype=lb.dtype, device=lb.device), lb),
                                  dim=1)
            fake_labels = torch.zeros(size=(lb.size(0), ),
                                      dtype=torch.int64,
                                      device=lb.device)
            robust_ce = CrossEntropyLoss()(-lb_padded, fake_labels)
        if batch_method == "robust":
            loss = robust_ce
        elif batch_method == "natural":
            loss = regular_ce
        if train:
            loss.backward()
            eps_scheduler.update_loss(loss.item() - regular_ce.item())
            opt.step()
        meter.update('Loss', loss.item(), data.size(0))
        if batch_method != "natural":
            meter.update('Robust_CE', robust_ce.item(), data.size(0))
            # For an example, if lower bounds of margins is >0 for all classes, the output is verifiably correct.
            # If any margin is < 0 this example is counted as an error
            meter.update('Verified_Err',
                         torch.sum((lb < 0).any(dim=1)).item() / data.size(0),
                         data.size(0))
        meter.update('Time', time.time() - start)
        if i % 50 == 0 and train:
            print('[{:2d}:{:4d}]: eps={:.8f} {}'.format(t, i, eps, meter))
    print('[{:2d}:{:4d}]: eps={:.8f} {}'.format(t, i, eps, meter))
def Train(model,
          model_ori,
          t,
          loader,
          eps_scheduler,
          opt,
          loss_fusion=False,
          valid=False):
    train = opt is not None
    meter = MultiAverageMeter()
    meter_layer = []

    data_max, data_min, std = loader.data_max, loader.data_min, loader.std
    if args.device == 'cuda':
        data_min, data_max, std = data_min.cuda(), data_max.cuda(), std.cuda()

    if train:
        model_ori.train()
        model.train()
        eps_scheduler.train()
        eps_scheduler.step_epoch()
    else:
        model_ori.eval()
        model.eval()
        eps_scheduler.eval()

    for i, (data, labels) in enumerate(loader):
        start = time.time()
        eps_scheduler.step_batch()
        eps = eps_scheduler.get_eps()
        epoch_progress = (i + 1) * 1. / len(loader) if train else 1.0

        if train:
            eps *= args.train_eps_mul
        if eps < args.min_eps:
            eps = args.min_eps
        if args.fix_eps:
            eps = eps_scheduler.get_max_eps()
        if args.natural:
            eps = 0.

        reg = t <= args.num_reg_epochs

        # For small eps just use natural training, no need to compute LiRPA bounds
        batch_method = 'natural' if (eps < 1e-50) else 'robust'
        robust = batch_method == 'robust'

        # labels = labels.to(torch.long)
        if args.device == 'cuda':
            data, labels = data.cuda().detach().requires_grad_(), labels.cuda()

        data_batch, labels_batch = data, labels
        grad_acc = args.grad_acc_steps
        assert data.shape[0] % grad_acc == 0
        bsz = data.shape[0] // grad_acc

        for k in range(grad_acc):
            if grad_acc > 1:
                data, labels = data_batch[bsz * k:bsz *
                                          (k + 1)], labels_batch[bsz * k:bsz *
                                                                 (k + 1)]

            if args.mode == 'cert':
                regular_ce, robust_loss, regular_err, robust_err = cert(
                    args,
                    model,
                    model_ori,
                    t,
                    epoch_progress,
                    data,
                    labels,
                    eps=eps,
                    data_max=data_max,
                    data_min=data_min,
                    std=std,
                    robust=robust,
                    reg=reg,
                    loss_fusion=loss_fusion,
                    eps_scheduler=eps_scheduler,
                    train=train,
                    meter=meter)
            elif args.mode == 'adv':
                method = args.method if train else 'pgd'
                regular_ce, robust_loss, regular_err, robust_err = adv(
                    args,
                    model,
                    model_ori,
                    t,
                    epoch_progress,
                    data,
                    labels,
                    eps=eps,
                    data_max=data_max,
                    data_min=data_min,
                    std=std,
                    train=train,
                    meter=meter)
            else:
                raise NotImplementedError
            update_meter(meter, regular_ce, robust_loss, regular_err,
                         robust_err, data.size(0))

            if reg:
                loss = compute_reg(args, model, meter, eps, eps_scheduler)
            elif args.xiao_reg:
                loss = compute_stab_reg(
                    args, model, meter, eps, eps_scheduler) + compute_L1_reg(
                        args, model_ori, meter, eps, eps_scheduler)
            elif args.vol_reg:  # by colt
                loss = compute_vol_reg(args, model, meter, eps, eps_scheduler)
            else:
                loss = torch.tensor(0.).to(args.device)
            if robust:
                loss += robust_loss * args.kappa + robust_loss * (1 -
                                                                  args.kappa)
            else:
                loss += regular_ce
            meter.update('Loss', loss.item(), data.size(0))

            if train:
                loss /= grad_acc
                loss.backward()

                if args.check_nan:
                    for p in model.parameters():
                        if torch.isnan(p.grad).any():
                            pdb.set_trace()
                            ckpt = {
                                'model_ori':
                                model_ori,
                                'args_cert':
                                (t, epoch_progress, data, labels, eps,
                                 data_max, data_min, std, robust, reg,
                                 loss_fusion, eps_scheduler, train, meter)
                            }
                            torch.save(ckpt, 'nan_ckpt')
                            pdb.set_trace()

        if train:
            grad_norm = torch.nn.utils.clip_grad_norm_(model_ori.parameters(),
                                                       max_norm=args.grad_norm)
            meter.update('grad_norm', grad_norm)
            opt.step()
            opt.zero_grad()

        meter.update('wnorm', get_weight_norm(model_ori))
        meter.update('Time', time.time() - start)

        if (i + 1) % args.log_interval == 0 and (train or args.eval
                                                 or args.verify):
            logger.info('[{:2d}:{:4d}/{:4d}]: eps={:.8f} {}'.format(
                t, i + 1, len(loader), eps, meter))
            if args.debug:
                print()
                pdb.set_trace()
    logger.info('[{:2d}]: eps={:.8f} {}'.format(t, eps, meter))

    if batch_method != 'natural':
        meter.update('eps', eps_scheduler.get_eps())

    if t <= args.num_reg_epochs:
        update_log_reg(writer, meter, t, train, model)
    update_log_writer(args,
                      writer,
                      meter,
                      t,
                      train,
                      robust=(batch_method != 'natural'))

    return meter
Exemple #4
0
def Train(model,
          t,
          loader,
          eps_scheduler,
          norm,
          train,
          opt,
          bound_type,
          method='robust',
          loss_fusion=True,
          final_node_name=None):
    num_class = 200
    meter = MultiAverageMeter()
    if train:
        model.train()
        eps_scheduler.train()
        eps_scheduler.step_epoch()
        eps_scheduler.set_epoch_length(
            int((len(loader.dataset) + loader.batch_size - 1) /
                loader.batch_size))
    else:
        model.eval()
        eps_scheduler.eval()

    exp_module = get_exp_module(model)

    def get_bound_loss(x=None, c=None):
        if loss_fusion:
            bound_lower, bound_upper = False, True
        else:
            bound_lower, bound_upper = True, False

        if bound_type == 'IBP':
            lb, ub = model(method_opt="compute_bounds",
                           x=x,
                           IBP=True,
                           C=c,
                           method=None,
                           final_node_name=final_node_name,
                           no_replicas=True)
        elif bound_type == 'CROWN':
            lb, ub = model(method_opt="compute_bounds",
                           x=x,
                           IBP=False,
                           C=c,
                           method='backward',
                           bound_lower=bound_lower,
                           bound_upper=bound_upper)
        elif bound_type == 'CROWN-IBP':
            # lb, ub = model.compute_bounds(ptb=ptb, IBP=True, x=data, C=c, method='backward')  # pure IBP bound
            # we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020)
            factor = (eps_scheduler.get_max_eps() -
                      eps_scheduler.get_eps()) / eps_scheduler.get_max_eps()
            ilb, iub = model(method_opt="compute_bounds",
                             x=x,
                             IBP=True,
                             C=c,
                             method=None,
                             final_node_name=final_node_name,
                             no_replicas=True)
            if factor < 1e-50:
                lb, ub = ilb, iub
            else:
                clb, cub = model(method_opt="compute_bounds",
                                 IBP=False,
                                 C=c,
                                 method='backward',
                                 bound_lower=bound_lower,
                                 bound_upper=bound_upper,
                                 final_node_name=final_node_name,
                                 no_replicas=True)
                if loss_fusion:
                    ub = cub * factor + iub * (1 - factor)
                else:
                    lb = clb * factor + ilb * (1 - factor)

        if loss_fusion:
            if isinstance(model, BoundDataParallel):
                max_input = model(get_property=True,
                                  node_class=BoundExp,
                                  att_name='max_input')
            else:
                max_input = exp_module.max_input
            return None, torch.mean(torch.log(ub) + max_input)
        else:
            # Pad zero at the beginning for each example, and use fake label '0' for all examples
            lb_padded = torch.cat((torch.zeros(
                size=(lb.size(0), 1), dtype=lb.dtype, device=lb.device), lb),
                                  dim=1)
            fake_labels = torch.zeros(size=(lb.size(0), ),
                                      dtype=torch.int64,
                                      device=lb.device)
            robust_ce = CrossEntropyLoss()(-lb_padded, fake_labels)
            return lb, robust_ce

    for i, (data, labels) in enumerate(loader):
        start = time.time()
        eps_scheduler.step_batch()
        eps = eps_scheduler.get_eps()
        # For small eps just use natural training, no need to compute LiRPA bounds
        batch_method = method
        if eps < 1e-50:
            batch_method = "natural"
        if train:
            opt.zero_grad()
        # bound input for Linf norm used only
        if norm == np.inf:
            data_max = torch.reshape((1. - loader.mean) / loader.std,
                                     (1, -1, 1, 1))
            data_min = torch.reshape((0. - loader.mean) / loader.std,
                                     (1, -1, 1, 1))
            data_ub = torch.min(data + (eps / loader.std).view(1, -1, 1, 1),
                                data_max)
            data_lb = torch.max(data - (eps / loader.std).view(1, -1, 1, 1),
                                data_min)
        else:
            data_ub = data_lb = data

        if list(model.parameters())[0].is_cuda:
            data, labels = data.cuda(), labels.cuda()
            data_lb, data_ub = data_lb.cuda(), data_ub.cuda()

        ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=data_lb, x_U=data_ub)
        x = BoundedTensor(data, ptb)
        if loss_fusion:
            if batch_method == 'natural' or not train:
                output = model(x, labels)
                regular_ce = torch.mean(torch.log(output))
            else:
                model(x, labels)
                regular_ce = torch.tensor(0., device=data.device)
            meter.update('CE', regular_ce.item(), x.size(0))
            x = (x, labels)
            c = None
        else:
            c = torch.eye(num_class).type_as(data)[labels].unsqueeze(
                1) - torch.eye(num_class).type_as(data).unsqueeze(0)
            # remove specifications to self
            I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(
                labels.data).unsqueeze(0)))
            c = (c[I].view(data.size(0), num_class - 1, num_class))
            x = (x, labels)
            output = model(x, final_node_name=final_node_name)
            regular_ce = CrossEntropyLoss()(
                output, labels)  # regular CrossEntropyLoss used for warming up
            meter.update('CE', regular_ce.item(), x[0].size(0))
            meter.update(
                'Err',
                torch.sum(torch.argmax(output, dim=1) != labels).item() /
                x[0].size(0), x[0].size(0))

        if batch_method == 'robust':
            # print(data.sum())
            lb, robust_ce = get_bound_loss(x=x, c=c)
            loss = robust_ce
        elif batch_method == 'natural':
            loss = regular_ce

        if train:
            loss.backward()

            if args.clip_grad_norm:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), max_norm=args.clip_grad_norm)
                meter.update('grad_norm', grad_norm)

            if isinstance(eps_scheduler, AdaptiveScheduler):
                eps_scheduler.update_loss(loss.item() - regular_ce.item())
            opt.step()
        meter.update('Loss', loss.item(), data.size(0))

        if batch_method != 'natural':
            meter.update('Robust_CE', robust_ce.item(), data.size(0))
            if not loss_fusion:
                # For an example, if lower bounds of margins is >0 for all classes, the output is verifiably correct.
                # If any margin is < 0 this example is counted as an error
                meter.update(
                    'Verified_Err',
                    torch.sum((lb < 0).any(dim=1)).item() / data.size(0),
                    data.size(0))
        meter.update('Time', time.time() - start)

        if (i + 1) % 250 == 0 and train:
            logger.log('[{:2d}:{:4d}]: eps={:.12f} {}'.format(
                t, i + 1, eps, meter))

    logger.log('[{:2d}:{:4d}]: eps={:.12f} {}'.format(t, i + 1, eps, meter))
    return meter
def Train(model, t, loader, eps_scheduler, norm, train, opt, bound_type, method='robust', loss_fusion=True, final_node_name=None):
    meter = MultiAverageMeter()
    if train:
        model.train()
        eps_scheduler.train()
        eps_scheduler.step_epoch(verbose=False)
        eps_scheduler.set_epoch_length(int((len(loader.dataset) + loader.batch_size - 1) / loader.batch_size))
    else:
        model.eval()
        eps_scheduler.eval()
    
    # Used for loss-fusion. Get the exp operation in computational graph.
    exp_module = get_exp_module(model)

    def get_bound_loss(x=None, c=None):
        if loss_fusion:
            # When loss fusion is used, we need the upper bound for the final loss function.
            bound_lower, bound_upper = False, True
        else:
            # When loss fusion is not used, we need the lower bound for the logit layer.
            bound_lower, bound_upper = True, False

        if bound_type == 'IBP':
            lb, ub = model(method_opt="compute_bounds", x=x, C=c, method="IBP", final_node_name=final_node_name, no_replicas=True)
        elif bound_type == 'CROWN':
            lb, ub = model(method_opt="compute_bounds", x=x, C=c, method="backward",
                                          bound_lower=bound_lower, bound_upper=bound_upper)
        elif bound_type == 'CROWN-IBP':
            # we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020)
            # factor = (eps_scheduler.get_max_eps() - eps_scheduler.get_eps()) / eps_scheduler.get_max_eps()
            ilb, iub = model(method_opt="compute_bounds", x=x, C=c, method="IBP", final_node_name=final_node_name, no_replicas=True)
            lb, ub = model(method_opt="compute_bounds", C=c, method="CROWN-IBP",
                         bound_lower=bound_lower, bound_upper=bound_upper, final_node_name=final_node_name, average_A=True, no_replicas=True)
        if loss_fusion:
            # When loss fusion is enabled, we need to get the common factor before softmax.
            if isinstance(model, BoundDataParallel):
                max_input = model(get_property=True, node_class=BoundExp, att_name='max_input')
            else:
                max_input = exp_module.max_input
            return None, torch.mean(torch.log(ub) + max_input)
        else:
            # Pad zero at the beginning for each example, and use fake label '0' for all examples
            lb_padded = torch.cat((torch.zeros(size=(lb.size(0), 1), dtype=lb.dtype, device=lb.device), lb), dim=1)
            fake_labels = torch.zeros(size=(lb.size(0),), dtype=torch.int64, device=lb.device)
            robust_ce = CrossEntropyLoss()(-lb_padded, fake_labels)
            return lb, robust_ce

    for i, (data, labels) in enumerate(loader):
        # For unit test. We only use a small number of batches
        if args.truncate_data:
            if i >= args.truncate_data:
                break

        start = time.time()
        eps_scheduler.step_batch()
        eps = eps_scheduler.get_eps()
        # For small eps just use natural training, no need to compute LiRPA bounds
        batch_method = method
        if eps < 1e-50:
            batch_method = "natural"
        if train:
            opt.zero_grad()

        if list(model.parameters())[0].is_cuda:
            data, labels = data.cuda(), labels.cuda()

        model.ptb.eps = eps
        x = data
        if loss_fusion:
            if batch_method == 'natural' or not train:
                output = model(x, labels)  # , disable_multi_gpu=True
                regular_ce = torch.mean(torch.log(output))
            else:
                model(x, labels)
                regular_ce = torch.tensor(0., device=data.device)
            meter.update('CE', regular_ce.item(), x.size(0))
            x = (x, labels)
            c = None
        else:
            # Generate speicification matrix (when loss fusion is not used).
            c = torch.eye(num_class).type_as(data)[labels].unsqueeze(1) - torch.eye(num_class).type_as(data).unsqueeze(
                0)
            # remove specifications to self.
            I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(labels.data).unsqueeze(0)))
            c = (c[I].view(data.size(0), num_class - 1, num_class))
            x = (x, labels)
            output = model(x, final_node_name=final_node_name)
            regular_ce = CrossEntropyLoss()(output, labels)  # regular CrossEntropyLoss used for warming up
            meter.update('CE', regular_ce.item(), x[0].size(0))
            meter.update('Err', torch.sum(torch.argmax(output, dim=1) != labels).item() / x[0].size(0), x[0].size(0))

        if batch_method == 'robust':
            lb, robust_ce = get_bound_loss(x=x, c=c)
            loss = robust_ce
        elif batch_method == 'natural':
            loss = regular_ce

        if train:
            loss.backward()

            if args.clip_grad_norm:
                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad_norm)
                meter.update('grad_norm', grad_norm)

            if isinstance(eps_scheduler, AdaptiveScheduler):
                eps_scheduler.update_loss(loss.item() - regular_ce.item())
            opt.step()
        meter.update('Loss', loss.item(), data.size(0))

        if batch_method != 'natural':
            meter.update('Robust_CE', robust_ce.item(), data.size(0))
            if not loss_fusion:
                # For an example, if lower bounds of margins is >0 for all classes, the output is verifiably correct.
                # If any margin is < 0 this example is counted as an error
                meter.update('Verified_Err', torch.sum((lb < 0).any(dim=1)).item() / data.size(0), data.size(0))
        meter.update('Time', time.time() - start)

        if (i + 1) % 50 == 0 and train:
            logger.log('[{:2d}:{:4d}]: eps={:.12f} {}'.format(t, i + 1, eps, meter))

    logger.log('[{:2d}:{:4d}]: eps={:.12f} {}'.format(t, i + 1, eps, meter))
    return meter
Exemple #6
0
def train_natural(args, model, ds_train, ds_test):
    if args.optimizer == 'sgd':
        opt = torch.optim.SGD(model.parameters(), lr=args.base_lr, momentum=0.9)
    elif args.optimizer == 'adam':
        opt = torch.optim.Adam(model.parameters(), lr=args.base_lr)
    else:
        raise ValueError(args.optimizer)

    start_epoch_time = time.time()
    meter = MultiAverageMeter()

    if args.data_loader == 'torch':
        train_loader, test_loader = get_loaders(args)
    else:
        test_loader = ds_test

    def train_step(step, X_batch, y_batch, lr_repl):
        lr = float(lr_repl[0])
        opt.param_groups[0]['lr'] = lr
        model.train()

        batch_size = math.ceil(args.batch_size / args.accum_steps)
        for i in range(args.accum_steps):
            X = X_batch[i*batch_size:(i+1)*batch_size].cuda()
            y = y_batch[i*batch_size:(i+1)*batch_size].cuda()

            if args.method == 'natural':
                output = model(X)
                loss = F.cross_entropy(output, y)
                (loss*(X.shape[0]/X_batch.shape[0])).backward() 
            else:
                raise NotImplementedError(args.method)

            meter.update('train_loss', loss.item(), y.size(0))
            meter.update('train_acc', (output.max(1)[1] == y).float().mean(), y.size(0))

        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)

        opt.step()
        opt.zero_grad()

        if step % args.log_interval == 0:
            logger.info('Training step {} lr {:.4f} {}'.format(
                step, lr, meter))   
            train_loss = 0
            train_acc = 0
            train_n = 0

        if step % args.save_interval == 0:
            path = os.path.join(args.out_dir, 'checkpoint_{}'.format(step))
            torch.save({ 'state_dict': model.state_dict(), 'step': step }, path)
            logger.info('Checkpoint saved to {}'.format(path))

        if step == total_steps:
            return

    total_steps = args.epochs * len(train_loader) 
    lr_fn = hyper.create_learning_rate_schedule(total_steps, args.base_lr, 'cosine', args.warmup_steps)
    lr_iter = hyper.lr_prefetch_iter(lr_fn, 0, total_steps)

    step = 0
    for epoch in range(1, args.epochs+1):
        logger.info('Epoch {}'.format(epoch))
        for (X_batch, y_batch) in train_loader:
            step += 1
            lr_repl = next(lr_iter)
            train_step(step, X_batch, y_batch, lr_repl)
        evaluate_natural(args, model, test_loader)            
Exemple #7
0
def train(epoch, batches, type):
    meter = MultiAverageMeter()
    assert (optimizer is not None)
    train = type == 'train'
    if args.robust:
        eps_scheduler.set_epoch_length(len(batches))
        if train:
            eps_scheduler.train()
            eps_scheduler.step_epoch()
        else:
            eps_scheduler.eval()
    for i, batch in enumerate(batches):
        if args.robust:
            eps_scheduler.step_batch()
            eps = eps_scheduler.get_eps()
        else:
            eps = 0
        acc, loss, acc_robust, loss_robust = \
            step(model, ptb, batch, eps=eps, train=train)
        meter.update('acc', acc, len(batch))
        meter.update('loss', loss, len(batch))
        meter.update('acc_rob', acc_robust, len(batch))
        meter.update('loss_rob', loss_robust, len(batch))
        if train:
            if (i + 1) % args.gradient_accumulation_steps == 0 or (
                    i + 1) == len(batches):
                scale_gradients(optimizer,
                                i % args.gradient_accumulation_steps + 1,
                                args.grad_clip)
                optimizer.step()
                optimizer.zero_grad()
            if lr_scheduler is not None:
                lr_scheduler.step()
            writer.add_scalar('loss_train_{}'.format(epoch), meter.avg('loss'),
                              i + 1)
            writer.add_scalar('loss_robust_train_{}'.format(epoch),
                              meter.avg('loss_rob'), i + 1)
            writer.add_scalar('acc_train_{}'.format(epoch), meter.avg('acc'),
                              i + 1)
            writer.add_scalar('acc_robust_train_{}'.format(epoch),
                              meter.avg('acc_rob'), i + 1)
        if (i + 1) % args.log_interval == 0 or (i + 1) == len(batches):
            logger.info('Epoch {}, {} step {}/{}: eps {:.5f}, {}'.format(
                epoch, type, i + 1, len(batches), eps, meter))
            if lr_scheduler is not None:
                logger.info('lr {}'.format(lr_scheduler.get_lr()))
    writer.add_scalar('loss/{}'.format(type), meter.avg('loss'), epoch)
    writer.add_scalar('loss_robust/{}'.format(type), meter.avg('loss_rob'),
                      epoch)
    writer.add_scalar('acc/{}'.format(type), meter.avg('acc'), epoch)
    writer.add_scalar('acc_robust/{}'.format(type), meter.avg('acc_rob'),
                      epoch)

    if train:
        if args.loss_fusion:
            state_dict_loss = model_loss.state_dict()
            state_dict = {}
            for name in state_dict_loss:
                assert (name.startswith('model.'))
                state_dict[name[6:]] = state_dict_loss[name]
            model_ori.load_state_dict(state_dict)
            model_bound = BoundedModule(model_ori,
                                        (dummy_embeddings, dummy_mask),
                                        bound_opts=bound_opts,
                                        device=args.device)
            model.model_from_embeddings = model_bound
        model.save(epoch)

    return meter.avg('acc_rob')
def Train(model,
          t,
          loader,
          start_eps,
          end_eps,
          max_eps,
          norm,
          train,
          opt,
          bound_type,
          method='robust'):
    num_class = 10
    meter = MultiAverageMeter()
    if train:
        model.train()
    else:
        model.eval()
    # Pre-generate the array for specifications, will be used latter for scatter
    sa = np.zeros((num_class, num_class - 1), dtype=np.int32)
    for i in range(sa.shape[0]):
        for j in range(sa.shape[1]):
            if j < i:
                sa[i][j] = j
            else:
                sa[i][j] = j + 1
    sa = torch.LongTensor(sa)
    total = len(loader.dataset)
    batch_size = loader.batch_size

    # Increase epsilon batch by batch
    batch_eps = np.linspace(start_eps, end_eps, (total // batch_size) + 1)
    # For small eps just use natural training, no need to compute LiRPA bounds
    if end_eps < 1e-6: method = "natural"

    for i, (data, labels) in enumerate(loader):
        start = time.time()
        eps = batch_eps[i]
        if train:
            opt.zero_grad()
        # generate specifications
        c = torch.eye(num_class).type_as(data)[labels].unsqueeze(
            1) - torch.eye(num_class).type_as(data).unsqueeze(0)
        # remove specifications to self
        I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(
            labels.data).unsqueeze(0)))
        c = (c[I].view(data.size(0), num_class - 1, num_class))
        # scatter matrix to avoid compute margin to self
        sa_labels = sa[labels]
        # storing computed lower bounds after scatter
        lb_s = torch.zeros(data.size(0), num_class)
        # bound input for Linf norm used only
        if norm == np.inf:
            data_ub = (data + eps).clamp(max=1.0)
            data_lb = (data - eps).clamp(min=0.0)
        else:
            data_ub = data_lb = data

        if list(model.parameters())[0].is_cuda:
            data, labels, sa_labels, c, lb_s = data.cuda(), labels.cuda(
            ), sa_labels.cuda(), c.cuda(), lb_s.cuda()
            data_lb, data_ub = data_lb.cuda(), data_ub.cuda()

        output = model(data)
        regular_ce = CrossEntropyLoss()(
            output, labels)  # regular CrossEntropyLoss used for warming up
        meter.update('CE', regular_ce.cpu().detach().numpy(), data.size(0))
        meter.update(
            'Err',
            torch.sum(
                torch.argmax(output, dim=1) != labels).cpu().detach().numpy() /
            data.size(0), data.size(0))

        # Specify Lp norm perturbation.
        # When using Linf perturbation, we manually set element-wise bound x_L and x_U. eps is not used for Linf norm.
        ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=data_lb, x_U=data_ub)
        if method == "robust":
            if bound_type == "IBP":
                lb, ub = model.compute_bounds(ptb=ptb,
                                              IBP=True,
                                              x=data,
                                              C=c,
                                              method=None)
            elif bound_type == "CROWN":
                lb, ub = model.compute_bounds(ptb=ptb,
                                              IBP=False,
                                              x=data,
                                              C=c,
                                              method="backward")
            elif bound_type == "CROWN-IBP":
                # lb, ub = model.compute_bounds(ptb=ptb, IBP=True, x=data, C=c, method="backward")  # pure IBP bound
                # we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020)
                factor = (max_eps - eps) / max_eps
                ilb, iub = model.compute_bounds(ptb=ptb,
                                                IBP=True,
                                                x=data,
                                                C=c,
                                                method=None)
                if factor < 1e-5:
                    lb = ilb
                else:
                    clb, cub = model.compute_bounds(ptb=ptb,
                                                    IBP=False,
                                                    x=data,
                                                    C=c,
                                                    method="backward")
                    lb = clb * factor + ilb * (1 - factor)

            # Filling a missing 0 in lb. The margin from class j to itself is always 0 and not computed.
            lb = lb_s.scatter(1, sa_labels, lb)
            # Use the robust cross entropy loss objective (Wong & Kolter, 2018)
            robust_ce = CrossEntropyLoss()(-lb, labels)
        if method == "robust":
            loss = robust_ce
        elif method == "natural":
            loss = regular_ce
        if train:
            loss.backward()
            opt.step()
        meter.update('Loss', loss.cpu().detach().numpy(), data.size(0))
        if method != "natural":
            meter.update('Robust_CE',
                         robust_ce.cpu().detach().numpy(), data.size(0))
            # For an example, if lower bounds of margins is >0 for all classes, the output is verifiably correct.
            # If any margin is < 0 this example is counted as an error
            meter.update(
                'Verified_Err',
                torch.sum(
                    (lb < 0).any(dim=1)).cpu().detach().numpy() / data.size(0),
                data.size(0))
        meter.update('Time', time.time() - start)
        if i % 50 == 0 and train:
            print('[{:2d}:{:4d}]: eps={:4f} {}'.format(t, i, eps, meter))

    print('[FINAL RESULT] epoch={:2d} eps={:.4f} {}'.format(t, eps, meter))