def test_matvec_treemv_modes(e, jit, holomorphic, pardtype, outdtype): diag_shift = 0.01 model_state = {} rescale_shift = False def apply_fun(params, samples): return e.f(params["params"], samples) mv = qgt_jacobian_pytree_logic.mat_vec homogeneous = pardtype is not None if not nkjax.is_complex_dtype(outdtype): mode = "real" elif homogeneous and nkjax.is_complex_dtype(pardtype) and holomorphic: mode = "holomorphic" else: mode = "complex" if mode == "holomorphic": v = e.v reassemble = lambda x: x else: v, reassemble = nkjax.tree_to_real(e.v) if jit: mv = jax.jit(mv) centered_oks, _ = qgt_jacobian_pytree_logic.prepare_centered_oks( apply_fun, e.params, e.samples, model_state, mode, rescale_shift) actual = reassemble(mv(v, centered_oks, diag_shift)) expected = reassemble_complex(e.S_real @ e.v_real_flat + diag_shift * e.v_real_flat, target=e.target) assert tree_allclose(actual, expected)
def sigmay(hilbert: _AbstractHilbert, site: int, dtype: _DType = complex) -> _LocalOperator: """ Builds the :math:`\\sigma^y` operator acting on the `site`-th of the Hilbert space `hilbert`. If `hilbert` is a non-Spin space of local dimension M, it is considered as a (M-1)/2 - spin space. :param hilbert: The hilbert space :param site: the site on which this operator acts :return: a nk.operator.LocalOperator """ import numpy as np import netket.jax as nkjax if not nkjax.is_complex_dtype(dtype): import jax.numpy as jnp import warnings old_dtype = dtype dtype = jnp.promote_types(complex, old_dtype) warnings.warn( np.ComplexWarning( f"A complex dtype is required (dtype={old_dtype} specified). " f"Promoting to dtype={dtype}.")) N = hilbert.size_at_index(site) S = (N - 1) / 2 D = np.array( [1j * np.sqrt((S + 1) * 2 * a - a * (a + 1)) for a in np.arange(1, N)]) mat = np.diag(D, -1) + np.diag(-D, 1) return _LocalOperator(hilbert, mat, [site], dtype=dtype)
def test_scale_invariant_regularization(e, outdtype, pardtype): if not nkjax.is_complex_dtype(pardtype) and nkjax.is_complex_dtype(outdtype): centered_jacobian_fun = qgt_jacobian_pytree_logic.centered_jacobian_cplx else: centered_jacobian_fun = qgt_jacobian_pytree_logic.centered_jacobian_real_holo mv = qgt_jacobian_pytree_logic._mat_vec centered_oks = centered_jacobian_fun(e.f, e.params, e.samples) centered_oks = qgt_jacobian_pytree_logic._divide_by_sqrt_n_samp( centered_oks, e.samples ) centered_oks_scaled, scale = qgt_jacobian_pytree_logic._rescale(centered_oks) actual = mv(e.v, centered_oks_scaled) expected = reassemble_complex(e.S_real_scaled @ e.v_real_flat, target=e.target) assert tree_allclose(actual, expected)
def astype_unsafe(x, dtype): """ this function is equivalent to x.astype(dtype) but does not raise a complexwarning, which we treat as an error in our tests """ if not nkjax.is_complex_dtype(dtype): x = x.real return x.astype(dtype)
def test_matvec_treemv(e, jit, holomorphic, pardtype, outdtype, chunk_size): mv = qgt_jacobian_pytree_logic._mat_vec if not nkjax.is_complex_dtype(pardtype) and nkjax.is_complex_dtype(outdtype): centered_jacobian_fun = qgt_jacobian_pytree_logic.centered_jacobian_cplx else: centered_jacobian_fun = qgt_jacobian_pytree_logic.centered_jacobian_real_holo centered_jacobian_fun = partial(centered_jacobian_fun, chunk_size=chunk_size) if jit: mv = jax.jit(mv) centered_jacobian_fun = jax.jit(centered_jacobian_fun, static_argnums=0) centered_oks = centered_jacobian_fun(e.f, e.params, e.samples) centered_oks = qgt_jacobian_pytree_logic._divide_by_sqrt_n_samp( centered_oks, e.samples ) actual = mv(e.v, centered_oks) expected = reassemble_complex(e.S_real @ e.v_real_flat, target=e.target) assert tree_allclose(actual, expected)
def GCNN( symmetries=None, product_table=None, irreps=None, point_group=None, mode="auto", shape=None, layers=None, features=None, characters=None, parity=None, param_dtype=np.float64, complex_output=True, **kwargs, ): r"""Implements a Group Convolutional Neural Network (G-CNN) that outputs a wavefunction that is invariant over a specified symmetry group. The G-CNN is described in `Cohen et al. <http://proceedings.mlr.press/v48/cohenc16.pdf>`_ and applied to quantum many-body problems in `Roth et al. <https://arxiv.org/pdf/2104.05085.pdf>`_ . The G-CNN alternates convolution operations with pointwise non-linearities. The first layer is symmetrized linear transform given by DenseSymm, while the other layers are G-convolutions given by DenseEquivariant. The hidden layers of the G-CNN are related by the following equation: .. math :: {\bf f}^{i+1}_h = \Gamma( \sum_h W_{g^{-1} h} {\bf f}^i_h). Args: symmetries: A specification of the symmetry group. Can be given by a :class:`nk.graph.Graph`, a :class:`nk.utils.PermutationGroup`, or an array :code:`[n_symm, n_sites]` specifying the permutations corresponding to symmetry transformations of the lattice. product_table: Product table describing the algebra of the symmetry group. Only needs to be specified if mode='fft' and symmetries is specified as an array. irreps: List of 3D tensors that project onto irreducible representations of the symmetry group. Only needs to be specified if mode='irreps' and symmetries is specified as an array. point_group: The point group, from which the space group is built. If symmetries is a graph the default point group is overwritten. mode: string "fft, irreps, matrix, auto" specifying whether to use a fast fourier transform over the translation group, a fourier transform using the irreducible representations or by constructing the full kernel matrix. shape: A tuple specifying the dimensions of the translation group. layers: Number of layers (not including sum layer over output). features: Number of features in each layer starting from the input. If a single number is given, all layers will have the same number of features. characters: Array specifying the characters of the desired symmetry representation. parity: Optional argument with value +/-1 that specifies the eigenvalue with respect to parity (only use on two level systems). param_dtype: The dtype of the weights. activation: The nonlinear activation function between hidden layers. Defaults to :func:`nk.nn.activation.reim_selu` . output_activation: The nonlinear activation before the output. equal_amplitudes: If True forces all basis states to have equal amplitude by setting :math:`\Re(\psi) = 0` . use_bias: If True uses a bias in all layers. precision: Numerical precision of the computation see :class:`jax.lax.Precision` for details. kernel_init: Initializer for the kernels of all layers. Defaults to :code:`lecun_normal(in_axis=1, out_axis=0)` which guarantees the correct variance of the output. bias_init: Initializer for the biases of all layers. complex_output: If True, ensures that the network output is always complex. Necessary when network parameters are real but some `characters` are negative. """ if isinstance(symmetries, Lattice) and (point_group is not None or symmetries._point_group is not None): # With graph try to find point group, otherwise default to automorphisms shape = tuple(symmetries.extent) sg = symmetries.space_group(point_group) if mode == "auto": mode = "fft" elif isinstance(symmetries, Graph): sg = symmetries.automorphisms() if mode == "auto": mode = "irreps" if mode == "fft": raise ValueError( "When requesting 'mode=fft' a valid point group must be specified" "in order to construct the space group") elif isinstance(symmetries, PermutationGroup): # If we get a group and default to irrep projection if mode == "auto": mode = "irreps" sg = symmetries else: if irreps is not None and (mode == "irreps" or mode == "auto"): mode = "irreps" sg = symmetries irreps = tuple(HashableArray(irrep) for irrep in irreps) elif product_table is not None and (mode == "fft" or mode == "auto"): mode = "fft" sg = symmetries product_table = HashableArray(product_table) else: raise ValueError( "Specification of symmetries is wrong or incompatible with selected mode" ) if mode == "fft": if shape is None: raise TypeError( "When requesting `mode=fft`, the shape of the translation group must be specified. " "Either supply the `shape` keyword argument or pass a `netket.graph.Graph` object to " "the symmetries keyword argument.") else: shape = tuple(shape) if isinstance(features, int): features = (features, ) * layers if characters is None: characters = HashableArray(np.ones(len(np.asarray(sg)))) else: if (not is_complex(characters) and not is_complex_dtype(param_dtype) and not complex_output and jnp.any(characters < 0)): raise ValueError( "`complex_output` must be used with real parameters and negative " "characters to avoid NaN errors.") characters = HashableArray(characters) if mode == "fft": sym = HashableArray(np.asarray(sg)) if product_table is None: product_table = HashableArray(sg.product_table) if parity: return GCNN_Parity_FFT( symmetries=sym, product_table=product_table, layers=layers, features=features, characters=characters, shape=shape, parity=parity, param_dtype=param_dtype, complex_output=complex_output, **kwargs, ) else: return GCNN_FFT( symmetries=sym, product_table=product_table, layers=layers, features=features, characters=characters, shape=shape, param_dtype=param_dtype, complex_output=complex_output, **kwargs, ) elif mode in ["irreps", "auto"]: sym = HashableArray(np.asarray(sg)) if irreps is None: irreps = tuple( HashableArray(irrep) for irrep in sg.irrep_matrices()) if parity: return GCNN_Parity_Irrep( symmetries=sym, irreps=irreps, layers=layers, features=features, characters=characters, parity=parity, param_dtype=param_dtype, complex_output=complex_output, **kwargs, ) else: return GCNN_Irrep( symmetries=sym, irreps=irreps, layers=layers, features=features, characters=characters, param_dtype=param_dtype, complex_output=complex_output, **kwargs, ) else: raise ValueError( f"Unknown mode={mode}. Valid modes are 'fft',irreps' or 'auto'.")