def train(train_loader, config, ADMM, criterion, optimizer, scheduler, epoch): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to train mode config.model.train() end = time.time() for i, (input, target) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) # adjust learning rate if config.admm: admm.admm_adjust_learning_rate(optimizer, epoch, config) else: scheduler.step() input = input.cuda(config.gpu, non_blocking=True) target = target.cuda(config.gpu) data = input if config.mixup: input, target_a, target_b, lam = mixup_data( input, target, config.alpha) # compute output output = config.model(input) if config.mixup: ce_loss = mixup_criterion(criterion, output, target_a, target_b, lam, config.smooth) else: ce_loss = criterion(output, target, smooth=config.smooth) if config.admm: admm.admm_update(config, ADMM, device, train_loader, optimizer, epoch, data, i) # update Z and U ce_loss, admm_loss, mixed_loss = admm.append_admm_loss( config, ADMM, ce_loss) # append admm losss # measure accuracy and record loss acc1, acc5 = accuracy(output, target, topk=(1, 5)) losses.update(ce_loss.item(), input.size(0)) top1.update(acc1[0], input.size(0)) top5.update(acc5[0], input.size(0)) # compute gradient and do SGD step optimizer.zero_grad() if config.admm: mixed_loss.backward() else: ce_loss.backward() if config.masked_progressive: with torch.no_grad(): for name, W in config.model.named_parameters(): if name in config.zero_masks: W.grad *= config.zero_masks[name] if config.masked_retrain: with torch.no_grad(): for name, W in config.model.named_parameters(): if name in config.masks: W.grad *= config.masks[name] optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % config.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5)) print("cross_entropy loss: {}".format(ce_loss))
def train(config, ADMM, device, train_loader, criterion, optimizer, scheduler, epoch): config.model.train() ce_loss = None for batch_idx, (data, target) in enumerate(train_loader): # adjust learning rate if config.admm: admm.admm_adjust_learning_rate(optimizer, epoch, config) else: if scheduler is not None: scheduler.step() data, target = data.to(device), target.to(device) if config.gpu is not None: data = data.cuda(config.gpu, non_blocking=True) target = target.cuda(config.gpu, non_blocking=True) if config.mixup: data, target_a, target_b, lam = mixup_data(data, target, config.alpha) optimizer.zero_grad() output = config.model(data) if config.mixup: ce_loss = mixup_criterion(criterion, output, target_a, target_b, lam, config.smooth) else: ce_loss = criterion(output, target, smooth=config.smooth) if config.admm: admm.admm_update(config, ADMM, device, train_loader, optimizer, epoch, data, batch_idx) # update Z and U ce_loss, admm_loss, mixed_loss = admm.append_admm_loss( config, ADMM, ce_loss) # append admm losss if config.admm: mixed_loss.backward() else: ce_loss.backward() if config.masked_progressive: with torch.no_grad(): for name, W in config.model.named_parameters(): if name in config.zero_masks: W.grad *= config.zero_masks[name] if config.masked_retrain: with torch.no_grad(): for name, W in config.model.named_parameters(): if name in config.masks: W.grad *= config.masks[name] optimizer.step() if batch_idx % config.print_freq == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), ce_loss.item()))
def train(train_loader, criterion, optimizer, epoch, config): batch_time = AverageMeter() data_time = AverageMeter() nat_losses = AverageMeter() adv_losses = AverageMeter() nat_loss = 0 adv_loss = 0 nat_top1 = AverageMeter() adv_top1 = AverageMeter() # switch to train mode config.model.train() end = time.time() for i, (input, target) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) # adjust learning rate if config.admm: admm.admm_adjust_learning_rate(optimizer, epoch, config) else: scheduler.step() if config.gpu is not None: input = input.cuda(config.gpu, non_blocking=True) target = target.cuda(config.gpu, non_blocking=True) if config.mixup: input, target_a, target_b, lam = mixup_data( input, target, config.alpha) # compute output nat_output, adv_output, pert_inputs = config.model(input, target) if config.mixup: adv_loss = mixup_criterion(criterion, adv_output, target_a, target_b, lam, config.smooth) nat_loss = mixup_criterion(criterion, nat_output, target_a, target_b, lam, config.smooth) else: adv_loss = criterion(adv_output, target, smooth=config.smooth) nat_loss = criterion(nat_output, target, smooth=config.smooth) if config.admm: admm.admm_update(config, ADMM, device, train_loader, optimizer, epoch, input, i) # update Z and U adv_loss, admm_loss, mixed_loss = admm.append_admm_loss( config, ADMM, adv_loss) # append admm losss # measure accuracy and record loss nat_acc1, _ = accuracy(nat_output, target, topk=(1, 5)) adv_acc1, _ = accuracy(adv_output, target, topk=(1, 5)) nat_losses.update(nat_loss.item(), input.size(0)) adv_losses.update(adv_loss.item(), input.size(0)) adv_top1.update(adv_acc1[0], input.size(0)) nat_top1.update(nat_acc1[0], input.size(0)) # compute gradient and do SGD step optimizer.zero_grad() if config.admm: mixed_loss.backward() else: adv_loss.backward() if config.masked_progressive: with torch.no_grad(): for name, W in config.model.named_parameters(): if name in config.zero_masks: W.grad *= config.zero_masks[name] if config.masked_retrain: with torch.no_grad(): for name, W in config.model.named_parameters(): if name in config.masks: W.grad *= config.masks[ name] #returns boolean array called mask when weights are above treshhold optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % config.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Nat_Loss {nat_loss.val:.4f} ({nat_loss.avg:.4f})\t' 'Nat_Acc@1 {nat_top1.val:.3f} ({nat_top1.avg:.3f})\t' 'Adv_Loss {adv_loss.val:.4f} ({adv_loss.avg:.4f})\t' 'Adv_Acc@1 {adv_top1.val:.3f} ({adv_top1.avg:.3f})\t'.format( epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time, nat_loss=nat_losses, nat_top1=nat_top1, adv_loss=adv_losses, adv_top1=adv_top1))