def test_meta_gradient_with_langevin(): num_samples = 4 num_langevin_steps = 3 D = 2 init_mean = npr.randn(D) * 0.01 init_log_stddevs = np.log(1*np.ones(D)) + npr.randn(D) * 0.01 init_log_stepsizes = np.log(0.01*np.ones(num_langevin_steps)) + npr.randn(num_langevin_steps) * 0.01 init_log_noise_sizes = np.log(.001*np.ones(num_langevin_steps)) + npr.randn(num_langevin_steps) * 0.01 init_log_gradient_scales = np.log(1*np.ones(D)) init_gradient_power = 0.9 sample_and_run_langevin, parser = build_langevin_sampler(logprob_two_moons, D, num_langevin_steps, approx=False) sampler_params = np.zeros(len(parser)) parser.put(sampler_params, 'mean', init_mean) parser.put(sampler_params, 'log_stddev', init_log_stddevs) parser.put(sampler_params, 'log_stepsizes', init_log_stepsizes) parser.put(sampler_params, 'log_noise_sizes', init_log_noise_sizes) parser.put(sampler_params, 'log_gradient_scales', init_log_gradient_scales) parser.put(sampler_params, 'invsig_gradient_power', inv_sigmoid(init_gradient_power)) def get_batch_marginal_likelihood_estimate(sampler_params): rs = np.random.npr.RandomState(0) samples, loglik_estimates, entropy_estimates = sample_and_run_langevin(sampler_params, rs, num_samples) marginal_likelihood_estimates = loglik_estimates + entropy_estimates return np.mean(marginal_likelihood_estimates) check_grads(get_batch_marginal_likelihood_estimate, sampler_params)
init_mean = np.zeros(D) init_stddevs = np.log(init_init_stddev_scale * np.ones((1,D))) init_log_stepsizes = np.log(init_langevin_stepsize * np.ones(num_steps)) init_log_noise_sizes = np.log(init_langevin_noise_size * np.ones(num_steps)) init_log_gradient_scales = np.log(np.ones((1,D))) logprob_mvn = build_logprob_mvn(mean=np.array([0.2,0.4]), cov=np.array([[1.0,0.9], [0.9,1.0]])) sample, parser = build_langevin_sampler(logprob_two_moons, D, num_steps, approx=False) sampler_params = np.zeros(len(parser)) parser.put(sampler_params, 'mean', init_mean) parser.put(sampler_params, 'log_stddev', init_stddevs) parser.put(sampler_params, 'log_stepsizes', init_log_stepsizes) parser.put(sampler_params, 'log_noise_sizes', init_log_noise_sizes) parser.put(sampler_params, 'log_gradient_scales', init_log_gradient_scales) parser.put(sampler_params, 'invsig_gradient_power', inv_sigmoid(init_gradient_power)) def get_batch_marginal_likelihood_estimate(sampler_params): samples, likelihood_estimates, entropy_estimates = sample(sampler_params, rs, num_samples) print "Mean loglik:", np.mean(likelihood_estimates.value),\ "Mean entropy:", np.mean(entropy_estimates.value) plot_density(samples.value, "approximating_dist.png") return np.mean(likelihood_estimates + entropy_estimates) ml_and_grad = value_and_grad(get_batch_marginal_likelihood_estimate) # Optimize Langevin parameters. for i in xrange(num_sampler_optimization_steps): ml, dml = ml_and_grad(sampler_params) print "Iter:", i, "log marginal likelihood:", ml, "avg gradient size: ", np.mean(np.abs(dml)) print "Gradient power:", sigmoid(parser.get(sampler_params, 'invsig_gradient_power'))
np.ones(num_langevin_steps)) init_log_noise_sizes = np.log(init_langevin_noise_size * np.ones(num_langevin_steps)) init_log_gradient_scales = np.log(np.ones((1, D))) sample_and_run_langevin, parser = build_langevin_sampler( prior_func, D, num_langevin_steps, approx=True) sampler_params = np.zeros(len(parser)) parser.put(sampler_params, 'mean', init_mean) parser.put(sampler_params, 'log_stddev', init_stddevs) parser.put(sampler_params, 'log_stepsizes', init_log_stepsizes) parser.put(sampler_params, 'log_noise_sizes', init_log_noise_sizes) parser.put(sampler_params, 'log_gradient_scales', init_log_gradient_scales) parser.put(sampler_params, 'invsig_gradient_power', inv_sigmoid(init_gradient_power)) rs = np.random.npr.RandomState(0) def batch_marginal_likelihood_estimate(sampler_params): samples, likelihood_estimates, entropy_estimates = sample_and_run_langevin( sampler_params, rs, num_samples) print "Mean loglik:", np.mean( likelihood_estimates.value), "Mean entropy:", np.mean( entropy_estimates.value) fig = plt.figure(1) fig.clf() ax = fig.add_subplot(111) plot_images(samples.value, ax, ims_per_row=images_per_row) plt.savefig('samples.png')