cudnn.benchmark = True device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Defined the network net = SRResNet(16, args.scale).to(device) optimizer = optim.Adam(net.parameters()) # Define the loss func mse_loss = nn.MSELoss() # Train best_weights = copy.deepcopy(net.state_dict()) best_psnr = 0.0 best_epoch = 0 for epoch in range(args.num_epochs): net.train() epoch_losses = AverageMeter() with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size)) as t: t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1)) for data in train_dataloader: inputs, labels = data inputs = inputs.to(device) labels = labels.to(device) preds = net(inputs) loss = mse_loss(preds, labels)
def main(): """ 训练. """ global checkpoint,start_epoch # 初始化 model = SRResNet(large_kernel_size=large_kernel_size, small_kernel_size=small_kernel_size, n_channels=n_channels, n_blocks=n_blocks, scaling_factor=scaling_factor) # 初始化优化器 optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),lr=lr) # 迁移至默认设备进行训练 model = model.to(device) criterion = nn.MSELoss().to(device) # 加载预训练模型 if checkpoint is not None: checkpoint = torch.load(checkpoint) start_epoch = checkpoint['epoch'] + 1 model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) if torch.cuda.is_available() and ngpu > 1: model = nn.DataParallel(model, device_ids=list(range(ngpu))) # 定制化的dataloaders train_dataset = SRDataset(data_folder,split='train', crop_size=crop_size, scaling_factor=scaling_factor, lr_img_type='imagenet-norm', hr_img_type='[-1, 1]') train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True) # 开始逐轮训练 for epoch in range(start_epoch, epochs+1): print("epoch:",epoch) model.train() # 训练模式:允许使用批样本归一化 loss_epoch = AverageMeter() # 统计损失函数 n_iter = len(train_loader) loss_data=0 tag = 0 # 按批处理 for i, (lr_imgs, hr_imgs) in enumerate(train_loader): # 数据移至默认设备进行训练 lr_imgs = lr_imgs.to(device) # (batch_size (N), 3, 24, 24), imagenet-normed 格式 hr_imgs = hr_imgs.to(device) # (batch_size (N), 3, 96, 96), [-1, 1]格式 # 前向传播 sr_imgs = model(lr_imgs) # 计算损失 loss = criterion(sr_imgs, hr_imgs) # 后向传播 optimizer.zero_grad() loss.backward() # 更新模型 optimizer.step() # 记录损失值 loss_epoch.update(loss.item(), lr_imgs.size(0)) loss_data = loss_data + loss.item() tag = tag + 1 #print("%.4f",loss.item()) ''' import pdb pdb.set_trace() ''' # 打印结果 #print("第 "+str(i)+ " 个batch训练结束") loss_data = loss_data/tag print("loss:",loss_data) # 手动释放内存 del lr_imgs, hr_imgs, sr_imgs # 保存训练模型 torch.save({ 'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, 'results/checkpoint_srresnet.pth')