コード例 #1
0
ファイル: test.py プロジェクト: tirstm01/hent-AI
 def __init__(self, model_path=None, hw='cpu'):
     assert model_path
     if hw == 'cpu':
         self.device = torch.device('cpu')
     if hw == 'cuda':
         self.device = torch.device('cuda')
     self.model = architecture.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \
                             mode='CNA', res_scale=1, upsample_mode='upconv')
     self.model.load_state_dict(torch.load(model_path), strict=True)
     self.model.eval()
     for k, v in self.model.named_parameters():
         v.requires_grad = False
     self.model = self.model.to(self.device)
     print('Model warmup complete')
コード例 #2
0
def setup(opts):
    net_PSNR_path = './models/RRDB_PSNR_x4.pth'
    net_ESRGAN_path = './models/RRDB_ESRGAN_x4.pth'
    net_PSNR = torch.load(net_PSNR_path)
    net_ESRGAN = torch.load(net_ESRGAN_path)
    net_interp = OrderedDict()
    interpolation_factor = opts.get('interpolationFactor', 0.5)
    for k, v_PSNR in net_PSNR.items():
        v_ESRGAN = net_ESRGAN[k]
        net_interp[k] = (1 - interpolation_factor
                         ) * v_PSNR + interpolation_factor * v_ESRGAN
    model = arch.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \
                        mode='CNA', res_scale=1, upsample_mode='upconv')
    model.load_state_dict(net_interp)
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)
    return model
コード例 #3
0
ファイル: esrgan_processing.py プロジェクト: wjnbreu/ml4a
def setup(model_dir):
    global model
    alpha = 0.5
    net_PSNR_path = '%s/RRDB_PSNR_x4.pth' % model_dir
    net_ESRGAN_path = '%s/RRDB_ESRGAN_x4.pth' % model_dir
    net_PSNR = torch.load(net_PSNR_path)
    net_ESRGAN = torch.load(net_ESRGAN_path)
    net_interp = OrderedDict()
    for k, v_PSNR in net_PSNR.items():
        v_ESRGAN = net_ESRGAN[k]
        net_interp[k] = (1 - alpha) * v_PSNR + alpha * v_ESRGAN
    model = arch.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \
                        mode='CNA', res_scale=1, upsample_mode='upconv')
    model.load_state_dict(net_interp)
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)
    return model
コード例 #4
0
    parts = part.split(".")
    n_parts = len(parts)
    if n_parts == 5 and parts[2] == 'sub':
        nb = int(parts[3])
    elif n_parts == 3:
        part_num = int(parts[1])
        if part_num > 6 and parts[2] == 'weight':
            scale2 += 1
        if part_num > max_part:
            max_part = part_num
            out_nc = state_dict[part].shape[0]
upscaleSize = 2**scale2
in_nc = state_dict['model.0.weight'].shape[1]
nf = state_dict['model.0.weight'].shape[0]

model = arch.RRDB_Net(in_nc, out_nc, nf, nb, gc=32, upscale=upscaleSize, norm_type=None, act_type='leakyrelu', \
                        mode='CNA', res_scale=1, upsample_mode='upconv')
model.load_state_dict(state_dict, strict=True)
del state_dict
model.eval()
for k, v in model.named_parameters():
    v.requires_grad = False
model = model.to(device)

print('Model path {:s}. \nProcessing...'.format(model_path))
sys.stdout.flush()

idx = 0
test_img_folder = test_img_folder.replace('*', '')
for path, subdirs, files in os.walk(test_img_folder):
    for name in files:
        idx += 1
コード例 #5
0
import sys
import os.path
import glob
import cv2
import numpy as np
import torch
import architecture as arch

model_path = sys.argv[
    1]  # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth
# device = torch.device('cuda')  # if you want to run on CPU, change 'cuda' -> cpu
device = torch.device('cpu')

test_img_folder = 'LR/*'

model = arch.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \
                        mode='CNA', res_scale=1, upsample_mode='upconv')
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
for k, v in model.named_parameters():
    v.requires_grad = False
model = model.to(device)

print('Model path {:s}. \nTesting...'.format(model_path))

idx = 0
for path in glob.glob(test_img_folder):
    idx += 1
    base = os.path.splitext(os.path.basename(path))[0]
    print(idx, base)
    # read image
    img = cv2.imread(path, cv2.IMREAD_COLOR)
コード例 #6
0
def main():
    args = parse_args()
    script_dir = Path(__file__).parent
    model_dir = script_dir / "models"
    out_dir = args.out_dir if args.out_dir else script_dir / "jit_models"
    out_dir.mkdir(parents=True, exist_ok=True)

    if args.models:
        models = args.models
    else:
        models = model_dir.rglob("*.pth")

    device = torch.device(args.device)

    # read image
    img_path = script_dir / "LR/baboon.png"
    img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
    img = img * 1.0 / 255
    img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
    img_LR = img.unsqueeze(0)
    img_LR = img_LR.to(device)

    for model_path in models:
        if not model_path.is_file():
            print(f"{str(model_path)} does not exist, skipping...")
            continue

        out_path = out_dir / (model_path.stem + args.suffix)
        if not args.force and out_path.is_file():
            print(f"{str(out_path)} already exists, skipping...")
            continue

        print(f"Tracing: {str(model_path)}")

        state_dict = torch.load(model_path)
        if "conv_first.weight" in state_dict:
            print("Error: Attempted to load a new-format model")
            return 1

        # Extract model information
        scale2 = 0
        max_part = 0
        in_nc = 3
        out_nc = 3
        nf = 64
        nb = 23
        for part in list(state_dict):
            parts = part.split(".")
            n_parts = len(parts)
            if n_parts == 5 and parts[2] == "sub":
                nb = int(parts[3])
            elif n_parts == 3:
                part_num = int(parts[1])
                if part_num > 6 and parts[2] == "weight":
                    scale2 += 1
                if part_num > max_part:
                    max_part = part_num
                    out_nc = state_dict[part].shape[0]
        upscale = 2 ** scale2
        in_nc = state_dict["model.0.weight"].shape[1]
        nf = state_dict["model.0.weight"].shape[0]

        device = torch.device(args.device)
        net = arch.RRDB_Net(
            in_nc,
            out_nc,
            nf,
            nb,
            gc=32,
            upscale=upscale,
            norm_type=None,
            act_type="leakyrelu",
            mode="CNA",
            res_scale=1,
            upsample_mode="upconv",
        )
        net.load_state_dict(state_dict, strict=True)
        del state_dict
        net.eval()

        for k, v in net.named_parameters():
            v.requires_grad = False
        net = net.to(device)

        with torch.jit.optimized_execution(should_optimize=True):
            # traced_script_module = torch.jit.trace(net, img_LR)
            traced_script_module = torch.jit.script(net)
            print(f"Saving to: {str(out_path)}")
            try:
                with out_path.open("wb") as out_file:
                    torch.jit.save(traced_script_module, out_file)
            except:
                os.remove(out_path)
                raise
コード例 #7
0
ファイル: esrgan.py プロジェクト: green-s/ESRGAN
def main():
    start = time.perf_counter_ns()
    model_dir = Path(__file__).resolve().parent / "models"
    models = enum_models(model_dir)
    models_help = get_models_help(models)
    args = parse_args(models, models_help)
    model_path = model_dir / models[args.model]

    state_dict = torch.load(model_path)
    if "conv_first.weight" in state_dict:
        print("Error: Attempted to load a new-format model")
        return 1

    # Extract model information
    scale2 = 0
    max_part = 0
    in_nc = 3
    out_nc = 3
    nf = 64
    nb = 23
    for part in list(state_dict):
        parts = part.split(".")
        n_parts = len(parts)
        if n_parts == 5 and parts[2] == "sub":
            nb = int(parts[3])
        elif n_parts == 3:
            part_num = int(parts[1])
            if part_num > 6 and parts[2] == "weight":
                scale2 += 1
            if part_num > max_part:
                max_part = part_num
                out_nc = state_dict[part].shape[0]
    upscale = 2**scale2
    in_nc = state_dict["model.0.weight"].shape[1]
    nf = state_dict["model.0.weight"].shape[0]

    if args.threads is not None and args.threads > 0:
        torch.set_num_threads(args.threads)
        torch.set_num_interop_threads(args.threads)

    if torch.cuda.is_available() and not args.cpu:
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    model = arch.RRDB_Net(
        in_nc,
        out_nc,
        nf,
        nb,
        gc=32,
        upscale=upscale,
        norm_type=None,
        act_type="leakyrelu",
        mode="CNA",
        res_scale=1,
        upsample_mode="upconv",
    )
    model.load_state_dict(state_dict, strict=True)
    del state_dict
    model.eval()

    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)

    for i, path in enumerate(
            Path(img_path) for img_glob in args.images
            for img_path in glob.glob(str(img_glob))):
        print(i + 1, path.name)
        # read image
        img = cv2.imread(str(path), cv2.IMREAD_COLOR)
        img = img * 1.0 / 255

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            for _ in range(args.scale):
                img = torch.from_numpy(
                    np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
                img = img.unsqueeze(0)

                if args.max_dimension:
                    data_chunks = DataChunks(img, args.max_dimension,
                                             args.padding, upscale)
                    chunks = data_chunks.iter()
                    if "tqdm" in sys.modules.keys():
                        chunks_count = data_chunks.hlen * data_chunks.vlen
                        chunks = tqdm.tqdm(chunks,
                                           total=chunks_count,
                                           unit=" chunks")
                    for chunk in chunks:
                        input = chunk.to(device)
                        output = model(input)
                        data_chunks.gather(output)
                    output = data_chunks.concatenate()
                else:
                    input = img.to(device)
                    output = model(input)

                img = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
                img = np.transpose(img[[2, 1, 0], :, :], (1, 2, 0))

        img = (img * 255.0).round()

        out_dir = args.out_dir if args.out_dir is not None else path.parent
        suffix = f"_{args.model}{args.end}" if args.append_model else args.end
        out_path = out_dir / (path.stem + suffix + ".png")
        cv2.imwrite(str(out_path), img)

    period = time.perf_counter_ns() - start
    print("Done in {:,}s".format(period / 1_000_000_000.0))
    return 0
コード例 #8
0
from tqdm import tqdm
from vis_tools import *
from data_utils import *
from loss import GeneratorLoss
from model import Generator, Discriminator
import architecture as arch
import pdb 
import torch.nn.functional as F

gpu_id = 0
port_num = 8091
display = visualizer(port=port_num)
report_feq = 10
NUM_EPOCHS = 40

netG = arch.RRDB_Net(4, 3, 64, 12, gc=32, upscale=1, norm_type=None, act_type='leakyrelu', \
                        mode='CNA', res_scale=1, upsample_mode='upconv')
netD = Discriminator()

generator_criterion = GeneratorLoss()

if torch.cuda.is_available():
    netG.to(gpu_id)
    netD.to(gpu_id)
    generator_criterion.to(gpu_id)

optimizerG = optim.Adam(netG.parameters(), lr=0.0002)
optimizerD = optim.Adam(netD.parameters(), lr=0.0002)

train_set = MyDataLoader(hr_dir='../data/train_sample/HR/', hr_sample_dir='../data/train_sample/HR_Sample/4/', lap_dir='../data/train_sample/LAP_HR_Norm/')
train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=2, shuffle=True)