def train_epoch(loader, model, loss_fun, optimizer, scaler, meter, cur_epoch): """Performs one epoch of training.""" # Shuffle the data data_loader.shuffle(loader, cur_epoch) # Update the learning rate lr = optim.get_epoch_lr(cur_epoch) optim.set_lr(optimizer, lr) # Enable training mode model.train() meter.reset() meter.iter_tic() for cur_iter, (inputs, labels) in enumerate(loader): # Transfer the data to the current GPU device inputs = inputs.npu() labels = labels.to(torch.int32).npu() labels = labels.to(non_blocking=False) # Convert labels to smoothed one-hot vector p_labels = labels[:] labels_one_hot = net.smooth_one_hot_labels(labels).npu() # Apply mixup to the batch (no effect if mixup alpha is 0) inputs, labels_one_hot, labels = net.mixup(inputs, labels_one_hot) # Perform the forward pass and compute the loss preds = model(inputs) loss = loss_fun(preds, labels_one_hot) stream = torch.npu.current_stream() stream.synchronize() # Perform the backward pass and update the parameters optimizer.zero_grad() stream = torch.npu.current_stream() stream.synchronize() with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() stream = torch.npu.current_stream() stream.synchronize() optimizer.step() stream = torch.npu.current_stream() stream.synchronize() # Compute the errors top1_err, top5_err = meters.topk_errors(preds, p_labels, [1, 5]) # Combine the errors across the GPUs (no reduction if 1 GPU used) # Combine the stats across the GPUs (no reduction if 1 GPU used) # loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err]) # Copy the stats from GPU to CPU (sync point) loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item( ) meter.iter_toc() # Update and log stats mb_size = inputs.size(0) * cfg.NUM_GPUS meter.update_stats(top1_err, top5_err, loss, lr, mb_size) meter.log_iter_stats(cur_epoch, cur_iter) meter.iter_tic() # Log epoch stats meter.log_epoch_stats(cur_epoch)
def train_epoch(loader, model, ema, loss_fun, optimizer, scaler, meter, cur_epoch): """Performs one epoch of training.""" # Shuffle the data data_loader.shuffle(loader, cur_epoch) # Update the learning rate lr = optim.get_epoch_lr(cur_epoch) optim.set_lr(optimizer, lr) # Enable training mode model.train() ema.train() meter.reset() meter.iter_tic() for cur_iter, (inputs, labels) in enumerate(loader): # Transfer the data to the current GPU device inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True) # Convert labels to smoothed one-hot vector labels_one_hot = net.smooth_one_hot_labels(labels) # Apply mixup to the batch (no effect if mixup alpha is 0) inputs, labels_one_hot, labels = net.mixup(inputs, labels_one_hot) # Perform the forward pass and compute the loss with amp.autocast(enabled=cfg.TRAIN.MIXED_PRECISION): preds = model(inputs) loss = loss_fun(preds, labels_one_hot) # Perform the backward pass and update the parameters optimizer.zero_grad() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() # Update ema weights net.update_model_ema(model, ema, cur_epoch, cur_iter) # Compute the errors top1_err, top5_err = meters.topk_errors(preds, labels, [1, 5]) # Combine the stats across the GPUs (no reduction if 1 GPU used) loss, top1_err, top5_err = dist.scaled_all_reduce( [loss, top1_err, top5_err]) # Copy the stats from GPU to CPU (sync point) loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item( ) meter.iter_toc() # Update and log stats mb_size = inputs.size(0) * cfg.NUM_GPUS meter.update_stats(top1_err, top5_err, loss, lr, mb_size) meter.log_iter_stats(cur_epoch, cur_iter) meter.iter_tic() # Log epoch stats meter.log_epoch_stats(cur_epoch)
def compute_time_train(model, loss_fun): """Computes precise model forward + backward time using dummy data.""" # Use train mode model.train() # Generate a dummy mini-batch and copy data to GPU im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS) inputs = torch.rand(batch_size, 3, im_size, im_size).cuda(non_blocking=False) labels = torch.zeros(batch_size, dtype=torch.int64).cuda(non_blocking=False) labels_one_hot = net.smooth_one_hot_labels(labels) # Cache BatchNorm2D running stats bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)] bn_stats = [[bn.running_mean.clone(), bn.running_var.clone()] for bn in bns] # Create a GradScaler for mixed precision training scaler = amp.GradScaler(enabled=cfg.TRAIN.MIXED_PRECISION) # Compute precise forward backward pass time fw_timer, bw_timer = Timer(), Timer() total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER for cur_iter in range(total_iter): # Reset the timers after the warmup phase if cur_iter == cfg.PREC_TIME.WARMUP_ITER: fw_timer.reset() bw_timer.reset() # Forward fw_timer.tic() with amp.autocast(enabled=cfg.TRAIN.MIXED_PRECISION): preds = model(inputs) loss = loss_fun(preds, labels_one_hot) torch.cuda.synchronize() fw_timer.toc() # Backward bw_timer.tic() scaler.scale(loss).backward() torch.cuda.synchronize() bw_timer.toc() # Restore BatchNorm2D running stats for bn, (mean, var) in zip(bns, bn_stats): bn.running_mean, bn.running_var = mean, var return fw_timer.average_time, bw_timer.average_time
def train_epoch(loader, teacher, student, loss_fun, optimizer, scaler, cur_epoch, kd, lr, total_epochs, batch): # Shuffle the data data_loader.shuffle(loader, cur_epoch) # Update the learning rate lr = update_lr(optimizer, cur_epoch, lr, total_epochs) print("Current learning rate: {}".format(lr)) # set the models mode student.train() if kd: teacher.eval() steps = int(STEPS / (BATCH_SIZE / 10)) start_time = time.time() for cur_iter, (inputs, labels) in enumerate(loader): # infer the models inputs = inputs.cuda() preds = student(inputs) if kd: labels = teacher(inputs) else: labels = labels.cuda(non_blocking=True) labels = smooth_one_hot_labels(labels) # calculate the loss loss = loss_fun(preds, labels) optimizer.zero_grad() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() loss = loss.item() if (cur_iter + 1) % 100 == 0: print("Loss at step {} is {:.2f}".format(cur_iter + steps * cur_epoch + 1, loss)) if cur_iter + 1 >= steps: print("Epoch time: {}s".format(round(time.time() - start_time))) break