示例#1
0
def test_matvec_treemv_modes(e, jit, holomorphic, pardtype, outdtype):
    diag_shift = 0.01
    model_state = {}
    rescale_shift = False

    def apply_fun(params, samples):
        return e.f(params["params"], samples)

    mv = qgt_jacobian_pytree_logic.mat_vec

    homogeneous = pardtype is not None

    if not nkjax.is_complex_dtype(outdtype):
        mode = "real"
    elif homogeneous and nkjax.is_complex_dtype(pardtype) and holomorphic:
        mode = "holomorphic"
    else:
        mode = "complex"

    if mode == "holomorphic":
        v = e.v
        reassemble = lambda x: x
    else:
        v, reassemble = nkjax.tree_to_real(e.v)

    if jit:
        mv = jax.jit(mv)

    centered_oks, _ = qgt_jacobian_pytree_logic.prepare_centered_oks(
        apply_fun, e.params, e.samples, model_state, mode, rescale_shift)
    actual = reassemble(mv(v, centered_oks, diag_shift))
    expected = reassemble_complex(e.S_real @ e.v_real_flat +
                                  diag_shift * e.v_real_flat,
                                  target=e.target)
    assert tree_allclose(actual, expected)
示例#2
0
def sigmay(hilbert: _AbstractHilbert,
           site: int,
           dtype: _DType = complex) -> _LocalOperator:
    """
    Builds the :math:`\\sigma^y` operator acting on the `site`-th of the Hilbert
    space `hilbert`.

    If `hilbert` is a non-Spin space of local dimension M, it is considered
    as a (M-1)/2 - spin space.

    :param hilbert: The hilbert space
    :param site: the site on which this operator acts
    :return: a nk.operator.LocalOperator
    """
    import numpy as np
    import netket.jax as nkjax

    if not nkjax.is_complex_dtype(dtype):
        import jax.numpy as jnp
        import warnings

        old_dtype = dtype
        dtype = jnp.promote_types(complex, old_dtype)
        warnings.warn(
            np.ComplexWarning(
                f"A complex dtype is required (dtype={old_dtype} specified). "
                f"Promoting to dtype={dtype}."))

    N = hilbert.size_at_index(site)
    S = (N - 1) / 2

    D = np.array(
        [1j * np.sqrt((S + 1) * 2 * a - a * (a + 1)) for a in np.arange(1, N)])
    mat = np.diag(D, -1) + np.diag(-D, 1)
    return _LocalOperator(hilbert, mat, [site], dtype=dtype)
示例#3
0
def test_scale_invariant_regularization(e, outdtype, pardtype):

    if not nkjax.is_complex_dtype(pardtype) and nkjax.is_complex_dtype(outdtype):
        centered_jacobian_fun = qgt_jacobian_pytree_logic.centered_jacobian_cplx
    else:
        centered_jacobian_fun = qgt_jacobian_pytree_logic.centered_jacobian_real_holo

    mv = qgt_jacobian_pytree_logic._mat_vec
    centered_oks = centered_jacobian_fun(e.f, e.params, e.samples)
    centered_oks = qgt_jacobian_pytree_logic._divide_by_sqrt_n_samp(
        centered_oks, e.samples
    )

    centered_oks_scaled, scale = qgt_jacobian_pytree_logic._rescale(centered_oks)
    actual = mv(e.v, centered_oks_scaled)
    expected = reassemble_complex(e.S_real_scaled @ e.v_real_flat, target=e.target)
    assert tree_allclose(actual, expected)
示例#4
0
def astype_unsafe(x, dtype):
    """
    this function is equivalent to x.astype(dtype) but
    does not raise a complexwarning, which we treat as an error
    in our tests
    """
    if not nkjax.is_complex_dtype(dtype):
        x = x.real
    return x.astype(dtype)
示例#5
0
def test_matvec_treemv(e, jit, holomorphic, pardtype, outdtype, chunk_size):
    mv = qgt_jacobian_pytree_logic._mat_vec

    if not nkjax.is_complex_dtype(pardtype) and nkjax.is_complex_dtype(outdtype):
        centered_jacobian_fun = qgt_jacobian_pytree_logic.centered_jacobian_cplx
    else:
        centered_jacobian_fun = qgt_jacobian_pytree_logic.centered_jacobian_real_holo
    centered_jacobian_fun = partial(centered_jacobian_fun, chunk_size=chunk_size)
    if jit:
        mv = jax.jit(mv)
        centered_jacobian_fun = jax.jit(centered_jacobian_fun, static_argnums=0)

    centered_oks = centered_jacobian_fun(e.f, e.params, e.samples)
    centered_oks = qgt_jacobian_pytree_logic._divide_by_sqrt_n_samp(
        centered_oks, e.samples
    )
    actual = mv(e.v, centered_oks)
    expected = reassemble_complex(e.S_real @ e.v_real_flat, target=e.target)
    assert tree_allclose(actual, expected)
示例#6
0
def GCNN(
    symmetries=None,
    product_table=None,
    irreps=None,
    point_group=None,
    mode="auto",
    shape=None,
    layers=None,
    features=None,
    characters=None,
    parity=None,
    param_dtype=np.float64,
    complex_output=True,
    **kwargs,
):
    r"""Implements a Group Convolutional Neural Network (G-CNN) that outputs a wavefunction
    that is invariant over a specified symmetry group.

    The G-CNN is described in `Cohen et al. <http://proceedings.mlr.press/v48/cohenc16.pdf>`_
    and applied to quantum many-body problems in `Roth et al. <https://arxiv.org/pdf/2104.05085.pdf>`_ .

    The G-CNN alternates convolution operations with pointwise non-linearities. The first
    layer is symmetrized linear transform given by DenseSymm, while the other layers are
    G-convolutions given by DenseEquivariant. The hidden layers of the G-CNN are related by
    the following equation:

    .. math ::

        {\bf f}^{i+1}_h = \Gamma( \sum_h W_{g^{-1} h} {\bf f}^i_h).


    Args:
        symmetries: A specification of the symmetry group. Can be given by a
            :class:`nk.graph.Graph`, a :class:`nk.utils.PermutationGroup`, or an
            array :code:`[n_symm, n_sites]` specifying the permutations
            corresponding to symmetry transformations of the lattice.
        product_table: Product table describing the algebra of the symmetry group.
            Only needs to be specified if mode='fft' and symmetries is specified as an array.
        irreps: List of 3D tensors that project onto irreducible representations of the symmetry group.
            Only needs to be specified if mode='irreps' and symmetries is specified as an array.
        point_group: The point group, from which the space group is built. If symmetries is a
            graph the default point group is overwritten.
        mode: string "fft, irreps, matrix, auto" specifying whether to use a fast
            fourier transform over the translation group, a fourier transform using
            the irreducible representations or by constructing the full kernel matrix.
        shape: A tuple specifying the dimensions of the translation group.
        layers: Number of layers (not including sum layer over output).
        features: Number of features in each layer starting from the input. If a single
            number is given, all layers will have the same number of features.
        characters: Array specifying the characters of the desired symmetry representation.
        parity: Optional argument with value +/-1 that specifies the eigenvalue
            with respect to parity (only use on two level systems).
        param_dtype: The dtype of the weights.
        activation: The nonlinear activation function between hidden layers. Defaults to
            :func:`nk.nn.activation.reim_selu` .
        output_activation: The nonlinear activation before the output.
        equal_amplitudes: If True forces all basis states to have equal amplitude
            by setting :math:`\Re(\psi) = 0` .
        use_bias: If True uses a bias in all layers.
        precision: Numerical precision of the computation see :class:`jax.lax.Precision` for details.
        kernel_init: Initializer for the kernels of all layers. Defaults to
            :code:`lecun_normal(in_axis=1, out_axis=0)` which guarantees the correct variance of the
            output.
        bias_init: Initializer for the biases of all layers.
        complex_output: If True, ensures that the network output is always complex.
            Necessary when network parameters are real but some `characters` are negative.


    """

    if isinstance(symmetries,
                  Lattice) and (point_group is not None
                                or symmetries._point_group is not None):
        # With graph try to find point group, otherwise default to automorphisms
        shape = tuple(symmetries.extent)
        sg = symmetries.space_group(point_group)
        if mode == "auto":
            mode = "fft"
    elif isinstance(symmetries, Graph):
        sg = symmetries.automorphisms()
        if mode == "auto":
            mode = "irreps"
        if mode == "fft":
            raise ValueError(
                "When requesting 'mode=fft' a valid point group must be specified"
                "in order to construct the space group")
    elif isinstance(symmetries, PermutationGroup):
        # If we get a group and default to irrep projection
        if mode == "auto":
            mode = "irreps"
        sg = symmetries
    else:
        if irreps is not None and (mode == "irreps" or mode == "auto"):
            mode = "irreps"
            sg = symmetries
            irreps = tuple(HashableArray(irrep) for irrep in irreps)
        elif product_table is not None and (mode == "fft" or mode == "auto"):
            mode = "fft"
            sg = symmetries
            product_table = HashableArray(product_table)
        else:
            raise ValueError(
                "Specification of symmetries is wrong or incompatible with selected mode"
            )

    if mode == "fft":
        if shape is None:
            raise TypeError(
                "When requesting `mode=fft`, the shape of the translation group must be specified. "
                "Either supply the `shape` keyword argument or pass a `netket.graph.Graph` object to "
                "the symmetries keyword argument.")
        else:
            shape = tuple(shape)

    if isinstance(features, int):
        features = (features, ) * layers

    if characters is None:
        characters = HashableArray(np.ones(len(np.asarray(sg))))
    else:
        if (not is_complex(characters) and not is_complex_dtype(param_dtype)
                and not complex_output and jnp.any(characters < 0)):
            raise ValueError(
                "`complex_output` must be used with real parameters and negative "
                "characters to avoid NaN errors.")
        characters = HashableArray(characters)

    if mode == "fft":
        sym = HashableArray(np.asarray(sg))
        if product_table is None:
            product_table = HashableArray(sg.product_table)
        if parity:
            return GCNN_Parity_FFT(
                symmetries=sym,
                product_table=product_table,
                layers=layers,
                features=features,
                characters=characters,
                shape=shape,
                parity=parity,
                param_dtype=param_dtype,
                complex_output=complex_output,
                **kwargs,
            )
        else:
            return GCNN_FFT(
                symmetries=sym,
                product_table=product_table,
                layers=layers,
                features=features,
                characters=characters,
                shape=shape,
                param_dtype=param_dtype,
                complex_output=complex_output,
                **kwargs,
            )
    elif mode in ["irreps", "auto"]:
        sym = HashableArray(np.asarray(sg))

        if irreps is None:
            irreps = tuple(
                HashableArray(irrep) for irrep in sg.irrep_matrices())

        if parity:
            return GCNN_Parity_Irrep(
                symmetries=sym,
                irreps=irreps,
                layers=layers,
                features=features,
                characters=characters,
                parity=parity,
                param_dtype=param_dtype,
                complex_output=complex_output,
                **kwargs,
            )
        else:
            return GCNN_Irrep(
                symmetries=sym,
                irreps=irreps,
                layers=layers,
                features=features,
                characters=characters,
                param_dtype=param_dtype,
                complex_output=complex_output,
                **kwargs,
            )
    else:
        raise ValueError(
            f"Unknown mode={mode}. Valid modes are 'fft',irreps' or 'auto'.")