def main(): """ Training. """ global start_epoch, epoch, checkpoint # Initialize model or load checkpoint if checkpoint is None: model = SRResNet(large_kernel_size=large_kernel_size, small_kernel_size=small_kernel_size, n_channels=n_channels, n_blocks=n_blocks, scaling_factor=scaling_factor) # Initialize the optimizer optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()), lr=lr) else: checkpoint = torch.load(checkpoint) start_epoch = checkpoint['epoch'] + 1 model = checkpoint['model'] optimizer = checkpoint['optimizer'] # Move to default device model = model.to(device) criterion = nn.MSELoss().to(device) # Custom dataloaders train_dataset = SRDataset(data_folder, split='train', crop_size=crop_size, scaling_factor=scaling_factor, lr_img_type='imagenet-norm', hr_img_type='[-1, 1]') train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True) # note that we're passing the collate function here # Total number of epochs to train for epochs = int(iterations // len(train_loader) + 1) # Epochs for epoch in range(start_epoch, epochs): # One epoch's training train(train_loader=train_loader, model=model, criterion=criterion, optimizer=optimizer, epoch=epoch) # Save checkpoint torch.save({ 'epoch': epoch, 'model': model, 'optimizer': optimizer }, 'checkpoint_srresnet.pth.tar')
def main(): """ 训练. """ global checkpoint,start_epoch # 初始化 model = SRResNet(large_kernel_size=large_kernel_size, small_kernel_size=small_kernel_size, n_channels=n_channels, n_blocks=n_blocks, scaling_factor=scaling_factor) # 初始化优化器 optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),lr=lr) # 迁移至默认设备进行训练 model = model.to(device) criterion = nn.MSELoss().to(device) # 加载预训练模型 if checkpoint is not None: checkpoint = torch.load(checkpoint) start_epoch = checkpoint['epoch'] + 1 model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) if torch.cuda.is_available() and ngpu > 1: model = nn.DataParallel(model, device_ids=list(range(ngpu))) # 定制化的dataloaders train_dataset = SRDataset(data_folder,split='train', crop_size=crop_size, scaling_factor=scaling_factor, lr_img_type='imagenet-norm', hr_img_type='[-1, 1]') train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True) # 开始逐轮训练 for epoch in range(start_epoch, epochs+1): print("epoch:",epoch) model.train() # 训练模式:允许使用批样本归一化 loss_epoch = AverageMeter() # 统计损失函数 n_iter = len(train_loader) loss_data=0 tag = 0 # 按批处理 for i, (lr_imgs, hr_imgs) in enumerate(train_loader): # 数据移至默认设备进行训练 lr_imgs = lr_imgs.to(device) # (batch_size (N), 3, 24, 24), imagenet-normed 格式 hr_imgs = hr_imgs.to(device) # (batch_size (N), 3, 96, 96), [-1, 1]格式 # 前向传播 sr_imgs = model(lr_imgs) # 计算损失 loss = criterion(sr_imgs, hr_imgs) # 后向传播 optimizer.zero_grad() loss.backward() # 更新模型 optimizer.step() # 记录损失值 loss_epoch.update(loss.item(), lr_imgs.size(0)) loss_data = loss_data + loss.item() tag = tag + 1 #print("%.4f",loss.item()) ''' import pdb pdb.set_trace() ''' # 打印结果 #print("第 "+str(i)+ " 个batch训练结束") loss_data = loss_data/tag print("loss:",loss_data) # 手动释放内存 del lr_imgs, hr_imgs, sr_imgs # 保存训练模型 torch.save({ 'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, 'results/checkpoint_srresnet.pth')
pin_memory=True, drop_last=True) test_dataset = SRGanDataset(args.test_lr_file, args.test_gt_file, in_memory=True, transform=transform) test_dataloader = DataLoader(dataset=test_dataset, batch_size=1) # Use in the cuda cudnn.benchmark = True device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Defined the network net = SRResNet(16, args.scale).to(device) optimizer = optim.Adam(net.parameters()) # Define the loss func mse_loss = nn.MSELoss() # Train best_weights = copy.deepcopy(net.state_dict()) best_psnr = 0.0 best_epoch = 0 for epoch in range(args.num_epochs): net.train() epoch_losses = AverageMeter() with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size)) as t: t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1))
lr = 0.001 betas = (0.99, 0.999) TRAIN_PATH = './compress_data/voc_train.pkl' VALID_PATH = './compress_data/voc_valid.pkl' ## Set up train_dataset = TrainDataset(TRAIN_PATH, crop_size=crop_size, upscale_factor=upscale_factor) valid_dataset = ValidDataset(VALID_PATH, crop_size=crop_size, upscale_factor=upscale_factor) trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2) trainloader_v2 = DataLoader( train_dataset, batch_size=1, shuffle=True, num_workers=2) # need to calculate score metrics validloader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=2) model = SRResNet(scale_factor=upscale_factor, kernel_size=9, n_channels=64) optimizer = torch.optim.Adam(params=model.parameters(), lr=lr, betas=betas) ## Training trainer = SRResnet_trainer(model, optimizer, device) trainer.train(trainloader, trainloader_v2, validloader, 1, 100)