Ejemplo n.º 1
0
 def integrator(self, metric):
     system = systems.GaussianEuclideanMetricSystem(
         neg_log_dens=lambda q: 0.125 * np.sum(q**4),
         metric=metric,
         grad_neg_log_dens=lambda q: 0.5 * q**3,
     )
     return integrators.LeapfrogIntegrator(system, 0.1)
Ejemplo n.º 2
0
 def integrator(self, metric):
     system = systems.EuclideanMetricSystem(
         neg_log_dens=lambda q: 0.5 * np.sum(q**2),
         metric=metric,
         grad_neg_log_dens=lambda q: q,
     )
     return integrators.LeapfrogIntegrator(system, 0.25)
Ejemplo n.º 3
0
 def __init__(self):
     super().__init__()
     system = systems.GaussianEuclideanMetricSystem(
         lambda q: 0, grad_neg_log_dens=lambda q: 0 * q)
     self.integrator = integrators.LeapfrogIntegrator(system, 0.5)
     self.states = {
         size: [
             states.ChainState(pos=q, mom=p, dir=1)
             for q, p in self.rng.standard_normal((N_STATE, 2, size))
         ]
         for size in SIZES
     }
     self.h_diff_tol = 1e-10
Ejemplo n.º 4
0
 def __init__(self):
     rng = np.random.RandomState(SEED)
     integrators_and_state_lists = []
     for size in SIZES:
         for metric in _generate_metrics(rng, size):
             system = systems.GaussianEuclideanMetricSystem(
                 neg_log_dens=lambda q: 0.125 * np.sum(q**4),
                 metric=metric,
                 grad_neg_log_dens=lambda q: 0.5 * q**3)
             integrator = integrators.LeapfrogIntegrator(system, 0.1)
             state_list = [
                 ChainState(pos=q, mom=p, dir=1)
                 for q, p in rng.standard_normal((N_STATE, 2, size))
             ]
             integrators_and_state_lists.append((integrator, state_list))
     super().__init__(integrators_and_state_lists, h_diff_tol=1e-2)
Ejemplo n.º 5
0
 def integrator(self, metric):
     system = systems.GaussianEuclideanMetricSystem(
         neg_log_dens=lambda q: 0,
         metric=metric,
         grad_neg_log_dens=lambda q: 0 * q)
     return integrators.LeapfrogIntegrator(system, 0.5)