def evaluate(model, data_loader, device): model.eval() # 验证集样本个数 num_samples = len(data_loader.dataset) # 用于存储预测正确的样本个数 sum_num = torch.zeros(1).to(device) # 在进程0中打印验证进度 if is_main_process(): data_loader = tqdm(data_loader) for step, data in enumerate(data_loader): images, labels = data pred = model(images.to(device)) pred = torch.max(pred, dim=1)[1] sum_num += torch.eq(pred, labels.to(device)).sum() # 等待所有进程计算完毕 if device != torch.device("cpu"): torch.cuda.synchronize(device) sum_num = reduce_value(sum_num, average=False) acc = sum_num.item() / num_samples return acc
def train_one_epoch(model, optimizer, data_loader, device, epoch, use_amp=False): model.train() loss_function = torch.nn.CrossEntropyLoss() mean_loss = torch.zeros(1).to(device) optimizer.zero_grad() # 在进程0中打印训练进度 if is_main_process(): data_loader = tqdm(data_loader) enable_amp = use_amp and "cuda" in device.type scaler = torch.cuda.amp.GradScaler(enabled=enable_amp) for step, data in enumerate(data_loader): images, labels = data with torch.cuda.amp.autocast(enabled=enable_amp): pred = model(images.to(device)) loss = loss_function(pred, labels.to(device)) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() optimizer.zero_grad() loss = reduce_value(loss, average=True) mean_loss = (mean_loss * step + loss.detach()) / ( step + 1) # update mean losses # 在进程0中打印平均loss if is_main_process(): data_loader.desc = "[epoch {}] mean loss {}".format( epoch, round(mean_loss.item(), 3)) if not torch.isfinite(loss): print('WARNING: non-finite loss, ending training ', loss) sys.exit(1) # 等待所有进程计算完毕 if device != torch.device("cpu"): torch.cuda.synchronize(device) return mean_loss.item()