def setup(opts):
    model = Transformer()
    model.load_state_dict(torch.load(opts["checkpoint"]))
    model.eval()

    if torch.cuda.is_available():
        print("GPU Mode")
        model.cuda()
    else:
        print("CPU Mode")
        model.float()

    return model
예제 #2
0
valid_ext = ['.jpg', '.png']

if not os.path.exists(opt.output_dir): os.mkdir(opt.output_dir)

# load pretrained model
model = Transformer()
model.load_state_dict(torch.load(os.path.join(opt.model_path, opt.style + '_net_G_float.pth')))
model.eval()

if opt.gpu > -1:
    print('GPU mode')
    model.cuda()
else:
    print('CPU mode')
    model.float()

for files in os.listdir(opt.input_dir):
    torch.cuda.empty_cache()
    gc.collect()
    ext = os.path.splitext(files)[1]
    if ext not in valid_ext:
        continue
    print('process file:' + files)
    # load image
    input_image = Image.open(os.path.join(opt.input_dir, files)).convert("RGB")
    # resize image, keep aspect ratio
    h = input_image.size[0]
    w = input_image.size[1]
    ratio = h *1.0 / w
    if ratio > 1:
예제 #3
0
def imageConverter(input_dir='input_img',
                   load_size=1080,
                   model_path='./pretrained_model',
                   style='Hayao',
                   output_dir='Output_img',
                   input_file='4--24.jpg'):
    gpu = -1
    file_name = input_file
    ext = os.path.splitext(file_name)

    if not os.path.exists(output_dir): os.mkdir(output_dir)

    # load pretrained model
    model = Transformer()
    model.load_state_dict(
        torch.load(os.path.join(model_path, style + '_net_G_float.pth')))
    model.eval()

    #check if gpu available
    if gpu > -1:
        print('GPU mode')
        model.cuda()
    else:
        # print('CPU mode')
        model.float()

    # load image
    input_image = Image.open(os.path.join(input_dir, file_name)).convert("RGB")
    # resize image, keep aspect ratio
    h = input_image.size[0]
    w = input_image.size[1]
    # TODO should change this usage and make it more elegant
    ratio = h * 1.0 / w
    if w > 1080 or h > 1080:
        load_size = 1080
    if load_size != -1:
        if ratio > 1:
            h = int(load_size)
            w = int(h * 1.0 / ratio)
        else:
            w = int(load_size)
            h = int(w * ratio)
        input_image = input_image.resize((h, w), Image.BICUBIC)
    input_image = np.asarray(input_image)
    # RGB -> BGR
    input_image = input_image[:, :, [2, 1, 0]]
    input_image = transforms.ToTensor()(input_image).unsqueeze(0)
    # preprocess, (-1, 1)
    input_image = -1 + 2 * input_image
    if gpu > -1:

        input_image = Variable(input_image, requires_grad=False).cuda()
    else:
        input_image = Variable(input_image, requires_grad=False).float()
    # forward
    output_image = model(input_image)
    output_image = output_image[0]
    # BGR -> RGB
    output_image = output_image[[2, 1, 0], :, :]
    print(output_image.shape)
    # deprocess, (0, 1)
    output_image = output_image.data.cpu().float() * 0.5 + 0.5
    # save
    final_name = file_name[:-4] + '_' + style + '.jpg'
    output_path = os.path.join(output_dir, final_name)
    vutils.save_image(output_image, output_path)

    return final_name
예제 #4
0
def main():
    if not os.path.exists(opt.output_dir):
        os.mkdir(opt.output_dir)

    # load pretrained model
    model = Transformer()
    model.load_state_dict(
        torch.load("{dir}/{name}".format(
            **{
                "dir": opt.model_path,
                "name": "{}_net_G_float.pth".format(opt.style)
            })))
    model.eval()

    if opt.gpu > -1:
        print("GPU mode")
        model.cuda()
    else:
        print("CPU mode")
        model.float()

    for filename in os.listdir(opt.input_dir):
        ext = os.path.splitext(filename)[1]
        if ext not in valid_ext:
            continue
        print(filename)
        # load image
        if ext == ".gif":
            if not os.path.exists("tmp"):
                os.mkdir("tmp")
            else:
                shutil.rmtree("tmp")
                os.mkdir("tmp")

            input_gif = Image.open(os.path.join(opt.input_dir, filename))
            for nframe in range(input_gif.n_frames):
                print("  {} / {}".format(nframe, input_gif.n_frames), end="\r")
                input_gif.seek(nframe)
                output_image = convert_image(
                    model,
                    input_gif.split()[0].convert("RGB"))
                save(image=output_image,
                     name="tmp/{name}_{nframe:04d}.jpg".format(
                         **{
                             "dir": opt.output_dir,
                             "name": "{}_{}".format(filename[:-4], opt.style),
                             "nframe": nframe
                         }))
            jpg_to_gif(input_gif, filename)
            shutil.rmtree("tmp")

        else:
            input_image = Image.open(os.path.join(opt.input_dir,
                                                  filename)).convert("RGB")
            output_image = convert_image(model, input_image)
            # save
            save(image=output_image,
                 name="{dir}/{name}.jpg".format(
                     **{
                         "dir": opt.output_dir,
                         "name": "{}_{}".format(filename[:-4], opt.style)
                     }))

    print("Done!")
valid_ext = ['.jpg', '.png']

if not os.path.exists(opt.output_dir): os.mkdir(opt.output_dir)

# load pretrained model
model = Transformer()
model.load_state_dict(torch.load(os.path.join(opt.model_path, opt.style + '_net_G_float.pth')))
model.eval()

if opt.gpu > -1:
	print('GPU mode')
	model.cuda()
else:
	print('CPU mode')
	model.float()

for files in os.listdir(opt.input_dir):
	ext = os.path.splitext(files)[1]
	if ext not in valid_ext:
		continue
	# load image
	input_image = Image.open(os.path.join(opt.input_dir, files)).convert("RGB")
	# resize image, keep aspect ratio
	h = input_image.size[0]
	w = input_image.size[1]
	ratio = h *1.0 / w
	if ratio > 1:
		h = opt.load_size
		w = int(h*1.0/ratio)
	else: