def train(train_loader,
          model,
          m,
          criterion,
          optimizer,
          attack,
          device,
          epoch,
          callback=None):
    model.train()
    cudnn.benchmark = True
    length = len(train_loader)

    batch_time = utils.AverageMeter()
    losses = utils.AverageMeter()
    losses_adv = utils.AverageMeter()
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()
    adv_acc = utils.AverageMeter()
    l2_adv = utils.AverageMeter()

    end = time.time()
    for i, (data, labels) in enumerate(train_loader):
        data = data.to(device)
        labels = labels.to(device, non_blocking=True)

        if args.adv and epoch >= args.start_adv_epoch:
            model.eval()
            utils.requires_grad_(m, False)

            clean_logits = model(data)
            loss = criterion(clean_logits, labels)

            adv = attack.attack(model, data, labels)
            l2_norms = (adv - data).view(args.batch_size, -1).norm(2, 1)
            mean_norm = l2_norms.mean()
            if args.max_norm:
                adv = torch.renorm(
                    adv - data, p=2, dim=0, maxnorm=args.max_norm) + data
            l2_adv.append(mean_norm.item())

            utils.requires_grad_(m, True)
            model.train()

            adv_logits = model(adv.detach())
            loss_adv = criterion(adv_logits, labels)

            loss_to_optimize = loss_adv

            losses_adv.append(loss_adv.item())
            l2_adv.append((adv - data).view(args.batch_size,
                                            -1).norm(p=2, dim=1).mean().item())
            adv_acc.append((adv_logits.argmax(1) == labels).sum().item() /
                           args.batch_size)
        else:
            clean_logits = model(data)
            loss = criterion(clean_logits, labels)
            loss_to_optimize = loss

        optimizer.zero_grad()
        loss_to_optimize.backward()
        optimizer.step()

        # measure accuracy and record loss
        prec1, prec5 = utils.accuracy(clean_logits, labels, topk=(1, 5))
        losses.append(loss.item())
        top1.append(prec1)
        top5.append(prec5)

        # measure elapsed time
        batch_time.append(time.time() - end)
        end = time.time()

        if (i + 1) % args.print_freq == 0 or (i + 1) == length:

            if args.adv and epoch >= args.start_adv_epoch:
                print(
                    'Epoch: [{0:>2d}][{1:>3d}/{2:>3d}] Time {batch_time.last_avg:.3f}'
                    '\tLoss {loss.last_avg:.4f}\tAdv {loss_adv.last_avg:.4f}'
                    '\tPrec@1 {top1.last_avg:.3%}\tPrec@5 {top5.last_avg:.3%}'.
                    format(epoch,
                           i + 1,
                           len(train_loader),
                           batch_time=batch_time,
                           loss=losses,
                           loss_adv=losses_adv,
                           top1=top1,
                           top5=top5))
            else:
                print(
                    'Epoch: [{0:>2d}][{1:>3d}/{2:>3d}] Time {batch_time.last_avg:.3f}\tLoss {loss.last_avg:.4f}'
                    '\tPrec@1 {top1.last_avg:.3%}\tPrec@5 {top5.last_avg:.3%}'.
                    format(epoch,
                           i + 1,
                           len(train_loader),
                           batch_time=batch_time,
                           loss=losses,
                           top1=top1,
                           top5=top5))

            if callback:
                if args.adv and epoch >= args.start_adv_epoch:
                    callback.scalars(['train_loss', 'adv_loss'],
                                     i / length + epoch,
                                     [losses.last_avg, losses_adv.last_avg])
                    callback.scalars(
                        ['train_prec@1', 'train_prec@5', 'adv_acc'],
                        i / length + epoch, [
                            top1.last_avg * 100, top5.last_avg * 100,
                            adv_acc.last_avg * 100
                        ])
                    callback.scalar('adv_l2', i / length + epoch,
                                    l2_adv.last_avg)

                else:
                    callback.scalar('train_loss', i / length + epoch,
                                    losses.last_avg)
                    callback.scalars(
                        ['train_prec@1', 'train_prec@5'], i / length + epoch,
                        [top1.last_avg * 100, top5.last_avg * 100])
                lr=args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay)
scheduler = lr_scheduler.StepLR(optimizer,
                                step_size=args.lr_step,
                                gamma=args.lr_decay)

attacker = DDN(steps=args.steps, device=DEVICE)

best_acc = 0
best_epoch = 0

for epoch in range(args.epochs):
    cudnn.benchmark = True
    model.train()
    requires_grad_(model, True)
    accs = AverageMeter()
    losses = AverageMeter()
    attack_norms = AverageMeter()

    scheduler.step()
    length = len(train_loader)
    for i, (images, labels) in enumerate(tqdm.tqdm(train_loader, ncols=80)):
        images, labels = images.to(DEVICE), labels.to(DEVICE)

        if args.adv is not None and epoch >= args.adv:
            model.eval()
            requires_grad_(model, False)
            with torch.no_grad():
                accs.append(
                    (model(images).argmax(1) == labels).float().mean().item())
max_loss = torch.log(torch.tensor(10.)).item()  # for callback
best_acc = 0
best_epoch = 0

valacc_final = 0


max_loss = torch.log(torch.tensor(1000.)).item()  # for callback
best_acc = 0
best_epoch = 0

for epoch in range(args.epochs):
    scheduler.step()
    cudnn.benchmark = True
    model.train()
    requires_grad_(m, True)
    accs = AverageMeter()
    losses = AverageMeter()
    attack_norms = AverageMeter()
    # print("len(train_loader) ", len(train_loader))
    length = len(train_loader)
    widgets = ['train :', Percentage(), ' ', Bar('#'), ' ', Timer(),
               ' ', ETA(), ' ', FileTransferSpeed()]
    pbar = ProgressBar(widgets=widgets)
    for i,batch_data in enumerate(tqdm.tqdm(train_loader, ncols=80)):
        images, labels =batch_data['image'],batch_data['label_idx'].to(DEVICE)
        images = _jpeg_compression3(images).to(DEVICE)
        # imh_cam = images[0]
        # imh_cam = imh_cam.cpu().numpy()  # FloatTensor转为ndarray
        # imh_cam = np.transpose(imh_cam, (1, 2, 0))  # 把channel那一维放到最后
        # plt.imshow(imh_cam)