コード例 #1
0
def Train(model,
          model_id,
          t,
          loader,
          start_eps,
          end_eps,
          max_eps,
          norm,
          logger,
          verbose,
          train,
          opt,
          method,
          adv_net=None,
          unetopt=None,
          **kwargs):
    # if train=True, use training mode
    # if train=False, use test mode, no back prop
    num_class = 17
    losses = AverageMeter()
    unetlosses = AverageMeter()
    unetloss = None
    errors = AverageMeter()
    adv_errors = AverageMeter()
    robust_errors = AverageMeter()
    regular_ce_losses = AverageMeter()
    adv_ce_losses = AverageMeter()
    robust_ce_losses = AverageMeter()
    batch_time = AverageMeter()
    # initial
    kappa = 1
    factor = 1
    if train:
        model.train()
        if adv_net is not None:
            adv_net.train()
    else:
        model.eval()
        if adv_net is not None:
            adv_net.eval()
    # pregenerate the array for specifications, will be used for scatter
    if method == "robust":
        sa = np.zeros((num_class, num_class - 1), dtype=np.int32)
        for i in range(sa.shape[0]):
            for j in range(sa.shape[1]):
                if j < i:
                    sa[i][j] = j
                else:
                    sa[i][j] = j + 1
        sa = torch.LongTensor(sa)
    elif method == "adv":
        if kwargs["attack_type"] == "patch-random":
            attacker = PatchAttacker(model, loader.mean, loader.std, kwargs)
        elif kwargs["attack_type"] == "patch-strong":
            attacker = PatchAttacker(model, loader.mean, loader.std, kwargs)
        elif kwargs["attack_type"] == "PGD":
            attacker = PGDAttacker(model, loader.mean, loader.std, kwargs)

    # total = len(loader.dataset)
    if train:
        total = 352366
    else:
        total = 24119
    batch_size = loader.batch_size
    if train:
        batch_eps = np.linspace(
            start_eps, end_eps,
            total // (batch_size * args.grad_acc_steps) + 1)
        batch_eps = batch_eps.repeat(args.grad_acc_steps)
    else:
        batch_eps = np.linspace(start_eps, end_eps, total // (batch_size) + 1)

    if end_eps < 1e-6:
        logger.log('eps {} close to 0, using natural training'.format(end_eps))
        method = "natural"

    if train:
        iterator = enumerate(loader)
    else:
        iterator = tqdm(enumerate(loader))

    if train:
        opt.zero_grad()
        if unetopt is not None:
            unetopt.zero_grad()

    for i, (data, labels) in iterator:
        data = torch.tensor(data)
        data = data.permute(0, 3, 1, 2).contiguous()
        labels = torch.tensor(labels)

        if "sample_limit" in kwargs and i * loader.batch_size > kwargs[
                "sample_limit"]:
            break
        start = time.time()
        eps = batch_eps[i]

        if method == "robust":
            # generate specifications
            c = torch.eye(num_class).type_as(data)[labels].unsqueeze(
                1) - torch.eye(num_class).type_as(data).unsqueeze(0)
            # remove specifications to self
            I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(
                labels.data).unsqueeze(0)))
            c = (c[I].view(data.size(0), num_class - 1, num_class))
            # scatter matrix to avoid computing margin to self
            sa_labels = sa[labels]
            # storing computed lower bounds after scatter
            lb_s = torch.zeros(data.size(0), num_class)

            #calculating upper and lower bound of the input
            if len(loader.std) == 1:
                std = torch.tensor([loader.std], dtype=torch.float)[:, None,
                                                                    None]
                mean = torch.tensor([loader.mean], dtype=torch.float)[:, None,
                                                                      None]
            elif len(loader.std) == 3:
                std = torch.tensor(loader.std, dtype=torch.float)[None, :,
                                                                  None, None]
                mean = torch.tensor(loader.mean, dtype=torch.float)[None, :,
                                                                    None, None]
            elif len(loader.std) == 14:
                std = torch.tensor(loader.std, dtype=torch.float)[None, :,
                                                                  None, None]
                mean = torch.tensor(loader.mean, dtype=torch.float)[None, :,
                                                                    None, None]
            else:
                raise ValueError('loader shape wrong')
            if kwargs["bound_type"] == "sparse-interval":
                data_ub = data
                data_lb = data
                eps = (eps / std).max()
            else:
                data_ub = (data + eps / std)
                data_lb = (data - eps / std)
                ub = ((1 - mean) / std)
                lb = (-mean / std)
                data_ub = torch.min(data_ub, ub)
                data_lb = torch.max(data_lb, lb)

            if list(model.parameters())[0].is_cuda:
                data_ub = data_ub.cuda()
                data_lb = data_lb.cuda()
                c = c.cuda()
                sa_labels = sa_labels.cuda()
                lb_s = lb_s.cuda()

        if list(model.parameters())[0].is_cuda:
            data = data.cuda()
            labels = labels.cuda()
        # the regular cross entropy
        if torch.cuda.device_count() > 1:
            output = nn.DataParallel(model)(data)
        else:
            output = model(data)

        regular_ce = CrossEntropyLoss()(output, labels)
        regular_ce_losses.update(regular_ce.cpu().detach().numpy(),
                                 data.size(0))
        errors.update(
            torch.sum(
                torch.argmax(output, dim=1) != labels).cpu().detach().numpy() /
            data.size(0), data.size(0))

        # the adversarial cross entropy
        if method == "adv":
            if kwargs["attack_type"] == "PGD":
                data_adv = attacker.perturb(data, labels, norm)
            elif kwargs["attack_type"] == "patch-random":
                data_adv = attacker.perturb(
                    data,
                    labels,
                    norm,
                    random_count=kwargs["random_mask_count"])
            else:
                raise RuntimeError("Unknown attack_type " +
                                   kwargs["bound_type"])
            output_adv = model(data_adv)
            adv_ce = CrossEntropyLoss()(output_adv, labels)
            adv_ce_losses.update(adv_ce.cpu().detach().numpy(), data.size(0))
            adv_errors.update(
                torch.sum(torch.argmax(output_adv, dim=1) != labels).cpu().
                detach().numpy() / data.size(0), data.size(0))

        if verbose or method == "robust":
            if kwargs["bound_type"] == "interval":
                ub, lb = model.interval_range(x_U=data_ub,
                                              x_L=data_lb,
                                              eps=eps,
                                              C=c)
            elif kwargs["bound_type"] == "sparse-interval":
                ub, lb = model.interval_range(x_U=data_ub,
                                              x_L=data_lb,
                                              eps=eps,
                                              C=c,
                                              k=kwargs["k"],
                                              Sparse=True)
            elif kwargs["bound_type"] == "patch-interval":
                if kwargs["attack_type"] == "patch-all" or kwargs[
                        "attack_type"] == "patch-all-pool":
                    if kwargs["attack_type"] == "patch-all":
                        width = data.shape[2] - kwargs["patch_w"] + 1
                        length = data.shape[3] - kwargs["patch_l"] + 1
                        pos_patch_count = width * length
                        final_bound_count = pos_patch_count
                    elif kwargs["attack_type"] == "patch-all-pool":
                        width = data.shape[2] - kwargs["patch_w"] + 1
                        length = data.shape[3] - kwargs["patch_l"] + 1
                        pos_patch_count = width * length
                        final_width = width
                        final_length = length
                        for neighbor in kwargs["neighbor"]:
                            final_width = ((final_width - 1) // neighbor + 1)
                            final_length = ((final_length - 1) // neighbor + 1)
                        final_bound_count = final_width * final_length

                    patch_idx = torch.arange(pos_patch_count,
                                             dtype=torch.long)[None, :]
                    if kwargs["attack_type"] == "patch-all" or kwargs[
                            "attack_type"] == "patch-all-pool":
                        x_cord = torch.zeros((1, pos_patch_count),
                                             dtype=torch.long)
                        y_cord = torch.zeros((1, pos_patch_count),
                                             dtype=torch.long)
                        idx = 0
                        for w in range(width):
                            for l in range(length):
                                x_cord[0, idx] = w
                                y_cord[0, idx] = l
                                idx = idx + 1

                    # expand the list to include coordinates from the complete patch
                    patch_idx = [patch_idx.flatten()]
                    x_cord = [x_cord.flatten()]
                    y_cord = [y_cord.flatten()]
                    for w in range(kwargs["patch_w"]):
                        for l in range(kwargs["patch_l"]):
                            patch_idx.append(patch_idx[0])
                            x_cord.append(x_cord[0] + w)
                            y_cord.append(y_cord[0] + l)

                    patch_idx = torch.cat(patch_idx, dim=0)
                    x_cord = torch.cat(x_cord, dim=0)
                    y_cord = torch.cat(y_cord, dim=0)

                    # create masks for each data point
                    mask = torch.zeros(
                        [1, pos_patch_count, data.shape[2], data.shape[3]],
                        dtype=torch.uint8)
                    mask[:, patch_idx, x_cord, y_cord] = 1
                    mask = mask[:, :, None, :, :]
                    mask = mask.cuda()
                    data_ub = torch.where(mask, data_ub[:, None, :, :, :],
                                          data[:, None, :, :, :])
                    data_lb = torch.where(mask, data_lb[:, None, :, :, :],
                                          data[:, None, :, :, :])

                    # data_ub size (#data*#possible patches, #channels, width, length)
                    data_ub = data_ub.view(-1, *data_ub.shape[2:])
                    data_lb = data_lb.view(-1, *data_lb.shape[2:])

                    c = c.repeat_interleave(final_bound_count, dim=0)

                elif kwargs["attack_type"] == "patch-random" or kwargs[
                        "attack_type"] == "patch-nn":
                    # First calculate the number of considered patches
                    if kwargs["attack_type"] == "patch-random":
                        pos_patch_count = kwargs["patch_count"]
                        final_bound_count = pos_patch_count
                        c = c.repeat_interleave(pos_patch_count, dim=0)
                    elif kwargs["attack_type"] == "patch-nn":
                        class_count = 10
                        pos_patch_count = kwargs["patch_count"] * class_count
                        final_bound_count = pos_patch_count
                        c = c.repeat_interleave(pos_patch_count, dim=0)

                    # Create four lists that enumerate the coordinate of the top left corner of the patch
                    # patch_idx, data_idx, x_cord, y_cord shpe = (# of datapoints, # of possible patches)
                    patch_idx = torch.arange(pos_patch_count,
                                             dtype=torch.long)[None, :].repeat(
                                                 data.shape[0], 1)
                    data_idx = torch.arange(data.shape[0],
                                            dtype=torch.long)[:, None].repeat(
                                                1, pos_patch_count)
                    if kwargs["attack_type"] == "patch-random":
                        x_cord = torch.randint(
                            0, data.shape[2] - kwargs["patch_w"] + 1,
                            (data.shape[0], pos_patch_count))
                        y_cord = torch.randint(
                            0, data.shape[3] - kwargs["patch_l"] + 1,
                            (data.shape[0], pos_patch_count))
                    elif kwargs["attack_type"] == "patch-nn":
                        lbs_pred = adv_net(data)
                        # Take only the feasible location
                        lbs_pred = lbs_pred[:, :, 0:lbs_pred.size(2) -
                                            kwargs["patch_l"] + 1,
                                            0:lbs_pred.size(3) -
                                            kwargs["patch_w"] + 1]

                        lbs_pred = lbs_pred.reshape(
                            lbs_pred.size(0) * lbs_pred.size(1), -1)
                        # lbs_pred (# datapoints*# of classes, #flattened image dim)
                        select_prob = nn.Softmax(1)(-lbs_pred * kwargs["T"])
                        # select_prob (# datapoints*# of classes, #flattened image dim)
                        random_loc = torch.multinomial(select_prob,
                                                       kwargs["patch_count"],
                                                       replacement=False)
                        # random_loc (# datapoints*# of classes, patch_count)
                        random_loc = random_loc.view(data.size(0), -1)
                        # random_loc (# datapoints, # of classes*patch_count)

                        x_cord = random_loc % (data.size(3) -
                                               kwargs["patch_w"] + 1)
                        y_cord = random_loc // (data.size(2) -
                                                kwargs["patch_l"] + 1)

                    # expand the list to include coordinates from the complete patch
                    patch_idx = [patch_idx.flatten()]
                    data_idx = [data_idx.flatten()]
                    x_cord = [x_cord.flatten()]
                    y_cord = [y_cord.flatten()]
                    for w in range(kwargs["patch_w"]):
                        for l in range(kwargs["patch_l"]):
                            patch_idx.append(patch_idx[0])
                            data_idx.append(data_idx[0])
                            x_cord.append(x_cord[0] + w)
                            y_cord.append(y_cord[0] + l)

                    patch_idx = torch.cat(patch_idx, dim=0)
                    data_idx = torch.cat(data_idx, dim=0)
                    x_cord = torch.cat(x_cord, dim=0)
                    y_cord = torch.cat(y_cord, dim=0)

                    #create masks for each data point
                    mask = torch.zeros([
                        data.shape[0], pos_patch_count, data.shape[2],
                        data.shape[3]
                    ],
                                       dtype=torch.uint8)
                    mask[data_idx, patch_idx, x_cord, y_cord] = 1
                    mask = mask[:, :, None, :, :]
                    mask = mask.cuda()
                    data_ub = torch.where(mask, data_ub[:, None, :, :, :],
                                          data[:, None, :, :, :])
                    data_lb = torch.where(mask, data_lb[:, None, :, :, :],
                                          data[:, None, :, :, :])

                    # data_ub size (#data*#possible patches, #channels, width, length)
                    data_ub = data_ub.view(-1, *data_ub.shape[2:])
                    data_lb = data_lb.view(-1, *data_lb.shape[2:])

                # forward pass all bounds
                if torch.cuda.device_count() > 1:
                    if kwargs["attack_type"] == "patch-all-pool":
                        ub, lb = nn.DataParallel(ParallelBoundPool(model))(
                            x_U=data_ub,
                            x_L=data_lb,
                            eps=eps,
                            C=c,
                            neighbor=kwargs["neighbor"],
                            pos_patch_width=width,
                            pos_patch_length=length)
                    else:
                        ub, lb = nn.DataParallel(ParallelBound(model))(
                            x_U=data_ub, x_L=data_lb, eps=eps, C=c)
                else:
                    if kwargs["attack_type"] == "patch-all-pool":
                        ub, lb = model.interval_range_pool(
                            x_U=data_ub,
                            x_L=data_lb,
                            eps=eps,
                            C=c,
                            neighbor=kwargs["neighbor"],
                            pos_patch_width=width,
                            pos_patch_length=length)
                    else:
                        ub, lb = model.interval_range(x_U=data_ub,
                                                      x_L=data_lb,
                                                      eps=eps,
                                                      C=c)

                # calculate unet loss
                if kwargs["attack_type"] == "patch-nn":
                    labels_mod = labels.repeat_interleave(pos_patch_count,
                                                          dim=0)
                    sa_labels_mod = sa[labels_mod]
                    sa_labels_mod = sa_labels_mod.cuda()
                    # storing computed lower bounds after scatter
                    lb_s_mod = torch.zeros(
                        data.size(0) * pos_patch_count, num_class).cuda()
                    lbs_actual = lb_s_mod.scatter(1, sa_labels_mod, lb)
                    # lbs_actual (# data * # of logits * # of classes, # of classes)

                    # lbs_pred (# datapoints*# of logits, #flattened image dim)
                    lbs_pred = lbs_pred.view(data.shape[0], num_class, -1)
                    # lbs_pred (# datapoints, # of logits, #flattened image dim)
                    lbs_pred = lbs_pred.permute(0, 2, 1)
                    # lbs_pred (# datapoints, #flattened image dim, # of logits)

                    # random_loc (# datapoints, # of logits*patch_count)
                    random_loc = random_loc.unsqueeze(2)
                    random_loc = random_loc.repeat_interleave(10, dim=2)
                    lbs_pred = lbs_pred.gather(1, random_loc)
                    # lbs_pred (# datapoints, # of logits*patch_count, # of logits)
                    lbs_pred = lbs_pred.view(-1, num_class)
                    # lbs_pred (# datapoints*# of logits*patch_count, # of logits)
                    unetloss = nn.MSELoss()(lbs_actual.detach(), lbs_pred)

                lb = lb.reshape(-1, final_bound_count, lb.shape[1])
                lb = torch.min(lb, dim=1)[0]
            else:
                raise RuntimeError("Unknown bound_type " +
                                   kwargs["bound_type"])
            # pdb.set_trace()
            lb = lb_s.scatter(1, sa_labels, lb)
            robust_ce = CrossEntropyLoss()(-lb, labels)

        if method == "robust":
            loss = robust_ce
        elif method == "natural":
            loss = regular_ce
        elif method == "adv":
            loss = adv_ce
        elif method == "robust_natural":
            natural_final_factor = kwargs["final-kappa"]
            kappa = (max_eps - eps * (1.0 - natural_final_factor)) / max_eps
            loss = (1 - kappa) * robust_ce + kappa * regular_ce
        else:
            raise ValueError("Unknown method " + method)

        if train:
            if unetloss is not None:
                unetloss.backward()
                unetlosses.update(unetloss.cpu().detach().numpy(),
                                  data.size(0))
            loss = loss
            loss.backward()
            if (i + 1) % args.grad_acc_steps == 0 or i == len(loader) - 1:
                if unetloss is not None:
                    for p in adv_net.parameters():
                        p.grad /= args.grad_acc_steps
                    unetopt.step()
                for p in model.parameters():
                    p.grad /= args.grad_acc_steps
                opt.step()
                opt.zero_grad()

        batch_time.update(time.time() - start)

        losses.update(loss.cpu().detach().numpy(), data.size(0))

        if verbose or method == "robust":
            robust_ce_losses.update(robust_ce.cpu().detach().numpy(),
                                    data.size(0))
            robust_errors.update(
                torch.sum(
                    (lb < 0).any(dim=1)).cpu().detach().numpy() / data.size(0),
                data.size(0))
        if i % 50 == 0 and train:
            logger.log(
                '[{:2d}:{:4d}]: eps {:4f}  '
                'Time {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                'Total Loss {loss.val:.4f} ({loss.avg:.4f})  '
                'Unet Loss {unetloss.val:.4f} ({unetloss.avg:.4f})  '
                'CE {regular_ce_loss.val:.4f} ({regular_ce_loss.avg:.4f})  '
                'RCE {robust_ce_loss.val:.4f} ({robust_ce_loss.avg:.4f})  '
                'ACE {adv_ce_loss.val:.4f} ({adv_ce_loss.avg:.4f})  '
                'Err {errors.val:.4f} ({errors.avg:.4f})  '
                'Rob Err {robust_errors.val:.4f} ({robust_errors.avg:.4f})  '
                'Adv Err {adv_errors.val:.4f} ({adv_errors.avg:.4f})  '
                'beta {factor:.3f} ({factor:.3f})  '
                'kappa {kappa:.3f} ({kappa:.3f})  '.format(
                    t,
                    i,
                    eps,
                    batch_time=batch_time,
                    loss=losses,
                    unetloss=unetlosses,
                    errors=errors,
                    robust_errors=robust_errors,
                    adv_errors=adv_errors,
                    regular_ce_loss=regular_ce_losses,
                    robust_ce_loss=robust_ce_losses,
                    adv_ce_loss=adv_ce_losses,
                    factor=factor,
                    kappa=kappa))

    # logger.log(  '[FINAL RESULT epoch:{:2d} eps:{:.4f}]: '
    #     'Time {batch_time.val:.3f} ({batch_time.avg:.3f})  '
    #     'Total Loss {loss.val:.4f} ({loss.avg:.4f})  '
    #     'Unet Loss {unetloss.val:.4f} ({unetloss.avg:.4f})  '
    #     'CE {regular_ce_loss.val:.4f} ({regular_ce_loss.avg:.4f})  '
    #     'RCE {robust_ce_loss.val:.4f} ({robust_ce_loss.avg:.4f})  '
    #     'ACE {adv_ce_loss.val:.4f} ({adv_ce_loss.avg:.4f})  '
    #     'Err {errors.val:.4f} ({errors.avg:.4f})  '
    #     'Rob Err {robust_errors.val:.4f} ({robust_errors.avg:.4f})  '
    #     'Adv Err {adv_errors.val:.4f} ({adv_errors.avg:.4f})  '
    #     'beta {factor:.3f} ({factor:.3f})  '
    #     'kappa {kappa:.3f} ({kappa:.3f})  \n'.format(
    #     t, eps, batch_time=batch_time,
    #     loss=losses,unetloss=unetlosses, errors=errors, robust_errors = robust_errors,
    #     adv_errors = adv_errors,
    #     regular_ce_loss = regular_ce_losses, robust_ce_loss = robust_ce_losses,
    #     adv_ce_loss = adv_ce_losses,
    #     kappa = kappa, factor=factor))

    if method == "natural":
        return errors.avg, errors.avg
    else:
        return robust_errors.avg, errors.avg
コード例 #2
0
ファイル: train.py プロジェクト: ANazaret/CROWN-IBP
def Train(model, t, loader, eps_scheduler, max_eps, norm, logger, verbose,
          train, opt, method, **kwargs):
    # if train=True, use training mode
    # if train=False, use test mode, no back prop

    num_class = 10
    losses = AverageMeter()
    l1_losses = AverageMeter()
    errors = AverageMeter()
    robust_errors = AverageMeter()
    regular_ce_losses = AverageMeter()
    robust_ce_losses = AverageMeter()
    relu_activities = AverageMeter()
    bound_bias = AverageMeter()
    bound_diff = AverageMeter()
    unstable_neurons = AverageMeter()
    dead_neurons = AverageMeter()
    alive_neurons = AverageMeter()
    batch_time = AverageMeter()
    batch_multiplier = kwargs.get("batch_multiplier", 1)
    kappa = 1
    beta = 1
    if train:
        model.train()
    else:
        model.eval()
    # pregenerate the array for specifications, will be used for scatter
    sa = np.zeros((num_class, num_class - 1), dtype=np.int32)
    for i in range(sa.shape[0]):
        for j in range(sa.shape[1]):
            if j < i:
                sa[i][j] = j
            else:
                sa[i][j] = j + 1
    sa = torch.LongTensor(sa)
    batch_size = loader.batch_size * batch_multiplier
    if batch_multiplier > 1 and train:
        logger.log(
            'Warning: Large batch training. The equivalent batch size is {} * {} = {}.'
            .format(batch_multiplier, loader.batch_size, batch_size))
    # per-channel std and mean
    std = torch.tensor(loader.std).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
    mean = torch.tensor(loader.mean).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)

    model_range = 0.0
    end_eps = eps_scheduler.get_eps(t + 1, 0)
    if end_eps < np.finfo(np.float32).tiny:
        logger.log('eps {} close to 0, using natural training'.format(end_eps))
        method = "natural"
    for i, (data, labels) in enumerate(loader):
        start = time.time()
        eps = eps_scheduler.get_eps(t, int(i // batch_multiplier))
        if train and i % batch_multiplier == 0:
            opt.zero_grad()
        # generate specifications
        c = torch.eye(num_class).type_as(data)[labels].unsqueeze(
            1) - torch.eye(num_class).type_as(data).unsqueeze(0)
        # remove specifications to self
        I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(
            labels.data).unsqueeze(0)))
        c = (c[I].view(data.size(0), num_class - 1, num_class))
        # scatter matrix to avoid compute margin to self
        sa_labels = sa[labels]
        # storing computed lower bounds after scatter
        lb_s = torch.zeros(data.size(0), num_class)
        ub_s = torch.zeros(data.size(0), num_class)

        # FIXME: Assume unnormalized data is from range 0 - 1
        if kwargs["bounded_input"]:
            if norm != np.inf:
                raise ValueError(
                    "bounded input only makes sense for Linf perturbation. "
                    "Please set the bounded_input option to false.")
            data_max = torch.reshape((1. - mean) / std, (1, -1, 1, 1))
            data_min = torch.reshape((0. - mean) / std, (1, -1, 1, 1))
            data_ub = torch.min(data + (eps / std), data_max)
            data_lb = torch.max(data - (eps / std), data_min)
        else:
            if norm == np.inf:
                data_ub = data + (eps / std)
                data_lb = data - (eps / std)
            else:
                # For other norms, eps will be used instead.
                data_ub = data_lb = data

        if list(model.parameters())[0].is_cuda:
            data = data.cuda()
            data_ub = data_ub.cuda()
            data_lb = data_lb.cuda()
            labels = labels.cuda()
            c = c.cuda()
            sa_labels = sa_labels.cuda()
            lb_s = lb_s.cuda()
            ub_s = ub_s.cuda()
        # convert epsilon to a tensor
        eps_tensor = data.new(1)
        eps_tensor[0] = eps

        # omit the regular cross entropy, since we use robust error
        output = model(data,
                       method_opt="forward",
                       disable_multi_gpu=(method == "natural"))
        regular_ce = CrossEntropyLoss()(output, labels)
        regular_ce_losses.update(regular_ce.cpu().detach().numpy(),
                                 data.size(0))
        errors.update(
            torch.sum(
                torch.argmax(output, dim=1) != labels).cpu().detach().numpy() /
            data.size(0), data.size(0))
        # get range statistic
        model_range = output.max().detach().cpu().item() - output.min().detach(
        ).cpu().item()
        '''
        torch.set_printoptions(threshold=5000)
        print('prediction:  ', output)
        ub, lb, _, _, _, _ = model(norm=norm, x_U=data_ub, x_L=data_lb, eps=eps, C=c, method_opt="interval_range")
        lb = lb_s.scatter(1, sa_labels, lb)
        ub = ub_s.scatter(1, sa_labels, ub)
        print('interval ub: ', ub)
        print('interval lb: ', lb)
        ub, _, lb, _ = model(norm=norm, x_U=data_ub, x_L=data_lb, eps=eps, C=c, upper=True, lower=True, method_opt="backward_range")
        lb = lb_s.scatter(1, sa_labels, lb)
        ub = ub_s.scatter(1, sa_labels, ub)
        print('crown-ibp ub: ', ub)
        print('crown-ibp lb: ', lb) 
        ub, _, lb, _ = model(norm=norm, x_U=data_ub, x_L=data_lb, eps=eps, C=c, upper=True, lower=True, method_opt="full_backward_range")
        lb = lb_s.scatter(1, sa_labels, lb)
        ub = ub_s.scatter(1, sa_labels, ub)
        print('full-crown ub: ', ub)
        print('full-crown lb: ', lb)
        input()
        '''

        if verbose or method != "natural":
            if kwargs["bound_type"] == "convex-adv":
                # Wong and Kolter's bound, or equivalently Fast-Lin
                if kwargs["convex-proj"] is not None:
                    proj = kwargs["convex-proj"]
                    if norm == np.inf:
                        norm_type = "l1_median"
                    elif norm == 2:
                        norm_type = "l2_normal"
                    else:
                        raise (ValueError(
                            "Unsupported norm {} for convex-adv".format(norm)))
                else:
                    proj = None
                    if norm == np.inf:
                        norm_type = "l1"
                    elif norm == 2:
                        norm_type = "l2"
                    else:
                        raise (ValueError(
                            "Unsupported norm {} for convex-adv".format(norm)))
                if loader.std == [1] or loader.std == [1, 1, 1]:
                    convex_eps = eps
                else:
                    convex_eps = eps / np.mean(loader.std)
                    # for CIFAR we are roughly / 0.2
                    # FIXME this is due to a bug in convex_adversarial, we cannot use per-channel eps
                if norm == np.inf:
                    # bounded input is only for Linf
                    if kwargs["bounded_input"]:
                        # FIXME the bounded projection in convex_adversarial has a bug, data range must be positive
                        assert loader.std == [1, 1, 1] or loader.std == [1]
                        data_l = 0.0
                        data_u = 1.0
                    else:
                        data_l = -np.inf
                        data_u = np.inf
                else:
                    data_l = data_u = None
                f = DualNetwork(model,
                                data,
                                convex_eps,
                                proj=proj,
                                norm_type=norm_type,
                                bounded_input=kwargs["bounded_input"],
                                data_l=data_l,
                                data_u=data_u)
                lb = f(c)
            elif kwargs["bound_type"] == "interval":
                ub, lb, relu_activity, unstable, dead, alive = model(
                    norm=norm,
                    x_U=data_ub,
                    x_L=data_lb,
                    eps=eps,
                    C=c,
                    method_opt="interval_range")
            elif kwargs["bound_type"] == "crown-full":
                _, _, lb, _ = model(norm=norm,
                                    x_U=data_ub,
                                    x_L=data_lb,
                                    eps=eps,
                                    C=c,
                                    upper=False,
                                    lower=True,
                                    method_opt="full_backward_range")
                unstable = dead = alive = relu_activity = torch.tensor([0])
            elif kwargs["bound_type"] == "crown-interval":
                # Enable multi-GPU only for the computationally expensive CROWN-IBP bounds,
                # not for regular forward propagation and IBP because the communication overhead can outweigh benefits, giving little speedup.
                ub, ilb, relu_activity, unstable, dead, alive = model(
                    norm=norm,
                    x_U=data_ub,
                    x_L=data_lb,
                    eps=eps,
                    C=c,
                    method_opt="interval_range")
                crown_final_beta = kwargs['final-beta']
                beta = (max_eps - eps * (1.0 - crown_final_beta)) / max_eps
                if beta < 1e-5:
                    lb = ilb
                else:
                    if kwargs["runnerup_only"]:
                        # regenerate a smaller c, with just the runner-up prediction
                        # mask ground truthlabel output, select the second largest class
                        # print(output)
                        # torch.set_printoptions(threshold=5000)
                        masked_output = output.detach().scatter(
                            1, labels.unsqueeze(-1), -100)
                        # print(masked_output)
                        # location of the runner up prediction
                        runner_up = masked_output.max(1)[1]
                        # print(runner_up)
                        # print(labels)
                        # get margin from the groud-truth to runner-up only
                        runnerup_c = torch.eye(num_class).type_as(data)[labels]
                        # print(runnerup_c)
                        # set the runner up location to -
                        runnerup_c.scatter_(1, runner_up.unsqueeze(-1), -1)
                        runnerup_c = runnerup_c.unsqueeze(1).detach()
                        # print(runnerup_c)
                        # get the bound for runnerup_c
                        _, _, clb, bias = model(norm=norm,
                                                x_U=data_ub,
                                                x_L=data_lb,
                                                eps=eps,
                                                C=c,
                                                method_opt="backward_range")
                        clb = clb.expand(clb.size(0), num_class - 1)
                    else:
                        # get the CROWN bound using interval bounds
                        _, _, clb, bias = model(norm=norm,
                                                x_U=data_ub,
                                                x_L=data_lb,
                                                eps=eps,
                                                C=c,
                                                method_opt="backward_range")
                        bound_bias.update(bias.sum() / data.size(0))
                    # how much better is crown-ibp better than ibp?
                    diff = (clb - ilb).sum().item()
                    bound_diff.update(diff / data.size(0), data.size(0))
                    # lb = torch.max(lb, clb)
                    lb = clb * beta + ilb * (1 - beta)
            else:
                raise RuntimeError("Unknown bound_type " +
                                   kwargs["bound_type"])
            lb = lb_s.scatter(1, sa_labels, lb)
            robust_ce = CrossEntropyLoss()(-lb, labels)
            if kwargs["bound_type"] != "convex-adv":

                relu_activities.update(
                    relu_activity.sum().detach().cpu().item() / data.size(0),
                    data.size(0))
                unstable_neurons.update(
                    unstable.sum().detach().cpu().item() / data.size(0),
                    data.size(0))
                dead_neurons.update(
                    dead.sum().detach().cpu().item() / data.size(0),
                    data.size(0))
                alive_neurons.update(
                    alive.sum().detach().cpu().item() / data.size(0),
                    data.size(0))

        if method == "robust":
            loss = robust_ce
        elif method == "robust_activity":
            loss = robust_ce + kwargs["activity_reg"] * relu_activity.sum()
        elif method == "natural":
            loss = regular_ce
        elif method == "robust_natural":
            natural_final_factor = kwargs["final-kappa"]
            kappa = (max_eps - eps * (1.0 - natural_final_factor)) / max_eps
            loss = (1 - kappa) * robust_ce + kappa * regular_ce
        else:
            raise ValueError("Unknown method " + method)

        if train and kwargs["l1_reg"] > np.finfo(np.float32).tiny:
            reg = kwargs["l1_reg"]
            l1_loss = 0.0
            for name, param in model.named_parameters():
                if 'bias' not in name:
                    l1_loss = l1_loss + torch.sum(torch.abs(param))
            l1_loss = reg * l1_loss
            loss = loss + l1_loss
            l1_losses.update(l1_loss.cpu().detach().numpy(), data.size(0))
        if train:
            loss.backward()
            if i % batch_multiplier == 0 or i == len(loader) - 1:
                opt.step()

        losses.update(loss.cpu().detach().numpy(), data.size(0))

        if verbose or method != "natural":
            robust_ce_losses.update(robust_ce.cpu().detach().numpy(),
                                    data.size(0))
            # robust_ce_losses.update(robust_ce, data.size(0))
            robust_errors.update(
                torch.sum(
                    (lb < 0).any(dim=1)).cpu().detach().numpy() / data.size(0),
                data.size(0))

        batch_time.update(time.time() - start)
        if i % 50 == 0 and train:
            logger.log(
                '[{:2d}:{:4d}]: eps {:4f}  '
                'Time {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                'Total Loss {loss.val:.4f} ({loss.avg:.4f})  '
                'L1 Loss {l1_loss.val:.4f} ({l1_loss.avg:.4f})  '
                'CE {regular_ce_loss.val:.4f} ({regular_ce_loss.avg:.4f})  '
                'RCE {robust_ce_loss.val:.4f} ({robust_ce_loss.avg:.4f})  '
                'Err {errors.val:.4f} ({errors.avg:.4f})  '
                'Rob Err {robust_errors.val:.4f} ({robust_errors.avg:.4f})  '
                'Uns {unstable.val:.1f} ({unstable.avg:.1f})  '
                'Dead {dead.val:.1f} ({dead.avg:.1f})  '
                'Alive {alive.val:.1f} ({alive.avg:.1f})  '
                'Tightness {tight.val:.5f} ({tight.avg:.5f})  '
                'Bias {bias.val:.5f} ({bias.avg:.5f})  '
                'Diff {diff.val:.5f} ({diff.avg:.5f})  '
                'R {model_range:.3f}  '
                'beta {beta:.3f} ({beta:.3f})  '
                'kappa {kappa:.3f} ({kappa:.3f})  '.format(
                    t,
                    i,
                    eps,
                    batch_time=batch_time,
                    loss=losses,
                    errors=errors,
                    robust_errors=robust_errors,
                    l1_loss=l1_losses,
                    regular_ce_loss=regular_ce_losses,
                    robust_ce_loss=robust_ce_losses,
                    unstable=unstable_neurons,
                    dead=dead_neurons,
                    alive=alive_neurons,
                    tight=relu_activities,
                    bias=bound_bias,
                    diff=bound_diff,
                    model_range=model_range,
                    beta=beta,
                    kappa=kappa))

    logger.log('[FINAL RESULT epoch:{:2d} eps:{:.4f}]: '
               'Time {batch_time.val:.3f} ({batch_time.avg:.3f})  '
               'Total Loss {loss.val:.4f} ({loss.avg:.4f})  '
               'L1 Loss {l1_loss.val:.4f} ({l1_loss.avg:.4f})  '
               'CE {regular_ce_loss.val:.4f} ({regular_ce_loss.avg:.4f})  '
               'RCE {robust_ce_loss.val:.4f} ({robust_ce_loss.avg:.4f})  '
               'Uns {unstable.val:.3f} ({unstable.avg:.3f})  '
               'Dead {dead.val:.1f} ({dead.avg:.1f})  '
               'Alive {alive.val:.1f} ({alive.avg:.1f})  '
               'Tight {tight.val:.5f} ({tight.avg:.5f})  '
               'Bias {bias.val:.5f} ({bias.avg:.5f})  '
               'Diff {diff.val:.5f} ({diff.avg:.5f})  '
               'Err {errors.val:.4f} ({errors.avg:.4f})  '
               'Rob Err {robust_errors.val:.4f} ({robust_errors.avg:.4f})  '
               'R {model_range:.3f}  '
               'beta {beta:.3f} ({beta:.3f})  '
               'kappa {kappa:.3f} ({kappa:.3f})  \n'.format(
                   t,
                   eps,
                   batch_time=batch_time,
                   loss=losses,
                   errors=errors,
                   robust_errors=robust_errors,
                   l1_loss=l1_losses,
                   regular_ce_loss=regular_ce_losses,
                   robust_ce_loss=robust_ce_losses,
                   unstable=unstable_neurons,
                   dead=dead_neurons,
                   alive=alive_neurons,
                   tight=relu_activities,
                   bias=bound_bias,
                   diff=bound_diff,
                   model_range=model_range,
                   kappa=kappa,
                   beta=beta))
    for i, l in enumerate(
            model if isinstance(model, BoundSequential) else model.module):
        if isinstance(l, BoundLinear) or isinstance(l, BoundConv2d):
            norm = l.weight.data.detach().view(l.weight.size(0),
                                               -1).abs().sum(1).max().cpu()
            logger.log('layer {} norm {}'.format(i, norm))
    if method == "natural":
        return errors.avg, errors.avg
    else:
        return robust_errors.avg, errors.avg
コード例 #3
0
def Train(model, t, loader, start_eps, end_eps, max_eps, logger, verbose, train, opt, method, **kwargs):
    # if train=True, use training mode
    # if train=False, use test mode, no back prop
    num_class = 10
    losses = AverageMeter()
    l1_losses = AverageMeter()
    errors = AverageMeter()
    robust_errors = AverageMeter()
    regular_ce_losses = AverageMeter()
    robust_ce_losses = AverageMeter()
    relu_activities = AverageMeter()
    bound_bias = AverageMeter()
    bound_diff = AverageMeter()
    unstable_neurons = AverageMeter()
    dead_neurons = AverageMeter()
    alive_neurons = AverageMeter()
    batch_time = AverageMeter()
    # initial 
    kappa = 1
    factor = 1
    if train:
        model.train()
    else:
        model.eval()
    # pregenerate the array for specifications, will be used for scatter
    sa = np.zeros((num_class, num_class - 1), dtype = np.int32)
    for i in range(sa.shape[0]):
        for j in range(sa.shape[1]):
            if j < i:
                sa[i][j] = j
            else:
                sa[i][j] = j + 1
    sa = torch.LongTensor(sa)
    total = len(loader.dataset)
    batch_size = loader.batch_size
    std = torch.tensor(loader.std).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)

    batch_eps = np.linspace(start_eps, end_eps, (total // batch_size) + 1)
    model_range = 0.0
    if end_eps < 1e-6:
        logger.log('eps {} close to 0, using natural training'.format(end_eps))
        method = "natural"
    for i, (data, labels) in enumerate(loader): 
        start = time.time()
        eps = batch_eps[i]
        if train:   
            opt.zero_grad()
        # generate specifications
        c = torch.eye(num_class).type_as(data)[labels].unsqueeze(1) - torch.eye(num_class).type_as(data).unsqueeze(0) 
        # remove specifications to self
        I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(labels.data).unsqueeze(0)))
        c = (c[I].view(data.size(0),num_class-1,num_class))
        # scatter matrix to avoid compute margin to self
        sa_labels = sa[labels]
        # storing computed lower bounds after scatter
        lb_s = torch.zeros(data.size(0), num_class)

        # FIXME: Assume data is from range 0 - 1
        if kwargs["bounded_input"]:
            assert loader.std == [1,1,1] or loader.std == [1]
            data_ub = (data + eps).clamp(max=1.0)
            data_lb = (data - eps).clamp(min=0.0)
        else:
            data_ub = data + (eps / std)
            data_lb = data - (eps / std)

        if list(model.parameters())[0].is_cuda:
            data = data.cuda()
            data_ub = data_ub.cuda()
            data_lb = data_lb.cuda()
            labels = labels.cuda()
            c = c.cuda()
            sa_labels = sa_labels.cuda()
            lb_s = lb_s.cuda()
        # convert epsilon to a tensor
        eps_tensor = data.new(1)
        eps_tensor[0] = eps

        # omit the regular cross entropy, since we use robust error
        output = model(data)
        regular_ce = CrossEntropyLoss()(output, labels)
        regular_ce_losses.update(regular_ce.cpu().detach().numpy(), data.size(0))
        errors.update(torch.sum(torch.argmax(output, dim=1)!=labels).cpu().detach().numpy()/data.size(0), data.size(0))
        # get range statistic
        model_range = output.max().detach().cpu().item() - output.min().detach().cpu().item()
        
        """
        ub, lb, _, _, _, _ = model.interval_range(data_lb, data_ub, c)
        lb = lb_s.scatter(1, sa_labels, lb)
        print('interval ub: ', ub)
        print('interval lb: ', lb)
        lb, _ = model.backward_range(data_lb, data_ub, c)
        lb = lb_s.scatter(1, sa_labels, lb)
        print('full lb: ', lb)
        input()
        """

        if verbose or method != "natural":
            if kwargs["bound_type"] == "convex-adv":
                # Wong and Kolter's bound, or equivalently Fast-Lin
                if kwargs["convex-proj"] is not None:
                    proj = kwargs["convex-proj"]
                    norm_type = "l1_median"
                else:
                    proj = None
                    norm_type = "l1"
                if loader.std == [1] or loader.std == [1, 1, 1]:
                    convex_eps = eps
                else:
                    convex_eps = eps / np.mean(loader.std)
                    # for CIFAR we are roughly / 0.2
                    # FIXME this is due to a bug in convex_adversarial, we cannot use per-channel eps
                if kwargs["bounded_input"]:
                    # FIXME the bounded projection in convex_adversarial has a bug, data range must be positive
                    data_l = 0.0
                    data_u = 1.0
                else:
                    data_l = -np.inf
                    data_u = np.inf
                f = DualNetwork(model, data, convex_eps, proj = proj, norm_type = norm_type, bounded_input = kwargs["bounded_input"], data_l = data_l, data_u = data_u)
                lb = f(c)
            elif kwargs["bound_type"] == "interval":
                ub, lb, relu_activity, unstable, dead, alive = model.interval_range(data_lb, data_ub, c)
            elif kwargs["bound_type"] == "crown-interval":
                ub, ilb, relu_activity, unstable, dead, alive = model.interval_range(data_lb, data_ub, c)
                crown_final_factor = kwargs['final-beta']
                factor = (max_eps - eps * (1.0 - crown_final_factor)) / max_eps
                if factor < 1e-5:
                    lb = ilb
                else:
                    if kwargs["runnerup_only"]:
                        # regenerate a smaller c, with just the runner-up prediction
                        # mask ground truthlabel output, select the second largest class
                        # print(output)
                        # torch.set_printoptions(threshold=5000)
                        masked_output = output.detach().scatter(1, labels.unsqueeze(-1), -100)
                        # print(masked_output)
                        # location of the runner up prediction
                        runner_up = masked_output.max(1)[1]
                        # print(runner_up)
                        # print(labels)
                        # get margin from the groud-truth to runner-up only
                        runnerup_c = torch.eye(num_class).type_as(data)[labels]
                        # print(runnerup_c)
                        # set the runner up location to -
                        runnerup_c.scatter_(1, runner_up.unsqueeze(-1), -1)
                        runnerup_c = runnerup_c.unsqueeze(1).detach()
                        # print(runnerup_c)
                        # get the bound for runnerup_c
                        clb, bias = model.backward_range(data_lb, data_ub, runnerup_c)
                        clb = clb.expand(clb.size(0), num_class - 1)
                    else:
                        # get the CROWN bound using interval bounds
                        clb, bias = model.backward_range(data_lb, data_ub, c)
                        bound_bias.update(bias.sum() / data.size(0))
                    # how much better is crown-ibp better than ibp?
                    diff = (clb - ilb).sum().item()
                    bound_diff.update(diff / data.size(0), data.size(0))
                    # lb = torch.max(lb, clb)
                    lb = clb * factor + ilb * (1 - factor)
            else:
                raise RuntimeError("Unknown bound_type " + kwargs["bound_type"])

            lb = lb_s.scatter(1, sa_labels, lb)
            robust_ce = CrossEntropyLoss()(-lb, labels)
            if kwargs["bound_type"] != "convex-adv":
                relu_activities.update(relu_activity.detach().cpu().item() / data.size(0), data.size(0))
                unstable_neurons.update(unstable / data.size(0), data.size(0))
                dead_neurons.update(dead / data.size(0), data.size(0))
                alive_neurons.update(alive / data.size(0), data.size(0))

        if method == "robust":
            loss = robust_ce
        elif method == "robust_activity":
            loss = robust_ce + kwargs["activity_reg"] * relu_activity
        elif method == "natural":
            loss = regular_ce
        elif method == "robust_natural":
            natural_final_factor = kwargs["final-kappa"]
            kappa = (max_eps - eps * (1.0 - natural_final_factor)) / max_eps
            loss = (1-kappa) * robust_ce + kappa * regular_ce
        else:
            raise ValueError("Unknown method " + method)

        if "l1_reg" in kwargs:
            reg = kwargs["l1_reg"]
            l1_loss = 0.0
            for name, param in model.named_parameters():
                if 'bias' not in name:
                    l1_loss = l1_loss + (reg * torch.sum(torch.abs(param)))
            loss = loss + l1_loss
            l1_losses.update(l1_loss.cpu().detach().numpy(), data.size(0))
        if train:
            loss.backward()
            opt.step()

        batch_time.update(time.time() - start)
        losses.update(loss.cpu().detach().numpy(), data.size(0))

        if verbose or method != "natural":
            robust_ce_losses.update(robust_ce.cpu().detach().numpy(), data.size(0))
            # robust_ce_losses.update(robust_ce, data.size(0))
            robust_errors.update(torch.sum((lb<0).any(dim=1)).cpu().detach().numpy() / data.size(0), data.size(0))
        if i % 50 == 0 and train:
            logger.log(  '[{:2d}:{:4d}]: eps {:4f}  '
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                    'Total Loss {loss.val:.4f} ({loss.avg:.4f})  '
                    'L1 Loss {l1_loss.val:.4f} ({l1_loss.avg:.4f})  '
                    'CE {regular_ce_loss.val:.4f} ({regular_ce_loss.avg:.4f})  '
                    'RCE {robust_ce_loss.val:.4f} ({robust_ce_loss.avg:.4f})  '
                    'Err {errors.val:.4f} ({errors.avg:.4f})  '
                    'Rob Err {robust_errors.val:.4f} ({robust_errors.avg:.4f})  '
                    'Uns {unstable.val:.1f} ({unstable.avg:.1f})  '
                    'Dead {dead.val:.1f} ({dead.avg:.1f})  '
                    'Alive {alive.val:.1f} ({alive.avg:.1f})  '
                    'Tightness {tight.val:.5f} ({tight.avg:.5f})  '
                    'Bias {bias.val:.5f} ({bias.avg:.5f})  '
                    'Diff {diff.val:.5f} ({diff.avg:.5f})  '
                    'R {model_range:.3f}  '
                    'beta {factor:.3f} ({factor:.3f})  '
                    'kappa {kappa:.3f} ({kappa:.3f})  '.format(
                    t, i, eps, batch_time=batch_time,
                    loss=losses, errors=errors, robust_errors = robust_errors, l1_loss = l1_losses,
                    regular_ce_loss = regular_ce_losses, robust_ce_loss = robust_ce_losses, 
                    unstable = unstable_neurons, dead = dead_neurons, alive = alive_neurons,
                    tight = relu_activities, bias = bound_bias, diff = bound_diff,
                    model_range = model_range, 
                    factor=factor, kappa = kappa))
    
                    
    logger.log(  '[FINAL RESULT epoch:{:2d} eps:{:.4f}]: '
        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})  '
        'Total Loss {loss.val:.4f} ({loss.avg:.4f})  '
        'L1 Loss {l1_loss.val:.4f} ({l1_loss.avg:.4f})  '
        'CE {regular_ce_loss.val:.4f} ({regular_ce_loss.avg:.4f})  '
        'RCE {robust_ce_loss.val:.4f} ({robust_ce_loss.avg:.4f})  '
        'Uns {unstable.val:.3f} ({unstable.avg:.3f})  '
        'Dead {dead.val:.1f} ({dead.avg:.1f})  '
        'Alive {alive.val:.1f} ({alive.avg:.1f})  '
        'Tight {tight.val:.5f} ({tight.avg:.5f})  '
        'Bias {bias.val:.5f} ({bias.avg:.5f})  '
        'Diff {diff.val:.5f} ({diff.avg:.5f})  '
        'Err {errors.val:.4f} ({errors.avg:.4f})  '
        'Rob Err {robust_errors.val:.4f} ({robust_errors.avg:.4f})  '
        'R {model_range:.3f}  '
        'beta {factor:.3f} ({factor:.3f})  '
        'kappa {kappa:.3f} ({kappa:.3f})  \n'.format(
        t, eps, batch_time=batch_time,
        loss=losses, errors=errors, robust_errors = robust_errors, l1_loss = l1_losses,
        regular_ce_loss = regular_ce_losses, robust_ce_loss = robust_ce_losses, 
        unstable = unstable_neurons, dead = dead_neurons, alive = alive_neurons,
        tight = relu_activities, bias = bound_bias, diff = bound_diff,
        model_range = model_range, 
        kappa = kappa, factor=factor))
    for i, l in enumerate(model):
        if isinstance(l, BoundLinear) or isinstance(l, BoundConv2d):
            norm = l.weight.data.detach().view(l.weight.size(0), -1).abs().sum(1).max().cpu()
            logger.log('layer {} norm {}'.format(i, norm))
    if method == "natural":
        return errors.avg, errors.avg
    else:
        return robust_errors.avg, errors.avg
コード例 #4
0
ファイル: warm_up_training.py プロジェクト: JmfanBU/AdvIBP
def epoch_train(model,
                t,
                loader,
                eps_scheduler,
                max_eps,
                norm,
                logger,
                verbose,
                train,
                opt,
                method,
                layer_idx=0,
                c_t=None,
                post_warm_up_scheduler=None,
                moment_grad=None,
                **kwargs):
    # if train=True, use training mode
    # if train=False, use test mode, no back prop
    num_class = 10
    losses = AverageMeter()
    errors = AverageMeter()
    adv_errors = AverageMeter()
    robust_errors = AverageMeter()
    regular_ce_losses = AverageMeter()
    robust_ce_losses = AverageMeter()
    batch_time = AverageMeter()
    batch_multiplier = kwargs.get("batch_multiplier", 1)
    coeff1, coeff2 = 1, 0
    beta = 1
    optimal = False
    c_eval = None
    g1_norm = 0

    if train:
        model.train()
    else:
        model.eval()
    # pregenerate the array for specifications, will be used for scatter
    sa = np.zeros((num_class, num_class - 1), dtype=np.int32)
    for i in range(sa.shape[0]):
        for j in range(sa.shape[1]):
            if j < i:
                sa[i][j] = j
            else:
                sa[i][j] = j + 1
    sa = torch.LongTensor(sa)
    batch_size = loader.batch_size * batch_multiplier
    if batch_multiplier > 1 and train:
        logger.log("Warning: Large batch training. The equivalent batch size "
                   "is {} * {} = {}.".format(batch_multiplier,
                                             loader.batch_size, batch_size))
    # per-channel std and mean
    std = torch.tensor(loader.std).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
    mean = torch.tensor(loader.mean).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)

    model_range = 0.0
    end_eps = eps_scheduler.get_eps(t + 1, 0)
    end_post_warm_up_eps = post_warm_up_scheduler.get_eps(t + 1, 0)
    if end_eps < np.finfo(np.float32).tiny and \
            end_post_warm_up_eps < np.finfo(np.float32).tiny:
        logger.log("eps {} close to 0, using natural training".format(end_eps))
        method = "natural"
    elif end_post_warm_up_eps < np.finfo(np.float32).tiny:
        logger.log("adversarial training warm up phase")
        method = "warm_up"
    if kwargs["adversarial_training"]:
        attack = LinfPGDAttack(model,
                               kwargs.get("epsilon", max_eps),
                               kwargs["attack_steps"],
                               kwargs["attack_stepsize"],
                               kwargs["random_start"],
                               kwargs["loss_func"],
                               mean=mean,
                               std=std)

    pbar = tqdm(loader)
    for i, (data, labels) in enumerate(pbar):
        start = time.time()
        eps = eps_scheduler.get_eps(t, int(i // batch_multiplier))
        post_warm_up_eps = post_warm_up_scheduler.get_eps(
            t, int(i // batch_multiplier))
        if train and i % batch_multiplier == 0:
            opt.zero_grad()
        # upper bound matrix mask
        c = torch.eye(num_class).type_as(data)[labels].unsqueeze(1) - \
            torch.eye(num_class).type_as(data).unsqueeze(0)
        # remove ground truth itself
        I_c = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(
            labels.data).unsqueeze(0)))
        c = c[I_c].view(data.size(0), num_class - 1, num_class)
        # scatter matrix to avoid compute margin to itself
        sa_labels = sa[labels]
        # storing computed upper and lower bounds after scatter
        lb_s = torch.zeros(data.size(0), num_class)
        ub_s = torch.zeros(data.size(0), num_class)

        if kwargs["bounded_input"]:
            # provided data is from range 0 - 1
            if norm != np.inf:
                raise ValueError(
                    "Bounded input only makes sense for Linf perturbation."
                    "Please set the bounded_input option to false.")
            data_max = torch.reshape((1. - mean) / std, (1, -1, 1, 1))
            data_min = torch.reshape((0. - mean) / std, (1, -1, 1, 1))
            data_ub = torch.min(data + (post_warm_up_eps / std), data_max)
            data_lb = torch.max(data - (post_warm_up_eps / std), data_min)
        else:
            if norm == np.inf:
                data_ub = data + (post_warm_up_eps / std)
                data_lb = data - (post_warm_up_eps / std)
            else:
                # For other norms, eps will be used instead
                data_ub = data_lb = data

        if list(model.parameters())[0].is_cuda:
            data = data.cuda(device)
            data_ub = data_ub.cuda(device)
            data_lb = data_lb.cuda(device)
            labels = labels.cuda(device)
            c = c.cuda(device)
            sa_labels = sa_labels.cuda(device)
            lb_s = lb_s.cuda(device)
            ub_s = ub_s.cuda(device)

        # omit the regular cross entropy, since we use robust error
        if kwargs["adversarial_training"] and method != "natural" and \
                method != "warm_up":
            output = model(data,
                           method_opt="forward",
                           disable_multi_gpu=(method == "natural"))
            if layer_idx != 0 and train:
                layer_ub, layer_lb = model(norm=norm,
                                           x_U=data_ub,
                                           x_L=data_lb,
                                           eps=post_warm_up_eps,
                                           layer_idx=layer_idx,
                                           method_opt="interval_range",
                                           intermediate=True)
                layer_center, epsilon = intermediate_eps(layer_ub, layer_lb)
                layer_eps = epsilon
                data_adv, c_eval = attack.perturb(layer_center,
                                                  labels,
                                                  epsilon=layer_eps,
                                                  layer_idx=layer_idx,
                                                  c_t=c_t)
                output_adv = model(data_adv,
                                   method_opt="forward",
                                   layer_idx=layer_idx,
                                   disable_multi_gpu=(method == "natural"))
            else:
                data_adv, c_eval = attack.perturb(
                    data,
                    labels,
                    epsilon=eps,
                    layer_idx=layer_idx if train else 0,
                    c_t=c_t)
                output_adv = model(data_adv,
                                   method_opt="forward",
                                   disable_multi_gpu=(method == "natural"))
            # lower bound for adv training
            regular_ce = CrossEntropyLoss()(output_adv, labels)
        elif method == "warm_up":
            output = model(data,
                           method_opt="forward",
                           disable_multi_gpu=(method == "natural"))
            data_adv, c_eval = attack.perturb(data,
                                              labels,
                                              epsilon=post_warm_up_eps,
                                              layer_idx=layer_idx,
                                              c_t=c_t)
            output_adv = model(data_adv,
                               method_opt="forward",
                               disable_multi_gpu=(method == "natural"))
            regular_ce = CrossEntropyLoss()(output_adv, labels)
        else:
            output = model(data,
                           method_opt="forward",
                           disable_multi_gpu=(method == "natural"))
            regular_ce = CrossEntropyLoss()(output, labels)
        regular_ce_losses.update(regular_ce.cpu().detach().numpy(),
                                 data.size(0))
        errors.update(
            torch.sum(
                torch.argmax(output, dim=1) != labels).cpu().detach().numpy() /
            data.size(0), data.size(0))
        if kwargs["adversarial_training"] and method != "natural":
            adv_errors.update(
                torch.sum(torch.argmax(output_adv, dim=1) != labels).cpu().
                detach().numpy() / data.size(0), data.size(0))
        # get range statistics
        model_range = output.max().detach().cpu().item() - \
            output.min().detach().cpu().item()

        if verbose or (method != "natural" and method != "warm_up"):
            if kwargs["bound_type"] == "interval":
                ub, lb = model(norm=norm,
                               x_U=data_ub,
                               x_L=data_lb,
                               eps=post_warm_up_eps,
                               C=c,
                               layer_idx=0,
                               method_opt="interval_range")
            elif kwargs["bound_type"] == "crown-interval":
                ub, ilb = model(norm=norm,
                                x_U=data_ub,
                                x_L=data_lb,
                                eps=post_warm_up_eps,
                                C=c,
                                layer_idx=0,
                                method_opt="interval_range")
                crown_final_beta = kwargs['final-beta']
                beta = (max_eps - post_warm_up_eps *
                        (1. - crown_final_beta)) / max_eps if train else 0.
                if beta < 1e-5:
                    lb = ilb
                else:
                    # get the CROWN bound using interval bopunds
                    _, _, clb, bias = model(norm=norm,
                                            x_U=data_ub,
                                            x_L=data_lb,
                                            eps=post_warm_up_eps,
                                            C=c,
                                            method_opt='backward_range')
                    lb = clb * beta + ilb * (1 - beta)
            else:
                raise RuntimeError("Unknown bound_type " +
                                   kwargs["bound_type"])
            lb = lb_s.scatter(1, sa_labels, lb)
            # upper bound for adv training
            robust_ce = CrossEntropyLoss()(-lb, labels)

        if method == "robust":
            if train:
                regular_grads = flat_grad(model, regular_ce)
                robust_grads = flat_grad(model, robust_ce)
                if moment_grad is None:
                    coeff1, coeff2, optimal = two_obj_gradient(regular_grads,
                                                               robust_grads,
                                                               c_eval=c_eval,
                                                               c_t=c_t)
                else:
                    if post_warm_up_eps == max_eps:
                        post_warm_up = True
                    else:
                        post_warm_up = False
                    coeff1, coeff2, optimal, g1_norm = moment_grad.compute_coeffs(
                        regular_grads,
                        robust_grads,
                        c_eval=c_eval,
                        c_t=c_t,
                        post_warm_up=post_warm_up)
                if post_warm_up and optimal == 'opposite dir':
                    loss = coeff1 * regular_ce + coeff2 * robust_ce \
                        + 0.5 * robust_ce.pow(2)
                elif post_warm_up:
                    loss = coeff1 * regular_ce + coeff2 * robust_ce
                else:
                    # warm up with the crown-ibp bounds
                    loss = robust_ce
                model.zero_grad()
            else:
                loss = coeff1 * regular_ce + coeff2 * robust_ce
        elif method == "natural" or method == "warm_up":
            loss = regular_ce
        elif method == 'baseline':
            loss = regular_ce + robust_ce
        else:
            raise ValueError("Unknown method " + method)

        if train:
            loss.backward()
            if i % batch_multiplier == 0 or i == len(loader) - 1:
                opt.step()

        losses.update(loss.cpu().detach().numpy(), data.size(0))

        if verbose or (method != "natural" and method != "warm_up"):
            robust_ce_losses.update(robust_ce.cpu().detach().numpy(),
                                    data.size(0))
            robust_errors.update(
                torch.sum(
                    (lb < 0).any(dim=1)).cpu().detach().numpy() / data.size(0),
                data.size(0))

        batch_time.update(time.time() - start)

        if train:
            if c_eval is not None:
                pbar.set_description(
                    'Epoch: {}, eps: {:.3g}, c_eval: {:.3g}, '
                    'grad1_norm: {:.4g}, '
                    'coeff1: {:.2g}, coeff2: {:.2g}, '
                    'optimal: {}, R: {model_range:.2f}'.format(
                        t,
                        eps,
                        c_eval,
                        g1_norm,
                        coeff1,
                        coeff2,
                        optimal,
                        model_range=model_range,
                    ))
            else:
                pbar.set_description(
                    'Epoch: {}, eps: {:.3g}, '
                    'coeff1: {:.2g}, coeff2: {:.2g}, '
                    'optimal: {}, R: {model_range:.2f}'.format(
                        t,
                        eps,
                        coeff1,
                        coeff2,
                        optimal,
                        model_range=model_range,
                    ))
        else:
            pbar.set_description('Epoch: {}, eps: {:.3g}, '
                                 'Robust loss: {rb_loss.val:.2f}, '
                                 'Err: {errors.val:.3f}, '
                                 'Rob Err: {robust_errors.val:.3f}'.format(
                                     t,
                                     eps,
                                     model_range=model_range,
                                     rb_loss=robust_ce_losses,
                                     errors=errors,
                                     adv_errors=adv_errors,
                                     robust_errors=robust_errors))
    if kwargs["bound_type"] == "crown-interval":
        logger.log('----------Summary----------\n'
                   'Reguler loss: {re_loss.avg:.2f}, '
                   'Robust loss: {rb_loss.avg:.2f}, '
                   'Beta: {beta:.2f}, '
                   'Err: {errors.avg:.3f}, '
                   'Adv Err: {adv_errors.avg:.3f}, '
                   'Rob Err: {robust_errors.avg:.3f},  '
                   'R: {model_range:.2f}'.format(loss=losses,
                                                 errors=errors,
                                                 model_range=model_range,
                                                 robust_errors=robust_errors,
                                                 adv_errors=adv_errors,
                                                 re_loss=regular_ce_losses,
                                                 rb_loss=robust_ce_losses,
                                                 beta=beta))
    else:
        logger.log('----------Summary----------\n'
                   'Reguler loss: {re_loss.avg:.2f}, '
                   'Robust loss: {rb_loss.avg:.2f}, '
                   'Err: {errors.avg:.3f}, '
                   'Adv Err: {adv_errors.avg:.3f}, '
                   'Rob Err: {robust_errors.avg:.3f},  '
                   'R: {model_range:.2f}'.format(loss=losses,
                                                 errors=errors,
                                                 model_range=model_range,
                                                 robust_errors=robust_errors,
                                                 adv_errors=adv_errors,
                                                 re_loss=regular_ce_losses,
                                                 rb_loss=robust_ce_losses))
    if method == "natural" or method == "warm_up":
        return errors.avg, errors.avg
    else:
        return robust_errors.avg, errors.avg
コード例 #5
0
ファイル: train_general.py プロジェクト: yeshaokai/auto_LiRPA
def Train(model, t, loader, start_eps, end_eps, max_eps, weights_eps_start,
          weights_eps_end, norm, logger, verbose, train, opt, method,
          **kwargs):
    # if train=True, use training mode
    # if train=False, use test mode, no back prop
    num_class = 10
    losses = AverageMeter()
    l1_losses = AverageMeter()
    errors = AverageMeter()
    robust_errors = AverageMeter()
    regular_ce_losses = AverageMeter()
    robust_ce_losses = AverageMeter()
    batch_time = AverageMeter()
    # initial
    if train:
        model.train()
    else:
        model.eval()
    # pregenerate the array for specifications, will be used for scatter
    sa = np.zeros((num_class, num_class - 1), dtype=np.int32)
    for i in range(sa.shape[0]):
        for j in range(sa.shape[1]):
            if j < i:
                sa[i][j] = j
            else:
                sa[i][j] = j + 1
    sa = torch.LongTensor(sa)
    total = len(loader.dataset)
    batch_size = loader.batch_size
    std = torch.tensor(loader.std).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)

    batch_eps = np.linspace(start_eps, end_eps, (total // batch_size) + 1)
    batch_weights_eps = np.zeros(
        ((total // batch_size) + 1, len(weights_eps_start)))
    for _i in range(len(weights_eps_start)):
        batch_weights_eps[:, _i] = np.linspace(weights_eps_start[_i],
                                               weights_eps_end[_i],
                                               (total // batch_size) + 1)
    model_range = 0.0
    if batch_weights_eps[-1, 0] == 0:
        logger.log('eps {} close to 0, using natural training'.format(end_eps))
        method = "natural"
    if train:
        opt.zero_grad()
    for i, (data, labels) in enumerate(loader):
        torch.cuda.empty_cache()
        if kwargs["bound_type"] == "weights-crown":
            data = data.reshape(data.shape[0], -1)
        start = time.time()
        eps = batch_eps[i]
        weights_eps = batch_weights_eps[i]
        # print(i, weights_eps, batch_eps)
        # if train:
        #     if not Grad_accum or i % Grad_accum_step == 0:
        #         # print('normal training without grad accsum')
        #         opt.zero_grad()
        # generate specifications
        c = torch.eye(num_class).type_as(data)[labels].unsqueeze(
            1) - torch.eye(num_class).type_as(data).unsqueeze(0)
        # remove specifications to self
        I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(
            labels.data).unsqueeze(0)))
        c = (c[I].view(data.size(0), num_class - 1, num_class))
        # scatter matrix to avoid compute margin to self
        sa_labels = sa[labels]
        # storing computed lower bounds after scatter
        lb_s = torch.zeros(data.size(0), num_class)
        ub_s = torch.zeros(data.size(0), num_class)

        # FIXME: Assume data is from range 0 - 1
        if kwargs["bounded_input"]:
            assert loader.std == [1, 1, 1] or loader.std == [1]
            if norm != np.inf:
                raise ValueError(
                    "bounded input only makes sense for Linf perturbation. "
                    "Please set the bounded_input option to false.")
            data_ub = (data + eps).clamp(max=1.0)
            data_lb = (data - eps).clamp(min=0.0)
        else:
            if norm == np.inf:
                data_ub = data + (eps / std)
                data_lb = data - (eps / std)
            else:
                data_ub = data_lb = data

        if list(model.parameters())[0].is_cuda:
            data = data.cuda()
            data_ub = data_ub.cuda()
            data_lb = data_lb.cuda()
            labels = labels.cuda()
            c = c.cuda()
            sa_labels = sa_labels.cuda()
            lb_s = lb_s.cuda()
            ub_s = ub_s.cuda()
        # convert epsilon to a tensor
        eps_tensor = data.new(1)
        eps_tensor[0] = eps

        # omit the regular cross entropy, since we use robust error
        output = model(data)
        regular_ce = CrossEntropyLoss()(output, labels)
        regular_ce_losses.update(regular_ce.cpu().detach().numpy(),
                                 data.size(0))
        errors.update(
            torch.sum(
                torch.argmax(output, dim=1) != labels).cpu().detach().numpy() /
            data.size(0), data.size(0))
        # get range statistic
        model_range = output.max().detach().cpu().item() - output.min().detach(
        ).cpu().item()

        if kwargs["bound_type"] == "weights-crown":
            ptb = PerturbationLpNorm_2bounds(norm=norm, eps=eps)
        else:
            ptb = PerturbationLpNorm(norm=norm, eps=eps)

        if verbose or method != "natural":
            if kwargs["bound_type"] == "interval":
                lb, ub = model.compute_bounds(ptb=ptb,
                                              IBP=True,
                                              x=data,
                                              C=c,
                                              method=None)
            elif kwargs["bound_type"] == "crown-full":
                lb, ub = model.compute_bounds(ptb=ptb,
                                              IBP=False,
                                              x=data,
                                              C=c,
                                              method="backward")
            elif kwargs["bound_type"] == "weights-crown":
                lb, ub = model.weights_full_backward_range(ptb=ptb,
                                                           norm=norm,
                                                           x=data,
                                                           C=c,
                                                           eps=eps,
                                                           w_eps=weights_eps)
            elif kwargs["bound_type"] == "crown-interval":
                lb, ub = model.compute_bounds(ptb=ptb,
                                              IBP=True,
                                              x=data,
                                              C=c,
                                              method="backward")
            else:
                raise RuntimeError("Unknown bound_type " +
                                   kwargs["bound_type"])

            lb = lb_s.scatter(1, sa_labels, lb)
            robust_ce = CrossEntropyLoss()(-lb, labels)

        if method == "robust":
            loss = robust_ce
        elif method == "natural":
            loss = regular_ce
        else:
            raise ValueError("Unknown method " + method)

        if "l1_reg" in kwargs:
            reg = kwargs["l1_reg"]
            l1_loss = 0.0
            for name, param in model.named_parameters():
                if 'bias' not in name:
                    l1_loss = l1_loss + (reg * torch.sum(torch.abs(param)))
            loss = loss + l1_loss
            l1_losses.update(l1_loss.cpu().detach().numpy(), data.size(0))
        if train:
            loss.backward()
            if not Grad_accum or (
                    i + 1) % Grad_accum_step == 0 or i == len(loader) - 1:
                opt.step()
                opt.zero_grad()

        losses.update(loss.cpu().detach().numpy(), data.size(0))

        if verbose or method != "natural":
            robust_ce_losses.update(robust_ce.cpu().detach().numpy(),
                                    data.size(0))
            # robust_ce_losses.update(robust_ce, data.size(0))
            robust_errors.update(
                torch.sum(
                    (lb < 0).any(dim=1)).cpu().detach().numpy() / data.size(0),
                data.size(0))

        avg_weights = model.choices[0].weight.data.cpu().numpy()

        batch_time.update(time.time() - start)
        if i % 50 == 0 and train:
            logger.log(
                '[{:2d}:{:4d}]: eps {:4f}  '
                'Time {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                'Total Loss {loss.val:.4f} ({loss.avg:.4f})  '
                'L1 Loss {l1_loss.val:.4f} ({l1_loss.avg:.4f})  '
                'CE {regular_ce_loss.val:.4f} ({regular_ce_loss.avg:.4f})  '
                'RCE {robust_ce_loss.val:.4f} ({robust_ce_loss.avg:.4f})  '
                'Err {errors.val:.4f} ({errors.avg:.4f})  '
                'Rob Err {robust_errors.val:.4f} ({robust_errors.avg:.4f})  '
                'R {model_range:.3f}  '
                'layer1 {avg_weights:.3f} ({range_weights:.3f})  '.format(
                    t,
                    i,
                    eps,
                    batch_time=batch_time,
                    loss=losses,
                    errors=errors,
                    robust_errors=robust_errors,
                    l1_loss=l1_losses,
                    regular_ce_loss=regular_ce_losses,
                    robust_ce_loss=robust_ce_losses,
                    model_range=model_range,
                    avg_weights=np.abs(avg_weights).mean(),
                    range_weights=np.ptp(avg_weights)))

    # if Grad_accum and train:
    #     opt.step()

    logger.log('[FINAL RESULT epoch:{:2d} eps:{:.4f}]: '
               'Time {batch_time.val:.3f} ({batch_time.avg:.3f})  '
               'Total Loss {loss.val:.4f} ({loss.avg:.4f})  '
               'L1 Loss {l1_loss.val:.4f} ({l1_loss.avg:.4f})  '
               'CE {regular_ce_loss.val:.4f} ({regular_ce_loss.avg:.4f})  '
               'RCE {robust_ce_loss.val:.4f} ({robust_ce_loss.avg:.4f})  '
               'Err {errors.val:.4f} ({errors.avg:.4f})  '
               'Rob Err {robust_errors.val:.4f} ({robust_errors.avg:.4f})  '
               'R {model_range:.3f}  '
               'layer1 {avg_weights:.3f} ({range_weights:.3f})  \n'.format(
                   t,
                   eps,
                   batch_time=batch_time,
                   loss=losses,
                   errors=errors,
                   robust_errors=robust_errors,
                   l1_loss=l1_losses,
                   regular_ce_loss=regular_ce_losses,
                   robust_ce_loss=robust_ce_losses,
                   model_range=model_range,
                   avg_weights=np.abs(avg_weights).mean(),
                   range_weights=np.ptp(avg_weights)))
    # for i, l in enumerate(model.module()):
    #     if isinstance(l, BoundLinear) or isinstance(l, BoundConv2d):
    #         norm = l.weight.data.detach().view(l.weight.size(0), -1).abs().sum(1).max().cpu()
    #         logger.log('layer {} norm {}'.format(i, norm))
    if method == "natural":
        return errors.avg, errors.avg
    else:
        return robust_errors.avg, errors.avg
コード例 #6
0
def Train(model,
          t,
          loader,
          start_eps,
          end_eps,
          max_eps,
          norm,
          train,
          opt,
          bound_type,
          method='robust'):
    num_class = 10
    meter = MultiAverageMeter()
    if train:
        model.train()
    else:
        model.eval()
    # Pre-generate the array for specifications, will be used latter for scatter
    sa = np.zeros((num_class, num_class - 1), dtype=np.int32)
    for i in range(sa.shape[0]):
        for j in range(sa.shape[1]):
            if j < i:
                sa[i][j] = j
            else:
                sa[i][j] = j + 1
    sa = torch.LongTensor(sa)
    total = len(loader.dataset)
    batch_size = loader.batch_size

    # Increase epsilon batch by batch
    batch_eps = np.linspace(start_eps, end_eps, (total // batch_size) + 1)
    # For small eps just use natural training, no need to compute LiRPA bounds
    if end_eps < 1e-6: method = "natural"

    for i, (data, labels) in enumerate(loader):
        start = time.time()
        eps = batch_eps[i]
        if train:
            opt.zero_grad()
        # generate specifications
        c = torch.eye(num_class).type_as(data)[labels].unsqueeze(
            1) - torch.eye(num_class).type_as(data).unsqueeze(0)
        # remove specifications to self
        I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(
            labels.data).unsqueeze(0)))
        c = (c[I].view(data.size(0), num_class - 1, num_class))
        # scatter matrix to avoid compute margin to self
        sa_labels = sa[labels]
        # storing computed lower bounds after scatter
        lb_s = torch.zeros(data.size(0), num_class)
        # bound input for Linf norm used only
        if norm == np.inf:
            data_ub = (data + eps).clamp(max=1.0)
            data_lb = (data - eps).clamp(min=0.0)
        else:
            data_ub = data_lb = data

        if list(model.parameters())[0].is_cuda:
            data, labels, sa_labels, c, lb_s = data.cuda(), labels.cuda(
            ), sa_labels.cuda(), c.cuda(), lb_s.cuda()
            data_lb, data_ub = data_lb.cuda(), data_ub.cuda()

        output = model(data)
        regular_ce = CrossEntropyLoss()(
            output, labels)  # regular CrossEntropyLoss used for warming up
        meter.update('CE', regular_ce.cpu().detach().numpy(), data.size(0))
        meter.update(
            'Err',
            torch.sum(
                torch.argmax(output, dim=1) != labels).cpu().detach().numpy() /
            data.size(0), data.size(0))

        # Specify Lp norm perturbation.
        # When using Linf perturbation, we manually set element-wise bound x_L and x_U. eps is not used for Linf norm.
        ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=data_lb, x_U=data_ub)
        if method == "robust":
            if bound_type == "IBP":
                lb, ub = model.compute_bounds(ptb=ptb,
                                              IBP=True,
                                              x=data,
                                              C=c,
                                              method=None)
            elif bound_type == "CROWN":
                lb, ub = model.compute_bounds(ptb=ptb,
                                              IBP=False,
                                              x=data,
                                              C=c,
                                              method="backward")
            elif bound_type == "CROWN-IBP":
                # lb, ub = model.compute_bounds(ptb=ptb, IBP=True, x=data, C=c, method="backward")  # pure IBP bound
                # we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020)
                factor = (max_eps - eps) / max_eps
                ilb, iub = model.compute_bounds(ptb=ptb,
                                                IBP=True,
                                                x=data,
                                                C=c,
                                                method=None)
                if factor < 1e-5:
                    lb = ilb
                else:
                    clb, cub = model.compute_bounds(ptb=ptb,
                                                    IBP=False,
                                                    x=data,
                                                    C=c,
                                                    method="backward")
                    lb = clb * factor + ilb * (1 - factor)

            # Filling a missing 0 in lb. The margin from class j to itself is always 0 and not computed.
            lb = lb_s.scatter(1, sa_labels, lb)
            # Use the robust cross entropy loss objective (Wong & Kolter, 2018)
            robust_ce = CrossEntropyLoss()(-lb, labels)
        if method == "robust":
            loss = robust_ce
        elif method == "natural":
            loss = regular_ce
        if train:
            loss.backward()
            opt.step()
        meter.update('Loss', loss.cpu().detach().numpy(), data.size(0))
        if method != "natural":
            meter.update('Robust_CE',
                         robust_ce.cpu().detach().numpy(), data.size(0))
            # For an example, if lower bounds of margins is >0 for all classes, the output is verifiably correct.
            # If any margin is < 0 this example is counted as an error
            meter.update(
                'Verified_Err',
                torch.sum(
                    (lb < 0).any(dim=1)).cpu().detach().numpy() / data.size(0),
                data.size(0))
        meter.update('Time', time.time() - start)
        if i % 50 == 0 and train:
            print('[{:2d}:{:4d}]: eps={:4f} {}'.format(t, i, eps, meter))

    print('[FINAL RESULT] epoch={:2d} eps={:.4f} {}'.format(t, eps, meter))