def _preprocess_init_state(self, init_state): """Make sure initial state is a ChainState and has momentum.""" if isinstance(init_state, np.ndarray): # If array use to set position component of new ChainState init_state = ChainState(pos=init_state, mom=None, dir=1) elif not isinstance(init_state, ChainState) or 'mom' not in init_state: raise TypeError( 'init_state should be an array or `ChainState` with ' '`mom` attribute.') if init_state.mom is None: init_state.mom = self.system.sample_momentum(init_state, self.rng) return init_state
def __init__(self): rng = np.random.RandomState(SEED) integrators_and_state_lists = [] for size in [s for s in SIZES if s > 1]: for metric in _generate_metrics(rng, size): for projection_solver in [ solvers.solve_projection_onto_manifold_quasi_newton, solvers.solve_projection_onto_manifold_newton ]: system = systems.DenseConstrainedEuclideanMetricSystem( neg_log_dens=lambda q: 0.125 * np.sum(q**4), metric=metric, grad_neg_log_dens=lambda q: 0.5 * q**3, constr=lambda q: q[0:1]**2 + q[1:2]**2 - 1., jacob_constr=lambda q: np.concatenate( [2 * q[0:1], 2 * q[1:2], np.zeros(q.shape[0] - 2)])[None]) integrator = integrators.ConstrainedLeapfrogIntegrator( system, 0.1, projection_solver=projection_solver) state_list = [ ChainState(pos=np.concatenate( [np.array([np.cos(theta), np.sin(theta)]), q]), mom=None, dir=1) for theta, q in zip( rng.uniform(size=N_STATE) * 2 * np.pi, rng.standard_normal((N_STATE, size - 2))) ] for state in state_list: state.mom = system.sample_momentum(state, rng) integrators_and_state_lists.append( (integrator, state_list)) super().__init__(integrators_and_state_lists, h_diff_tol=1e-2)
def init_state_list(self, rng, size_more_than_one, metric): return [ ChainState( pos=np.concatenate([np.array([0.0]), q]), mom=metric @ np.concatenate([np.array([0.0]), p]), dir=1, ) for q, p in rng.standard_normal((N_STATE, 2, size_more_than_one - 1)) ]
def init_state_list(self, rng, size_more_than_one, system): init_state_list = [ ChainState( pos=np.concatenate( [np.array([np.cos(theta), np.sin(theta)]), q]), mom=None, dir=1, ) for theta, q in zip( rng.uniform(size=N_STATE) * 2 * np.pi, rng.standard_normal((N_STATE, size_more_than_one - 2)), ) ] for state in init_state_list: state.mom = system.sample_momentum(state, rng) return init_state_list
def __init__(self): rng = np.random.RandomState(SEED) integrators_and_state_lists = [] for size in SIZES: system = systems.DenseRiemannianMetricSystem( lambda q: 0.5 * np.sum(q**2), grad_neg_log_dens=lambda q: q, metric_func=lambda q: np.identity(q.shape[0]), vjp_metric_func=lambda q: lambda m: np.zeros_like(q)) integrator = integrators.ImplicitLeapfrogIntegrator(system, 0.5) state_list = [ ChainState(pos=q, mom=p, dir=1) for q, p in rng.standard_normal((N_STATE, 2, size)) ] integrators_and_state_lists.append((integrator, state_list)) super().__init__(integrators_and_state_lists, h_diff_tol=5e-3)
def __init__(self): rng = np.random.RandomState(SEED) integrators_and_state_lists = [] for size in SIZES: for metric in _generate_metrics(rng, size): system = systems.GaussianEuclideanMetricSystem( neg_log_dens=lambda q: 0.125 * np.sum(q**4), metric=metric, grad_neg_log_dens=lambda q: 0.5 * q**3) integrator = integrators.LeapfrogIntegrator(system, 0.1) state_list = [ ChainState(pos=q, mom=p, dir=1) for q, p in rng.standard_normal((N_STATE, 2, size)) ] integrators_and_state_lists.append((integrator, state_list)) super().__init__(integrators_and_state_lists, h_diff_tol=1e-2)
def __init__(self): rng = np.random.RandomState(SEED) integrators_and_state_lists = [] for size in [s for s in SIZES if s > 1]: system = systems.GaussianDenseConstrainedEuclideanMetricSystem( lambda q: 0., grad_neg_log_dens=lambda q: 0. * q, constr=lambda q: q[:1], jacob_constr=lambda q: np.identity(q.shape[0])[:1]) integrator = integrators.ConstrainedLeapfrogIntegrator(system, 0.5) state_list = [ ChainState(pos=np.concatenate([np.array([0.]), q]), mom=np.concatenate([np.array([0.]), p]), dir=1) for q, p in rng.standard_normal((N_STATE, 2, size - 1)) ] integrators_and_state_lists.append((integrator, state_list)) super().__init__(integrators_and_state_lists, h_diff_tol=1e-10)
def __init__(self): rng = np.random.RandomState(SEED) integrators_and_state_lists = [] for size in [s for s in SIZES if s > 1]: for metric in _generate_metrics(rng, size): system = systems.DenseConstrainedEuclideanMetricSystem( neg_log_dens=lambda q: 0.5 * np.sum(q**2), metric=metric, grad_neg_log_dens=lambda q: q, constr=lambda q: q[:1], jacob_constr=lambda q: np.eye(1, q.shape[0], 0)) integrator = integrators.ConstrainedLeapfrogIntegrator( system, 0.1) state_list = [ ChainState( pos=np.concatenate([np.array([0.]), q]), mom=metric @ np.concatenate([np.array([0.]), p]), dir=1) for q, p in rng.standard_normal((N_STATE, 2, size - 1)) ] integrators_and_state_lists.append((integrator, state_list)) super().__init__(integrators_and_state_lists, h_diff_tol=1e-2)
def init_state_list(rng, size): return [ ChainState(pos=q, mom=p, dir=1) for q, p in rng.standard_normal((N_STATE, 2, size)) ]
def _sample_chain(self, rng, n_sample, init_state, trace_funcs, chain_index, parallel_chains, memmap_enabled=False, memmap_path=None, monitor_stats=None): for trans_key, transition in self.transitions.items(): for var_key in transition.state_variables: if var_key not in init_state: raise ValueError( f'init_state does contain have {var_key} value ' f'required by {trans_key} transition.') if not isinstance(init_state, (ChainState, dict)): raise TypeError( 'init_state should be a dictionary or `ChainState`.') state = (ChainState( **init_state) if isinstance(init_state, dict) else init_state) chain_stats = self._init_chain_stats(n_sample, memmap_enabled, memmap_path, chain_index) # Initialise chain trace arrays traces = {} for trace_func in trace_funcs: for key, val in trace_func(state).items(): val = np.array(val) if np.isscalar(val) else val init = np.nan if np.issubdtype(val.dtype, np.inexact) else 0 if memmap_enabled: filename = self._generate_memmap_filename( memmap_path, 'trace', key, chain_index) traces[key] = self._open_new_memmap( filename, (n_sample, ) + val.shape, val.dtype, init) else: traces[key] = np.full((n_sample, ) + val.shape, init, val.dtype) total_return_nbytes = get_size(chain_stats) + get_size(traces) # Check if running in parallel and if total number of bytes to be # returned exceeds pickle limit if parallel_chains and total_return_nbytes > 2**31 - 1: raise RuntimeError( f'Total number of bytes allocated for arrays to be returned ' f'({total_return_nbytes / 2**30:.2f} GiB) exceeds size limit ' f'for returning results of a process (2 GiB). Try rerunning ' f'with chain memory-mapping enabled (`memmap_enabled=True`).') if TQDM_AVAILABLE: kwargs = { 'desc': f'Chain {0 if chain_index is None else chain_index}', 'unit': 'it', 'dynamic_ncols': True, } if parallel_chains: sample_range = tqdm_auto.trange(n_sample, **kwargs, position=chain_index) else: sample_range = tqdm.trange(n_sample, **kwargs) else: sample_range = range(n_sample) try: for sample_index in sample_range: for trans_key, transition in self.transitions.items(): state, trans_stats = transition.sample(state, rng) if trans_stats is not None: if trans_key not in chain_stats: logger.warning( f'Transition {trans_key} returned statistics ' f'but has no `statistic_types` attribute.') for key, val in trans_stats.items(): if key in chain_stats[trans_key]: chain_stats[trans_key][key][sample_index] = val for trace_func in trace_funcs: for key, val in trace_func(state).items(): traces[key][sample_index] = val if TQDM_AVAILABLE and monitor_stats is not None: postfix_stats = {} for (trans_key, stats_key) in monitor_stats: if (trans_key not in chain_stats or stats_key not in chain_stats[trans_key]): logger.warning( f'Statistics key pair {(trans_key, stats_key)}' f' to be monitored is not valid.') print_key = f'mean({stats_key})' postfix_stats[print_key] = np.mean( chain_stats[trans_key][stats_key][:sample_index + 1]) sample_range.set_postfix(postfix_stats) except KeyboardInterrupt: if memmap_enabled: for trace in traces.values(): trace.flush() for trans_stats in chain_stats.values(): for stat in trans_stats.values(): stat.flush() else: # If not interrupted increment sample_index so that it equals # n_sample to flag chain completed sampling sample_index += 1 if parallel_chains and memmap_enabled: trace_filenames = self._memmaps_to_filenames(traces) stats_filenames = self._memmaps_to_filenames(chain_stats) return trace_filenames, stats_filenames, sample_index return state, traces, chain_stats, sample_index