def train_epoch(epoch, model, loader, optimizer, loss_fn, args, lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None): if args.prefetcher and args.mixup > 0 and loader.mixup_enabled: if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: loader.mixup_enabled = False batch_time_m = AverageMeter() data_time_m = AverageMeter() losses_m = AverageMeter() model.train() end = time.time() last_idx = len(loader) - 1 num_updates = epoch * len(loader) for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_idx data_time_m.update(time.time() - end) if not args.prefetcher: input, target = input.cuda(), target.cuda() if args.mixup > 0.: input, target = mixup_batch(input, target, alpha=args.mixup, num_classes=args.num_classes, smoothing=args.smoothing, disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch) output = model(input) loss = loss_fn(output, target) if not args.distributed: losses_m.update(loss.item(), input.size(0)) optimizer.zero_grad() if use_amp: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() optimizer.step() torch.cuda.synchronize() if model_ema is not None: model_ema.update(model) num_updates += 1 batch_time_m.update(time.time() - end) if last_batch or batch_idx % args.log_interval == 0: lrl = [param_group['lr'] for param_group in optimizer.param_groups] lr = sum(lrl) / len(lrl) if args.distributed: reduced_loss = reduce_tensor(loss.data, args.world_size) losses_m.update(reduced_loss.item(), input.size(0)) if args.local_rank == 0: logging.info( 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) ' 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'LR: {lr:.3e} ' 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( epoch, batch_idx, len(loader), 100. * batch_idx / last_idx, loss=losses_m, batch_time=batch_time_m, rate=input.size(0) * args.world_size / batch_time_m.val, rate_avg=input.size(0) * args.world_size / batch_time_m.avg, lr=lr, data_time=data_time_m)) if args.save_images and output_dir: torchvision.utils.save_image( input, os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), padding=0, normalize=True) if saver is not None and args.recovery_interval and ( last_batch or (batch_idx + 1) % args.recovery_interval == 0): saver.save_recovery(model, optimizer, args, epoch, model_ema=model_ema, use_amp=use_amp, batch_idx=batch_idx) if lr_scheduler is not None: lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) end = time.time() # end for if hasattr(optimizer, 'sync_lookahead'): optimizer.sync_lookahead() return OrderedDict([('loss', losses_m.avg)])
def train_epoch(epoch, model, loader, optimizer, loss_fn, args, lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None, teacher=None): if args.prefetcher and args.mixup > 0 and loader.mixup_enabled: if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: loader.mixup_enabled = False batch_time_m = AverageMeter() data_time_m = AverageMeter() losses_m = AverageMeter() model.train() if args.kd_ratio > 0: teacher.train() end = time.time() last_idx = len(loader) - 1 num_updates = epoch * len(loader) for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_idx data_time_m.update(time.time() - end) if not args.prefetcher: input, target = input.cuda(), target.cuda() if args.mixup > 0.: input, target = mixup_batch(input, target, alpha=args.mixup, num_classes=args.num_classes, smoothing=args.smoothing, disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch) if args.kd_ratio > 0: with torch.no_grad(): soft_logits = teacher(input).detach() soft_label = F.softmax(soft_logits, dim=1) output = model(input) if args.kd_ratio > 0: # using the supernet at full scale to supervise the training of the subnets # this is taken from https://github.com/mit-han-lab/once-for-all/blob/ # 4decacf9d85dbc948a902c6dc34ea032d711e0a9/ofa/imagenet_codebase/run_manager/run_manager.py#L426 kd_loss = cross_entropy_loss_with_soft_target(output, soft_label) loss = args.kd_ratio * kd_loss + loss_fn(output, target) loss = loss * (2 / (args.kd_ratio + 1)) else: loss = loss_fn(output, target) # loss = loss_fn(output, target) if not args.distributed: losses_m.update(loss.item(), input.size(0)) optimizer.zero_grad() if use_amp: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() optimizer.step() torch.cuda.synchronize() if model_ema is not None: model_ema.update(model) num_updates += 1 batch_time_m.update(time.time() - end) if last_batch or batch_idx % args.log_interval == 0: lrl = [param_group['lr'] for param_group in optimizer.param_groups] lr = sum(lrl) / len(lrl) if args.distributed: reduced_loss = reduce_tensor(loss.data, args.world_size) losses_m.update(reduced_loss.item(), input.size(0)) if args.local_rank == 0: logging.info( 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) ' 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'LR: {lr:.3e} ' 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( epoch, batch_idx, len(loader), 100. * batch_idx / last_idx, loss=losses_m, batch_time=batch_time_m, rate=input.size(0) * args.world_size / batch_time_m.val, rate_avg=input.size(0) * args.world_size / batch_time_m.avg, lr=lr, data_time=data_time_m)) if args.save_images and output_dir: torchvision.utils.save_image( input, os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), padding=0, normalize=True) if saver is not None and args.recovery_interval and ( last_batch or (batch_idx + 1) % args.recovery_interval == 0): saver.save_recovery(model, optimizer, args, epoch, model_ema=model_ema, use_amp=use_amp, batch_idx=batch_idx) if lr_scheduler is not None: lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) end = time.time() # end for if hasattr(optimizer, 'sync_lookahead'): optimizer.sync_lookahead() return OrderedDict([('loss', losses_m.avg)])