Пример #1
0
    def get_basis(self):

        myNumStates = mpi.distribute_sampling(self.lDim**self.N)
        myFirstState = mpi.first_sample_id()

        deviceCount = global_defs.device_count()
        if not global_defs.usePmap:
            deviceCount = 1

        self.numStatesPerDevice = [
            (myNumStates + deviceCount - 1) // deviceCount
        ] * deviceCount
        self.numStatesPerDevice[
            -1] += myNumStates - deviceCount * self.numStatesPerDevice[0]
        self.numStatesPerDevice = jnp.array(self.numStatesPerDevice)

        totalNumStates = deviceCount * self.numStatesPerDevice[0]

        if not global_defs.usePmap:
            self.numStatesPerDevice = self.numStatesPerDevice[0]

        intReps = jnp.arange(myFirstState, myFirstState + totalNumStates)
        if global_defs.usePmap:
            intReps = intReps.reshape((global_defs.device_count(), -1))
        self.basis = jnp.zeros(intReps.shape + (self.N, ), dtype=np.int32)
        if self.lDim == 2:
            self.basis = self._get_basis_ldim2_pmapd(self.basis, intReps)
        else:
            self.basis = self._get_basis_pmapd(self.basis, intReps, self.lDim)

        return self.basis
Пример #2
0
    def _get_samples_gen(self, net, numSamples):
        def dc():
            if global_defs.usePmap:
                return global_defs.device_count()
            return 1

        numSamples = mpi.distribute_sampling(numSamples, localDevices=dc())
        numSamplesStr = str(numSamples)

        # check whether _get_samples is already compiled for given number of samples
        if not numSamplesStr in self._get_samples_gen_jitd:
            if global_defs.usePmap:
                self._get_samples_gen_jitd[
                    numSamplesStr] = global_defs.pmap_for_my_devices(
                        lambda x, y, z: x.sample(y, z),
                        static_broadcasted_argnums=(1, ),
                        in_axes=(None, None, 0))
            else:
                self._get_samples_gen_jitd[
                    numSamplesStr] = global_defs.jit_for_my_device(
                        lambda x, y, z: x.sample(y, z), static_argnums=(1, ))

        tmpKey = None
        if global_defs.usePmap:
            tmpKey = random.split(self.key[0], 2 * global_defs.device_count())
            self.key = tmpKey[:global_defs.device_count()]
            tmpKey = tmpKey[global_defs.device_count():]
        else:
            tmpKey, self.key = random.split(self.key)

        return self._get_samples_gen_jitd[numSamplesStr](net.get_sampler_net(),
                                                         numSamples, tmpKey)
Пример #3
0
    def test_mean(self):

        data = jnp.array(
            np.arange(720 * 4 * global_defs.device_count()).reshape(
                (global_defs.device_count() * 720, 4)))
        myNumSamples = mpi.distribute_sampling(global_defs.device_count() *
                                               720)

        myData = data[mpi.rank * myNumSamples:(mpi.rank + 1) *
                      myNumSamples].reshape(get_shape((-1, 4)))

        self.assertTrue(
            jnp.sum(mpi.global_mean(myData) - jnp.mean(data, axis=0)) < 1e-10)
Пример #4
0
    def __init__(self,
                 key,
                 updateProposer,
                 sampleShape,
                 numChains=1,
                 updateProposerArg=None,
                 numSamples=100,
                 thermalizationSweeps=10,
                 sweepSteps=10):
        """Initializes the MCMCSampler class.

        Args:

        * ``key``: An instance of ``jax.random.PRNGKey``.
        * ``updateProposer``: A function to propose updates for the MCMC algorithm. \
        It is called as ``updateProposer(key, config, **kwargs)``, where ``key`` is an instance of \
        ``jax.random.PRNGKey``, ``config`` is a computational basis configuration, and ``**kwargs`` \
        are optional additional arguments. The function is supposed to return a computational basis \
        state that is used as update proposal in the MCMC algorithm.
        * ``sampleShape``: Shape of computational basis configurations.
        * ``numChains``: Number of Markov chains, which are run in parallel.
        * ``updateProposerArg``: An optional argument that will be passed to the ``updateProposer`` \
        as ``kwargs``.
        * ``numSamples``: Default number of samples to be returned by the ``sample()`` member function.
        * ``thermalizationSweeps``: Number of sweeps to perform for thermalization of the Markov chain.
        * ``sweepSteps``: Number of proposed updates per sweep.
        """

        stateShape = (numChains, ) + sampleShape
        if global_defs.usePmap:
            stateShape = (global_defs.device_count(), ) + stateShape
        self.states = jnp.zeros(stateShape, dtype=np.int32)

        self.updateProposer = updateProposer
        self.updateProposerArg = updateProposerArg

        self.key = key
        if global_defs.usePmap:
            self.key = jax.random.split(self.key, global_defs.device_count())
        self.thermalizationSweeps = thermalizationSweeps
        self.sweepSteps = sweepSteps
        self.numSamples = numSamples

        self.numChains = numChains

        # jit'd member functions
        self._get_samples_jitd = {
        }  # will hold a jit'd function for each number of samples
        self._get_samples_gen_jitd = {
        }  # will hold a jit'd function for each number of samples
Пример #5
0
    def _mc_init(self, net):

        # Initialize logPsiSq
        self.logPsiSq = 2. * net.real_coefficients(self.states)

        shape = (1, )
        if global_defs.usePmap:
            shape = (global_defs.device_count(), ) + shape

        self.numProposed = jnp.zeros(shape, dtype=np.int64)
        self.numAccepted = jnp.zeros(shape, dtype=np.int64)
def get_shape(shape):
    if global_defs.usePmap:
        return (global_defs.device_count(),) + shape
    return shape
Пример #7
0
 def dc():
     if global_defs.usePmap:
         return global_defs.device_count()
     return 1