Пример #1
0
def main():
    """
    Training.
    """
    global start_epoch, epoch, checkpoint

    # Initialize model or load checkpoint
    if checkpoint is None:
        model = SRResNet(large_kernel_size=large_kernel_size,
                         small_kernel_size=small_kernel_size,
                         n_channels=n_channels,
                         n_blocks=n_blocks,
                         scaling_factor=scaling_factor)
        # Initialize the optimizer
        optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                                   model.parameters()),
                                     lr=lr)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']

    # Move to default device
    model = model.to(device)
    criterion = nn.MSELoss().to(device)

    # Custom dataloaders
    train_dataset = SRDataset(data_folder,
                              split='train',
                              crop_size=crop_size,
                              scaling_factor=scaling_factor,
                              lr_img_type='imagenet-norm',
                              hr_img_type='[-1, 1]')
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers,
        pin_memory=True)  # note that we're passing the collate function here

    # Total number of epochs to train for
    epochs = int(iterations // len(train_loader) + 1)

    # Epochs
    for epoch in range(start_epoch, epochs):
        # One epoch's training
        train(train_loader=train_loader,
              model=model,
              criterion=criterion,
              optimizer=optimizer,
              epoch=epoch)

        # Save checkpoint
        torch.save({
            'epoch': epoch,
            'model': model,
            'optimizer': optimizer
        }, 'checkpoint_srresnet.pth.tar')
Пример #2
0
def main():
    """
    训练.
    """
    global checkpoint,start_epoch

    # 初始化
    model = SRResNet(large_kernel_size=large_kernel_size,
                        small_kernel_size=small_kernel_size,
                        n_channels=n_channels,
                        n_blocks=n_blocks,
                        scaling_factor=scaling_factor)
    # 初始化优化器
    optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),lr=lr)

            
    # 迁移至默认设备进行训练
    model = model.to(device)
    criterion = nn.MSELoss().to(device)

    # 加载预训练模型
    if checkpoint is not None:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
    
    if torch.cuda.is_available() and ngpu > 1:
        model = nn.DataParallel(model, device_ids=list(range(ngpu)))

    # 定制化的dataloaders
    train_dataset = SRDataset(data_folder,split='train',
                              crop_size=crop_size,
                              scaling_factor=scaling_factor,
                              lr_img_type='imagenet-norm',
                              hr_img_type='[-1, 1]')
    train_loader = torch.utils.data.DataLoader(train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers,
        pin_memory=True) 

    # 开始逐轮训练
    for epoch in range(start_epoch, epochs+1):

        print("epoch:",epoch)
        model.train()  # 训练模式:允许使用批样本归一化

        loss_epoch = AverageMeter()  # 统计损失函数

        n_iter = len(train_loader)
        loss_data=0
        tag = 0 
        # 按批处理
        for i, (lr_imgs, hr_imgs) in enumerate(train_loader):

            # 数据移至默认设备进行训练
            lr_imgs = lr_imgs.to(device)  # (batch_size (N), 3, 24, 24), imagenet-normed 格式
            hr_imgs = hr_imgs.to(device)  # (batch_size (N), 3, 96, 96),  [-1, 1]格式

            # 前向传播
            sr_imgs = model(lr_imgs)

            # 计算损失
            loss = criterion(sr_imgs, hr_imgs)  

            # 后向传播
            optimizer.zero_grad()
            loss.backward()

            # 更新模型
            optimizer.step()

            # 记录损失值
            loss_epoch.update(loss.item(), lr_imgs.size(0))
            loss_data = loss_data + loss.item()
            tag = tag + 1
            #print("%.4f",loss.item())
            '''
            import pdb
            pdb.set_trace()
            '''
            # 打印结果
            #print("第 "+str(i)+ " 个batch训练结束")
 
        loss_data = loss_data/tag
        print("loss:",loss_data)
        # 手动释放内存              
        del lr_imgs, hr_imgs, sr_imgs

        # 保存训练模型
        torch.save({
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }, 'results/checkpoint_srresnet.pth')
Пример #3
0
                              pin_memory=True,
                              drop_last=True)

test_dataset = SRGanDataset(args.test_lr_file,
                            args.test_gt_file,
                            in_memory=True,
                            transform=transform)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=1)

# Use in the cuda
cudnn.benchmark = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Defined the network
net = SRResNet(16, args.scale).to(device)
optimizer = optim.Adam(net.parameters())

# Define the loss func
mse_loss = nn.MSELoss()

# Train
best_weights = copy.deepcopy(net.state_dict())
best_psnr = 0.0
best_epoch = 0
for epoch in range(args.num_epochs):
    net.train()
    epoch_losses = AverageMeter()

    with tqdm(total=(len(train_dataset) -
                     len(train_dataset) % args.batch_size)) as t:
        t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1))
Пример #4
0
    lr = 0.001
    betas = (0.99, 0.999)
    TRAIN_PATH = './compress_data/voc_train.pkl'
    VALID_PATH = './compress_data/voc_valid.pkl'

    ## Set up
    train_dataset = TrainDataset(TRAIN_PATH,
                                 crop_size=crop_size,
                                 upscale_factor=upscale_factor)
    valid_dataset = ValidDataset(VALID_PATH,
                                 crop_size=crop_size,
                                 upscale_factor=upscale_factor)

    trainloader = DataLoader(train_dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=2)
    trainloader_v2 = DataLoader(
        train_dataset, batch_size=1, shuffle=True,
        num_workers=2)  # need to calculate score metrics
    validloader = DataLoader(valid_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=2)

    model = SRResNet(scale_factor=upscale_factor, kernel_size=9, n_channels=64)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=lr, betas=betas)

    ## Training
    trainer = SRResnet_trainer(model, optimizer, device)
    trainer.train(trainloader, trainloader_v2, validloader, 1, 100)