def prepare_image(): dataset_image_path="/home/qianqianjun/CODE/DataSets/DaeDatasets" save_dir="/home/qianqianjun/CODE/实验结果/重建结果/改进编码器" os.makedirs(save_dir,exist_ok=True) length=5000 dataset=Datas(dataset_image_path,length) loader=torch.utils.data.DataLoader(dataset,batch_size=25,shuffle=False,num_workers=4) encoders,decoders=Encoders(parm),Decoders(parm) device=torch.device("cpu") nn.Module.load_state_dict(encoders,torch.load("../CelebA_out/2020-03-28_23:13:37/checkpoints/encoders.pth",map_location=device)) nn.Module.load_state_dict(decoders,torch.load("../CelebA_out/2020-03-28_23:13:37/checkpoints/decoders.pth",map_location=device)) if parm.useCuda: encoders=encoders.cuda() decoders=decoders.cuda() for name,data in loader: batch_data=batchDataTransform(data,3) base=getBaseGrid(imgSize=parm.imgSize,Inbatch=True,batchSize=batch_data.shape[0]) if parm.useCuda: batch_data = batch_data.cuda() base = base.cuda() z,zI,zW,intersI,intersW=encoders(batch_data) texture,warp,out,_=decoders(zI,zW,intersI,intersW,base) images=out.detach().cpu().numpy() images=np.transpose(images,[0,2,3,1]) images=np.array(images*255,dtype=np.uint8) for index,img in enumerate(images): temp=cv2.cvtColor(img,cv2.COLOR_BGR2RGB) cv2.imwrite(os.path.join(save_dir,name[index]),temp)
def load_module(dataset_name: str, prefix_path): paths = os.listdir(os.path.join(prefix_path, parm.result_out, dataset_name)) if len(paths) == 0: exit("没有找到预训练的模型,请先训练网络") model_dir = os.path.abspath( os.path.join(prefix_path, parm.result_out, dataset_name, paths[-1])) print("加载模型地址:{}".format(model_dir)) encoders = Encoders(parm) decoders = Decoders(parm) if parm.useCuda: torch.nn.Module.load_state_dict( encoders, torch.load( os.path.join(model_dir, parm.dirCheckpoints, "encoders.pth"))) torch.nn.Module.load_state_dict( decoders, torch.load( os.path.join(model_dir, parm.dirCheckpoints, "decoders.pth"))) encoders = encoders.cuda() decoders = decoders.cuda() else: device = torch.device("cpu") # 加载预训练的模型 torch.nn.Module.load_state_dict( encoders, torch.load(os.path.join(model_dir, parm.dirCheckpoints, "encoders.pth"), map_location=device)) torch.nn.Module.load_state_dict( decoders, torch.load(os.path.join(model_dir, parm.dirCheckpoints, "decoders.pth"), map_location=device)) return encoders, decoders, model_dir
def init_model(parm: Parameter, dataset_name: str, path_prefix: str): """ 训练前初始化模型。 在Fine-tune 模式下,自动加载预训练模型进行微调训练 Fine-tune 为 False 或者没有预训练的模型,从头开始训练 :param parm: :param dataset_name: 训练使用的数据集名称 :param path_prefix: 路径前缀,使用绝对路径 :return: """ encoders = Encoders(parm) decoders = Decoders(parm) if parm.useCuda: encoders = nn.Module.cuda(encoders) decoders = nn.Module.cuda(decoders) model_path = os.path.join(path_prefix, parm.result_out, dataset_name) if os.path.exists(model_path) and len(os.listdir(model_path)) >= 1: path = os.path.join(model_path, os.listdir(model_path)[-1], parm.dirCheckpoints) nn.Module.load_state_dict( encoders, torch.load(os.path.join(path, "encoders.pth"))) nn.Module.load_state_dict( decoders, torch.load(os.path.join(path, "decoders.pth"))) if not os.path.exists(os.path.join(path, "info")): start_epoch = 0 learning_rate = parm.learning_rate else: with open(os.path.join(path, "info"), "rb") as f: info = pk.load(f, encoding="utf-8") start_epoch = info.get("epoch") learning_rate = info.get("lr") else: nn.Module.apply(encoders, weight_init) nn.Module.apply(decoders, weight_init) start_epoch = 0 learning_rate = parm.learning_rate return start_epoch, learning_rate, encoders, decoders
import os import torch from model.UnetDAE import Encoders, Decoders from setting.parameter import parameter as parm from tools.utils import batchDataTransform model_path="../CelebA_out/2020-03-28_23:13:37/checkpoints" parm.useCuda=False # 导入模型 encoders=Encoders(parm) decoders=Decoders(parm) device=torch.device("cpu") torch.nn.Module.load_state_dict(encoders,torch.load(os.path.join(model_path,"encoders.pth"),map_location=device)) torch.nn.Module.load_state_dict(decoders,torch.load(os.path.join(model_path,"decoders.pth"),map_location=device)) # 导入数据 pos_path="/home/qianqianjun/桌面/胡子人脸" neg_path="/home/qianqianjun/桌面/男星" import cv2 pos_imgs=[ cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(pos_path,name)),cv2.COLOR_BGR2RGB),(parm.imgSize,parm.imgSize)) for name in os.listdir(pos_path) ] neg_imgs=[ cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(neg_path,name)),cv2.COLOR_BGR2RGB),(parm.imgSize,parm.imgSize)) for name in os.listdir(neg_path) ]
import os import cv2 import torch import torch.nn as nn from model.UnetDAE import Encoders, Decoders from setting.parameter import parameter as parm from tools.utils import getBaseGrid, batchDataTransform, saveIntermediateImage encoders = Encoders(parm) decoders = Decoders(parm) model_path = "out/other/2020-04-06_14:04:27/checkpoints" device = torch.device("cpu") nn.Module.load_state_dict( encoders, torch.load(os.path.join(model_path, "encoders.pth"), map_location=device)) nn.Module.load_state_dict( decoders, torch.load(os.path.join(model_path, "decoders.pth"), map_location=device)) img_dir = "/home/qianqianjun/桌面/杨洋男" names = os.listdir(img_dir) imgs_path = [os.path.join(img_dir, name) for name in names] batch_data = batchDataTransform(torch.tensor([ cv2.cvtColor( cv2.resize(cv2.imread(img_path), (parm.imgSize, parm.imgSize)), cv2.COLOR_BGR2RGB) for img_path in imgs_path[:9] ], dtype=torch.float32),
warp_inter_interpolation_result.append( test_warp_inters[i] + lambda_warp_inter_out[i] * warp_inters[i].repeat(test_warp_inters[i].size()[0], 1, 1, 1)) # 结果图像重建 texture_result, warp_result, result, _ = decoders( texture_code_interpolation_result, warp_code_interpolation_result, texture_inter_interpolation_result, warp_inter_interpolation_result, base) return result n_sample, n_row = 8, 4 data_dir = "/home/qianqianjun/下载/jaffe" # 模型加载 model_dir = "../JAFFE_out/2020-03-27_21:38:52/checkpoints" encoders, decoders = Encoders(parm), Decoders(parm) torch.nn.Module.load_state_dict( encoders, torch.load(os.path.join(model_dir, "encoders.pth"))) torch.nn.Module.load_state_dict( decoders, torch.load(os.path.join(model_dir, "decoders.pth"))) # 加载数据 paths = os.listdir(data_dir) test_imgs = [] display_index = [1, 2, 3, 4, 6, 7, 9] for index, i in enumerate(paths): if not i.endswith(".tiff"): continue if i.find("NE") != -1: test_imgs.append( cv2.resize(cv2.imread(os.path.join(data_dir, i)), (parm.imgSize, parm.imgSize)))