def to_array(self, normalize: bool = True, allgather: bool = True) -> jnp.ndarray: if self._array is None and normalize: self._array = nn.to_array( self.hilbert, self._apply_fun, self.variables, normalize=normalize, allgather=allgather, ) if normalize: arr = self._array else: arr = nn.to_array( self.hilbert, self._apply_fun, self.variables, normalize=normalize, allgather=allgather, ) return arr
def _reset(sampler, machine, parameters, state): pdf = jnp.absolute( to_array(sampler.hilbert, machine, parameters) ** sampler.machine_pow ) pdf = pdf / pdf.sum() return state.replace(pdf=pdf)
def to_array(self, normalize: bool = True) -> jnp.ndarray: return nn.to_array(self.hilbert, self._apply_fun, self.variables, normalize=normalize)