Esempio n. 1
0
 def _preprocess_init_state(self, init_state):
     """Make sure initial state is a ChainState and has momentum."""
     if isinstance(init_state, np.ndarray):
         # If array use to set position component of new ChainState
         init_state = ChainState(pos=init_state, mom=None, dir=1)
     elif not isinstance(init_state, ChainState) or 'mom' not in init_state:
         raise TypeError(
             'init_state should be an array or `ChainState` with '
             '`mom` attribute.')
     if init_state.mom is None:
         init_state.mom = self.system.sample_momentum(init_state, self.rng)
     return init_state
Esempio n. 2
0
 def __init__(self):
     rng = np.random.RandomState(SEED)
     integrators_and_state_lists = []
     for size in [s for s in SIZES if s > 1]:
         for metric in _generate_metrics(rng, size):
             for projection_solver in [
                     solvers.solve_projection_onto_manifold_quasi_newton,
                     solvers.solve_projection_onto_manifold_newton
             ]:
                 system = systems.DenseConstrainedEuclideanMetricSystem(
                     neg_log_dens=lambda q: 0.125 * np.sum(q**4),
                     metric=metric,
                     grad_neg_log_dens=lambda q: 0.5 * q**3,
                     constr=lambda q: q[0:1]**2 + q[1:2]**2 - 1.,
                     jacob_constr=lambda q: np.concatenate(
                         [2 * q[0:1], 2 * q[1:2],
                          np.zeros(q.shape[0] - 2)])[None])
                 integrator = integrators.ConstrainedLeapfrogIntegrator(
                     system, 0.1, projection_solver=projection_solver)
                 state_list = [
                     ChainState(pos=np.concatenate(
                         [np.array([np.cos(theta),
                                    np.sin(theta)]), q]),
                                mom=None,
                                dir=1)
                     for theta, q in zip(
                         rng.uniform(size=N_STATE) * 2 *
                         np.pi, rng.standard_normal((N_STATE, size - 2)))
                 ]
                 for state in state_list:
                     state.mom = system.sample_momentum(state, rng)
                 integrators_and_state_lists.append(
                     (integrator, state_list))
     super().__init__(integrators_and_state_lists, h_diff_tol=1e-2)
Esempio n. 3
0
 def init_state_list(self, rng, size_more_than_one, metric):
     return [
         ChainState(
             pos=np.concatenate([np.array([0.0]), q]),
             mom=metric @ np.concatenate([np.array([0.0]), p]),
             dir=1,
         ) for q, p in rng.standard_normal((N_STATE, 2,
                                            size_more_than_one - 1))
     ]
Esempio n. 4
0
 def init_state_list(self, rng, size_more_than_one, system):
     init_state_list = [
         ChainState(
             pos=np.concatenate(
                 [np.array([np.cos(theta), np.sin(theta)]), q]),
             mom=None,
             dir=1,
         ) for theta, q in zip(
             rng.uniform(size=N_STATE) * 2 * np.pi,
             rng.standard_normal((N_STATE, size_more_than_one - 2)),
         )
     ]
     for state in init_state_list:
         state.mom = system.sample_momentum(state, rng)
     return init_state_list
Esempio n. 5
0
 def __init__(self):
     rng = np.random.RandomState(SEED)
     integrators_and_state_lists = []
     for size in SIZES:
         system = systems.DenseRiemannianMetricSystem(
             lambda q: 0.5 * np.sum(q**2),
             grad_neg_log_dens=lambda q: q,
             metric_func=lambda q: np.identity(q.shape[0]),
             vjp_metric_func=lambda q: lambda m: np.zeros_like(q))
         integrator = integrators.ImplicitLeapfrogIntegrator(system, 0.5)
         state_list = [
             ChainState(pos=q, mom=p, dir=1)
             for q, p in rng.standard_normal((N_STATE, 2, size))
         ]
         integrators_and_state_lists.append((integrator, state_list))
     super().__init__(integrators_and_state_lists, h_diff_tol=5e-3)
Esempio n. 6
0
 def __init__(self):
     rng = np.random.RandomState(SEED)
     integrators_and_state_lists = []
     for size in SIZES:
         for metric in _generate_metrics(rng, size):
             system = systems.GaussianEuclideanMetricSystem(
                 neg_log_dens=lambda q: 0.125 * np.sum(q**4),
                 metric=metric,
                 grad_neg_log_dens=lambda q: 0.5 * q**3)
             integrator = integrators.LeapfrogIntegrator(system, 0.1)
             state_list = [
                 ChainState(pos=q, mom=p, dir=1)
                 for q, p in rng.standard_normal((N_STATE, 2, size))
             ]
             integrators_and_state_lists.append((integrator, state_list))
     super().__init__(integrators_and_state_lists, h_diff_tol=1e-2)
Esempio n. 7
0
 def __init__(self):
     rng = np.random.RandomState(SEED)
     integrators_and_state_lists = []
     for size in [s for s in SIZES if s > 1]:
         system = systems.GaussianDenseConstrainedEuclideanMetricSystem(
             lambda q: 0.,
             grad_neg_log_dens=lambda q: 0. * q,
             constr=lambda q: q[:1],
             jacob_constr=lambda q: np.identity(q.shape[0])[:1])
         integrator = integrators.ConstrainedLeapfrogIntegrator(system, 0.5)
         state_list = [
             ChainState(pos=np.concatenate([np.array([0.]), q]),
                        mom=np.concatenate([np.array([0.]), p]),
                        dir=1)
             for q, p in rng.standard_normal((N_STATE, 2, size - 1))
         ]
         integrators_and_state_lists.append((integrator, state_list))
     super().__init__(integrators_and_state_lists, h_diff_tol=1e-10)
Esempio n. 8
0
 def __init__(self):
     rng = np.random.RandomState(SEED)
     integrators_and_state_lists = []
     for size in [s for s in SIZES if s > 1]:
         for metric in _generate_metrics(rng, size):
             system = systems.DenseConstrainedEuclideanMetricSystem(
                 neg_log_dens=lambda q: 0.5 * np.sum(q**2),
                 metric=metric,
                 grad_neg_log_dens=lambda q: q,
                 constr=lambda q: q[:1],
                 jacob_constr=lambda q: np.eye(1, q.shape[0], 0))
             integrator = integrators.ConstrainedLeapfrogIntegrator(
                 system, 0.1)
             state_list = [
                 ChainState(
                     pos=np.concatenate([np.array([0.]), q]),
                     mom=metric @ np.concatenate([np.array([0.]), p]),
                     dir=1)
                 for q, p in rng.standard_normal((N_STATE, 2, size - 1))
             ]
             integrators_and_state_lists.append((integrator, state_list))
     super().__init__(integrators_and_state_lists, h_diff_tol=1e-2)
Esempio n. 9
0
def init_state_list(rng, size):
    return [
        ChainState(pos=q, mom=p, dir=1)
        for q, p in rng.standard_normal((N_STATE, 2, size))
    ]
Esempio n. 10
0
 def _sample_chain(self,
                   rng,
                   n_sample,
                   init_state,
                   trace_funcs,
                   chain_index,
                   parallel_chains,
                   memmap_enabled=False,
                   memmap_path=None,
                   monitor_stats=None):
     for trans_key, transition in self.transitions.items():
         for var_key in transition.state_variables:
             if var_key not in init_state:
                 raise ValueError(
                     f'init_state does contain have {var_key} value '
                     f'required by {trans_key} transition.')
     if not isinstance(init_state, (ChainState, dict)):
         raise TypeError(
             'init_state should be a dictionary or `ChainState`.')
     state = (ChainState(
         **init_state) if isinstance(init_state, dict) else init_state)
     chain_stats = self._init_chain_stats(n_sample, memmap_enabled,
                                          memmap_path, chain_index)
     # Initialise chain trace arrays
     traces = {}
     for trace_func in trace_funcs:
         for key, val in trace_func(state).items():
             val = np.array(val) if np.isscalar(val) else val
             init = np.nan if np.issubdtype(val.dtype, np.inexact) else 0
             if memmap_enabled:
                 filename = self._generate_memmap_filename(
                     memmap_path, 'trace', key, chain_index)
                 traces[key] = self._open_new_memmap(
                     filename, (n_sample, ) + val.shape, val.dtype, init)
             else:
                 traces[key] = np.full((n_sample, ) + val.shape, init,
                                       val.dtype)
     total_return_nbytes = get_size(chain_stats) + get_size(traces)
     # Check if running in parallel and if total number of bytes to be
     # returned exceeds pickle limit
     if parallel_chains and total_return_nbytes > 2**31 - 1:
         raise RuntimeError(
             f'Total number of bytes allocated for arrays to be returned '
             f'({total_return_nbytes / 2**30:.2f} GiB) exceeds size limit '
             f'for returning results of a process (2 GiB). Try rerunning '
             f'with chain memory-mapping enabled (`memmap_enabled=True`).')
     if TQDM_AVAILABLE:
         kwargs = {
             'desc': f'Chain {0 if chain_index is None else chain_index}',
             'unit': 'it',
             'dynamic_ncols': True,
         }
         if parallel_chains:
             sample_range = tqdm_auto.trange(n_sample,
                                             **kwargs,
                                             position=chain_index)
         else:
             sample_range = tqdm.trange(n_sample, **kwargs)
     else:
         sample_range = range(n_sample)
     try:
         for sample_index in sample_range:
             for trans_key, transition in self.transitions.items():
                 state, trans_stats = transition.sample(state, rng)
                 if trans_stats is not None:
                     if trans_key not in chain_stats:
                         logger.warning(
                             f'Transition {trans_key} returned statistics '
                             f'but has no `statistic_types` attribute.')
                     for key, val in trans_stats.items():
                         if key in chain_stats[trans_key]:
                             chain_stats[trans_key][key][sample_index] = val
             for trace_func in trace_funcs:
                 for key, val in trace_func(state).items():
                     traces[key][sample_index] = val
             if TQDM_AVAILABLE and monitor_stats is not None:
                 postfix_stats = {}
                 for (trans_key, stats_key) in monitor_stats:
                     if (trans_key not in chain_stats
                             or stats_key not in chain_stats[trans_key]):
                         logger.warning(
                             f'Statistics key pair {(trans_key, stats_key)}'
                             f' to be monitored is not valid.')
                     print_key = f'mean({stats_key})'
                     postfix_stats[print_key] = np.mean(
                         chain_stats[trans_key][stats_key][:sample_index +
                                                           1])
                 sample_range.set_postfix(postfix_stats)
     except KeyboardInterrupt:
         if memmap_enabled:
             for trace in traces.values():
                 trace.flush()
             for trans_stats in chain_stats.values():
                 for stat in trans_stats.values():
                     stat.flush()
     else:
         # If not interrupted increment sample_index so that it equals
         # n_sample to flag chain completed sampling
         sample_index += 1
     if parallel_chains and memmap_enabled:
         trace_filenames = self._memmaps_to_filenames(traces)
         stats_filenames = self._memmaps_to_filenames(chain_stats)
         return trace_filenames, stats_filenames, sample_index
     return state, traces, chain_stats, sample_index