예제 #1
0
    def get_basis(self):

        myNumStates = mpi.distribute_sampling(self.lDim**self.N)
        myFirstState = mpi.first_sample_id()

        deviceCount = global_defs.device_count()
        if not global_defs.usePmap:
            deviceCount = 1

        self.numStatesPerDevice = [
            (myNumStates + deviceCount - 1) // deviceCount
        ] * deviceCount
        self.numStatesPerDevice[
            -1] += myNumStates - deviceCount * self.numStatesPerDevice[0]
        self.numStatesPerDevice = jnp.array(self.numStatesPerDevice)

        totalNumStates = deviceCount * self.numStatesPerDevice[0]

        if not global_defs.usePmap:
            self.numStatesPerDevice = self.numStatesPerDevice[0]

        intReps = jnp.arange(myFirstState, myFirstState + totalNumStates)
        if global_defs.usePmap:
            intReps = intReps.reshape((global_defs.device_count(), -1))
        self.basis = jnp.zeros(intReps.shape + (self.N, ), dtype=np.int32)
        if self.lDim == 2:
            self.basis = self._get_basis_ldim2_pmapd(self.basis, intReps)
        else:
            self.basis = self._get_basis_pmapd(self.basis, intReps, self.lDim)

        return self.basis
예제 #2
0
    def _get_samples_gen(self, net, numSamples):
        def dc():
            if global_defs.usePmap:
                return global_defs.device_count()
            return 1

        numSamples = mpi.distribute_sampling(numSamples, localDevices=dc())
        numSamplesStr = str(numSamples)

        # check whether _get_samples is already compiled for given number of samples
        if not numSamplesStr in self._get_samples_gen_jitd:
            if global_defs.usePmap:
                self._get_samples_gen_jitd[
                    numSamplesStr] = global_defs.pmap_for_my_devices(
                        lambda x, y, z: x.sample(y, z),
                        static_broadcasted_argnums=(1, ),
                        in_axes=(None, None, 0))
            else:
                self._get_samples_gen_jitd[
                    numSamplesStr] = global_defs.jit_for_my_device(
                        lambda x, y, z: x.sample(y, z), static_argnums=(1, ))

        tmpKey = None
        if global_defs.usePmap:
            tmpKey = random.split(self.key[0], 2 * global_defs.device_count())
            self.key = tmpKey[:global_defs.device_count()]
            tmpKey = tmpKey[global_defs.device_count():]
        else:
            tmpKey, self.key = random.split(self.key)

        return self._get_samples_gen_jitd[numSamplesStr](net.get_sampler_net(),
                                                         numSamples, tmpKey)
예제 #3
0
    def test_mean(self):

        data = jnp.array(
            np.arange(720 * 4 * global_defs.device_count()).reshape(
                (global_defs.device_count() * 720, 4)))
        myNumSamples = mpi.distribute_sampling(global_defs.device_count() *
                                               720)

        myData = data[mpi.rank * myNumSamples:(mpi.rank + 1) *
                      myNumSamples].reshape(get_shape((-1, 4)))

        self.assertTrue(
            jnp.sum(mpi.global_mean(myData) - jnp.mean(data, axis=0)) < 1e-10)
예제 #4
0
    def _get_samples_mcmc(self, net, numSamples):

        # Initialize sampling stuff
        self._mc_init(net)

        def dc():
            if global_defs.usePmap:
                return global_defs.device_count()
            return 1

        numSamples = mpi.distribute_sampling(numSamples,
                                             localDevices=dc(),
                                             numChainsPerDevice=self.numChains)
        numSamplesStr = str(numSamples)

        # check whether _get_samples is already compiled for given number of samples
        if not numSamplesStr in self._get_samples_jitd:
            if global_defs.usePmap:
                self._get_samples_jitd[
                    numSamplesStr] = global_defs.pmap_for_my_devices(
                        partial(self._get_samples, sweepFunction=self._sweep),
                        static_broadcasted_argnums=(1, 2, 3, 9),
                        in_axes=(None, None, None, None, 0, 0, 0, 0, 0, None,
                                 None))
            else:
                self._get_samples_jitd[
                    numSamplesStr] = global_defs.jit_for_my_device(
                        partial(self._get_samples, sweepFunction=self._sweep),
                        static_argnums=(1, 2, 3, 9))

        (self.states, self.logPsiSq, self.key, self.numProposed, self.numAccepted), configs =\
            self._get_samples_jitd[numSamplesStr](net.get_sampler_net(), numSamples, self.thermalizationSweeps, self.sweepSteps,
                                                    self.states, self.logPsiSq, self.key, self.numProposed, self.numAccepted,
                                                    self.updateProposer, self.updateProposerArg)

        #return configs, None
        return configs, net(configs)