Esempio n. 1
0
def prepare_image():
    dataset_image_path="/home/qianqianjun/CODE/DataSets/DaeDatasets"
    save_dir="/home/qianqianjun/CODE/实验结果/重建结果/改进编码器"
    os.makedirs(save_dir,exist_ok=True)
    length=5000
    dataset=Datas(dataset_image_path,length)
    loader=torch.utils.data.DataLoader(dataset,batch_size=25,shuffle=False,num_workers=4)
    encoders,decoders=Encoders(parm),Decoders(parm)
    device=torch.device("cpu")
    nn.Module.load_state_dict(encoders,torch.load("../CelebA_out/2020-03-28_23:13:37/checkpoints/encoders.pth",map_location=device))
    nn.Module.load_state_dict(decoders,torch.load("../CelebA_out/2020-03-28_23:13:37/checkpoints/decoders.pth",map_location=device))
    if parm.useCuda:
        encoders=encoders.cuda()
        decoders=decoders.cuda()
    for name,data in loader:
        batch_data=batchDataTransform(data,3)
        base=getBaseGrid(imgSize=parm.imgSize,Inbatch=True,batchSize=batch_data.shape[0])
        if parm.useCuda:
            batch_data = batch_data.cuda()
            base = base.cuda()
        z,zI,zW,intersI,intersW=encoders(batch_data)
        texture,warp,out,_=decoders(zI,zW,intersI,intersW,base)
        images=out.detach().cpu().numpy()
        images=np.transpose(images,[0,2,3,1])
        images=np.array(images*255,dtype=np.uint8)
        for index,img in enumerate(images):
            temp=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
            cv2.imwrite(os.path.join(save_dir,name[index]),temp)
Esempio n. 2
0
def load_module(dataset_name: str, prefix_path):
    paths = os.listdir(os.path.join(prefix_path, parm.result_out,
                                    dataset_name))
    if len(paths) == 0:
        exit("没有找到预训练的模型,请先训练网络")
    model_dir = os.path.abspath(
        os.path.join(prefix_path, parm.result_out, dataset_name, paths[-1]))
    print("加载模型地址:{}".format(model_dir))
    encoders = Encoders(parm)
    decoders = Decoders(parm)
    if parm.useCuda:
        torch.nn.Module.load_state_dict(
            encoders,
            torch.load(
                os.path.join(model_dir, parm.dirCheckpoints, "encoders.pth")))
        torch.nn.Module.load_state_dict(
            decoders,
            torch.load(
                os.path.join(model_dir, parm.dirCheckpoints, "decoders.pth")))
        encoders = encoders.cuda()
        decoders = decoders.cuda()
    else:
        device = torch.device("cpu")
        # 加载预训练的模型
        torch.nn.Module.load_state_dict(
            encoders,
            torch.load(os.path.join(model_dir, parm.dirCheckpoints,
                                    "encoders.pth"),
                       map_location=device))
        torch.nn.Module.load_state_dict(
            decoders,
            torch.load(os.path.join(model_dir, parm.dirCheckpoints,
                                    "decoders.pth"),
                       map_location=device))

    return encoders, decoders, model_dir
Esempio n. 3
0
def init_model(parm: Parameter, dataset_name: str, path_prefix: str):
    """
    训练前初始化模型。
    在Fine-tune 模式下,自动加载预训练模型进行微调训练
    Fine-tune 为 False 或者没有预训练的模型,从头开始训练
    :param parm:
    :param dataset_name:  训练使用的数据集名称
    :param path_prefix:  路径前缀,使用绝对路径
    :return:
    """
    encoders = Encoders(parm)
    decoders = Decoders(parm)
    if parm.useCuda:
        encoders = nn.Module.cuda(encoders)
        decoders = nn.Module.cuda(decoders)
    model_path = os.path.join(path_prefix, parm.result_out, dataset_name)
    if os.path.exists(model_path) and len(os.listdir(model_path)) >= 1:
        path = os.path.join(model_path,
                            os.listdir(model_path)[-1], parm.dirCheckpoints)
        nn.Module.load_state_dict(
            encoders, torch.load(os.path.join(path, "encoders.pth")))
        nn.Module.load_state_dict(
            decoders, torch.load(os.path.join(path, "decoders.pth")))
        if not os.path.exists(os.path.join(path, "info")):
            start_epoch = 0
            learning_rate = parm.learning_rate
        else:
            with open(os.path.join(path, "info"), "rb") as f:
                info = pk.load(f, encoding="utf-8")
            start_epoch = info.get("epoch")
            learning_rate = info.get("lr")
    else:
        nn.Module.apply(encoders, weight_init)
        nn.Module.apply(decoders, weight_init)
        start_epoch = 0
        learning_rate = parm.learning_rate
    return start_epoch, learning_rate, encoders, decoders
Esempio n. 4
0
import os

import torch

from model.UnetDAE import Encoders, Decoders
from setting.parameter import parameter as parm
from tools.utils import batchDataTransform

model_path="../CelebA_out/2020-03-28_23:13:37/checkpoints"

parm.useCuda=False

# 导入模型
encoders=Encoders(parm)
decoders=Decoders(parm)
device=torch.device("cpu")
torch.nn.Module.load_state_dict(encoders,torch.load(os.path.join(model_path,"encoders.pth"),map_location=device))
torch.nn.Module.load_state_dict(decoders,torch.load(os.path.join(model_path,"decoders.pth"),map_location=device))

# 导入数据
pos_path="/home/qianqianjun/桌面/胡子人脸"
neg_path="/home/qianqianjun/桌面/男星"
import cv2
pos_imgs=[
    cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(pos_path,name)),cv2.COLOR_BGR2RGB),(parm.imgSize,parm.imgSize))
    for name in os.listdir(pos_path)
]
neg_imgs=[
    cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(neg_path,name)),cv2.COLOR_BGR2RGB),(parm.imgSize,parm.imgSize))
    for name in os.listdir(neg_path)
]
Esempio n. 5
0
import os

import cv2
import torch
import torch.nn as nn

from model.UnetDAE import Encoders, Decoders
from setting.parameter import parameter as parm
from tools.utils import getBaseGrid, batchDataTransform, saveIntermediateImage

encoders = Encoders(parm)
decoders = Decoders(parm)
model_path = "out/other/2020-04-06_14:04:27/checkpoints"
device = torch.device("cpu")
nn.Module.load_state_dict(
    encoders,
    torch.load(os.path.join(model_path, "encoders.pth"), map_location=device))
nn.Module.load_state_dict(
    decoders,
    torch.load(os.path.join(model_path, "decoders.pth"), map_location=device))

img_dir = "/home/qianqianjun/桌面/杨洋男"
names = os.listdir(img_dir)
imgs_path = [os.path.join(img_dir, name) for name in names]
batch_data = batchDataTransform(torch.tensor([
    cv2.cvtColor(
        cv2.resize(cv2.imread(img_path),
                   (parm.imgSize, parm.imgSize)), cv2.COLOR_BGR2RGB)
    for img_path in imgs_path[:9]
],
                                             dtype=torch.float32),
Esempio n. 6
0
        warp_inter_interpolation_result.append(
            test_warp_inters[i] + lambda_warp_inter_out[i] *
            warp_inters[i].repeat(test_warp_inters[i].size()[0], 1, 1, 1))
    # 结果图像重建
    texture_result, warp_result, result, _ = decoders(
        texture_code_interpolation_result, warp_code_interpolation_result,
        texture_inter_interpolation_result, warp_inter_interpolation_result,
        base)
    return result


n_sample, n_row = 8, 4
data_dir = "/home/qianqianjun/下载/jaffe"
# 模型加载
model_dir = "../JAFFE_out/2020-03-27_21:38:52/checkpoints"
encoders, decoders = Encoders(parm), Decoders(parm)
torch.nn.Module.load_state_dict(
    encoders, torch.load(os.path.join(model_dir, "encoders.pth")))
torch.nn.Module.load_state_dict(
    decoders, torch.load(os.path.join(model_dir, "decoders.pth")))
# 加载数据
paths = os.listdir(data_dir)
test_imgs = []
display_index = [1, 2, 3, 4, 6, 7, 9]
for index, i in enumerate(paths):
    if not i.endswith(".tiff"):
        continue
    if i.find("NE") != -1:
        test_imgs.append(
            cv2.resize(cv2.imread(os.path.join(data_dir, i)),
                       (parm.imgSize, parm.imgSize)))