Example #1
0
def vec_to_real(vec: Array) -> Tuple[Array, Callable]:
    """
    If the input vector is real, splits the vector into real
    and imaginary parts and concatenates them along the 0-th
    axis.

    It is equivalent to changing the complex storage from AOS
    to SOA.

    Args:
        vec: a dense vector
    """
    out, reassemble = nkjax.tree_to_real(vec)

    if nkjax.is_complex(vec):
        re, im = out

        out = jnp.concatenate([re, im], axis=0)

        def reassemble_concat(x):
            x = tuple(jnp.split(x, 2, axis=0))
            return reassemble(x)

    else:
        reassemble_concat = reassemble

    return out, reassemble_concat
Example #2
0
def O_mean(forward_fn, params, samples, holomorphic=True):
    r"""
    compute \langle O \rangle
    i.e. the mean of the rows of the jacobian of forward_fn
    """

    # determine the output type of the forward pass
    dtype = jax.eval_shape(forward_fn, params, samples).dtype
    w = jnp.ones(samples.shape[0],
                 dtype=dtype) * (1.0 / (samples.shape[0] * mpi.n_nodes))

    homogeneous = nkjax.tree_ishomogeneous(params)
    real_params = not nkjax.tree_leaf_iscomplex(params)
    real_out = not nkjax.is_complex(jax.eval_shape(forward_fn, params,
                                                   samples))

    if homogeneous and (real_params or holomorphic):
        if real_params and not real_out:
            # R->C
            return O_vjp_rc(forward_fn, params, samples, w)
        else:
            # R->R and holomorphic C->C
            return O_vjp(forward_fn, params, samples, w)
    else:
        # R&C -> C
        # non-holomorphic
        # C->R
        assert False
def _choose_jacobian_mode(apply_fun, pars, model_state, samples, mode, holomorphic):
    homogeneous_vars = nkjax.tree_ishomogeneous(pars)

    if holomorphic is True:
        if not homogeneous_vars:
            warnings.warn(
                dedent(
                    """The ansatz has non homogeneous variables, which might not behave well with the
                       holomorhic implemnetation.
                       Use `holomorphic=False` or mode='complex' for more accurate results but
                       lower performance.
                    """
                )
            )
        mode = "holomorphic"
    else:
        leaf_iscomplex = nkjax.tree_leaf_iscomplex(pars)
        complex_output = nkjax.is_complex(
            jax.eval_shape(
                apply_fun,
                {"params": pars, **model_state},
                samples.reshape(-1, samples.shape[-1]),
            )
        )

        if complex_output:
            if leaf_iscomplex:
                if holomorphic is None:
                    warnings.warn(
                        dedent(
                            """
                                Complex-to-Complex model detected. Defaulting to `holomorphic=False` for
                                the implementation of QGTJacobianDense.
                                If your model is holomorphic, specify `holomorphic=True` to use a more
                                performant implementation.
                                To suppress this warning specify `holomorphic`.
                                """
                        ),
                        UserWarning,
                    )
                mode = "complex"
            else:
                mode = "complex"
        else:
            mode = "real"

    if mode == "real":
        return 0
    elif mode == "complex":
        return 1
    elif mode == "holomorphic":
        return 2
    else:
        raise ValueError(f"unknown mode {mode}")
Example #4
0
def _to_dense(self: QGTOnTheFlyT) -> jnp.ndarray:
    """
    Convert the lazy matrix representation to a dense matrix representation.s

    Returns:
        A dense matrix representation of this S matrix.
    """
    Npars = nkjax.tree_size(self.params)
    I = jax.numpy.eye(Npars)
    out = jax.vmap(lambda x: self @ x, in_axes=0)(I)

    if nkjax.is_complex(out):
        out = out.T

    return out
Example #5
0
def _to_dense(self: QGTOnTheFlyT) -> jnp.ndarray:
    """
    Convert the lazy matrix representation to a dense matrix representation

    Returns:
        A dense matrix representation of this S matrix.
    """
    Npars = nkjax.tree_size(self._params)
    I = jax.numpy.eye(Npars)

    if self._chunking:
        # the linear_call in mat_vec_chunked does currently not have a jax batching rule,
        # so it cannot be vmapped but we can use scan
        # which is better for reducing the memory consumption anyway
        _, out = jax.lax.scan(lambda _, x: (None, self @ x), None, I)
    else:
        out = jax.vmap(lambda x: self @ x, in_axes=0)(I)

    if nkjax.is_complex(out):
        out = out.T

    return out
Example #6
0
def GCNN(
    symmetries=None,
    product_table=None,
    irreps=None,
    point_group=None,
    mode="auto",
    shape=None,
    layers=None,
    features=None,
    characters=None,
    parity=None,
    param_dtype=np.float64,
    complex_output=True,
    **kwargs,
):
    r"""Implements a Group Convolutional Neural Network (G-CNN) that outputs a wavefunction
    that is invariant over a specified symmetry group.

    The G-CNN is described in `Cohen et al. <http://proceedings.mlr.press/v48/cohenc16.pdf>`_
    and applied to quantum many-body problems in `Roth et al. <https://arxiv.org/pdf/2104.05085.pdf>`_ .

    The G-CNN alternates convolution operations with pointwise non-linearities. The first
    layer is symmetrized linear transform given by DenseSymm, while the other layers are
    G-convolutions given by DenseEquivariant. The hidden layers of the G-CNN are related by
    the following equation:

    .. math ::

        {\bf f}^{i+1}_h = \Gamma( \sum_h W_{g^{-1} h} {\bf f}^i_h).


    Args:
        symmetries: A specification of the symmetry group. Can be given by a
            :class:`nk.graph.Graph`, a :class:`nk.utils.PermutationGroup`, or an
            array :code:`[n_symm, n_sites]` specifying the permutations
            corresponding to symmetry transformations of the lattice.
        product_table: Product table describing the algebra of the symmetry group.
            Only needs to be specified if mode='fft' and symmetries is specified as an array.
        irreps: List of 3D tensors that project onto irreducible representations of the symmetry group.
            Only needs to be specified if mode='irreps' and symmetries is specified as an array.
        point_group: The point group, from which the space group is built. If symmetries is a
            graph the default point group is overwritten.
        mode: string "fft, irreps, matrix, auto" specifying whether to use a fast
            fourier transform over the translation group, a fourier transform using
            the irreducible representations or by constructing the full kernel matrix.
        shape: A tuple specifying the dimensions of the translation group.
        layers: Number of layers (not including sum layer over output).
        features: Number of features in each layer starting from the input. If a single
            number is given, all layers will have the same number of features.
        characters: Array specifying the characters of the desired symmetry representation.
        parity: Optional argument with value +/-1 that specifies the eigenvalue
            with respect to parity (only use on two level systems).
        param_dtype: The dtype of the weights.
        activation: The nonlinear activation function between hidden layers. Defaults to
            :func:`nk.nn.activation.reim_selu` .
        output_activation: The nonlinear activation before the output.
        equal_amplitudes: If True forces all basis states to have equal amplitude
            by setting :math:`\Re(\psi) = 0` .
        use_bias: If True uses a bias in all layers.
        precision: Numerical precision of the computation see :class:`jax.lax.Precision` for details.
        kernel_init: Initializer for the kernels of all layers. Defaults to
            :code:`lecun_normal(in_axis=1, out_axis=0)` which guarantees the correct variance of the
            output.
        bias_init: Initializer for the biases of all layers.
        complex_output: If True, ensures that the network output is always complex.
            Necessary when network parameters are real but some `characters` are negative.


    """

    if isinstance(symmetries,
                  Lattice) and (point_group is not None
                                or symmetries._point_group is not None):
        # With graph try to find point group, otherwise default to automorphisms
        shape = tuple(symmetries.extent)
        sg = symmetries.space_group(point_group)
        if mode == "auto":
            mode = "fft"
    elif isinstance(symmetries, Graph):
        sg = symmetries.automorphisms()
        if mode == "auto":
            mode = "irreps"
        if mode == "fft":
            raise ValueError(
                "When requesting 'mode=fft' a valid point group must be specified"
                "in order to construct the space group")
    elif isinstance(symmetries, PermutationGroup):
        # If we get a group and default to irrep projection
        if mode == "auto":
            mode = "irreps"
        sg = symmetries
    else:
        if irreps is not None and (mode == "irreps" or mode == "auto"):
            mode = "irreps"
            sg = symmetries
            irreps = tuple(HashableArray(irrep) for irrep in irreps)
        elif product_table is not None and (mode == "fft" or mode == "auto"):
            mode = "fft"
            sg = symmetries
            product_table = HashableArray(product_table)
        else:
            raise ValueError(
                "Specification of symmetries is wrong or incompatible with selected mode"
            )

    if mode == "fft":
        if shape is None:
            raise TypeError(
                "When requesting `mode=fft`, the shape of the translation group must be specified. "
                "Either supply the `shape` keyword argument or pass a `netket.graph.Graph` object to "
                "the symmetries keyword argument.")
        else:
            shape = tuple(shape)

    if isinstance(features, int):
        features = (features, ) * layers

    if characters is None:
        characters = HashableArray(np.ones(len(np.asarray(sg))))
    else:
        if (not is_complex(characters) and not is_complex_dtype(param_dtype)
                and not complex_output and jnp.any(characters < 0)):
            raise ValueError(
                "`complex_output` must be used with real parameters and negative "
                "characters to avoid NaN errors.")
        characters = HashableArray(characters)

    if mode == "fft":
        sym = HashableArray(np.asarray(sg))
        if product_table is None:
            product_table = HashableArray(sg.product_table)
        if parity:
            return GCNN_Parity_FFT(
                symmetries=sym,
                product_table=product_table,
                layers=layers,
                features=features,
                characters=characters,
                shape=shape,
                parity=parity,
                param_dtype=param_dtype,
                complex_output=complex_output,
                **kwargs,
            )
        else:
            return GCNN_FFT(
                symmetries=sym,
                product_table=product_table,
                layers=layers,
                features=features,
                characters=characters,
                shape=shape,
                param_dtype=param_dtype,
                complex_output=complex_output,
                **kwargs,
            )
    elif mode in ["irreps", "auto"]:
        sym = HashableArray(np.asarray(sg))

        if irreps is None:
            irreps = tuple(
                HashableArray(irrep) for irrep in sg.irrep_matrices())

        if parity:
            return GCNN_Parity_Irrep(
                symmetries=sym,
                irreps=irreps,
                layers=layers,
                features=features,
                characters=characters,
                parity=parity,
                param_dtype=param_dtype,
                complex_output=complex_output,
                **kwargs,
            )
        else:
            return GCNN_Irrep(
                symmetries=sym,
                irreps=irreps,
                layers=layers,
                features=features,
                characters=characters,
                param_dtype=param_dtype,
                complex_output=complex_output,
                **kwargs,
            )
    else:
        raise ValueError(
            f"Unknown mode={mode}. Valid modes are 'fft',irreps' or 'auto'.")