Exemplo n.º 1
0
testing_data_loader = DataLoader(dataset=test_set,
                                 num_workers=opt.threads,
                                 batch_size=opt.testBatchSize,
                                 shuffle=False)

print('===> Building model')
if opt.model_type == 'DBPNLL':
    model = DBPNLL(num_channels=3,
                   base_filter=64,
                   feat=256,
                   num_stages=10,
                   scale_factor=opt.upscale_factor)  ###D-DBPN
elif opt.model_type == 'DBPN-RES-MR64-3':
    model = DBPNITER(num_channels=3,
                     base_filter=64,
                     feat=256,
                     num_stages=3,
                     scale_factor=opt.upscale_factor)  ###D-DBPN
else:
    model = DBPN(num_channels=3,
                 base_filter=64,
                 feat=256,
                 num_stages=7,
                 scale_factor=opt.upscale_factor)  ###D-DBPN

if cuda:
    model = torch.nn.DataParallel(model, device_ids=gpus_list)

model.load_state_dict(
    torch.load(opt.model, map_location=lambda storage, loc: storage))
print('Pre-trained SR model is loaded.')
Exemplo n.º 2
0
import numpy as np
import torch
import utils
import os
import Config
import correction_func

#############################################

# Import the SR network here (e.g. DBPN):
from dbpn_iterative import Net as DBPNITER

# Define the desired SR model here (e.g. DBPN):
SR_model = DBPNITER(num_channels=3,
                    base_filter=64,
                    feat=256,
                    num_stages=3,
                    scale_factor=conf.scale)
SR_model = torch.nn.DataParallel(SR_model,
                                 device_ids=[conf.gpu],
                                 output_device=conf.device)
state_dict = torch.load('./models/DBPN-RES-MR64-3_%dx.pth' % conf.scale,
                        map_location=conf.gpu)
SR_model.load_state_dict(state_dict)
SR_model = SR_model.module
SR_model = SR_model.eval()
R_dag = lambda I: SR_model(I) + torch.nn.functional.interpolate(
    I, scale_factor=conf.scale, mode='bicubic')

out_dir = './'  # Set the output directory
in_dir = './'  # Set the input directory