def get_basis(self): myNumStates = mpi.distribute_sampling(self.lDim**self.N) myFirstState = mpi.first_sample_id() deviceCount = global_defs.device_count() if not global_defs.usePmap: deviceCount = 1 self.numStatesPerDevice = [ (myNumStates + deviceCount - 1) // deviceCount ] * deviceCount self.numStatesPerDevice[ -1] += myNumStates - deviceCount * self.numStatesPerDevice[0] self.numStatesPerDevice = jnp.array(self.numStatesPerDevice) totalNumStates = deviceCount * self.numStatesPerDevice[0] if not global_defs.usePmap: self.numStatesPerDevice = self.numStatesPerDevice[0] intReps = jnp.arange(myFirstState, myFirstState + totalNumStates) if global_defs.usePmap: intReps = intReps.reshape((global_defs.device_count(), -1)) self.basis = jnp.zeros(intReps.shape + (self.N, ), dtype=np.int32) if self.lDim == 2: self.basis = self._get_basis_ldim2_pmapd(self.basis, intReps) else: self.basis = self._get_basis_pmapd(self.basis, intReps, self.lDim) return self.basis
def _get_samples_gen(self, net, numSamples): def dc(): if global_defs.usePmap: return global_defs.device_count() return 1 numSamples = mpi.distribute_sampling(numSamples, localDevices=dc()) numSamplesStr = str(numSamples) # check whether _get_samples is already compiled for given number of samples if not numSamplesStr in self._get_samples_gen_jitd: if global_defs.usePmap: self._get_samples_gen_jitd[ numSamplesStr] = global_defs.pmap_for_my_devices( lambda x, y, z: x.sample(y, z), static_broadcasted_argnums=(1, ), in_axes=(None, None, 0)) else: self._get_samples_gen_jitd[ numSamplesStr] = global_defs.jit_for_my_device( lambda x, y, z: x.sample(y, z), static_argnums=(1, )) tmpKey = None if global_defs.usePmap: tmpKey = random.split(self.key[0], 2 * global_defs.device_count()) self.key = tmpKey[:global_defs.device_count()] tmpKey = tmpKey[global_defs.device_count():] else: tmpKey, self.key = random.split(self.key) return self._get_samples_gen_jitd[numSamplesStr](net.get_sampler_net(), numSamples, tmpKey)
def test_mean(self): data = jnp.array( np.arange(720 * 4 * global_defs.device_count()).reshape( (global_defs.device_count() * 720, 4))) myNumSamples = mpi.distribute_sampling(global_defs.device_count() * 720) myData = data[mpi.rank * myNumSamples:(mpi.rank + 1) * myNumSamples].reshape(get_shape((-1, 4))) self.assertTrue( jnp.sum(mpi.global_mean(myData) - jnp.mean(data, axis=0)) < 1e-10)
def _get_samples_mcmc(self, net, numSamples): # Initialize sampling stuff self._mc_init(net) def dc(): if global_defs.usePmap: return global_defs.device_count() return 1 numSamples = mpi.distribute_sampling(numSamples, localDevices=dc(), numChainsPerDevice=self.numChains) numSamplesStr = str(numSamples) # check whether _get_samples is already compiled for given number of samples if not numSamplesStr in self._get_samples_jitd: if global_defs.usePmap: self._get_samples_jitd[ numSamplesStr] = global_defs.pmap_for_my_devices( partial(self._get_samples, sweepFunction=self._sweep), static_broadcasted_argnums=(1, 2, 3, 9), in_axes=(None, None, None, None, 0, 0, 0, 0, 0, None, None)) else: self._get_samples_jitd[ numSamplesStr] = global_defs.jit_for_my_device( partial(self._get_samples, sweepFunction=self._sweep), static_argnums=(1, 2, 3, 9)) (self.states, self.logPsiSq, self.key, self.numProposed, self.numAccepted), configs =\ self._get_samples_jitd[numSamplesStr](net.get_sampler_net(), numSamples, self.thermalizationSweeps, self.sweepSteps, self.states, self.logPsiSq, self.key, self.numProposed, self.numAccepted, self.updateProposer, self.updateProposerArg) #return configs, None return configs, net(configs)