def get_basis(self): myNumStates = mpi.distribute_sampling(self.lDim**self.N) myFirstState = mpi.first_sample_id() deviceCount = global_defs.device_count() if not global_defs.usePmap: deviceCount = 1 self.numStatesPerDevice = [ (myNumStates + deviceCount - 1) // deviceCount ] * deviceCount self.numStatesPerDevice[ -1] += myNumStates - deviceCount * self.numStatesPerDevice[0] self.numStatesPerDevice = jnp.array(self.numStatesPerDevice) totalNumStates = deviceCount * self.numStatesPerDevice[0] if not global_defs.usePmap: self.numStatesPerDevice = self.numStatesPerDevice[0] intReps = jnp.arange(myFirstState, myFirstState + totalNumStates) if global_defs.usePmap: intReps = intReps.reshape((global_defs.device_count(), -1)) self.basis = jnp.zeros(intReps.shape + (self.N, ), dtype=np.int32) if self.lDim == 2: self.basis = self._get_basis_ldim2_pmapd(self.basis, intReps) else: self.basis = self._get_basis_pmapd(self.basis, intReps, self.lDim) return self.basis
def _get_samples_gen(self, net, numSamples): def dc(): if global_defs.usePmap: return global_defs.device_count() return 1 numSamples = mpi.distribute_sampling(numSamples, localDevices=dc()) numSamplesStr = str(numSamples) # check whether _get_samples is already compiled for given number of samples if not numSamplesStr in self._get_samples_gen_jitd: if global_defs.usePmap: self._get_samples_gen_jitd[ numSamplesStr] = global_defs.pmap_for_my_devices( lambda x, y, z: x.sample(y, z), static_broadcasted_argnums=(1, ), in_axes=(None, None, 0)) else: self._get_samples_gen_jitd[ numSamplesStr] = global_defs.jit_for_my_device( lambda x, y, z: x.sample(y, z), static_argnums=(1, )) tmpKey = None if global_defs.usePmap: tmpKey = random.split(self.key[0], 2 * global_defs.device_count()) self.key = tmpKey[:global_defs.device_count()] tmpKey = tmpKey[global_defs.device_count():] else: tmpKey, self.key = random.split(self.key) return self._get_samples_gen_jitd[numSamplesStr](net.get_sampler_net(), numSamples, tmpKey)
def test_mean(self): data = jnp.array( np.arange(720 * 4 * global_defs.device_count()).reshape( (global_defs.device_count() * 720, 4))) myNumSamples = mpi.distribute_sampling(global_defs.device_count() * 720) myData = data[mpi.rank * myNumSamples:(mpi.rank + 1) * myNumSamples].reshape(get_shape((-1, 4))) self.assertTrue( jnp.sum(mpi.global_mean(myData) - jnp.mean(data, axis=0)) < 1e-10)
def __init__(self, key, updateProposer, sampleShape, numChains=1, updateProposerArg=None, numSamples=100, thermalizationSweeps=10, sweepSteps=10): """Initializes the MCMCSampler class. Args: * ``key``: An instance of ``jax.random.PRNGKey``. * ``updateProposer``: A function to propose updates for the MCMC algorithm. \ It is called as ``updateProposer(key, config, **kwargs)``, where ``key`` is an instance of \ ``jax.random.PRNGKey``, ``config`` is a computational basis configuration, and ``**kwargs`` \ are optional additional arguments. The function is supposed to return a computational basis \ state that is used as update proposal in the MCMC algorithm. * ``sampleShape``: Shape of computational basis configurations. * ``numChains``: Number of Markov chains, which are run in parallel. * ``updateProposerArg``: An optional argument that will be passed to the ``updateProposer`` \ as ``kwargs``. * ``numSamples``: Default number of samples to be returned by the ``sample()`` member function. * ``thermalizationSweeps``: Number of sweeps to perform for thermalization of the Markov chain. * ``sweepSteps``: Number of proposed updates per sweep. """ stateShape = (numChains, ) + sampleShape if global_defs.usePmap: stateShape = (global_defs.device_count(), ) + stateShape self.states = jnp.zeros(stateShape, dtype=np.int32) self.updateProposer = updateProposer self.updateProposerArg = updateProposerArg self.key = key if global_defs.usePmap: self.key = jax.random.split(self.key, global_defs.device_count()) self.thermalizationSweeps = thermalizationSweeps self.sweepSteps = sweepSteps self.numSamples = numSamples self.numChains = numChains # jit'd member functions self._get_samples_jitd = { } # will hold a jit'd function for each number of samples self._get_samples_gen_jitd = { } # will hold a jit'd function for each number of samples
def _mc_init(self, net): # Initialize logPsiSq self.logPsiSq = 2. * net.real_coefficients(self.states) shape = (1, ) if global_defs.usePmap: shape = (global_defs.device_count(), ) + shape self.numProposed = jnp.zeros(shape, dtype=np.int64) self.numAccepted = jnp.zeros(shape, dtype=np.int64)
def get_shape(shape): if global_defs.usePmap: return (global_defs.device_count(),) + shape return shape
def dc(): if global_defs.usePmap: return global_defs.device_count() return 1