def generate_video(saved_model_path, video_category=None):
    """Uses the trained model to predict the frames and produce a video out of them"""
    # load model
    model = load_model(saved_model_path)

    which_one = video_category
    train_files, test_files = get_train_test_files(which=which_one)
    test_gen = get_data_gen(files=test_files,
                            timesteps=timesteps,
                            batch_size=batch_size,
                            im_size=(im_width, im_height))

    y_true = []
    y_pred = []

    for _ in range(200):
        x, y = next(test_gen)
        y_true.extend(y)

        predictions = model.predict_on_batch(x)
        y_pred.extend(predictions)

    clip1 = ImageSequenceClip([denormalize(i) for i in y_true], fps=5)
    clip2 = ImageSequenceClip([denormalize(i) for i in y_pred], fps=5)
    clip2 = clip2.set_position((clip1.w, 0))
    video = CompositeVideoClip((clip1, clip2), size=(clip1.w * 2, clip1.h))
    video.write_videofile(
        "{}.mp4".format(which_one if which_one else "render"), fps=5)
def plot_different_models(timesteps=[5, 10]):
    """
    Compares ssim/psnr of different models. The models for each of the supplied timestap
    must be present
    param timesteps A list of numbers indicating the timesteps that were used for training different models
    """
    from skimage.measure import compare_psnr, compare_ssim
    psnrs = {}
    ssims = {}
    for ts in timesteps:
        model_name = "r_p2p_gen_t{}.model".format(ts)
        model = load_model(model_name)
        train_files, test_files = get_train_test_files()
        test_gen = get_data_gen(files=train_files,
                                timesteps=ts,
                                batch_size=batch_size,
                                im_size=(im_width, im_height))

        y_true = []
        y_pred = []

        for _ in range(200):
            x, y = next(test_gen)
            y_true.extend(y)

            predictions = model.predict_on_batch(x)
            y_pred.extend(predictions)
        psnrs[ts] = [
            compare_psnr(denormalize(yt), denormalize(p))
            for yt, p in zip((y_true), (y_pred))
        ]
        ssims[ts] = [
            compare_ssim(denormalize(yt), denormalize(p), multichannel=True)
            for yt, p in zip((y_true), (y_pred))
        ]

    plt.boxplot([psnrs[ts] for ts in timesteps], labels=timesteps)
    plt.savefig("jigsaws_psnrs_all.png")

    plt.figure()
    plt.boxplot([ssims[ts] for ts in timesteps], labels=timesteps)
    plt.savefig("jigsaws_ssims_all.png")
Beispiel #3
0
def plot_metrics(y_true, y_pred):
    from skimage.measure import compare_psnr, compare_ssim
    psnrs = [
        compare_psnr(yt, p)
        for yt, p in zip(denormalize(y_true), denormalize(y_pred))
    ]
    ssims = [
        compare_ssim(yt, p, multichannel=True)
        for yt, p in zip(denormalize(y_true), denormalize(y_pred))
    ]
    plt.figure(figsize=(5, 4))
    plt.boxplot(psnrs, 0, 'gD')
    plt.savefig("./jigsaws_psnrs_boxplot.png")

    plt.figure(figsize=(5, 4))
    plt.boxplot(ssims, 0, 'rD')
    plt.savefig("./jigsaws_ssims_boxplot.png")

    print("Mean PSNR = ", np.mean(np.array(psnrs)))
    print("Mean SSIM = ", np.mean(np.array(ssims)))
Beispiel #4
0
def main(
    model_dir: str,
    vc_src: str,
    vc_tgt: str,
    adv_tgt: str,
    output: str,
    eps: float,
    n_iters: int,
    attack_type: str,
):
    assert attack_type == "emb" or vc_src is not None
    model, config, attr, device = load_model(model_dir)

    vc_tgt = file2mel(vc_tgt, **config["preprocess"])
    adv_tgt = file2mel(adv_tgt, **config["preprocess"])

    vc_tgt = normalize(vc_tgt, attr)
    adv_tgt = normalize(adv_tgt, attr)

    vc_tgt = torch.from_numpy(vc_tgt).T.unsqueeze(0).to(device)
    adv_tgt = torch.from_numpy(adv_tgt).T.unsqueeze(0).to(device)

    if attack_type != "emb":
        vc_src = file2mel(vc_src, **config["preprocess"])
        vc_src = normalize(vc_src, attr)
        vc_src = torch.from_numpy(vc_src).T.unsqueeze(0).to(device)

    if attack_type == "e2e":
        adv_inp = e2e_attack(model, vc_src, vc_tgt, adv_tgt, eps, n_iters)
    elif attack_type == "emb":
        adv_inp = emb_attack(model, vc_tgt, adv_tgt, eps, n_iters)
    elif attack_type == "fb":
        adv_inp = fb_attack(model, vc_src, vc_tgt, adv_tgt, eps, n_iters)
    else:
        raise NotImplementedError()

    adv_inp = adv_inp.squeeze(0).T
    adv_inp = denormalize(adv_inp.data.cpu().numpy(), attr)
    adv_inp = mel2wav(adv_inp, **config["preprocess"])

    sf.write(output, adv_inp, config["preprocess"]["sample_rate"])
Beispiel #5
0
 def forward(self, x):
     xn = normalize(x)
     x1 = self.block1(xn)
     x2 = self.block2(xn)
     x3 = x1 + x2
     return denormalize(x3)
Beispiel #6
0
 def forward(self, x):
     x1 = self.block1(normalize(x))
     re = self.residual(x1) 
     x2 = self.block2(re)
     x3 = self.block3(x1 + x2)
     return denormalize(x3)