def init_model(): model = Transformer() model.load_state_dict( torch.load( os.path.join('./pretrained_model', 'Hayao' + '_net_G_float.pth'))) model.eval() model.cuda(0) return model
def setup(opts): model = Transformer() model.load_state_dict(torch.load(opts["checkpoint"])) model.eval() if torch.cuda.is_available(): print("GPU Mode") model.cuda() else: print("CPU Mode") model.float() return model
def load_models(s3, bucket): styles = ["Hosoda", "Hayao", "Shinkai", "Paprika"] models = {} for style in styles: model = Transformer() response = s3.get_object(Bucket=bucket, Key=f"models/{style}_net_G_float.pth") state = torch.load(BytesIO(response["Body"].read())) model.load_state_dict(state) model.eval() models[style] = model return models
import torch import torchvision.transforms as transforms import cv2 import os import matplotlib.pyplot as plt from network.Transformer import Transformer model = Transformer() model.load_state_dict(torch.load('pretrained_model/Hayao_net_G_float.pth')) model.eval() print('Model loaded!') img_size = 700 img_path = 'test_img/9.jpg' img = cv2.imread(img_path) T = transforms.Compose([ transforms.ToPILImage(), transforms.Resize(img_size, 2), transforms.ToTensor() ]) img_input = T(img).unsqueeze(0) img_input = -1 + 2 * img_input img_output = model(img_input) img_output = (img_output.squeeze().detach().numpy() + 1.) / 2. img_output = img_output.transpose([1, 2, 0])
class Main(object): def __init__(self): # network self.T = None self.D = None self.opt_T = None self.opt_D = None self.self_regularization_loss = None self.local_adversarial_loss = None self.delta = None # data self.anime_train_loader = None self.animeblur_train_loader = None self.real_loader = None self.anime_train_loader = None self.animeblur_train_loader = None def build_network(self): print('=' * 50) print('Building network...') #self.T = Transformer(4, cfg.img_channels, nb_features=64) self.T = Transformer() self.D = Discriminator(input_features=cfg.img_channels) if cfg.cuda_use: self.T.cuda() self.D.cuda() self.opt_T = torch.optim.Adam(self.T.parameters(), lr=cfg.t_lr) self.opt_D = torch.optim.SGD(self.D.parameters(), lr=cfg.d_lr) self.self_regularization_loss = nn.L1Loss(size_average=False) self.local_adversarial_loss = nn.CrossEntropyLoss(size_average=True) self.delta = cfg.delta self.vgg = Vgg16(requires_grad=False) network.netutils.init_vgg16("./models/") self.vgg.load_state_dict(torch.load(os.path.join("./models/", "vgg16.weight"))) self.vgg.cuda() def load_data(self): print('=' * 50) print('Loading data...') transform = transforms.Compose([ # transforms.Grayscale, transforms.Scale((cfg.img_width, cfg.img_height)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) anime_train_folder = torchvision.datasets.ImageFolder(root=cfg.anime_path, transform=transform) animeblur_train_folder = torchvision.datasets.ImageFolder(root=cfg.animeblur_path, transform=transform) real_train_folder = torchvision.datasets.ImageFolder(root=cfg.real_path, transform=transform) self.anime_train_loader = Data.DataLoader(anime_train_folder, batch_size=cfg.batch_size, shuffle=True, pin_memory=True) self.animeblur_train_loader = Data.DataLoader(animeblur_train_folder, batch_size=cfg.batch_size, shuffle=True, pin_memory=True) self.real_train_loader = Data.DataLoader(real_train_folder, batch_size=cfg.batch_size, shuffle=True, pin_memory=True) print('anime_train_batch %d' % len(self.anime_train_loader)) print('animeblur_train_batch %d' % len(self.animeblur_train_loader)) real_folder = torchvision.datasets.ImageFolder(root=cfg.real_path, transform=transform) # real_folder.imgs = real_folder.imgs[:2000] self.real_loader = Data.DataLoader(real_folder, batch_size=cfg.batch_size, shuffle=True, pin_memory=True) print('real_batch %d' % len(self.real_loader)) def pre_train_t(self): print('=' * 50) #device = torch.device("cuda" if args.cuda else "cpu") if cfg.ref_pre_path: print('Loading t_pre from %s' % cfg.ref_pre_path) self.T.load_state_dict(torch.load(cfg.ref_pre_path)) return print('pre-training the refiner network %d times...' % cfg.t_pretrain) mse_loss = torch.nn.MSELoss() for index in range(cfg.t_pretrain): #print("aaaaaaaa") anime_image_batch, _ = iter(self.anime_train_loader).next() anime_image_batch = Variable(anime_image_batch).cuda() animeblur_image_batch, _ = iter(self.animeblur_train_loader).next() animeblur_image_batch = Variable(animeblur_image_batch).cuda() real_image_batch, _ = iter(self.real_train_loader).next() real_image_batch = Variable(real_image_batch).cuda() #print(real_image_batch.size()) self.T.train() transreal_image_batch = self.T(real_image_batch) ################################# real_features = self.vgg(real_image_batch).relu4_3 transreal_features = self.vgg(transreal_image_batch).relu4_3 #t_loss = self.self_regularization_loss( real_image_batch, transreal_image_batch ) #t_loss = mse_loss(real_features, transreal_features) loss = torch.abs(transreal_features - real_features) #t_loss = loss.sum() / (cfg.batch_size * loss.mean()) t_loss = loss.mean() print(t_loss.size()) print(t_loss) ################################# # t_loss = torch.div(t_loss, cfg.batch_size) t_loss = torch.mul(t_loss, self.delta) self.opt_T.zero_grad() t_loss.backward() self.opt_T.step() # log every `log_interval` steps if (index % cfg.t_pre_per == 0) or (index == cfg.t_pretrain - 1): # figure_name = 'refined_image_batch_pre_train_step_{}.png'.format(index) print('[%d/%d] (R)reg_loss: %.4f' % (index, cfg.t_pretrain, t_loss.data[0])) anime_image_batch, _ = iter(self.anime_train_loader).next() anime_image_batch = Variable(anime_image_batch, volatile=True).cuda() animeblur_image_batch, _ = iter(self.animeblur_train_loader).next() animeblur_image_batch = Variable(animeblur_image_batch, volatile=True).cuda() real_image_batch, _ = iter(self.real_loader).next() real_image_batch = Variable(real_image_batch, volatile=True).cuda() self.T.eval() ref_image_batch = self.T(real_image_batch) figure_path = os.path.join(cfg.train_res_path, 'refined_image_batch_pre_train_%d.png' % index) generate_img_batch(anime_image_batch.data.cpu(), ref_image_batch.data.cpu(), real_image_batch.data.cpu(), figure_path) self.T.train() print('Save t_pre to models/t_pre.pkl') torch.save(self.T.state_dict(), 'models/t_pre.pkl') def pre_train_d(self): print('=' * 50) if cfg.disc_pre_path: print('Loading D_pre from %s' % cfg.disc_pre_path) self.D.load_state_dict(torch.load(cfg.disc_pre_path)) return print('pre-training the discriminator network %d times...' % cfg.t_pretrain) self.D.train() self.T.eval() for index in range(cfg.d_pretrain): real_image_batch, _ = iter(self.real_loader).next() real_image_batch = Variable(real_image_batch).cuda() anime_image_batch, _ = iter(self.anime_train_loader).next() anime_image_batch = Variable(anime_image_batch).cuda() animeblur_image_batch, _ = iter(self.anime_train_loader).next() animeblur_image_batch = Variable(animeblur_image_batch).cuda() assert real_image_batch.size(0) == anime_image_batch.size(0) assert real_image_batch.size(0) == anime_image_batch.size(0) d_real_pred = self.D(real_image_batch).view(-1, 2) d_anime_y = Variable(torch.zeros(d_real_pred.size(0)).type(torch.LongTensor)).cuda() # real to fake d_real_y = Variable(torch.ones(d_real_pred.size(0)).type(torch.LongTensor)).cuda() # blur is fake d_blur_y = Variable(torch.ones(d_real_pred.size(0)).type(torch.LongTensor)).cuda() # ============ real image D ==================================================== # self.D.train() d_anime_pred = self.D(anime_image_batch).view(-1, 2) acc_anime = calc_acc(d_anime_pred, 'real') d_loss_anime = self.local_adversarial_loss(d_anime_pred, d_anime_y) # d_loss_real = torch.div(d_loss_real, cfg.batch_size) # ============ anime image D ==================================================== # self.T.eval() real_image_batch = self.T(real_image_batch) # self.D.train() d_real_pred = self.D(real_image_batch).view(-1, 2) acc_real = calc_acc(d_real_pred, 'refine') d_loss_real = self.local_adversarial_loss(d_real_pred, d_real_y) # d_loss_ref = torch.div(d_loss_ref, cfg.batch_size) # =========== blue image D ============= d_animeblur_pred = self.D(animeblur_image_batch).view(-1, 2) acc_blur = calc_acc(d_animeblur_pred, 'refine') d_loss_animeblur = self.local_adversarial_loss(d_animeblur_pred, d_blur_y) d_loss = d_loss_anime + d_loss_animeblur + d_loss_real self.opt_D.zero_grad() d_loss.backward() self.opt_D.step() if (index % cfg.d_pre_per == 0) or (index == cfg.d_pretrain - 1): print('[%d/%d] (D)d_loss:%f acc_anime:%.2f%% acc_real:%.2f%% acc_blur:%.2f%%' % (index, cfg.d_pretrain, d_loss.data[0], acc_anime, acc_real, acc_blur)) print('Save D_pre to models/D_pre.pkl') torch.save(self.D.state_dict(), 'models/D_pre.pkl') def train(self): print('=' * 50) print('Training...') #self.D.load_state_dict(torch.load("models/D_620.pkl")) #self.T.load_state_dict(torch.load("models/T_620.pkl")) ''' image_history_buffer = ImageHistoryBuffer((0, cfg.img_channels, cfg.img_height, cfg.img_width), cfg.buffet_size * 10, cfg.batch_size) ''' for step in range(cfg.train_steps): print('Step[%d/%d]' % (step, cfg.train_steps)) # ========= train the T ========= self.D.eval() self.T.train() for p in self.D.parameters(): p.requires_grad = False total_t_loss = 0.0 total_t_loss_reg_scale = 0.0 total_t_loss_adv = 0.0 total_acc_adv = 0.0 for index in range(cfg.k_t): real_image_batch, _ = iter(self.real_loader).next() real_image_batch = Variable(real_image_batch).cuda() #real_image_batch, _ = iter(self.real_loader).next() #real_image_batch = Variable(real_image_batch).cuda() d_real_pred = self.D(real_image_batch).view(-1, 2) d_real_y = Variable(torch.zeros(d_real_pred.size(0)).type(torch.LongTensor)).cuda() transreal_image_batch = self.T(real_image_batch) d_transreal_pred = self.D(transreal_image_batch).view(-1, 2) acc_adv = calc_acc(d_transreal_pred, 'real') t_loss_adv = self.local_adversarial_loss(d_transreal_pred, d_real_y) #--------================================================================ real_features = self.vgg(real_image_batch).relu4_3 transreal_features = self.vgg(transreal_image_batch).relu4_3 #t_loss = self.self_regularization_loss( real_image_batch, transreal_image_batch ) #t_loss = mse_loss(real_features, transreal_features) loss = torch.abs(transreal_features - real_features) #t_loss_reg = loss.sum() / cfg.batch_size t_loss_reg = loss.mean() #--------================================================================ #t_loss_reg = self.self_regularization_loss( self.vgg(real_image_batch).relu4_3 , self.vgg(transreal_image_batch).relu4_3 ) t_loss_reg_scale = torch.mul(t_loss_reg, self.delta) t_loss = t_loss_adv + t_loss_reg_scale self.opt_T.zero_grad() self.opt_D.zero_grad() t_loss.backward() self.opt_T.step() total_t_loss += t_loss total_t_loss_reg_scale += t_loss_reg_scale total_t_loss_adv += t_loss_adv total_acc_adv += acc_adv mean_t_loss = total_t_loss / cfg.k_t mean_t_loss_reg_scale = total_t_loss_reg_scale / cfg.k_t mean_t_loss_adv = total_t_loss_adv / cfg.k_t mean_acc_adv = total_acc_adv / cfg.k_t print('(R)t_loss:%.4f t_loss_reg:%.4f, t_loss_adv:%f(%.2f%%)' % (mean_t_loss.data[0], mean_t_loss_reg_scale.data[0], mean_t_loss_adv.data[0], mean_acc_adv)) # ========= train the D ========= self.T.eval() self.D.train() for p in self.D.parameters(): p.requires_grad = True for index in range(cfg.k_d): real_image_batch, _ = iter(self.real_loader).next() anime_image_batch, _ = iter(self.anime_train_loader).next() animeblur_image_batch, _ = iter(self.animeblur_train_loader).next() assert real_image_batch.size(0) == anime_image_batch.size(0) real_image_batch = Variable(real_image_batch).cuda() anime_image_batch = Variable(anime_image_batch).cuda() animeblur_image_batch = Variable(anime_image_batch).cuda() d_anime_y = Variable(torch.zeros(d_real_pred.size(0)).type(torch.LongTensor)).cuda() d_real_y = Variable(torch.ones(d_real_pred.size(0)).type(torch.LongTensor)).cuda() d_blur_y = Variable(torch.ones(d_real_pred.size(0)).type(torch.LongTensor)).cuda() d_anime_pred = self.D(anime_image_batch).view(-1, 2) acc_anime = calc_acc(d_anime_pred, 'real') d_loss_anime = self.local_adversarial_loss(d_anime_pred, d_anime_y) real_image_batch = self.T(real_image_batch) d_real_pred = self.D(real_image_batch).view(-1, 2) acc_real = calc_acc(d_real_pred, 'refine') d_loss_real = self.local_adversarial_loss(d_real_pred, d_real_y) d_animeblur_pred = self.D(animeblur_image_batch).view(-1, 2) acc_blur = calc_acc(d_animeblur_pred, 'refine') d_loss_animeblur = self.local_adversarial_loss(d_animeblur_pred, d_blur_y) d_loss = d_loss_real + d_loss_anime + d_loss_animeblur self.D.zero_grad() d_loss.backward() self.opt_D.step() print('(D)d_loss:%.4f anime_loss:%.4f(%.2f%%) real_loss:%.4f(%.2f%%) blur_loss:%.4f(%.2f%%)' % (d_loss.data[0] / 2, d_loss_anime.data[0], acc_anime, d_loss_real.data[0], acc_real, d_loss_animeblur.data[0], acc_blur)) if step % cfg.save_per == 0: print('Save two model dict.') torch.save(self.D.state_dict(), cfg.D_path % step) torch.save(self.T.state_dict(), cfg.T_path % step) real_image_batch, _ = iter(self.real_loader).next() real_image_batch = Variable(real_image_batch, volatile=True).cuda() anime_image_batch, _ = iter(self.anime_train_loader).next() anime_image_batch = Variable(anime_image_batch, volatile=True).cuda() animeblur_image_batch, _ = iter(self.animeblur_train_loader).next() animeblur_image_batch = Variable(animeblur_image_batch, volatile=True).cuda() self.T.eval() realtrans_image_batch = self.T(real_image_batch) self.generate_batch_train_image(real_image_batch, realtrans_image_batch, animeblur_image_batch, step_index=step) def generate_batch_train_image(self, anime_image_batch, ref_image_batch, real_image_batch, step_index=-1): print('=' * 50) print('Generating a batch of training images...') self.T.eval() pic_path = os.path.join(cfg.train_res_path, 'step_%d.png' % step_index) generate_img_batch(anime_image_batch.cpu().data, ref_image_batch.cpu().data, real_image_batch.cpu().data, pic_path) print('=' * 50)
def imageConverter(input_dir='input_img', load_size=1080, model_path='./pretrained_model', style='Hayao', output_dir='Output_img', input_file='4--24.jpg'): gpu = -1 file_name = input_file ext = os.path.splitext(file_name) if not os.path.exists(output_dir): os.mkdir(output_dir) # load pretrained model model = Transformer() model.load_state_dict( torch.load(os.path.join(model_path, style + '_net_G_float.pth'))) model.eval() #check if gpu available if gpu > -1: print('GPU mode') model.cuda() else: # print('CPU mode') model.float() # load image input_image = Image.open(os.path.join(input_dir, file_name)).convert("RGB") # resize image, keep aspect ratio h = input_image.size[0] w = input_image.size[1] # TODO should change this usage and make it more elegant ratio = h * 1.0 / w if w > 1080 or h > 1080: load_size = 1080 if load_size != -1: if ratio > 1: h = int(load_size) w = int(h * 1.0 / ratio) else: w = int(load_size) h = int(w * ratio) input_image = input_image.resize((h, w), Image.BICUBIC) input_image = np.asarray(input_image) # RGB -> BGR input_image = input_image[:, :, [2, 1, 0]] input_image = transforms.ToTensor()(input_image).unsqueeze(0) # preprocess, (-1, 1) input_image = -1 + 2 * input_image if gpu > -1: input_image = Variable(input_image, requires_grad=False).cuda() else: input_image = Variable(input_image, requires_grad=False).float() # forward output_image = model(input_image) output_image = output_image[0] # BGR -> RGB output_image = output_image[[2, 1, 0], :, :] print(output_image.shape) # deprocess, (0, 1) output_image = output_image.data.cpu().float() * 0.5 + 0.5 # save final_name = file_name[:-4] + '_' + style + '.jpg' output_path = os.path.join(output_dir, final_name) vutils.save_image(output_image, output_path) return final_name
def main(): if not os.path.exists(opt.output_dir): os.mkdir(opt.output_dir) # load pretrained model model = Transformer() model.load_state_dict( torch.load("{dir}/{name}".format( **{ "dir": opt.model_path, "name": "{}_net_G_float.pth".format(opt.style) }))) model.eval() if opt.gpu > -1: print("GPU mode") model.cuda() else: print("CPU mode") model.float() for filename in os.listdir(opt.input_dir): ext = os.path.splitext(filename)[1] if ext not in valid_ext: continue print(filename) # load image if ext == ".gif": if not os.path.exists("tmp"): os.mkdir("tmp") else: shutil.rmtree("tmp") os.mkdir("tmp") input_gif = Image.open(os.path.join(opt.input_dir, filename)) for nframe in range(input_gif.n_frames): print(" {} / {}".format(nframe, input_gif.n_frames), end="\r") input_gif.seek(nframe) output_image = convert_image( model, input_gif.split()[0].convert("RGB")) save(image=output_image, name="tmp/{name}_{nframe:04d}.jpg".format( **{ "dir": opt.output_dir, "name": "{}_{}".format(filename[:-4], opt.style), "nframe": nframe })) jpg_to_gif(input_gif, filename) shutil.rmtree("tmp") else: input_image = Image.open(os.path.join(opt.input_dir, filename)).convert("RGB") output_image = convert_image(model, input_image) # save save(image=output_image, name="{dir}/{name}.jpg".format( **{ "dir": opt.output_dir, "name": "{}_{}".format(filename[:-4], opt.style) })) print("Done!")
parser.add_argument('--load_size', default = 450) parser.add_argument('--model_path', default = './pretrained_model') parser.add_argument('--style', default = 'Hayao') parser.add_argument('--output_dir', default = 'test_output') parser.add_argument('--gpu', type=int, default = 0) opt = parser.parse_args() valid_ext = ['.jpg', '.png'] if not os.path.exists(opt.output_dir): os.mkdir(opt.output_dir) # load pretrained model model = Transformer() model.load_state_dict(torch.load(os.path.join(opt.model_path, opt.style + '_net_G_float.pth'))) model.eval() if opt.gpu > -1: print('GPU mode') model.cuda() else: print('CPU mode') model.float() for files in os.listdir(opt.input_dir): ext = os.path.splitext(files)[1] if ext not in valid_ext: continue # load image input_image = Image.open(os.path.join(opt.input_dir, files)).convert("RGB") # resize image, keep aspect ratio