Пример #1
0
def main():
    """
    Training.
    """
    global start_epoch, epoch, checkpoint

    # Initialize model or load checkpoint
    if checkpoint is None:
        model = SRResNet(large_kernel_size=large_kernel_size,
                         small_kernel_size=small_kernel_size,
                         n_channels=n_channels,
                         n_blocks=n_blocks,
                         scaling_factor=scaling_factor)
        # Initialize the optimizer
        optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                                   model.parameters()),
                                     lr=lr)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']

    # Move to default device
    model = model.to(device)
    criterion = nn.MSELoss().to(device)

    # Custom dataloaders
    train_dataset = SRDataset(data_folder,
                              split='train',
                              crop_size=crop_size,
                              scaling_factor=scaling_factor,
                              lr_img_type='imagenet-norm',
                              hr_img_type='[-1, 1]')
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers,
        pin_memory=True)  # note that we're passing the collate function here

    # Total number of epochs to train for
    epochs = int(iterations // len(train_loader) + 1)

    # Epochs
    for epoch in range(start_epoch, epochs):
        # One epoch's training
        train(train_loader=train_loader,
              model=model,
              criterion=criterion,
              optimizer=optimizer,
              epoch=epoch)

        # Save checkpoint
        torch.save({
            'epoch': epoch,
            'model': model,
            'optimizer': optimizer
        }, 'checkpoint_srresnet.pth.tar')
from models import SRResNet, Generator
import argparse
import os

parser = argparse.ArgumentParser()
parser.add_argument("--model", default="SRGAN", type=str, help="Model type")
opt = parser.parse_args()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if opt.model == "SRGAN":
    model = Generator(in_channels=3, n_residual_blocks=7, up_scale=4)
    model.load_state_dict(
        torch.load('./experiments/srgan.pt',
                   map_location=device)['G_state_dict'])
else:
    model = SRResNet(scale_factor=4, kernel_size=9, n_channels=64)
    model.load_state_dict(
        torch.load('./experiments/srresnet.pt',
                   map_location=device)['model_state_dict'])


def test_for_dataset(datasetdir):
    outputdir = os.path.join(datasetdir, "tinhtoanSRResnet")
    os.makedirs(outputdir, exist_ok=True)
    testdataset = TestDatasetFromFolder(datasetdir)
    size = len(testdataset)
    testloader = DataLoader(testdataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=2)
    model.eval()
Пример #3
0
if not os.path.exists("data"):
    print("Downloading flower dataset...")
    subprocess.check_output(
        "mkdir data && curl https://storage.googleapis.com/wandb/flower-enhance.tar.gz | tar xz -C data",
        shell=True)

config.steps_per_epoch = len(
    glob.glob(config.train_dir + "/*-in.jpg")) // config.batch_size
config.val_steps_per_epoch = len(
    glob.glob(config.val_dir + "/*-in.jpg")) // config.batch_size

# Neural network
input1 = Input(shape=(config.input_height, config.input_width, 3),
               dtype='float32')
model = Model(inputs=input1,
              outputs=SRResNet(input1, config.filters, config.nBlocks))

print(model.summary())

es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=30)
mc = ModelCheckpoint('srresnet.h5',
                     monitor='val_perceptual_distance',
                     mode='min',
                     save_best_only=True)

##DONT ALTER metrics=[perceptual_distance]
model.compile(optimizer='adam', loss='mse', metrics=[perceptual_distance])

model.fit_generator(image_generator(config.batch_size, config.train_dir,
                                    config),
                    steps_per_epoch=config.steps_per_epoch,
Пример #4
0
                              num_workers=args.num_workers,
                              pin_memory=True,
                              drop_last=True)

test_dataset = SRGanDataset(args.test_lr_file,
                            args.test_gt_file,
                            in_memory=True,
                            transform=transform)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=1)

# Use in the cuda
cudnn.benchmark = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Defined the network
net = SRResNet(16, args.scale).to(device)
optimizer = optim.Adam(net.parameters())

# Define the loss func
mse_loss = nn.MSELoss()

# Train
best_weights = copy.deepcopy(net.state_dict())
best_psnr = 0.0
best_epoch = 0
for epoch in range(args.num_epochs):
    net.train()
    epoch_losses = AverageMeter()

    with tqdm(total=(len(train_dataset) -
                     len(train_dataset) % args.batch_size)) as t:
Пример #5
0
parser.add_argument('--scale', type=int, default=True)
args = parser.parse_args()

# Using the cuda
cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the param
dataset = SRGanDataset(gt_path=args.gt_file,
                       lr_path=args.lr_file,
                       in_memory=False,
                       transform=None)
loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=8)

# Define the net
model = SRResNet(16, args.scale)
model.load_state_dict(torch.load(args.weights_file))
model = model.to(device)
model.eval()

index = 0
with torch.no_grad():
    for data in loader:
        index = index + 1
        lr, gt = data

        lr = lr.to(device)
        gt = gt.to(device)

        _, _, height, weight = lr.size()
        gt = gt[:, :, :height * args.scale, :weight * args.scale]
Пример #6
0
def main():
    """
    训练.
    """
    global checkpoint,start_epoch

    # 初始化
    model = SRResNet(large_kernel_size=large_kernel_size,
                        small_kernel_size=small_kernel_size,
                        n_channels=n_channels,
                        n_blocks=n_blocks,
                        scaling_factor=scaling_factor)
    # 初始化优化器
    optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),lr=lr)

            
    # 迁移至默认设备进行训练
    model = model.to(device)
    criterion = nn.MSELoss().to(device)

    # 加载预训练模型
    if checkpoint is not None:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
    
    if torch.cuda.is_available() and ngpu > 1:
        model = nn.DataParallel(model, device_ids=list(range(ngpu)))

    # 定制化的dataloaders
    train_dataset = SRDataset(data_folder,split='train',
                              crop_size=crop_size,
                              scaling_factor=scaling_factor,
                              lr_img_type='imagenet-norm',
                              hr_img_type='[-1, 1]')
    train_loader = torch.utils.data.DataLoader(train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers,
        pin_memory=True) 

    # 开始逐轮训练
    for epoch in range(start_epoch, epochs+1):

        print("epoch:",epoch)
        model.train()  # 训练模式:允许使用批样本归一化

        loss_epoch = AverageMeter()  # 统计损失函数

        n_iter = len(train_loader)
        loss_data=0
        tag = 0 
        # 按批处理
        for i, (lr_imgs, hr_imgs) in enumerate(train_loader):

            # 数据移至默认设备进行训练
            lr_imgs = lr_imgs.to(device)  # (batch_size (N), 3, 24, 24), imagenet-normed 格式
            hr_imgs = hr_imgs.to(device)  # (batch_size (N), 3, 96, 96),  [-1, 1]格式

            # 前向传播
            sr_imgs = model(lr_imgs)

            # 计算损失
            loss = criterion(sr_imgs, hr_imgs)  

            # 后向传播
            optimizer.zero_grad()
            loss.backward()

            # 更新模型
            optimizer.step()

            # 记录损失值
            loss_epoch.update(loss.item(), lr_imgs.size(0))
            loss_data = loss_data + loss.item()
            tag = tag + 1
            #print("%.4f",loss.item())
            '''
            import pdb
            pdb.set_trace()
            '''
            # 打印结果
            #print("第 "+str(i)+ " 个batch训练结束")
 
        loss_data = loss_data/tag
        print("loss:",loss_data)
        # 手动释放内存              
        del lr_imgs, hr_imgs, sr_imgs

        # 保存训练模型
        torch.save({
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }, 'results/checkpoint_srresnet.pth')
Пример #7
0
    lr = 0.001
    betas = (0.99, 0.999)
    TRAIN_PATH = './compress_data/voc_train.pkl'
    VALID_PATH = './compress_data/voc_valid.pkl'

    ## Set up
    train_dataset = TrainDataset(TRAIN_PATH,
                                 crop_size=crop_size,
                                 upscale_factor=upscale_factor)
    valid_dataset = ValidDataset(VALID_PATH,
                                 crop_size=crop_size,
                                 upscale_factor=upscale_factor)

    trainloader = DataLoader(train_dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=2)
    trainloader_v2 = DataLoader(
        train_dataset, batch_size=1, shuffle=True,
        num_workers=2)  # need to calculate score metrics
    validloader = DataLoader(valid_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=2)

    model = SRResNet(scale_factor=upscale_factor, kernel_size=9, n_channels=64)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=lr, betas=betas)

    ## Training
    trainer = SRResnet_trainer(model, optimizer, device)
    trainer.train(trainloader, trainloader_v2, validloader, 1, 100)