Exemplo n.º 1
0
class Inference():
    """
    Inference class that wrapped styleTransfer classes.  
    """

    def __init__(self):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.mosaic_WeightsPath = "Fast-Neural-Style-Transfer/weights/mosaic_10000.pth"
        self.cuphead_WeightsPath = mospath + "Fast-Neural-Style-Transfer/weights/cuphead_10000.pth"
        self.starry_night_WeightsPath = mospath + "Fast-Neural-Style-Transfer/weights/starry_night_10000.pth"

        # instance of style Transfer model.
        self.transformer = TransformerNet().to(self.device)
        self.transform = style_transform()

    def load_model(self, modelType):
        """
        Args:
            modelType: string that contains name of the model to switch to it.
        """
        if modelType.lower() == "mosaic":
            self.transformer.load_state_dict(torch.load(
                self.mosaic_WeightsPath, map_location=torch.device('cpu')))
        elif modelType.lower() == "cuphead":
            self.transformer.load_state_dict(torch.load(
                self.cuphead_WeightsPath, map_location=torch.device('cpu')))
        else:
            self.transformer.load_state_dict(torch.load(
                self.starry_night_WeightsPath, map_location=torch.device('cpu')))

        self.transformer.eval()

    def predict(self, modelType, Base64Img):
        img = base64.b64decode(Base64Img)
        image = Image.open(BytesIO(img)).convert('RGB')
        self.load_model(modelType)
        
        image_tensor = Variable(self.transform(image)).to(self.device)
        image_tensor = image_tensor.unsqueeze(0)

        # Stylize image
        with torch.no_grad():
            stylized_image = denormalize(self.transformer(image_tensor)).cpu()
        imageBytes = save_image(
            stylized_image, "/home/mostafax/Desktop/Style-Transfer-App/Fast-Neural-Style-Transfer/images/result.jpeg")
        
        
        my_string = base64.b64encode(imageBytes)

        return my_string    
Exemplo n.º 2
0
def test():
    image_path = './image_folder'
    ckpt = './path_to_ckpt'
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    transform = test_transform()

    transformer = TransformerNet().to(device)
    transformer.load_state_dict(torch.load(ckpt))
    transformer.eval()
    image = torch.Tensor(transform(Image.open(image_path))).to(device)
    image = image.unsqueeze(0)

    with torch.no_grad():
        style_trans_img = denormalize(transformer(image)).cpu()

    fn = image_path.split("/")[-1]
    save_image(style_trans_img, f"images/outputs/stylized-{fn}")
Exemplo n.º 3
0
def style_image(image_path, model_path):

    image = Image.open(image_path)
    width, height = image.size
    alpha = image.convert('RGBA').split()[-1]

    # @TODO - import the mean color...
    mean_color = Image.new("RGB", image.size, (124, 116, 103))

    rgb_image = image.convert('RGB')
    rgb_image.paste(mean_color, mask=invert(alpha))

    cuda_available = torch.cuda and torch.cuda.is_available()
    device = torch.device("cuda" if cuda_available else "cpu")

    model = torch.load(model_path, map_location=device)

    image_filename = os.path.basename(image_path)
    model_filename = os.path.basename(model_path)
    model_name = os.path.splitext(model_filename)[0]

    os.makedirs(f"images/outputs/{model_name}", exist_ok=True)

    transform = style_transform()

    # Define model and load model checkpoint
    transformer = TransformerNet().to(device)
    transformer.load_state_dict(model)
    transformer.eval()

    # Prepare input
    image_tensor = Variable(transform(rgb_image)).to(device)
    image_tensor = image_tensor.unsqueeze(0)

    # Stylize image
    with torch.no_grad():
        output_tensor = depro(transformer(image_tensor))

    stylized_image = F.to_pil_image(output_tensor) \
        .convert('RGBA') \
        .crop((0, 0, width, height))

    stylized_image.putalpha(alpha)
    stylized_image.save(f"images/outputs/{model_name}/{image_filename}", 'PNG')
def transfer_img(usr_img_path, style_model_path, new_img_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Define model and load model checkpoint
    transformer = TransformerNet().to(device)
    transformer.load_state_dict(torch.load(style_model_path))
    transformer.eval()

    # Prepare input
    transform = style_transform()
    image_tensor = transform(Image.open(usr_img_path)).to(device)
    image_tensor = image_tensor.unsqueeze(0)

    # Stylize image
    with torch.no_grad():
        stylized_image = denormalize(transformer(image_tensor)).cpu()

    # Save image
    save_image(stylized_image, new_img_path)
def exportStyleTransfer(image_path, style):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    styles = ['9', '32', '51', '55', '56', '58', '108', '140', '150', '153', '154', '155', '156']
    checkpoint_model = '/home/KLTN_TheFaceOfArtFaceParsing/Updates/StyleTransfer/models/' + styles[style-1] + '_4000.pth'
    transform = style_transform()

    # Define model and load model checkpoint
    transformer = TransformerNet().to(device)
    transformer.load_state_dict(torch.load(checkpoint_model))
    transformer.eval()

    # Prepare input
    image_tensor = Variable(transform(Image.open(image_path))).to(device)
    image_tensor = image_tensor.unsqueeze(0)

    # Stylize image
    with torch.no_grad():
        stylized_image = denormalize(transformer(image_tensor)).cpu()

    # Save image
    save_image(stylized_image,'/home/KLTN_TheFaceOfArtFaceParsing/result.jpg')
Exemplo n.º 6
0
                        type=str,
                        required=True,
                        help="Path to checkpoint model")
    args = parser.parse_args()
    print(args)

    os.makedirs("images/outputs", exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    transform = style_transform()

    # Define model and load model checkpoint
    transformer = TransformerNet().to(device)
    transformer.load_state_dict(torch.load(args.checkpoint_model, 'cpu'))
    transformer.eval()

    stylized_frames = []
    for frame in tqdm.tqdm(extract_frames(args.video_path),
                           desc="Processing frames"):
        # Prepare input frame
        image_tensor = Variable(transform(frame)).to(device).unsqueeze(0)
        # Stylize image
        with torch.no_grad():
            stylized_image = transformer(image_tensor)
        # Add to frames
        stylized_frames += [deprocess(stylized_image)]

    # Create video from frames
    video_name = args.video_path.split("/")[-1].split(".")[0]
    writer = skvideo.io.FFmpegWriter(
Exemplo n.º 7
0
def get_transformer(checkpoint_path: str) -> torch.nn.Module:
    transformer = TransformerNet().to(device)
    transformer.load_state_dict(torch.load(checkpoint_path, map_location=device))
    transformer.eval()
    return transformer
Exemplo n.º 8
0
def index():
    if request.method == 'POST':
        f = request.files['image']
        filename = f.filename
        path = os.path.join(UPLOAD_FOLDER, filename)
        f.save(path)

        style_img = path
        a = int(Image.open(style_img).size[0])
        b = int(Image.open(style_img).size[1])
        device = torch.device("cpu")

        transform = style_transform()

        transformer = TransformerNet().to(device)
        model = request.form['style']

        if model == 'mosaic':
            transformer.load_state_dict(
                torch.load("static/model/mosaic_10000.pth",
                           map_location=torch.device('cpu')))
            transformer.eval()
            filename = "mosaic{}".format(filename)
        elif model == 'mona':
            transformer.load_state_dict(
                torch.load("static/model/mona_24000.pth",
                           map_location=torch.device('cpu')))
            transformer.eval()
            filename = "mona{}".format(filename)
        elif model == 'starry':
            transformer.load_state_dict(
                torch.load("static/model/starry_night_10000.pth",
                           map_location=torch.device('cpu')))
            transformer.eval()
            filename = "starry{}".format(filename)
    # Prepare input
        if a * b < 800000:
            image_tensor = Variable(
                transform(Image.open(style_img).convert("RGB"))).to(device)
            image_tensor = image_tensor.unsqueeze(0)
        elif 800000 < a * b < 1960000:
            image_tensor = Variable(
                transform(
                    Image.open(style_img).convert("RGB").resize(
                        (int(a * 2 / 3), int(b * 2 / 3))))).to(device)
            image_tensor = image_tensor.unsqueeze(0)
        else:
            image_tensor = Variable(
                transform(
                    Image.open(style_img).convert("RGB").resize(
                        (int(a / 2), int(b / 2))))).to(device)
            image_tensor = image_tensor.unsqueeze(0)

        with torch.no_grad():
            stylized_image = denormalize(transformer(image_tensor)).cpu()

        save_image(stylized_image,
                   "./static/predict/result_{}".format(filename))
        # stylized_image.save("./static/predict/{}".format(filename))
        #prediction pass to pipeline model

        return render_template("index.html",
                               fileupload=True,
                               img_name="result_" + filename)
    return render_template("index.html", fileupload=False)
Exemplo n.º 9
0
def main():

    uploaded_file = st.file_uploader(
        "Choose a picc", type=['jpg', 'png', 'webm', 'mp4', 'gif', 'jpeg'])
    if uploaded_file is not None:
        st.image(uploaded_file, width=200)

    folder = os.path.abspath(os.getcwd())
    folder = folder + '/models'

    fnames = []

    for basename in os.listdir(folder):
        print(basename)
        fname = os.path.join(folder, basename)

        if fname.endswith('.pth'):
            fnames.append(fname)
    checkpoint = st.selectbox('Select a pretrained model', fnames)

    # parser = argparse.ArgumentParser()
    # parser.add_argument("--image_path", type=str, required=True, help="Path to image")
    # parser.add_argument("--checkpoint_model", type=str, required=True, help="Path to checkpoint model")
    # args = parser.parse_args()
    # print(args)

    os.makedirs("images/outputs", exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # device = torch.device("cpu")
    transform = style_transform()
    try:
        # Define model and load model checkpoint
        transformer = TransformerNet().to(device)
        transformer.load_state_dict(torch.load(checkpoint))
        transformer.eval()

        # Prepare input
        image_tensor = Variable(
            transform(Image.open(uploaded_file).convert('RGB'))).to(device)
        image_tensor = image_tensor.unsqueeze(0)

        # Stylize image
        with torch.no_grad():
            stylized_image = denormalize(transformer(image_tensor)).cpu()

        # colormaps = ['Accent', 'Accent_r', 'Blues', 'Blues_r', 'BrBG', 'BrBG_r', 'BuGn', 'BuGn_r', 'BuPu', 'BuPu_r', 'CMRmap', 'CMRmap_r', 'Dark2', 'Dark2_r', 'GnBu', 'GnBu_r', 'Greens', 'Greens_r', 'Greys', 'Greys_r', 'OrRd', 'OrRd_r', 'Oranges', 'Oranges_r', 'PRGn', 'PRGn_r', 'Paired', 'Paired_r', 'Pastel1', 'Pastel1_r', 'Pastel2', 'Pastel2_r', 'PiYG', 'PiYG_r', 'PuBu', 'PuBuGn', 'PuBuGn_r', 'PuBu_r', 'PuOr', 'PuOr_r', 'PuRd', 'PuRd_r', 'Purples', 'Purples_r', 'RdBu', 'RdBu_r', 'RdGy', 'RdGy_r', 'RdPu', 'RdPu_r', 'RdYlBu', 'RdYlBu_r', 'RdYlGn', 'RdYlGn_r', 'Reds', 'Reds_r', 'Set1', 'Set1_r', 'Set2', 'Set2_r', 'Set3', 'Set3_r', 'Spectral', 'Spectral_r', 'Wistia', 'Wistia_r', 'YlGn', 'YlGnBu', 'YlGnBu_r', 'YlGn_r', 'YlOrBr', 'YlOrBr_r', 'YlOrRd', 'YlOrRd_r', 'afmhot', 'afmhot_r', 'autumn', 'autumn_r', 'binary', 'binary_r', 'bone', 'bone_r', 'brg', 'brg_r', 'bwr', 'bwr_r', 'cividis', 'cividis_r', 'cool', 'cool_r', 'coolwarm', 'coolwarm_r', 'copper', 'copper_r', 'cubehelix', 'cubehelix_r', 'flag', 'flag_r', 'gist_earth', 'gist_earth_r', 'gist_gray', 'gist_gray_r', 'gist_heat', 'gist_heat_r', 'gist_ncar', 'gist_ncar_r', 'gist_rainbow', 'gist_rainbow_r', 'gist_stern', 'gist_stern_r', 'gist_yarg', 'gist_yarg_r', 'gnuplot', 'gnuplot2', 'gnuplot2_r', 'gnuplot_r', 'gray', 'gray_r', 'hot', 'hot_r', 'hsv', 'hsv_r', 'inferno', 'inferno_r', 'jet', 'jet_r', 'magma', 'magma_r', 'nipy_spectral', 'nipy_spectral_r', 'ocean', 'ocean_r', 'pink', 'pink_r', 'plasma', 'plasma_r', 'prism', 'prism_r', 'rainbow', 'rainbow_r', 'seismic', 'seismic_r', 'spring', 'spring_r', 'summer', 'summer_r', 'tab10', 'tab10_r', 'tab20', 'tab20_r', 'tab20b', 'tab20b_r', 'tab20c', 'tab20c_r', 'terrain', 'terrain_r', 'turbo', 'turbo_r', 'twilight', 'twilight_r', 'twilight_shifted', 'twilight_shifted_r', 'viridis', 'viridis_r', 'winter', 'winter_r']

    # colormap = st.selectbox('Select a colormap', colormaps)

    # plt.imshow(stylized_image.numpy()[0][0], cmap=colormap)
    # plt.imshow(stylized_image.numpy()[0][0], cmap='gist_rainbow')
    # img = np.squeeze(stylized_image)
    # plt.imshow(img[0])
    # plt.show()
    # st.image(img)
    # # Save image

        fn = str(np.random.randint(0, 100)) + 'image.jpg'
        save_image(stylized_image, f"images/outputs/stylized-{fn}")

        st.image(f"images/outputs/stylized-{fn}")
    except:
        st.write('Choose an image')
Exemplo n.º 10
0
import torch
import torchvision
from torchvision import datasets, models, transforms
from models import TransformerNet, VGG16
import coremltools as ct
import urllib
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.utils import save_image
import numpy as np

net = TransformerNet()
net.load_state_dict(
    torch.load(f"starry_night_10000.pth", map_location=torch.device('cpu')))
net.eval()

x = torch.rand(1, 3, 512, 512)
traced_model = torch.jit.trace(net, x)

model = ct.convert(traced_model,
                   inputs=[ct.ImageType(name="input_image", shape=x.shape)])

model.save("starry_night.mlmodel")
Exemplo n.º 11
0
def main():

    #Upload images
    uploaded_file = st.file_uploader("Choose a picture", type=['jpg', 'png'])
    # if uploaded_file is not None:
    #     st.image(uploaded_file, width=200)
    second_uploaded_file = st.file_uploader("Choose another picture",
                                            type=['jpg', 'png'])
    # if second_uploaded_file is not None:
    #     st.image(second_uploaded_file, width=200)
    try:
        image1 = Image.open(uploaded_file)
        image2 = Image.open(second_uploaded_file)

        image1_arr = np.array(image1)
        image2_arr = np.array(image2)

        print(image1_arr.shape)
        print(image2_arr.shape)

        show_file = st.empty()
        show_file1 = st.empty()
        show_file2 = st.empty()
        show_file3 = st.empty()
        show_file4 = st.empty()
        show_file5 = st.empty()

        if not uploaded_file:
            show_file.info('Please upload a file')
            return
        show_file.title('Input Images')
        show_file1.image(uploaded_file, width=100)
        show_file2.title('+')
        show_file3.image(second_uploaded_file, width=100)
        show_file4.title('=')
        show_file5.image(image1, width=300)
        # Read images to opencv
        # src_img = cv2.imencode('jpg', image1)
        # dst_img = cv2.imencode('jpg',image2)
    except:
        show_file = st.empty()
        show_file.info('Please upload a file')

    os.makedirs("images/outputs", exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    transform = style_transform()

    # Define model and load model checkpoint
    transformer = TransformerNet().to(device)
    transformer.load_state_dict(
        torch.load(
            '/home/nick/Downloads/faceswap_app/models/starry_night_10000.pth'))
    transformer.eval()

    # Prepare input
    image_tensor = Variable(transform(Image.open(uploaded_file))).to(device)
    image_tensor = image_tensor.unsqueeze(0)

    # Stylize image
    with torch.no_grad():
        stylized_image = denormalize(transformer(image_tensor)).cpu()

    # Save image
    fn = args.image_path.split("/")[-1]
    save_image(stylized_image, f"images/outputs/stylized-{fn}")
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    data_train = load_data(args)
    iterator = data_train

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(weights=args.vgg16, requires_grad=False).to(device)
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])

    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        count = 0
        if args.noise_count:
            noiseimg_n = np.zeros((3, args.image_size, args.image_size),
                                  dtype=np.float32)
            # Preparing noise image.
            for n_c in range(args.noise_count):
                x_n = random.randrange(args.image_size)
                y_n = random.randrange(args.image_size)
                noiseimg_n[0][x_n][y_n] += random.randrange(
                    -args.noise, args.noise)
                noiseimg_n[1][x_n][y_n] += random.randrange(
                    -args.noise, args.noise)
                noiseimg_n[2][x_n][y_n] += random.randrange(
                    -args.noise, args.noise)
                noiseimg = torch.from_numpy(noiseimg_n)
                noiseimg = noiseimg.to(device)
        for batch_id, sample in enumerate(iterator):
            x = sample['image']
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            if args.noise_count:
                # Adding the noise image to the source image.
                noisy_x = x + noiseimg
                noisy_y = transformer(noisy_x)
                noisy_y = utils.normalize_batch(noisy_y)

            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            L_feat = args.lambda_feat * mse_loss(features_y.relu2_2,
                                                 features_x.relu2_2)

            L_style = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                L_style += mse_loss(gm_y, gm_s[:n_batch, :, :])
            L_style *= args.lambda_style

            L_tv = (torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) +
                    torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :])))

            L_tv *= args.lambda_tv

            if args.noise_count:
                L_pop = args.lambda_noise * F.mse_loss(y, noisy_y)
                L = L_feat + L_style + L_tv + L_pop
                print(
                    'Epoch {},{}/{}. Total loss: {}. Loss distribution: feat {}, style {}, tv {}, pop {}'
                    .format(e, batch_id, len(data_train), L.data,
                            L_feat.data / L.data, L_style.data / L.data,
                            L_tv.data / L.data, L_pop.data / L.data))
            else:
                L = L_feat + L_style + L_tv
                print(
                    'Epoch {},{}/{}. Total loss: {}. Loss distribution: feat {}, style {}, tv {}'
                    .format(e, batch_id, len(data_train), L.data,
                            L_feat.data / L.data, L_style.data / L.data,
                            L_tv.data / L.data))
            L = L_style * 1e10 + L_feat * 1e5
            L.backward()
            optimizer.step()

    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)