def init_model():
    model = Transformer()
    model.load_state_dict(
        torch.load(
            os.path.join('./pretrained_model', 'Hayao' + '_net_G_float.pth')))
    model.eval()
    model.cuda(0)
    return model
def setup(opts):
    model = Transformer()
    model.load_state_dict(torch.load(opts["checkpoint"]))
    model.eval()

    if torch.cuda.is_available():
        print("GPU Mode")
        model.cuda()
    else:
        print("CPU Mode")
        model.float()

    return model
示例#3
0
def load_models(s3, bucket):

    styles = ["Hosoda", "Hayao", "Shinkai", "Paprika"]
    models = {}

    for style in styles:
        model = Transformer()
        response = s3.get_object(Bucket=bucket, Key=f"models/{style}_net_G_float.pth")
        state = torch.load(BytesIO(response["Body"].read()))
        model.load_state_dict(state)
        model.eval()
        models[style] = model

    return models
示例#4
0
import torch
import torchvision.transforms as transforms
import cv2
import os
import matplotlib.pyplot as plt

from network.Transformer import Transformer

model = Transformer()
model.load_state_dict(torch.load('pretrained_model/Hayao_net_G_float.pth'))
model.eval()
print('Model loaded!')

img_size = 700
img_path = 'test_img/9.jpg'

img = cv2.imread(img_path)

T = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(img_size, 2),
    transforms.ToTensor()
])

img_input = T(img).unsqueeze(0)
img_input = -1 + 2 * img_input

img_output = model(img_input)
img_output = (img_output.squeeze().detach().numpy() + 1.) / 2.
img_output = img_output.transpose([1, 2, 0])
示例#5
0
parser.add_argument('--input_dir', type=str, default = 'test_img')
parser.add_argument('--load_size', type=int, default = 450)
parser.add_argument('--model_path', type=str, default = './pretrained_model')
parser.add_argument('--style', type=str, default = 'Hayao')
parser.add_argument('--output_dir', type=str, default = 'test_output')
parser.add_argument('--gpu', type=int, default = 0)

opt = parser.parse_args()

valid_ext = ['.jpg', '.png']

if not os.path.exists(opt.output_dir): os.mkdir(opt.output_dir)

# load pretrained model
model = Transformer()
model.load_state_dict(torch.load(os.path.join(opt.model_path, opt.style + '_net_G_float.pth')))
model.eval()

if opt.gpu > -1:
    print('GPU mode')
    model.cuda()
else:
    print('CPU mode')
    model.float()

for files in os.listdir(opt.input_dir):
    torch.cuda.empty_cache()
    gc.collect()
    ext = os.path.splitext(files)[1]
    if ext not in valid_ext:
        continue
示例#6
0
class Main(object):
    def __init__(self):
        # network
        self.T = None
        self.D = None
        self.opt_T = None
        self.opt_D = None
        self.self_regularization_loss = None
        self.local_adversarial_loss = None
        self.delta = None

        # data
        self.anime_train_loader = None
        self.animeblur_train_loader = None
        self.real_loader = None

        self.anime_train_loader = None
        self.animeblur_train_loader = None


    def build_network(self):
        print('=' * 50)
        print('Building network...')
        #self.T = Transformer(4, cfg.img_channels, nb_features=64)

        self.T = Transformer()
        self.D = Discriminator(input_features=cfg.img_channels)

        if cfg.cuda_use:
            self.T.cuda()
            self.D.cuda()

        self.opt_T = torch.optim.Adam(self.T.parameters(), lr=cfg.t_lr)
        self.opt_D = torch.optim.SGD(self.D.parameters(), lr=cfg.d_lr)
        self.self_regularization_loss = nn.L1Loss(size_average=False)
        self.local_adversarial_loss = nn.CrossEntropyLoss(size_average=True)
        self.delta = cfg.delta

        self.vgg = Vgg16(requires_grad=False)
        network.netutils.init_vgg16("./models/")

        self.vgg.load_state_dict(torch.load(os.path.join("./models/", "vgg16.weight")))

        self.vgg.cuda()

    def load_data(self):
        print('=' * 50)
        print('Loading data...')

        transform = transforms.Compose([
#            transforms.Grayscale,
            transforms.Scale((cfg.img_width, cfg.img_height)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        anime_train_folder = torchvision.datasets.ImageFolder(root=cfg.anime_path, transform=transform)
        animeblur_train_folder = torchvision.datasets.ImageFolder(root=cfg.animeblur_path, transform=transform)



        real_train_folder = torchvision.datasets.ImageFolder(root=cfg.real_path, transform=transform)



        self.anime_train_loader = Data.DataLoader(anime_train_folder, batch_size=cfg.batch_size, shuffle=True,
                                                pin_memory=True)
        self.animeblur_train_loader = Data.DataLoader(animeblur_train_folder, batch_size=cfg.batch_size, shuffle=True,
                                                pin_memory=True)

        self.real_train_loader = Data.DataLoader(real_train_folder, batch_size=cfg.batch_size, shuffle=True,
                                                pin_memory=True)

        print('anime_train_batch %d' % len(self.anime_train_loader))
        print('animeblur_train_batch %d' % len(self.animeblur_train_loader))
        real_folder = torchvision.datasets.ImageFolder(root=cfg.real_path, transform=transform)
        # real_folder.imgs = real_folder.imgs[:2000]
        self.real_loader = Data.DataLoader(real_folder, batch_size=cfg.batch_size, shuffle=True,
                                           pin_memory=True)
        print('real_batch %d' % len(self.real_loader))


    def pre_train_t(self):
        print('=' * 50)

        #device = torch.device("cuda" if args.cuda else "cpu")

        if cfg.ref_pre_path:
            print('Loading t_pre from %s' % cfg.ref_pre_path)
            self.T.load_state_dict(torch.load(cfg.ref_pre_path))
            return

        print('pre-training the refiner network %d times...' % cfg.t_pretrain)

        mse_loss = torch.nn.MSELoss()

        for index in range(cfg.t_pretrain):

            #print("aaaaaaaa")
            anime_image_batch, _ = iter(self.anime_train_loader).next()
            anime_image_batch = Variable(anime_image_batch).cuda()


            animeblur_image_batch, _ = iter(self.animeblur_train_loader).next()

            animeblur_image_batch = Variable(animeblur_image_batch).cuda()


            real_image_batch, _ = iter(self.real_train_loader).next()
            real_image_batch = Variable(real_image_batch).cuda()

            #print(real_image_batch.size())

            self.T.train()
            transreal_image_batch = self.T(real_image_batch)


            #################################
            real_features = self.vgg(real_image_batch).relu4_3
            transreal_features = self.vgg(transreal_image_batch).relu4_3
            #t_loss = self.self_regularization_loss( real_image_batch, transreal_image_batch )
            #t_loss = mse_loss(real_features, transreal_features)
            loss = torch.abs(transreal_features - real_features)

            #t_loss = loss.sum() / (cfg.batch_size * loss.mean())
            t_loss = loss.mean()
            print(t_loss.size())
            print(t_loss)
            #################################
            # t_loss = torch.div(t_loss, cfg.batch_size)
            t_loss = torch.mul(t_loss, self.delta)

            self.opt_T.zero_grad()
            t_loss.backward()
            self.opt_T.step()

            # log every `log_interval` steps
            if (index % cfg.t_pre_per == 0) or (index == cfg.t_pretrain - 1):
                # figure_name = 'refined_image_batch_pre_train_step_{}.png'.format(index)
                print('[%d/%d] (R)reg_loss: %.4f' % (index, cfg.t_pretrain, t_loss.data[0]))

                anime_image_batch, _ = iter(self.anime_train_loader).next()
                anime_image_batch = Variable(anime_image_batch, volatile=True).cuda()

                animeblur_image_batch, _ = iter(self.animeblur_train_loader).next()
                animeblur_image_batch = Variable(animeblur_image_batch, volatile=True).cuda()


                real_image_batch, _ = iter(self.real_loader).next()
                real_image_batch = Variable(real_image_batch, volatile=True).cuda()

                self.T.eval()
                ref_image_batch = self.T(real_image_batch)

                figure_path = os.path.join(cfg.train_res_path, 'refined_image_batch_pre_train_%d.png' % index)

                generate_img_batch(anime_image_batch.data.cpu(), ref_image_batch.data.cpu(),
                                   real_image_batch.data.cpu(), figure_path)



                self.T.train()

                print('Save t_pre to models/t_pre.pkl')
                torch.save(self.T.state_dict(), 'models/t_pre.pkl')

    def pre_train_d(self):
        print('=' * 50)
        if cfg.disc_pre_path:
            print('Loading D_pre from %s' % cfg.disc_pre_path)
            self.D.load_state_dict(torch.load(cfg.disc_pre_path))
            return


        print('pre-training the discriminator network %d times...' % cfg.t_pretrain)

        self.D.train()
        self.T.eval()
        for index in range(cfg.d_pretrain):
            real_image_batch, _ = iter(self.real_loader).next()
            real_image_batch = Variable(real_image_batch).cuda()

            anime_image_batch, _ = iter(self.anime_train_loader).next()
            anime_image_batch = Variable(anime_image_batch).cuda()

            animeblur_image_batch, _ = iter(self.anime_train_loader).next()
            animeblur_image_batch = Variable(animeblur_image_batch).cuda()


            assert real_image_batch.size(0) == anime_image_batch.size(0)
            assert real_image_batch.size(0) == anime_image_batch.size(0)

            d_real_pred = self.D(real_image_batch).view(-1, 2)


            d_anime_y = Variable(torch.zeros(d_real_pred.size(0)).type(torch.LongTensor)).cuda()
            # real to fake
            d_real_y = Variable(torch.ones(d_real_pred.size(0)).type(torch.LongTensor)).cuda()
            # blur is fake
            d_blur_y = Variable(torch.ones(d_real_pred.size(0)).type(torch.LongTensor)).cuda()


            # ============ real image D ====================================================
            # self.D.train()
            d_anime_pred = self.D(anime_image_batch).view(-1, 2)
            acc_anime = calc_acc(d_anime_pred, 'real')
            d_loss_anime = self.local_adversarial_loss(d_anime_pred, d_anime_y)
            # d_loss_real = torch.div(d_loss_real, cfg.batch_size)

            # ============ anime image D ====================================================
            # self.T.eval()
            real_image_batch = self.T(real_image_batch)

            # self.D.train()
            d_real_pred = self.D(real_image_batch).view(-1, 2)
            acc_real = calc_acc(d_real_pred, 'refine')
            d_loss_real = self.local_adversarial_loss(d_real_pred, d_real_y)
            # d_loss_ref = torch.div(d_loss_ref, cfg.batch_size)

            # =========== blue image D =============

            d_animeblur_pred = self.D(animeblur_image_batch).view(-1, 2)
            acc_blur = calc_acc(d_animeblur_pred, 'refine')
            d_loss_animeblur = self.local_adversarial_loss(d_animeblur_pred, d_blur_y)



            d_loss = d_loss_anime + d_loss_animeblur + d_loss_real

            self.opt_D.zero_grad()
            d_loss.backward()
            self.opt_D.step()

            if (index % cfg.d_pre_per == 0) or (index == cfg.d_pretrain - 1):
                print('[%d/%d] (D)d_loss:%f  acc_anime:%.2f%% acc_real:%.2f%% acc_blur:%.2f%%'
                      % (index, cfg.d_pretrain, d_loss.data[0], acc_anime, acc_real, acc_blur))

        print('Save D_pre to models/D_pre.pkl')
        torch.save(self.D.state_dict(), 'models/D_pre.pkl')

    def train(self):
        print('=' * 50)
        print('Training...')


        #self.D.load_state_dict(torch.load("models/D_620.pkl"))
        #self.T.load_state_dict(torch.load("models/T_620.pkl"))


        '''
        image_history_buffer = ImageHistoryBuffer((0, cfg.img_channels, cfg.img_height, cfg.img_width),
                                                  cfg.buffet_size * 10, cfg.batch_size)
        '''

        
        for step in range(cfg.train_steps):
            print('Step[%d/%d]' % (step, cfg.train_steps))

            # ========= train the T =========
            self.D.eval()
            self.T.train()

            for p in self.D.parameters():
                p.requires_grad = False

            total_t_loss = 0.0
            total_t_loss_reg_scale = 0.0
            total_t_loss_adv = 0.0
            total_acc_adv = 0.0

            for index in range(cfg.k_t):

                real_image_batch, _ = iter(self.real_loader).next()
                real_image_batch = Variable(real_image_batch).cuda()

                #real_image_batch, _ = iter(self.real_loader).next()
                #real_image_batch = Variable(real_image_batch).cuda()

                d_real_pred = self.D(real_image_batch).view(-1, 2)

                d_real_y = Variable(torch.zeros(d_real_pred.size(0)).type(torch.LongTensor)).cuda()

                transreal_image_batch = self.T(real_image_batch)
                d_transreal_pred = self.D(transreal_image_batch).view(-1, 2)

                acc_adv = calc_acc(d_transreal_pred, 'real')
                t_loss_adv = self.local_adversarial_loss(d_transreal_pred, d_real_y)

		#--------================================================================

                real_features = self.vgg(real_image_batch).relu4_3
                transreal_features = self.vgg(transreal_image_batch).relu4_3
                #t_loss = self.self_regularization_loss( real_image_batch, transreal_image_batch )
                #t_loss = mse_loss(real_features, transreal_features)
                loss = torch.abs(transreal_features - real_features)

                #t_loss_reg = loss.sum() / cfg.batch_size

                t_loss_reg = loss.mean()

                #--------================================================================

                #t_loss_reg = self.self_regularization_loss(   self.vgg(real_image_batch).relu4_3 ,  self.vgg(transreal_image_batch).relu4_3  )
                t_loss_reg_scale = torch.mul(t_loss_reg, self.delta)


                t_loss = t_loss_adv + t_loss_reg_scale

                self.opt_T.zero_grad()
                self.opt_D.zero_grad()
                t_loss.backward()
                self.opt_T.step()

                total_t_loss += t_loss
                total_t_loss_reg_scale += t_loss_reg_scale
                total_t_loss_adv += t_loss_adv
                total_acc_adv += acc_adv

            mean_t_loss = total_t_loss / cfg.k_t
            mean_t_loss_reg_scale = total_t_loss_reg_scale / cfg.k_t
            mean_t_loss_adv = total_t_loss_adv / cfg.k_t
            mean_acc_adv = total_acc_adv / cfg.k_t

            print('(R)t_loss:%.4f t_loss_reg:%.4f, t_loss_adv:%f(%.2f%%)'
                  % (mean_t_loss.data[0], mean_t_loss_reg_scale.data[0], mean_t_loss_adv.data[0], mean_acc_adv))

            # ========= train the D =========
            self.T.eval()
            self.D.train()
            for p in self.D.parameters():
                p.requires_grad = True

            for index in range(cfg.k_d):
                real_image_batch, _ = iter(self.real_loader).next()
                anime_image_batch, _ = iter(self.anime_train_loader).next()
                animeblur_image_batch, _ = iter(self.animeblur_train_loader).next()
                assert real_image_batch.size(0) == anime_image_batch.size(0)

                real_image_batch = Variable(real_image_batch).cuda()
                anime_image_batch = Variable(anime_image_batch).cuda()
                animeblur_image_batch = Variable(anime_image_batch).cuda()





                d_anime_y = Variable(torch.zeros(d_real_pred.size(0)).type(torch.LongTensor)).cuda()
                d_real_y = Variable(torch.ones(d_real_pred.size(0)).type(torch.LongTensor)).cuda()
                d_blur_y = Variable(torch.ones(d_real_pred.size(0)).type(torch.LongTensor)).cuda()

                d_anime_pred = self.D(anime_image_batch).view(-1, 2)
                acc_anime = calc_acc(d_anime_pred, 'real')
                d_loss_anime = self.local_adversarial_loss(d_anime_pred, d_anime_y)

                real_image_batch = self.T(real_image_batch)
                d_real_pred = self.D(real_image_batch).view(-1, 2)
                acc_real = calc_acc(d_real_pred, 'refine')
                d_loss_real = self.local_adversarial_loss(d_real_pred, d_real_y)

                d_animeblur_pred = self.D(animeblur_image_batch).view(-1, 2)
                acc_blur = calc_acc(d_animeblur_pred, 'refine')
                d_loss_animeblur = self.local_adversarial_loss(d_animeblur_pred, d_blur_y)





                d_loss = d_loss_real + d_loss_anime + d_loss_animeblur

                self.D.zero_grad()
                d_loss.backward()
                self.opt_D.step()

                print('(D)d_loss:%.4f anime_loss:%.4f(%.2f%%) real_loss:%.4f(%.2f%%) blur_loss:%.4f(%.2f%%)'
                      % (d_loss.data[0] / 2, d_loss_anime.data[0], acc_anime, d_loss_real.data[0], acc_real, d_loss_animeblur.data[0], acc_blur))

            if step % cfg.save_per == 0:
                print('Save two model dict.')
                torch.save(self.D.state_dict(), cfg.D_path % step)
                torch.save(self.T.state_dict(), cfg.T_path % step)


                real_image_batch, _ = iter(self.real_loader).next()
                real_image_batch = Variable(real_image_batch, volatile=True).cuda()

                anime_image_batch, _ = iter(self.anime_train_loader).next()
                anime_image_batch = Variable(anime_image_batch, volatile=True).cuda()

                animeblur_image_batch, _ = iter(self.animeblur_train_loader).next()
                animeblur_image_batch = Variable(animeblur_image_batch, volatile=True).cuda()


                self.T.eval()
                realtrans_image_batch = self.T(real_image_batch)
                self.generate_batch_train_image(real_image_batch, realtrans_image_batch, animeblur_image_batch, step_index=step)

    def generate_batch_train_image(self, anime_image_batch, ref_image_batch, real_image_batch, step_index=-1):
        print('=' * 50)
        print('Generating a batch of training images...')
        self.T.eval()

        pic_path = os.path.join(cfg.train_res_path, 'step_%d.png' % step_index)
        generate_img_batch(anime_image_batch.cpu().data, ref_image_batch.cpu().data, real_image_batch.cpu().data, pic_path)
        print('=' * 50)
示例#7
0
def imageConverter(input_dir='input_img',
                   load_size=1080,
                   model_path='./pretrained_model',
                   style='Hayao',
                   output_dir='Output_img',
                   input_file='4--24.jpg'):
    gpu = -1
    file_name = input_file
    ext = os.path.splitext(file_name)

    if not os.path.exists(output_dir): os.mkdir(output_dir)

    # load pretrained model
    model = Transformer()
    model.load_state_dict(
        torch.load(os.path.join(model_path, style + '_net_G_float.pth')))
    model.eval()

    #check if gpu available
    if gpu > -1:
        print('GPU mode')
        model.cuda()
    else:
        # print('CPU mode')
        model.float()

    # load image
    input_image = Image.open(os.path.join(input_dir, file_name)).convert("RGB")
    # resize image, keep aspect ratio
    h = input_image.size[0]
    w = input_image.size[1]
    # TODO should change this usage and make it more elegant
    ratio = h * 1.0 / w
    if w > 1080 or h > 1080:
        load_size = 1080
    if load_size != -1:
        if ratio > 1:
            h = int(load_size)
            w = int(h * 1.0 / ratio)
        else:
            w = int(load_size)
            h = int(w * ratio)
        input_image = input_image.resize((h, w), Image.BICUBIC)
    input_image = np.asarray(input_image)
    # RGB -> BGR
    input_image = input_image[:, :, [2, 1, 0]]
    input_image = transforms.ToTensor()(input_image).unsqueeze(0)
    # preprocess, (-1, 1)
    input_image = -1 + 2 * input_image
    if gpu > -1:

        input_image = Variable(input_image, requires_grad=False).cuda()
    else:
        input_image = Variable(input_image, requires_grad=False).float()
    # forward
    output_image = model(input_image)
    output_image = output_image[0]
    # BGR -> RGB
    output_image = output_image[[2, 1, 0], :, :]
    print(output_image.shape)
    # deprocess, (0, 1)
    output_image = output_image.data.cpu().float() * 0.5 + 0.5
    # save
    final_name = file_name[:-4] + '_' + style + '.jpg'
    output_path = os.path.join(output_dir, final_name)
    vutils.save_image(output_image, output_path)

    return final_name
def main():
    if not os.path.exists(opt.output_dir):
        os.mkdir(opt.output_dir)

    # load pretrained model
    model = Transformer()
    model.load_state_dict(
        torch.load("{dir}/{name}".format(
            **{
                "dir": opt.model_path,
                "name": "{}_net_G_float.pth".format(opt.style)
            })))
    model.eval()

    if opt.gpu > -1:
        print("GPU mode")
        model.cuda()
    else:
        print("CPU mode")
        model.float()

    for filename in os.listdir(opt.input_dir):
        ext = os.path.splitext(filename)[1]
        if ext not in valid_ext:
            continue
        print(filename)
        # load image
        if ext == ".gif":
            if not os.path.exists("tmp"):
                os.mkdir("tmp")
            else:
                shutil.rmtree("tmp")
                os.mkdir("tmp")

            input_gif = Image.open(os.path.join(opt.input_dir, filename))
            for nframe in range(input_gif.n_frames):
                print("  {} / {}".format(nframe, input_gif.n_frames), end="\r")
                input_gif.seek(nframe)
                output_image = convert_image(
                    model,
                    input_gif.split()[0].convert("RGB"))
                save(image=output_image,
                     name="tmp/{name}_{nframe:04d}.jpg".format(
                         **{
                             "dir": opt.output_dir,
                             "name": "{}_{}".format(filename[:-4], opt.style),
                             "nframe": nframe
                         }))
            jpg_to_gif(input_gif, filename)
            shutil.rmtree("tmp")

        else:
            input_image = Image.open(os.path.join(opt.input_dir,
                                                  filename)).convert("RGB")
            output_image = convert_image(model, input_image)
            # save
            save(image=output_image,
                 name="{dir}/{name}.jpg".format(
                     **{
                         "dir": opt.output_dir,
                         "name": "{}_{}".format(filename[:-4], opt.style)
                     }))

    print("Done!")
parser.add_argument('--input_dir', default = 'test_img')
parser.add_argument('--load_size', default = 450)
parser.add_argument('--model_path', default = './pretrained_model')
parser.add_argument('--style', default = 'Hayao')
parser.add_argument('--output_dir', default = 'test_output')
parser.add_argument('--gpu', type=int, default = 0)

opt = parser.parse_args()

valid_ext = ['.jpg', '.png']

if not os.path.exists(opt.output_dir): os.mkdir(opt.output_dir)

# load pretrained model
model = Transformer()
model.load_state_dict(torch.load(os.path.join(opt.model_path, opt.style + '_net_G_float.pth')))
model.eval()

if opt.gpu > -1:
	print('GPU mode')
	model.cuda()
else:
	print('CPU mode')
	model.float()

for files in os.listdir(opt.input_dir):
	ext = os.path.splitext(files)[1]
	if ext not in valid_ext:
		continue
	# load image
	input_image = Image.open(os.path.join(opt.input_dir, files)).convert("RGB")
示例#10
0
# if not os.path.exists(opt.output_dir): os.mkdir(opt.output_dir)
GPU = False
if torch.cuda.is_available():
	GPU = True

pret = "./pretrained_model"

out_dir = "test_output"

style = "Hayao"

load_size = 450

# load pretrained model
model = Transformer()
model.load_state_dict(torch.load(os.path.join(pret, style + '_net_G_float.pth')))
model.eval()

open_dir = "test_img"

if GPU:
	print('GPU mode')
	model.cuda()
else:
	print('CPU mode')
	model.float()

for files in os.listdir(open_dir):
	ext = os.path.splitext(files)[1]
	if ext not in valid_ext:
		continue