示例#1
0
 def regularize_weights(self):
     """Either performs Pareto-smoothing of the IW, or applies clipping."""
     if self.pareto:
         psiw = az.psislw(self.log_sinf_weights)
         self.log_sinf_weights = psiw[0]
         self.sinf_weights = np.exp(self.log_sinf_weights)
     elif not self.pareto:
         self.log_sinf_weights = np.clip(
             self.log_sinf_weights,
             a_min=None,
             a_max=logsumexp(self.log_sinf_weights) +
             (self.k_trunc - 1) * np.log(len(self.log_sinf_weights)))
         self.log_sinf_weights = self.log_sinf_weights - logsumexp(
             self.log_sinf_weights)
         self.sinf_weights = np.exp(self.log_sinf_weights)
示例#2
0
"""
LOO-PIT ECDF Plot
=================

_thumb: .5, .7
"""
import matplotlib.pyplot as plt
import arviz as az

az.style.use("arviz-darkgrid")

idata = az.load_arviz_data("radon")
log_like = idata.sample_stats.log_likelihood.sel(chain=0).values.T
log_weights = az.psislw(-log_like)[0]

az.plot_loo_pit(idata,
                y="y_like",
                log_weights=log_weights,
                ecdf=True,
                color="maroon")

plt.show()
示例#3
0
                size=num_proposal_samples)
            samples_vb_swa = np.random.multivariate_normal(
                means_vb_swa,
                np.diag(sigmas_vb_swa),
                size=num_proposal_samples)
            samples_vb_rms = np.random.multivariate_normal(
                means_vb_rms,
                np.diag(sigmas_vb_rms),
                size=num_proposal_samples)

            q_swa = stats.norm.logpdf(samples_vb_swa, means_vb_swa,
                                      sigmas_vb_swa)
            logp_swa = np.array([fit_hmc.log_prob(s) for s in samples_vb_swa])
            #logp_swa = np.array([fit_hmc.])
            log_iw_swa = logp_swa - np.sum(q_swa, axis=1)
            psis_lw_swa, K_hat_swa = psislw(log_iw_swa.T)
            print('K hat statistic for SWA')
            print(K_hat_swa)

            # VB-CLR
            q_clr = stats.norm.logpdf(samples_vb_clr2, means_vb_clr2,
                                      sigmas_vb_clr2)
            logp_clr = np.array([fit_hmc.log_prob(s) for s in samples_vb_clr2])
            log_iw_clr = logp_clr - np.sum(q_clr, axis=1)
            psis_lw_clr, K_hat_clr = psislw(log_iw_clr.T)
            print('K hat statistic for CLR')
            print(K_hat_clr)

            q_rms = stats.norm.pdf(samples_vb_rms, means_vb_rms, sigmas_vb_rms)
            logp_rms = np.array([fit_hmc.log_prob(s) for s in samples_vb_rms])
            log_iw_rms = logp_rms - np.sum(np.log(q_rms), axis=1)
            stan_vb_mean = np.mean(stan_vb_w, axis=0)
            stan_vb_cov = np.cov(stan_vb_w[:, 0], stan_vb_w[:, 1])

            params_vb_means = np.mean(stan_vb_w, axis=0)
            params_vb_std = np.std(stan_vb_w, axis=0)
            params_vb_sq = np.mean(stan_vb_w**2, axis=0)

            logq = stats.norm.pdf(stan_vb_w, params_vb_means, params_vb_std)
            logq_sum = np.sum(np.log(logq), axis=1)
            # log_joint_density = la['log_joint_density']
            stan_vb_log_joint_density = fit_vb_samples[:, K]
            log_iw = stan_vb_log_joint_density - logq_sum
            print(np.max(log_iw))
            print(log_iw.shape)

            psis_lw, K_hat_stan = psislw(log_iw.T)
            K_hat_stan_advi_list[j, n] = K_hat_stan
            print(psis_lw.shape)
            print('K hat statistic for Stan ADVI:')
            print(K_hat_stan)

    ###################### Plotting L2 norm here #################################

plt.figure()
plt.plot(stan_vb_w[:, 0], stan_vb_w[:, 1], 'mo', label='STAN-ADVI')
plt.savefig('vb_w_samples_mf.pdf')

np.save('K_hat_linear_' + datatype + '_' + algo_name + '_' + str(N) + 'N',
        K_hat_stan_advi_list)

plt.figure()