def vec_to_real(vec: Array) -> Tuple[Array, Callable]: """ If the input vector is real, splits the vector into real and imaginary parts and concatenates them along the 0-th axis. It is equivalent to changing the complex storage from AOS to SOA. Args: vec: a dense vector """ out, reassemble = nkjax.tree_to_real(vec) if nkjax.is_complex(vec): re, im = out out = jnp.concatenate([re, im], axis=0) def reassemble_concat(x): x = tuple(jnp.split(x, 2, axis=0)) return reassemble(x) else: reassemble_concat = reassemble return out, reassemble_concat
def O_mean(forward_fn, params, samples, holomorphic=True): r""" compute \langle O \rangle i.e. the mean of the rows of the jacobian of forward_fn """ # determine the output type of the forward pass dtype = jax.eval_shape(forward_fn, params, samples).dtype w = jnp.ones(samples.shape[0], dtype=dtype) * (1.0 / (samples.shape[0] * mpi.n_nodes)) homogeneous = nkjax.tree_ishomogeneous(params) real_params = not nkjax.tree_leaf_iscomplex(params) real_out = not nkjax.is_complex(jax.eval_shape(forward_fn, params, samples)) if homogeneous and (real_params or holomorphic): if real_params and not real_out: # R->C return O_vjp_rc(forward_fn, params, samples, w) else: # R->R and holomorphic C->C return O_vjp(forward_fn, params, samples, w) else: # R&C -> C # non-holomorphic # C->R assert False
def _choose_jacobian_mode(apply_fun, pars, model_state, samples, mode, holomorphic): homogeneous_vars = nkjax.tree_ishomogeneous(pars) if holomorphic is True: if not homogeneous_vars: warnings.warn( dedent( """The ansatz has non homogeneous variables, which might not behave well with the holomorhic implemnetation. Use `holomorphic=False` or mode='complex' for more accurate results but lower performance. """ ) ) mode = "holomorphic" else: leaf_iscomplex = nkjax.tree_leaf_iscomplex(pars) complex_output = nkjax.is_complex( jax.eval_shape( apply_fun, {"params": pars, **model_state}, samples.reshape(-1, samples.shape[-1]), ) ) if complex_output: if leaf_iscomplex: if holomorphic is None: warnings.warn( dedent( """ Complex-to-Complex model detected. Defaulting to `holomorphic=False` for the implementation of QGTJacobianDense. If your model is holomorphic, specify `holomorphic=True` to use a more performant implementation. To suppress this warning specify `holomorphic`. """ ), UserWarning, ) mode = "complex" else: mode = "complex" else: mode = "real" if mode == "real": return 0 elif mode == "complex": return 1 elif mode == "holomorphic": return 2 else: raise ValueError(f"unknown mode {mode}")
def _to_dense(self: QGTOnTheFlyT) -> jnp.ndarray: """ Convert the lazy matrix representation to a dense matrix representation.s Returns: A dense matrix representation of this S matrix. """ Npars = nkjax.tree_size(self.params) I = jax.numpy.eye(Npars) out = jax.vmap(lambda x: self @ x, in_axes=0)(I) if nkjax.is_complex(out): out = out.T return out
def _to_dense(self: QGTOnTheFlyT) -> jnp.ndarray: """ Convert the lazy matrix representation to a dense matrix representation Returns: A dense matrix representation of this S matrix. """ Npars = nkjax.tree_size(self._params) I = jax.numpy.eye(Npars) if self._chunking: # the linear_call in mat_vec_chunked does currently not have a jax batching rule, # so it cannot be vmapped but we can use scan # which is better for reducing the memory consumption anyway _, out = jax.lax.scan(lambda _, x: (None, self @ x), None, I) else: out = jax.vmap(lambda x: self @ x, in_axes=0)(I) if nkjax.is_complex(out): out = out.T return out
def GCNN( symmetries=None, product_table=None, irreps=None, point_group=None, mode="auto", shape=None, layers=None, features=None, characters=None, parity=None, param_dtype=np.float64, complex_output=True, **kwargs, ): r"""Implements a Group Convolutional Neural Network (G-CNN) that outputs a wavefunction that is invariant over a specified symmetry group. The G-CNN is described in `Cohen et al. <http://proceedings.mlr.press/v48/cohenc16.pdf>`_ and applied to quantum many-body problems in `Roth et al. <https://arxiv.org/pdf/2104.05085.pdf>`_ . The G-CNN alternates convolution operations with pointwise non-linearities. The first layer is symmetrized linear transform given by DenseSymm, while the other layers are G-convolutions given by DenseEquivariant. The hidden layers of the G-CNN are related by the following equation: .. math :: {\bf f}^{i+1}_h = \Gamma( \sum_h W_{g^{-1} h} {\bf f}^i_h). Args: symmetries: A specification of the symmetry group. Can be given by a :class:`nk.graph.Graph`, a :class:`nk.utils.PermutationGroup`, or an array :code:`[n_symm, n_sites]` specifying the permutations corresponding to symmetry transformations of the lattice. product_table: Product table describing the algebra of the symmetry group. Only needs to be specified if mode='fft' and symmetries is specified as an array. irreps: List of 3D tensors that project onto irreducible representations of the symmetry group. Only needs to be specified if mode='irreps' and symmetries is specified as an array. point_group: The point group, from which the space group is built. If symmetries is a graph the default point group is overwritten. mode: string "fft, irreps, matrix, auto" specifying whether to use a fast fourier transform over the translation group, a fourier transform using the irreducible representations or by constructing the full kernel matrix. shape: A tuple specifying the dimensions of the translation group. layers: Number of layers (not including sum layer over output). features: Number of features in each layer starting from the input. If a single number is given, all layers will have the same number of features. characters: Array specifying the characters of the desired symmetry representation. parity: Optional argument with value +/-1 that specifies the eigenvalue with respect to parity (only use on two level systems). param_dtype: The dtype of the weights. activation: The nonlinear activation function between hidden layers. Defaults to :func:`nk.nn.activation.reim_selu` . output_activation: The nonlinear activation before the output. equal_amplitudes: If True forces all basis states to have equal amplitude by setting :math:`\Re(\psi) = 0` . use_bias: If True uses a bias in all layers. precision: Numerical precision of the computation see :class:`jax.lax.Precision` for details. kernel_init: Initializer for the kernels of all layers. Defaults to :code:`lecun_normal(in_axis=1, out_axis=0)` which guarantees the correct variance of the output. bias_init: Initializer for the biases of all layers. complex_output: If True, ensures that the network output is always complex. Necessary when network parameters are real but some `characters` are negative. """ if isinstance(symmetries, Lattice) and (point_group is not None or symmetries._point_group is not None): # With graph try to find point group, otherwise default to automorphisms shape = tuple(symmetries.extent) sg = symmetries.space_group(point_group) if mode == "auto": mode = "fft" elif isinstance(symmetries, Graph): sg = symmetries.automorphisms() if mode == "auto": mode = "irreps" if mode == "fft": raise ValueError( "When requesting 'mode=fft' a valid point group must be specified" "in order to construct the space group") elif isinstance(symmetries, PermutationGroup): # If we get a group and default to irrep projection if mode == "auto": mode = "irreps" sg = symmetries else: if irreps is not None and (mode == "irreps" or mode == "auto"): mode = "irreps" sg = symmetries irreps = tuple(HashableArray(irrep) for irrep in irreps) elif product_table is not None and (mode == "fft" or mode == "auto"): mode = "fft" sg = symmetries product_table = HashableArray(product_table) else: raise ValueError( "Specification of symmetries is wrong or incompatible with selected mode" ) if mode == "fft": if shape is None: raise TypeError( "When requesting `mode=fft`, the shape of the translation group must be specified. " "Either supply the `shape` keyword argument or pass a `netket.graph.Graph` object to " "the symmetries keyword argument.") else: shape = tuple(shape) if isinstance(features, int): features = (features, ) * layers if characters is None: characters = HashableArray(np.ones(len(np.asarray(sg)))) else: if (not is_complex(characters) and not is_complex_dtype(param_dtype) and not complex_output and jnp.any(characters < 0)): raise ValueError( "`complex_output` must be used with real parameters and negative " "characters to avoid NaN errors.") characters = HashableArray(characters) if mode == "fft": sym = HashableArray(np.asarray(sg)) if product_table is None: product_table = HashableArray(sg.product_table) if parity: return GCNN_Parity_FFT( symmetries=sym, product_table=product_table, layers=layers, features=features, characters=characters, shape=shape, parity=parity, param_dtype=param_dtype, complex_output=complex_output, **kwargs, ) else: return GCNN_FFT( symmetries=sym, product_table=product_table, layers=layers, features=features, characters=characters, shape=shape, param_dtype=param_dtype, complex_output=complex_output, **kwargs, ) elif mode in ["irreps", "auto"]: sym = HashableArray(np.asarray(sg)) if irreps is None: irreps = tuple( HashableArray(irrep) for irrep in sg.irrep_matrices()) if parity: return GCNN_Parity_Irrep( symmetries=sym, irreps=irreps, layers=layers, features=features, characters=characters, parity=parity, param_dtype=param_dtype, complex_output=complex_output, **kwargs, ) else: return GCNN_Irrep( symmetries=sym, irreps=irreps, layers=layers, features=features, characters=characters, param_dtype=param_dtype, complex_output=complex_output, **kwargs, ) else: raise ValueError( f"Unknown mode={mode}. Valid modes are 'fft',irreps' or 'auto'.")