toy_init_particles = sample_gmm(num_samples,
                                    [1.],
                                    [np.array([-10])],
                                    [np.eye(1)])

    true_samples = sample_gmm(num_samples, weights, mus, sigmas)

    plotfunc = get_plotfunc(true_samples)

    toy_score = gmm_gld(weights, mus, sigmas)

    svgd = SVGD(gld=toy_score)

    particles= svgd.do_svgd_iterations_optimized(init_particles=toy_init_particles,
                                         num_iterations=num_iterations,
                                                  learning_rate=0.1,
                                                  plotfunc=plotfunc)


    x = np.linspace(-15, 10, 1000)
    fig, ax = plt.subplots(1,1)

    sns.kdeplot(true_samples.flatten(), ax=ax, label='Sampler')
    sns.kdeplot(particles.flatten(), ax=ax, label='SVGD')

    ax.legend()

    #plt.show()
    plt.close()
if __name__ == '__main__':
    num_iterations = 5000
    num_samples = 100

    weights = [1]
    mus = [np.array([0, 0])]
    sigmas = [np.eye(2)]

    init_particles = sample_gmm(num_samples, [1.], [np.array([-7, -7])],
                                [np.eye(2)])

    true_samples = sample_gmm(num_samples, weights, mus, sigmas)

    plotfunc = get_plotfunc(true_samples)

    gld = gmm_gld(weights, mus, sigmas)

    svgd = SVGD(gld=gld)

    log_lik_func = log_lik_gmm(weights, mus, sigmas)

    report_metrics = metrics(log_lik_func)

    particles = svgd.do_svgd_iterations_optimized(
        init_particles=init_particles,
        num_iterations=num_iterations,
        learning_rate=1e-2,
        plotfunc=plotfunc,
        metrics=report_metrics)
        [corr2, 1]
        ])]

    init_particles = sample_gmm(5,
                                    [1.],
                                    [np.array([-10, 0])],
                                    [np.eye(2)])

    true_samples = sample_gmm(num_samples, weights, mus, sigmas)

    plotfunc = get_plotfunc(true_samples)

    gld = gmm_gld(weights, mus, sigmas)

    svgd = SVGD(gld=gld)

    log_lik_func = log_lik_gmm(weights, mus, sigmas)

    report_metrics = metrics(log_lik_func)

    particles = svgd.do_svgd_iterations_optimized(
        init_particles=init_particles,
        num_iterations=num_iterations,
        learning_rate=1e-3,
        plotfunc=plotfunc,
        metrics=report_metrics,
        progress_freq=100)