def __init__(self, model_path=None, hw='cpu'): assert model_path if hw == 'cpu': self.device = torch.device('cpu') if hw == 'cuda': self.device = torch.device('cuda') self.model = architecture.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \ mode='CNA', res_scale=1, upsample_mode='upconv') self.model.load_state_dict(torch.load(model_path), strict=True) self.model.eval() for k, v in self.model.named_parameters(): v.requires_grad = False self.model = self.model.to(self.device) print('Model warmup complete')
def setup(opts): net_PSNR_path = './models/RRDB_PSNR_x4.pth' net_ESRGAN_path = './models/RRDB_ESRGAN_x4.pth' net_PSNR = torch.load(net_PSNR_path) net_ESRGAN = torch.load(net_ESRGAN_path) net_interp = OrderedDict() interpolation_factor = opts.get('interpolationFactor', 0.5) for k, v_PSNR in net_PSNR.items(): v_ESRGAN = net_ESRGAN[k] net_interp[k] = (1 - interpolation_factor ) * v_PSNR + interpolation_factor * v_ESRGAN model = arch.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \ mode='CNA', res_scale=1, upsample_mode='upconv') model.load_state_dict(net_interp) for k, v in model.named_parameters(): v.requires_grad = False model = model.to(device) return model
def setup(model_dir): global model alpha = 0.5 net_PSNR_path = '%s/RRDB_PSNR_x4.pth' % model_dir net_ESRGAN_path = '%s/RRDB_ESRGAN_x4.pth' % model_dir net_PSNR = torch.load(net_PSNR_path) net_ESRGAN = torch.load(net_ESRGAN_path) net_interp = OrderedDict() for k, v_PSNR in net_PSNR.items(): v_ESRGAN = net_ESRGAN[k] net_interp[k] = (1 - alpha) * v_PSNR + alpha * v_ESRGAN model = arch.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \ mode='CNA', res_scale=1, upsample_mode='upconv') model.load_state_dict(net_interp) for k, v in model.named_parameters(): v.requires_grad = False model = model.to(device) return model
parts = part.split(".") n_parts = len(parts) if n_parts == 5 and parts[2] == 'sub': nb = int(parts[3]) elif n_parts == 3: part_num = int(parts[1]) if part_num > 6 and parts[2] == 'weight': scale2 += 1 if part_num > max_part: max_part = part_num out_nc = state_dict[part].shape[0] upscaleSize = 2**scale2 in_nc = state_dict['model.0.weight'].shape[1] nf = state_dict['model.0.weight'].shape[0] model = arch.RRDB_Net(in_nc, out_nc, nf, nb, gc=32, upscale=upscaleSize, norm_type=None, act_type='leakyrelu', \ mode='CNA', res_scale=1, upsample_mode='upconv') model.load_state_dict(state_dict, strict=True) del state_dict model.eval() for k, v in model.named_parameters(): v.requires_grad = False model = model.to(device) print('Model path {:s}. \nProcessing...'.format(model_path)) sys.stdout.flush() idx = 0 test_img_folder = test_img_folder.replace('*', '') for path, subdirs, files in os.walk(test_img_folder): for name in files: idx += 1
import sys import os.path import glob import cv2 import numpy as np import torch import architecture as arch model_path = sys.argv[ 1] # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth # device = torch.device('cuda') # if you want to run on CPU, change 'cuda' -> cpu device = torch.device('cpu') test_img_folder = 'LR/*' model = arch.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \ mode='CNA', res_scale=1, upsample_mode='upconv') model.load_state_dict(torch.load(model_path), strict=True) model.eval() for k, v in model.named_parameters(): v.requires_grad = False model = model.to(device) print('Model path {:s}. \nTesting...'.format(model_path)) idx = 0 for path in glob.glob(test_img_folder): idx += 1 base = os.path.splitext(os.path.basename(path))[0] print(idx, base) # read image img = cv2.imread(path, cv2.IMREAD_COLOR)
def main(): args = parse_args() script_dir = Path(__file__).parent model_dir = script_dir / "models" out_dir = args.out_dir if args.out_dir else script_dir / "jit_models" out_dir.mkdir(parents=True, exist_ok=True) if args.models: models = args.models else: models = model_dir.rglob("*.pth") device = torch.device(args.device) # read image img_path = script_dir / "LR/baboon.png" img = cv2.imread(str(img_path), cv2.IMREAD_COLOR) img = img * 1.0 / 255 img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() img_LR = img.unsqueeze(0) img_LR = img_LR.to(device) for model_path in models: if not model_path.is_file(): print(f"{str(model_path)} does not exist, skipping...") continue out_path = out_dir / (model_path.stem + args.suffix) if not args.force and out_path.is_file(): print(f"{str(out_path)} already exists, skipping...") continue print(f"Tracing: {str(model_path)}") state_dict = torch.load(model_path) if "conv_first.weight" in state_dict: print("Error: Attempted to load a new-format model") return 1 # Extract model information scale2 = 0 max_part = 0 in_nc = 3 out_nc = 3 nf = 64 nb = 23 for part in list(state_dict): parts = part.split(".") n_parts = len(parts) if n_parts == 5 and parts[2] == "sub": nb = int(parts[3]) elif n_parts == 3: part_num = int(parts[1]) if part_num > 6 and parts[2] == "weight": scale2 += 1 if part_num > max_part: max_part = part_num out_nc = state_dict[part].shape[0] upscale = 2 ** scale2 in_nc = state_dict["model.0.weight"].shape[1] nf = state_dict["model.0.weight"].shape[0] device = torch.device(args.device) net = arch.RRDB_Net( in_nc, out_nc, nf, nb, gc=32, upscale=upscale, norm_type=None, act_type="leakyrelu", mode="CNA", res_scale=1, upsample_mode="upconv", ) net.load_state_dict(state_dict, strict=True) del state_dict net.eval() for k, v in net.named_parameters(): v.requires_grad = False net = net.to(device) with torch.jit.optimized_execution(should_optimize=True): # traced_script_module = torch.jit.trace(net, img_LR) traced_script_module = torch.jit.script(net) print(f"Saving to: {str(out_path)}") try: with out_path.open("wb") as out_file: torch.jit.save(traced_script_module, out_file) except: os.remove(out_path) raise
def main(): start = time.perf_counter_ns() model_dir = Path(__file__).resolve().parent / "models" models = enum_models(model_dir) models_help = get_models_help(models) args = parse_args(models, models_help) model_path = model_dir / models[args.model] state_dict = torch.load(model_path) if "conv_first.weight" in state_dict: print("Error: Attempted to load a new-format model") return 1 # Extract model information scale2 = 0 max_part = 0 in_nc = 3 out_nc = 3 nf = 64 nb = 23 for part in list(state_dict): parts = part.split(".") n_parts = len(parts) if n_parts == 5 and parts[2] == "sub": nb = int(parts[3]) elif n_parts == 3: part_num = int(parts[1]) if part_num > 6 and parts[2] == "weight": scale2 += 1 if part_num > max_part: max_part = part_num out_nc = state_dict[part].shape[0] upscale = 2**scale2 in_nc = state_dict["model.0.weight"].shape[1] nf = state_dict["model.0.weight"].shape[0] if args.threads is not None and args.threads > 0: torch.set_num_threads(args.threads) torch.set_num_interop_threads(args.threads) if torch.cuda.is_available() and not args.cpu: device = torch.device("cuda") else: device = torch.device("cpu") model = arch.RRDB_Net( in_nc, out_nc, nf, nb, gc=32, upscale=upscale, norm_type=None, act_type="leakyrelu", mode="CNA", res_scale=1, upsample_mode="upconv", ) model.load_state_dict(state_dict, strict=True) del state_dict model.eval() for k, v in model.named_parameters(): v.requires_grad = False model = model.to(device) for i, path in enumerate( Path(img_path) for img_glob in args.images for img_path in glob.glob(str(img_glob))): print(i + 1, path.name) # read image img = cv2.imread(str(path), cv2.IMREAD_COLOR) img = img * 1.0 / 255 with warnings.catch_warnings(): warnings.simplefilter("ignore") for _ in range(args.scale): img = torch.from_numpy( np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() img = img.unsqueeze(0) if args.max_dimension: data_chunks = DataChunks(img, args.max_dimension, args.padding, upscale) chunks = data_chunks.iter() if "tqdm" in sys.modules.keys(): chunks_count = data_chunks.hlen * data_chunks.vlen chunks = tqdm.tqdm(chunks, total=chunks_count, unit=" chunks") for chunk in chunks: input = chunk.to(device) output = model(input) data_chunks.gather(output) output = data_chunks.concatenate() else: input = img.to(device) output = model(input) img = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() img = np.transpose(img[[2, 1, 0], :, :], (1, 2, 0)) img = (img * 255.0).round() out_dir = args.out_dir if args.out_dir is not None else path.parent suffix = f"_{args.model}{args.end}" if args.append_model else args.end out_path = out_dir / (path.stem + suffix + ".png") cv2.imwrite(str(out_path), img) period = time.perf_counter_ns() - start print("Done in {:,}s".format(period / 1_000_000_000.0)) return 0
from tqdm import tqdm from vis_tools import * from data_utils import * from loss import GeneratorLoss from model import Generator, Discriminator import architecture as arch import pdb import torch.nn.functional as F gpu_id = 0 port_num = 8091 display = visualizer(port=port_num) report_feq = 10 NUM_EPOCHS = 40 netG = arch.RRDB_Net(4, 3, 64, 12, gc=32, upscale=1, norm_type=None, act_type='leakyrelu', \ mode='CNA', res_scale=1, upsample_mode='upconv') netD = Discriminator() generator_criterion = GeneratorLoss() if torch.cuda.is_available(): netG.to(gpu_id) netD.to(gpu_id) generator_criterion.to(gpu_id) optimizerG = optim.Adam(netG.parameters(), lr=0.0002) optimizerD = optim.Adam(netD.parameters(), lr=0.0002) train_set = MyDataLoader(hr_dir='../data/train_sample/HR/', hr_sample_dir='../data/train_sample/HR_Sample/4/', lap_dir='../data/train_sample/LAP_HR_Norm/') train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=2, shuffle=True)