def loss_mixup_without_prefetcher(model, loss_fn, input, target, epoch, args): if args.mixup > 0.: lam = 1. if not args.mixup_off_epoch or epoch < args.mixup_off_epoch: lam = np.random.beta(args.mixup, args.mixup) input.mul_(lam).add_(1 - lam, input.flip(0)) target = mixup_target(target, args.num_classes, lam, args.smoothing) output = model(input) loss = loss_fn(output, target) return loss
def train_epoch(epoch, model, loader, optimizer, loss_fn, args, lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None): if args.prefetcher and args.mixup > 0 and loader.mixup_enabled: if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: loader.mixup_enabled = False batch_time_m = AverageMeter() data_time_m = AverageMeter() losses_m = AverageMeter() model.train() end = time.time() last_idx = len(loader) - 1 num_updates = epoch * len(loader) for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_idx data_time_m.update(time.time() - end) if not args.prefetcher: input = input.cuda() target = target.cuda() if args.mixup > 0.: lam = 1. if not args.mixup_off_epoch or epoch < args.mixup_off_epoch: lam = np.random.beta(args.mixup, args.mixup) input.mul_(lam).add_(1 - lam, input.flip(0)) target = mixup_target(target, args.num_classes, lam, args.smoothing) output = model(input) loss = loss_fn(output, target) if not args.distributed: losses_m.update(loss.item(), input.size(0)) optimizer.zero_grad() if use_amp: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() optimizer.step() torch.cuda.synchronize() if model_ema is not None: model_ema.update(model) num_updates += 1 batch_time_m.update(time.time() - end) if last_batch or batch_idx % args.log_interval == 0: lrl = [param_group['lr'] for param_group in optimizer.param_groups] lr = sum(lrl) / len(lrl) if args.distributed: reduced_loss = reduce_tensor(loss.data, args.world_size) losses_m.update(reduced_loss.item(), input.size(0)) if args.local_rank == 0: logging.info( 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) ' 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'LR: {lr:.3e} ' 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( epoch, batch_idx, len(loader), 100. * batch_idx / last_idx, loss=losses_m, batch_time=batch_time_m, rate=input.size(0) * args.world_size / batch_time_m.val, rate_avg=input.size(0) * args.world_size / batch_time_m.avg, lr=lr, data_time=data_time_m)) if args.save_images and output_dir: torchvision.utils.save_image( input, os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), padding=0, normalize=True) if saver is not None and args.recovery_interval and ( last_batch or (batch_idx + 1) % args.recovery_interval == 0): saver.save_recovery(model, optimizer, args, epoch, model_ema=model_ema, batch_idx=batch_idx) if lr_scheduler is not None: lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) end = time.time() return OrderedDict([('loss', losses_m.avg)])
def train_epoch(epoch, model, loader, optimizer, loss_fn, args, lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None, teacher_model=None): if args.prefetcher and args.mixup > 0 and loader.mixup_enabled: if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: loader.mixup_enabled = False batch_time_m = AverageMeter() data_time_m = AverageMeter() losses_m = AverageMeter() model.train() if args.KD_train: teacher_model.eval() end = time.time() last_idx = len(loader) - 1 num_updates = epoch * len(loader) for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_idx data_time_m.update(time.time() - end) if not args.prefetcher: input = input.cuda() target = target.cuda() if args.mixup > 0.: lam = 1. if not args.mixup_off_epoch or epoch < args.mixup_off_epoch: lam = np.random.beta(args.mixup, args.mixup) input.mul_(lam).add_(1 - lam, input.flip(0)) target = mixup_target(target, args.num_classes, lam, args.smoothing) r = np.random.rand(1) if args.beta > 0 and r < args.cutmix_prob: # generate mixed sample lam = np.random.beta(args.beta, args.beta) rand_index = torch.randperm(input.size()[0]).cuda() target_a = target target_b = target[rand_index] bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam) input[:, :, bbx1:bbx2, bby1:bby2] = input[rand_index, :, bbx1:bbx2, bby1:bby2] # adjust lambda to exactly match pixel ratio lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2])) # compute output input_var = torch.autograd.Variable(input, requires_grad=True) target_a_var = torch.autograd.Variable(target_a) target_b_var = torch.autograd.Variable(target_b) output = model(input_var) loss = loss_fn(output, target_a_var) * lam + loss_fn( output, target_b_var) * (1. - lam) else: # NOTE KD Train is exclusive with mixcut, FIX it later output = model(input) if args.KD_train: # teacher_model.cuda() teacher_outputs_tmp = [] assert (input.shape[0] % args.teacher_step == 0) step_size = int(input.shape[0] // args.teacher_step) with torch.no_grad(): for k in range(0, int(input.shape[0]), step_size): input_tmp = input[k:k + step_size, :, :, :] teacher_outputs_tmp.append(teacher_model(input_tmp)) # torch.cuda.empty_cache() # import pdb; pdb.set_trace() teacher_outputs = torch.cat(teacher_outputs_tmp) alpha = args.KD_alpha T = args.KD_temperature loss = loss_fn(F.log_softmax(output/T, dim=1), F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \ F.cross_entropy(output, target) * (1. - alpha) else: loss = loss_fn(output, target) if not args.distributed: losses_m.update(loss.item(), input.size(0)) optimizer.zero_grad() if use_amp: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() optimizer.step() torch.cuda.synchronize() if model_ema is not None: model_ema.update(model) num_updates += 1 batch_time_m.update(time.time() - end) if last_batch or batch_idx % args.log_interval == 0: lrl = [param_group['lr'] for param_group in optimizer.param_groups] lr = sum(lrl) / len(lrl) if args.distributed: reduced_loss = reduce_tensor(loss.data, args.world_size) losses_m.update(reduced_loss.item(), input.size(0)) if args.local_rank == 0: logging.info( 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) ' 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'LR: {lr:.3e} ' 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( epoch, batch_idx, len(loader), 100. * batch_idx / last_idx, loss=losses_m, batch_time=batch_time_m, rate=input.size(0) * args.world_size / batch_time_m.val, rate_avg=input.size(0) * args.world_size / batch_time_m.avg, lr=lr, data_time=data_time_m)) if args.save_images and output_dir: torchvision.utils.save_image( input, os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), padding=0, normalize=True) if saver is not None and args.recovery_interval and ( last_batch or (batch_idx + 1) % args.recovery_interval == 0): save_epoch = epoch + 1 if last_batch else epoch saver.save_recovery(model, optimizer, args, save_epoch, model_ema=model_ema, batch_idx=batch_idx) if lr_scheduler is not None: lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) end = time.time() return OrderedDict([('loss', losses_m.avg)])