Ejemplo n.º 1
0
def test():

    if (test_content_image == '' or test_style_image == ''):
        print("Please input the content and style image!")
        return

    # 设置设备为GPU/cpu
    if torch.cuda.is_available() and gpu >= 0:
        device = torch.device(f'cuda:{gpu}')
        print(f'# CUDA available: {torch.cuda.get_device_name(0)}')
    else:
        device = 'cpu'
    # device = 'cpu'
    # 载入编码器和解码器
    encoder = VGGEncoder().to(device)
    decoder = Decoder()
    decoder.load_state_dict(torch.load(model_state_path))
    decoder = decoder.to(device)

    try:
        content = Image.open(test_content_image)
        style = Image.open(test_style_image)
        c_tensor = trans(content).unsqueeze(0).to(device)
        s_tensor = trans(style).unsqueeze(0).to(device)
        # 不记录torch自带的计算图
        with torch.no_grad():
            cf = encoder(c_tensor)
            sf = encoder(s_tensor)
            style_swap_res = style_swap(cf, sf, patch_size, 1)
            del cf
            del sf
            del encoder
            out = decoder(style_swap_res)

        c_denorm = denorm(c_tensor, device)
        out_denorm = denorm(out, device)
        res = torch.cat([c_denorm, out_denorm], dim=0)
        res = res.to('cpu')
    except RuntimeError as e:
        traceback.print_exc()
        print('Images are too large to transfer.')

    if default_output_name == '':
        c_name = os.path.splitext(os.path.basename(test_content_image))[0]
        s_name = os.path.splitext(os.path.basename(test_style_image))[0]
        output_name = f'{c_name}_{s_name}'
    else:
        output_name = default_output_name

    try:
        save_image(out_denorm, f'{output_name}.jpg')
        save_image(res, f'{output_name}_pair.jpg', nrow=2)
        o = Image.open(f'{output_name}_pair.jpg')
        style = style.resize((i // 4 for i in content.size))
        box = (o.width // 2, o.height - style.height)
        o.paste(style, box)
        o.save(f'{output_name}_style_transfer_demo.jpg', quality=95)
        print(f'result saved into files starting with {output_name}')
    except:
        pass
Ejemplo n.º 2
0
def init_model(model, args):
    if model == 'conv':
        encoder, decoder = SimpleEncoder(), SimpleDecoder()
    elif 'vgg' in model:
        encoder, decoder = VGGEncoder(model, args.fea_c), VGGDecoder(
            model, args.fea_c)
    else:
        print('Model not found! Use "conv" instead.')
        encoder, decoder = SimpleEncoder(), SimpleDecoder()
    vae = VAE(encoder, decoder)
    return vae
Ejemplo n.º 3
0
 def __init__(self):
     super().__init__()
     self.vgg_encoder = VGGEncoder()
     self.decoder = Decoder()
Ejemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser(description='Style Swap by Pytorch')
    parser.add_argument('--batch_size',
                        '-b',
                        type=int,
                        default=4,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=3,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--patch_size',
                        '-p',
                        type=int,
                        default=5,
                        help='Size of extracted patches from style features')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU ID(nagative value indicate CPU)')
    parser.add_argument('--learning_rate',
                        '-lr',
                        type=int,
                        default=1e-4,
                        help='learning rate for Adam')
    parser.add_argument('--tv_weight',
                        type=int,
                        default=1e-6,
                        help='weight for total variation loss')
    parser.add_argument('--snapshot_interval',
                        type=int,
                        default=500,
                        help='Interval of snapshot to generate image')
    parser.add_argument('--train_content_dir',
                        type=str,
                        default='/data/chen/content',
                        help='content images directory for train')
    parser.add_argument('--train_style_dir',
                        type=str,
                        default='/data/chen/style',
                        help='style images directory for train')
    parser.add_argument('--test_content_dir',
                        type=str,
                        default='/data/chen/content',
                        help='content images directory for test')
    parser.add_argument('--test_style_dir',
                        type=str,
                        default='/data/chen/style',
                        help='style images directory for test')
    parser.add_argument('--save_dir',
                        type=str,
                        default='result',
                        help='save directory for result and loss')

    args = parser.parse_args()

    # create directory to save
    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)

    loss_dir = f'{args.save_dir}/loss'
    model_state_dir = f'{args.save_dir}/model_state'
    image_dir = f'{args.save_dir}/image'

    if not os.path.exists(loss_dir):
        os.mkdir(loss_dir)
        os.mkdir(model_state_dir)
        os.mkdir(image_dir)

    # set device on GPU if available, else CPU
    if torch.cuda.is_available() and args.gpu >= 0:
        device = torch.device(f'cuda:{args.gpu}')
        print(f'# CUDA available: {torch.cuda.get_device_name(0)}')
    else:
        device = 'cpu'

    print(f'# Minibatch-size: {args.batch_size}')
    print(f'# epoch: {args.epoch}')
    print('')

    # prepare dataset and dataLoader
    train_dataset = PreprocessDataset(args.train_content_dir,
                                      args.train_style_dir)
    test_dataset = PreprocessDataset(args.test_content_dir,
                                     args.test_style_dir)
    iters = len(train_dataset)
    print(f'Length of train image pairs: {iters}')

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             shuffle=True)
    test_iter = iter(test_loader)

    # set model and optimizer
    encoder = VGGEncoder().to(device)
    decoder = Decoder().to(device)
    optimizer = Adam(decoder.parameters(), lr=args.learning_rate)

    # start training
    criterion = nn.MSELoss()
    loss_list = []

    for e in range(1, args.epoch + 1):
        print(f'Start {e} epoch')
        for i, (content, style) in tqdm(enumerate(train_loader, 1)):
            content = content.to(device)
            style = style.to(device)
            content_feature = encoder(content)
            style_feature = encoder(style)

            style_swap_res = []
            for b in range(content_feature.shape[0]):
                c = content_feature[b].unsqueeze(0)
                s = style_feature[b].unsqueeze(0)
                cs = style_swap(c, s, args.patch_size, 1)
                style_swap_res.append(cs)
            style_swap_res = torch.cat(style_swap_res, 0)

            out_style_swap = decoder(style_swap_res)
            out_content = decoder(content_feature)
            out_style = decoder(style_feature)

            out_style_swap_latent = encoder(out_style_swap)
            out_content_latent = encoder(out_content)
            out_style_latent = encoder(out_style)

            image_reconstruction_loss = criterion(
                content, out_content) + criterion(style, out_style)

            feature_reconstruction_loss = criterion(style_feature, out_style_latent) +\
                criterion(content_feature, out_content_latent) +\
                criterion(style_swap_res, out_style_swap_latent)

            tv_loss = TVloss(out_style_swap, args.tv_weight) + TVloss(out_content, args.tv_weight) \
                + TVloss(out_style, args.tv_weight)

            loss = image_reconstruction_loss + feature_reconstruction_loss + tv_loss

            loss_list.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(
                f'[{e}/total {args.epoch} epoch],[{i} /'
                f'total {round(iters/args.batch_size)} iteration]: {loss.item()}'
            )

            if i % args.snapshot_interval == 0:
                content, style = next(test_iter)
                content = content.to(device)
                style = style.to(device)
                with torch.no_grad():
                    content_feature = encoder(content)
                    style_feature = encoder(style)
                    style_swap_res = []
                    for b in range(content_feature.shape[0]):
                        c = content_feature[b].unsqueeze(0)
                        s = style_feature[b].unsqueeze(0)
                        cs = style_swap(c, s, args.patch_size, 1)
                        style_swap_res.append(cs)
                    style_swap_res = torch.cat(style_swap_res, 0)
                    out_style_swap = decoder(style_swap_res)
                    out_content = decoder(content_feature)
                    out_style = decoder(style_feature)

                content = denorm(content, device)
                style = denorm(style, device)
                out_style_swap = denorm(out_style_swap, device)
                out_content = denorm(out_content, device)
                out_style = denorm(out_style, device)
                res = torch.cat(
                    [content, style, out_content, out_style, out_style_swap],
                    dim=0)
                res = res.to('cpu')
                save_image(res,
                           f'{image_dir}/{e}_epoch_{i}_iteration.png',
                           nrow=content_feature.shape[0])
        torch.save(decoder.state_dict(), f'{model_state_dir}/{e}_epoch.pth')
    plt.plot(range(len(loss_list)), loss_list)
    plt.xlabel('iteration')
    plt.ylabel('loss')
    plt.title('train loss')
    plt.savefig(f'{loss_dir}/train_loss.png')
    with open(f'{loss_dir}/loss_log.txt', 'w') as f:
        for l in loss_list:
            f.write(f'{l}\n')
    print(f'Loss saved in {loss_dir}')
Ejemplo n.º 5
0
def train():
    # 储存结果的文件夹,如果不存在则创建
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    model_state_dir = f'{save_dir}/model_state'

    if not os.path.exists(model_state_dir):
        os.mkdir(model_state_dir)

    # 选择使用的设备,GPU或者CPU
    if torch.cuda.is_available() and gpu >= 0:
        device = torch.device(f'cuda:{gpu}')
        print(f'# CUDA available: {torch.cuda.get_device_name(0)}')
    else:
        device = 'cpu'

    print(f'# Minibatch-size: {batch_size}')
    print(f'# epoch: {epoch}')
    print('')

    # 加载数据
    train_dataset = PreprocessDataset(train_content_dir, train_style_dir)
    iters = len(train_dataset)
    print(f'Length of train image pairs: {iters}')

    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True)

    # 模型和优化器的载入
    encoder = VGGEncoder().to(device)
    decoder = Decoder().to(device)
    optimizer = Adam(decoder.parameters(), lr=learning_rate)

    # 开始训练
    criterion = nn.MSELoss()  # 均方误差
    loss_list = []

    # 第e轮训练
    for e in range(1, epoch + 1):
        print(f'Start {e} epoch')
        for i, (content, style) in tqdm(enumerate(train_loader, 1)):
            content = content.to(device)  # (batch_size,3,256,256)
            style = style.to(device)  # (batch_size,3,256,256)
            content_feature = encoder(content)  # (batch_size,256,64,64)
            style_feature = encoder(style)  # (batch_size,256,64,64)

            style_swap_res = []
            for b in range(content_feature.shape[0]):
                c = content_feature[b].unsqueeze(0)  # (1,256,64,64)
                s = style_feature[b].unsqueeze(0)
                cs = style_swap(c, s, patch_size, 1)
                style_swap_res.append(cs)
            style_swap_res = torch.cat(style_swap_res, 0)  # (4,256,64,64)

            out_style_swap = decoder(style_swap_res)  # 风格转换后的结果(转换后的特征解码)
            out_content = decoder(content_feature)  # 内容特征直接解码
            out_style = decoder(style_feature)  # 风格特征直接解码

            out_style_swap_latent = encoder(out_style_swap)  # 风格转换的结果再提取特征
            out_content_latent = encoder(out_content)  # 内容编码解码再编码
            out_style_latent = encoder(out_style)  # 风格编码解码再编码

            image_reconstruction_loss = criterion(
                content, out_content) + criterion(style,
                                                  out_style)  # 图像损失:图像编码再解码的损失

            feature_reconstruction_loss = criterion(style_feature, out_style_latent) + \
                                          criterion(content_feature, out_content_latent) + \
                                          criterion(style_swap_res, out_style_swap_latent)  # 特征损失:特征解码再编码的损失

            tv_loss = TVloss(out_style_swap, tv_weight) + TVloss(out_content, tv_weight) \
                      + TVloss(out_style, tv_weight)  # tv损失:相邻像素的差异

            loss = image_reconstruction_loss + feature_reconstruction_loss + tv_loss

            loss_list.append(loss.item())

            # 优化器反向传播调整参数
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(
                f'[{e}/total {epoch} epoch],[{i} /'
                f'total {round(iters / batch_size)} iteration]: {loss.item()}')
        torch.save(decoder.state_dict(),
                   f'{model_state_dir}/{e}_epoch.pth')  # 每个epoch模型保存
Ejemplo n.º 6
0
def style_main_temp(pics_dir=None, style_dir=''):
    e_used = False
    if pics_dir is None:
        pics_dir = []
    s_name = os.path.splitext(os.path.basename(style_dir))[0]
    if pics_dir is not None:
        # content_name=pics_dir[0].replace("\\","/").split("/")[-2]

        # 判断是否存在预览图,若都已存在则不加载模型
        flag = True
        for pic_dir in pics_dir:
            get_path = PathTemp(jpg_path_=pic_dir, style_path_=style_dir)
            style_output = get_path.get_temp_after_jpg_path()
            if os.path.exists(style_output) is False:
                flag = False
                break
        if flag is False:
            device = Device()
            # set model
            d = load_model(device)
            s = Image.open(style_dir)
            s_tensor = trans(s).unsqueeze(0).to(device)

            for file_path in pics_dir:
                if file_path.endswith(".jpg") is False:
                    continue
                try:
                    # 文件名
                    file, c_name = get_c_name_and_file_name(file_path)
                    s_name = get_style_name(style_dir)
                    get_path = PathTemp(jpg_path_=file_path,
                                        style_path_=style_dir)
                    style_output = get_path.get_temp_after_jpg_path()
                    if os.path.exists(style_output) is False:
                        e = VGGEncoder().to(device)
                        e_used = True
                        tar = get_target_img(
                            file_path,
                            device,
                            e,
                            d,
                            s_tensor,
                            c_name,
                            s_name,
                            style_outdir=os.path.dirname(file_path))

                except RuntimeError:
                    print(
                        'Images are too large to transfer. Size under 1000 are recommended '
                        + file_path)
                    InfoNotifier.InfoNotifier.g_progress_info.append(
                        file_path + '太大,无法迁移风格,推荐尝试1000×1000以下图片')

                try:
                    if os.path.exists(style_output) is False:
                        # save style transfer result
                        if os.path.exists(
                                os.path.dirname(style_output)) is False:
                            os.makedirs(os.path.dirname(style_output))
                        tar.save(style_output, quality=100)
                        print(f'result saved into files {style_output}')
                        InfoNotifier.InfoNotifier.g_progress_info.append(
                            f'已生成 {style_output}')
                    else:

                        InfoNotifier.InfoNotifier.g_progress_info.append(
                            f'{style_output}/' + ' 已存在,跳过')

                except BaseException as ec:
                    print(ec)
                    InfoNotifier.InfoNotifier.g_progress_info.append(ec)
            try:
                if e_used is True:
                    del e
            except RuntimeError:
                pass
Ejemplo n.º 7
0
def style_main_txt(txt_path='',
                   work_='',
                   style_dir='',
                   chosen_content_file_list=None,
                   dir_dict=None,
                   seamless=False):
    e_used = False
    if chosen_content_file_list is None:
        chosen_content_file_list = []
    if dir_dict is None:
        dir_dict = {}
    if os.path.exists(txt_path) is not None:
        # content_name=pics_dir[0].replace("\\","/").split("/")[-2]
        # set device on GPU if available, else CPU
        device = Device()
        # set model
        d = load_model(device)
        s = Image.open(style_dir)
        s_tensor = trans(s).unsqueeze(0).to(device)
        # read txt
        # style_name=os.path.basename(style_dir).split('.')[0]
        f = open(txt_path, "r", encoding='utf-8-sig')
        for file_path in f:
            file_path = file_path.replace("\n", "").replace("\\", "/")
            flag = False
            # 判断该图片是否在选中目录中
            for file in chosen_content_file_list:
                if dir_dict[file] == os.path.dirname(file_path):
                    flag = True
                    break
            if flag is True:
                # file_path=file_path.replace("\n","")
                # file_name=os.path.basename(file_path)
                # file_path=work_+'/'+file_path
                # parent_path=os.path.dirname(file_path)
                get_path = PathUtils(work_, style_dir, file_path)
                # get_path.work_ = work_
                # get_path.style_path = style_dir
                # get_path.dds_path=file_path
                if seamless is False:
                    jpg_path = get_path.dds_to_jpg_path()
                else:
                    jpg_path = get_path.get_expanded_jpg_path()
                # jpg_path=parent_path+'/style_transfer/'+file_name.replace(".dds",".jpg")
                if os.path.exists(jpg_path) is False:
                    print(jpg_path + "is not exist,jump from process")
                    continue
                if seamless is False:
                    style_output_path = get_path.get_style_path()
                else:
                    style_output_path = get_path.get_expanded_style_path()

                style_outdir = os.path.dirname(
                    os.path.dirname(style_output_path))
                if os.path.exists(style_outdir) is False:
                    os.makedirs(style_outdir)

                if jpg_path.endswith(".jpg") is False:
                    continue
                try:
                    # 文件名
                    # # file = os.path.basename(jpg_path)
                    # c_name = os.path.splitext(os.path.basename(jpg_path))[0]
                    # s_name = os.path.splitext(os.path.basename(style_dir))[0]
                    file, c_name = get_c_name_and_file_name(jpg_path)
                    s_name = get_style_name(style_dir)
                    if os.path.exists(style_output_path) is False:

                        # if os.path.exists(f'{style_outdir}{s_name}/' + file) is False:
                        e = VGGEncoder().to(device)
                        e_used = True
                        tar = get_target_img(jpg_path,
                                             device,
                                             e,
                                             d,
                                             s_tensor,
                                             c_name,
                                             s_name,
                                             style_outdir=style_outdir)
                    else:
                        print("file exists")
                        InfoNotifier.InfoNotifier.g_progress_info.append(
                            style_output_path + '已存在,跳过')

                except RuntimeError:
                    print(
                        'Images are too large to transfer. Size under 1000 are recommended '
                        + file_path)
                    InfoNotifier.InfoNotifier.g_progress_info.append(
                        f"{file_path}太大,无法迁移风格,推荐尝试1000×1000以下图片")

                try:
                    if os.path.exists(style_output_path) is False:
                        # save style transfer result
                        if os.path.exists(
                                os.path.dirname(style_output_path)) is False:
                            os.makedirs(os.path.dirname(style_output_path))
                        tar.save(style_output_path, quality=100)
                        print(f'result saved into files {style_output_path}/')
                        InfoNotifier.InfoNotifier.g_progress_info.append(
                            f'风格图保存到: {style_output_path}')

                    else:
                        print("exists")
                except BaseException as ec:
                    print(ec)
                    InfoNotifier.InfoNotifier.g_progress_info.append(
                        'error when saving stylized image')
        try:
            if e_used is True:
                del e
        except RuntimeError:
            pass
Ejemplo n.º 8
0
def style_main(pics_dir=None, style_dir='', base_dir='', seamless=False):
    e_used = False
    if pics_dir is None:
        pics_dir = []
    s_name = os.path.splitext(os.path.basename(style_dir))[0]
    if pics_dir is not None:
        # content_name=pics_dir[0].replace("\\","/").split("/")[-2]
        # set device on GPU if available, else CPU

        # 判断是否存在预览图,若都已存在则不加载模型
        flag = True
        for pic_dir in pics_dir:
            if os.path.exists(
                    f'{os.path.dirname(pic_dir)}/temp/{s_name}/{os.path.basename(pic_dir)}'
            ) is False:
                flag = False
                break
        if flag is False:
            device = Device()
            # set model
            d = load_model(device)
            s = Image.open(style_dir)
            s_tensor = trans(s).unsqueeze(0).to(device)

            for file_path in pics_dir:
                file_path.replace("\\", "/")
                get_path = PathUtils(base_dir, style_dir, file_path)

                if seamless is False:
                    jpg_path = get_path.dds_to_jpg_path()
                    style_output = get_path.get_style_path()
                else:
                    jpg_path = get_path.get_expanded_jpg_path()
                    style_output = get_path.get_expanded_style_path()

                save_dir = os.path.dirname(os.path.dirname(style_output))
                if os.path.exists(save_dir) is False:
                    os.makedirs(save_dir)
                if os.path.exists(jpg_path) is False:
                    print(jpg_path + "is not exist,jump from process")
                    continue
                if jpg_path.endswith(".jpg") is False:
                    continue
                try:
                    # 文件名
                    # file=os.path.basename(jpg_path)
                    # c_name = os.path.splitext(os.path.basename(jpg_path))[0]
                    # s_name = os.path.splitext(os.path.basename(style_dir))[0]
                    file, c_name = get_c_name_and_file_name(jpg_path)
                    s_name = get_style_name(style_dir)
                    print(s_name)
                    if os.path.exists(style_output) is False:
                        e = VGGEncoder().to(device)
                        e_used = True
                        tar = get_target_img(jpg_path,
                                             device,
                                             e,
                                             d,
                                             s_tensor,
                                             c_name,
                                             s_name,
                                             style_outdir=save_dir)
                    else:
                        print("file exists")
                        InfoNotifier.InfoNotifier.g_progress_info.append(
                            get_path.get_style_path() + '已存在,跳过')
                except RuntimeError:
                    print(
                        'Images are too large to transfer. Size under 1000 are recommended '
                        + file_path)

                try:
                    if os.path.exists(style_output) is False:
                        # save style transfer result
                        if os.path.exists(
                                os.path.dirname(style_output)) is False:
                            os.makedirs(os.path.dirname(style_output))

                        tar.save(style_output, quality=100)
                        print(f'result saved into files {style_output}')
                        InfoNotifier.InfoNotifier.g_progress_info.append(
                            f'风格图保存到: {style_output}')
                    else:
                        # print('exists')
                        InfoNotifier.InfoNotifier.g_progress_info.append(
                            style_output + ' 已存在,跳过')
                except BaseException as ec:
                    print(ec)
            try:
                if e_used is True:
                    del e
            except RuntimeError:
                pass
Ejemplo n.º 9
0
def main():
    parser = argparse.ArgumentParser(description='Style Swap by Pytorch')
    parser.add_argument('--content',
                        '-c',
                        type=str,
                        default=None,
                        help='Content image path e.g. content.jpg')
    parser.add_argument('--style',
                        '-s',
                        type=str,
                        default=None,
                        help='Style image path e.g. image.jpg')
    parser.add_argument(
        '--output_name',
        '-o',
        type=str,
        default=None,
        help='Output path for generated image, no need to add ext, e.g. out')
    parser.add_argument('--patch_size',
                        '-p',
                        type=int,
                        default=3,
                        help='Size of extracted patches from style features')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU ID(nagative value indicate CPU)')
    parser.add_argument('--model_state_path',
                        type=str,
                        default='model_state.pth',
                        help='save directory for result and loss')

    args = parser.parse_args()

    # set device on GPU if available, else CPU
    if torch.cuda.is_available() and args.gpu >= 0:
        device = torch.device(f'cuda:{args.gpu}')
        print(f'# CUDA available: {torch.cuda.get_device_name(0)}')
    else:
        device = 'cpu'

    # set model
    e = VGGEncoder().to(device)
    d = Decoder()
    d.load_state_dict(torch.load(args.model_state_path))
    d = d.to(device)

    try:
        c = Image.open(args.content)
        s = Image.open(args.style)
        c_tensor = trans(c).unsqueeze(0).to(device)
        s_tensor = trans(s).unsqueeze(0).to(device)
        with torch.no_grad():
            cf = e(c_tensor)
            sf = e(s_tensor)
            style_swap_res = style_swap(cf, sf, args.patch_size, 1)
            del cf
            del sf
            del e
            out = d(style_swap_res)

        c_denorm = denorm(c_tensor, device)
        out_denorm = denorm(out, device)
        res = torch.cat([c_denorm, out_denorm], dim=0)
        res = res.to('cpu')
    except RuntimeError:
        print(
            'Images are too large to transfer. Size under 1000 are recommended '
        )

    if args.output_name is None:
        c_name = os.path.splitext(os.path.basename(args.content))[0]
        s_name = os.path.splitext(os.path.basename(args.style))[0]
        args.output_name = f'{c_name}_{s_name}'

    try:
        save_image(out_denorm, f'{args.output_name}.jpg', nrow=1)
        save_image(res, f'{args.output_name}_pair.jpg', nrow=2)

        o = Image.open(f'{args.output_name}_pair.jpg')
        s = s.resize((i // 4 for i in c.size))
        box = (o.width // 2, o.height - s.height)
        o.paste(s, box)
        o.save(f'{args.output_name}_style_transfer_demo.jpg', quality=95)
        print(f'result saved into files starting with {args.output_name}')
    except:
        pass