def _expect(n_chains, log_pdf, expected_fun, pars, σ, *expected_fun_args): L_σ = expected_fun(pars, σ, *expected_fun_args) if n_chains is not None: L_σ = L_σ.reshape((n_chains, -1)) L̄_σ = mpi_statistics(L_σ.T) # L̄_σ = L_σ.mean(axis=0) return L̄_σ.mean, L̄_σ
def _expect_fwd(n_chains, log_pdf, expected_fun, pars, σ, *expected_fun_args): L_σ = expected_fun(pars, σ, *expected_fun_args) if n_chains is not None: L_σ_r = L_σ.reshape((n_chains, -1)) else: L_σ_r = L_σ L̄_stat = mpi_statistics(L_σ_r.T) L̄_σ = L̄_stat.mean # L̄_σ = L_σ.mean(axis=0) # Use the baseline trick to reduce the variance ΔL_σ = L_σ - L̄_σ return (L̄_σ, L̄_stat), (pars, σ, expected_fun_args, ΔL_σ)