def train_epoch(train_loader, net, criterion, optimizer, cur_epoch, start_epoch, tic): """Train one epoch""" rank = torch.distributed.get_rank() batch_time, data_time, losses, top1, topk = utils.construct_meters() progress = utils.ProgressMeter( len(train_loader), [batch_time, data_time, losses, top1, topk], prefix=f"TRAIN: [{cur_epoch+1}]", ) # Set learning rate lr = utils.get_epoch_lr(cur_epoch) utils.set_lr(optimizer, lr) if rank == 0: logger.debug( f"CURRENT EPOCH: {cur_epoch+1:3d}, LR: {lr:.4f}, POLICY: {cfg.OPTIM.LR_POLICY}" ) # Set sampler train_loader.sampler.set_epoch(cur_epoch) net.train() end = time.time() for idx, (inputs, targets) in enumerate(train_loader): data_time.update(time.time() - end) inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True) outputs = net(inputs) loss = criterion(outputs, targets) optimizer.zero_grad() loss.backward() optimizer.step() batch_size = inputs.size(0) acc_1, acc_k = utils.accuracy(outputs, targets, topk=(1, cfg.TRAIN.TOPK)) loss, acc_1, acc_k = utils.scaled_all_reduce([loss, acc_1, acc_k]) losses.update(loss.item(), batch_size) top1.update(acc_1[0].item(), batch_size) topk.update(acc_k[0].item(), batch_size) batch_time.update(time.time() - end) end = time.time() if rank == 0 and ((idx + 1) % cfg.TRAIN.PRINT_FREQ == 0 or (idx + 1) == len(train_loader)): progress.cal_eta(idx + 1, len(train_loader), tic, cur_epoch, start_epoch) progress.display(idx + 1)
def validate(val_loader, net, criterion): """Validte the model""" rank = torch.distributed.get_rank() batch_time, data_time, losses, top1, topk = utils.construct_meters() progress = utils.ProgressMeter( len(val_loader), [batch_time, data_time, losses, top1, topk], prefix="VAL: ", ) net.eval() with torch.no_grad(): end = time.time() for idx, (inputs, targets) in enumerate(val_loader): data_time.update(time.time() - end) inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True) outputs = net(inputs) loss = criterion(outputs, targets) acc_1, acc_k = utils.accuracy(outputs, targets, topk=(1, cfg.TRAIN.TOPK)) loss, acc_1, acc_k = utils.scaled_all_reduce([loss, acc_1, acc_k]) batch_size = inputs.size(0) losses.update(loss.item(), batch_size) top1.update(acc_1[0].item(), batch_size) topk.update(acc_k[0].item(), batch_size) batch_time.update(time.time() - end) end = time.time() if rank == 0 and ((idx + 1) % cfg.TEST.PRINT_FREQ == 0 or (idx + 1) == len(val_loader)): progress.display(idx + 1) return top1.avg, topk.avg