Beispiel #1
0
def symm_input_warning(x_shape, new_x_shape, name):
    warn_deprecation(
        (f"{len(x_shape)}-dimensional input to {name} layer is deprecated.\n"
         f"Input shape {x_shape} has been reshaped to {new_x_shape}, where "
         "the middle dimension encodes different input channels.\n"
         "Please provide a 3-dimensional input.\nThis warning will become an "
         "error in the future."))
Beispiel #2
0
def graph_to_N_depwarn(N, graph):

    if graph is not None:
        warn_deprecation(
            r"""
            The ``graph`` argument for hilbert spaces has been deprecated in v3.0.
            It has been replaced by the argument ``N`` accepting an integer, with
            the number of nodes in the graph.

            You can update your code by passing `N=_your_graph.n_nodes`.
            If you are also using `Ising`, `Heisenberg`, `BoseHubbard` or `GraphOperator`
            Hamiltonians you must now provide them with the extra argument
            ``graph=_your_graph``, as they no longer take it from the Hilbert space.
            """
        )

        if N == 1:
            return graph.n_nodes
        else:
            raise ValueError(
                "Graph object can only take one argument among N and graph"
                "(deprecated)."
            )

    return N
Beispiel #3
0
def QGTOnTheFly(vstate=None, **kwargs) -> "QGTOnTheFlyT":
    """
    Lazy representation of an S Matrix computed by performing 2 jvp
    and 1 vjp products, using the variational state's model, the
    samples that have already been computed, and the vector.

    The S matrix is not computed yet, but can be computed by calling
    :code:`to_dense`.
    The details on how the ⟨S⟩⁻¹⟨F⟩ system is solved are contaianed in
    the field `sr`.

    Args:
        vstate: The variational State.
    """
    if vstate is None:
        return partial(QGTOnTheFly, **kwargs)

    if "centered" in kwargs:
        warn_deprecation(
            "The argument `centered` is deprecated. The implementation now always behaves as if centered=False."
        )

    return QGTOnTheFlyT(
        apply_fun=vstate._apply_fun,
        params=vstate.parameters,
        samples=vstate.samples,
        model_state=vstate.model_state,
        **kwargs,
    )
Beispiel #4
0
def QGTOnTheFly(vstate=None, **kwargs) -> "QGTOnTheFlyT":
    """
    Lazy representation of an S Matrix computed by performing 2 jvp
    and 1 vjp products, using the variational state's model, the
    samples that have already been computed, and the vector.

    The S matrix is not computed yet, but can be computed by calling
    :code:`to_dense`.
    The details on how the ⟨S⟩⁻¹⟨F⟩ system is solved are contained in
    the field `sr`.

    Args:
        vstate: The variational State.
    """
    if vstate is None:
        return partial(QGTOnTheFly, **kwargs)

    if "centered" in kwargs:
        warn_deprecation(
            "The argument `centered` is deprecated. The implementation now always behaves as if centered=False."
        )
        kwargs.pop("centered")

    # TODO: Find a better way to handle this case
    from netket.vqs import ExactState

    if isinstance(vstate, ExactState):
        raise TypeError("Only QGTJacobianPyTree works with ExactState.")

    if jnp.ndim(vstate.samples) == 2:
        samples = vstate.samples
    else:
        samples = vstate.samples.reshape((-1, vstate.samples.shape[-1]))

    chunk_size = vstate.chunk_size
    n_samples = samples.shape[0]

    if chunk_size is None or chunk_size >= n_samples:
        mv_factory = mat_vec_factory
        chunking = False
    else:
        samples, _ = nkjax.chunk(samples, chunk_size)
        mv_factory = mat_vec_chunked_factory
        chunking = True

    mat_vec = mv_factory(
        forward_fn=vstate._apply_fun,
        params=vstate.parameters,
        model_state=vstate.model_state,
        samples=samples,
    )
    return QGTOnTheFlyT(
        _mat_vec=mat_vec,
        _params=vstate.parameters,
        _chunking=chunking,
        **kwargs,
    )
Beispiel #5
0
    def random_state(
        self,
        key=NoneType(),
        size: Optional[int] = NoneType(),
        dtype=np.float32,
        out: Optional[np.ndarray] = None,
        rgen=None,
    ) -> jnp.ndarray:
        r"""Generates either a single or a batch of uniformly distributed random states.
        Runs as :code:`random_state(self, key, size=None, dtype=np.float32)` by default.

        Args:
            key: rng state from a jax-style functional generator.
            size: If provided, returns a batch of configurations of the form :code:`(size, N)` if size
                  is an integer or :code:`(*size, N)` if it is a tuple and where :math:`N` is the Hilbert space size.
                  By default, a single random configuration with shape :code:`(#,)` is returned.
            dtype: DType of the resulting vector.
            out: Deprecated. Will be removed in v3.1
            rgen: Deprecated. Will be removed in v3.1

        Returns:
            A state or batch of states sampled from the uniform distribution on the hilbert space.

        Example:

            >>> import netket, jax
            >>> hi = netket.hilbert.Qubit(N=2)
            >>> k1, k2 = jax.random.split(jax.random.PRNGKey(1))
            >>> print(hi.random_state(key=k1))
            [1. 0.]
            >>> print(hi.random_state(key=k2, size=2))
            [[0. 0.]
             [0. 1.]]
        """
        # legacy support
        # TODO: Remove in 3.1
        # if no positional arguments, and key is unspecified -> legacy

        if isinstance(key, NoneType):
            warn_deprecation(legacy_warn_str)
            # legacy sure
            if isinstance(size, NoneType):
                return self._random_state_legacy(size=None, out=out, rgen=rgen)
            else:
                return self._random_state_legacy(size=size, out=out, rgen=rgen)
        elif (isinstance(key, tuple)
              or isinstance(key, int) and isinstance(size, NoneType)):
            # if one positional argument legacy typee...
            warn_deprecation(legacy_warn_str)
            return self._random_state_legacy(size=key, out=out, rgen=rgen)
        else:
            from netket.hilbert import random

            size = size if not isinstance(size, NoneType) else None

            return random.random_state(self, key, size, dtype=dtype)
Beispiel #6
0
    def n_discard(self) -> int:
        """
        DEPRECATED: Use `n_discard_per_chain` instead.

        Number of discarded samples at the beginning of the markov chain.
        """
        warn_deprecation(
            "`n_discard` has been renamed to `n_discard_per_chain` and deprecated."
            "Please update your code to use `n_discard_per_chain`.")

        return self.n_discard_per_chain
Beispiel #7
0
    def setup(self):
        # TODO: evenutally remove this warning
        # supports a deprecated attribute
        if self.extra_bias:
            warn_deprecation(
                (
                    "`extra_bias` is detrimental for performance and is deprecated. "
                    "Please switch to the default `extra_bias=False`. Previously saved "
                    "parameters can be migrated using `nk.models.update_GCNN_parity`."
                )
            )

        self.n_symm = np.asarray(self.symmetries).shape[0]

        self.dense_symm = DenseSymmFFT(
            space_group=self.symmetries,
            shape=self.shape,
            features=self.features[0],
            dtype=self.dtype,
            use_bias=self.use_bias,
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
            precision=self.precision,
        )

        self.equivariant_layers = [
            DenseEquivariantFFT(
                product_table=self.product_table,
                shape=self.shape,
                features=self.features[layer + 1],
                use_bias=self.use_bias,
                dtype=self.dtype,
                precision=self.precision,
                kernel_init=self.kernel_init,
                bias_init=self.bias_init,
            )
            for layer in range(self.layers - 1)
        ]

        self.equivariant_layers_flip = [
            DenseEquivariantFFT(
                product_table=self.product_table,
                shape=self.shape,
                features=self.features[layer + 1],
                # this would bias the same outputs as self.equivariant
                use_bias=self.extra_bias and self.use_bias,
                dtype=self.dtype,
                precision=self.precision,
                kernel_init=self.kernel_init,
                bias_init=self.bias_init,
            )
            for layer in range(self.layers - 1)
        ]
Beispiel #8
0
def SRLazyGMRES(diag_shift: float = 0.01, centered: bool = None, **kwargs):

    if centered is not None:
        warn_deprecation(
            "The argument `centered` is deprecated. The implementation now always behaves as if centered=False."
        )

    return SR(
        qgt.QGTOnTheFly,
        solver=partial(jax.scipy.sparse.linalg.gmres, **kwargs),
        diag_shift=diag_shift,
        **kwargs,
    )
Beispiel #9
0
    def __init__(
        self,
        sampler,
        model=None,
        *,
        sampler_diag: Sampler = None,
        n_samples_diag: int = None,
        n_samples_per_rank_diag: Optional[int] = None,
        n_discard_per_chain_diag: Optional[int] = None,
        n_discard_diag: Optional[int] = None,  # deprecated
        seed=None,
        sampler_seed: Optional[int] = None,
        variables=None,
        **kwargs,
    ):
        """
        Constructs the MCMixedState.
        Arguments are the same as :class:`MCState`.

        Arguments:
            sampler: The sampler
            model: (Optional) The model. If not provided, you must provide init_fun and apply_fun.
            n_samples: the total number of samples across chains and processes when sampling (default=1000).
            n_samples_per_rank: the total number of samples across chains on one process when sampling. Cannot be
                specified together with n_samples (default=None).
            n_discard_per_chain: number of discarded samples at the beginning of each monte-carlo chain (default=n_samples/10).
            n_samples_diag: the total number of samples across chains and processes when sampling the diagonal
                of the density matrix (default=1000).
            n_samples_per_rank_diag: the total number of samples across chains on one process when sampling the diagonal.
                Cannot be specified together with `n_samples_diag` (default=None).
            n_discard_per_chain_diag: number of discarded samples at the beginning of each monte-carlo chain used when sampling
                the diagonal of the density matrix for observables (default=n_samples_diag/10).
            parameters: Optional PyTree of weights from which to start.
            seed: rng seed used to generate a set of parameters (only if parameters is not passed). Defaults to a random one.
            sampler_seed: rng seed used to initialise the sampler. Defaults to a random one.
            mutable: Dict specifying mutable arguments. Use it to specify if the model has a state that can change
                during evaluation, but that should not be optimised. See also flax.linen.module.apply documentation
                (default=False)
            init_fun: Function of the signature f(model, shape, rng_key, dtype) -> Optional_state, parameters used to
                initialise the parameters. Defaults to the standard flax initialiser. Only specify if your network has
                a non-standard init method.
            apply_fun: Function of the signature f(model, variables, σ) that should evaluate the model. Defaults to
                `model.apply(variables, σ)`. specify only if your network has a non-standard apply method.
            training_kwargs: a dict containing the optional keyword arguments to be passed to the apply_fun during training.
                Useful for example when you have a batchnorm layer that constructs the average/mean only during training.

        """

        seed, seed_diag = jax.random.split(nkjax.PRNGKey(seed))
        if sampler_seed is None:
            sampler_seed_diag = None
        else:
            sampler_seed, sampler_seed_diag = jax.random.split(
                nkjax.PRNGKey(sampler_seed)
            )

        self._diagonal = None

        hilbert_physical = sampler.hilbert.physical

        super().__init__(
            sampler.hilbert.physical,
            sampler,
            model,
            **kwargs,
            seed=seed,
            sampler_seed=sampler_seed,
            variables=variables,
        )

        if sampler_diag is None:
            sampler_diag = sampler.replace(hilbert=hilbert_physical)

        sampler_diag = sampler_diag.replace(machine_pow=1)

        diagonal_apply_fun = nkjax.HashablePartial(apply_diagonal, self._apply_fun)

        for kw in [
            "n_samples",
            "n_discard",
            "n_discard_per_chain",
        ]:  # TODO remove n_discard after deprecation.
            if kw in kwargs:
                kwargs.pop(kw)

        # TODO: remove deprecation.
        if n_discard_diag is not None and n_discard_per_chain_diag is not None:
            raise ValueError(
                "`n_discard_diag` has been renamed to `n_discard_per_chain_diag` and deprecated."
                "Specify only `n_discard_per_chain_diag`."
            )
        elif n_discard_diag is not None:
            warn_deprecation(
                "`n_discard_diag` has been renamed to `n_discard_per_chain_diag` and deprecated."
                "Please update your code to `n_discard_per_chain_diag`."
            )
            n_discard_per_chain_diag = n_discard_diag

        self._diagonal = MCState(
            sampler_diag,
            apply_fun=diagonal_apply_fun,
            n_samples=n_samples_diag,
            n_samples_per_rank=n_samples_per_rank_diag,
            n_discard_per_chain=n_discard_per_chain_diag,
            variables=self.variables,
            seed=seed_diag,
            sampler_seed=sampler_seed_diag,
            **kwargs,
        )
Beispiel #10
0
 def n_discard_diag(self, val) -> int:
     warn_deprecation(
         "`n_discard_diag` has been renamed to `n_discard_per_chain_diag` and deprecated."
         "Please update your code to use `n_discard_per_chain_diag`."
     )
     self.n_discard_per_chain_diag = val
Beispiel #11
0
    def __init__(
        self,
        hamiltonian: AbstractOperator,
        optimizer,
        *args,
        variational_state=None,
        preconditioner: PreconditionerT = None,
        sr: PreconditionerT = None,
        sr_restart: bool = None,
        **kwargs,
    ):
        """
        Initializes the driver class.

        Args:
            hamiltonian: The Hamiltonian of the system.
            optimizer: Determines how optimization steps are performed given the
                bare energy gradient.
            preconditioner: Determines which preconditioner to use for the loss gradient.
                This must be a tuple of `(object, solver)` as documented in the section
                `preconditioners` in the documentation. The standard preconditioner
                included with NetKet is Stochastic Reconfiguration. By default, no
                preconditioner is used and the bare gradient is passed to the optimizer.
        """
        if variational_state is None:
            variational_state = MCState(*args, **kwargs)

        if variational_state.hilbert != hamiltonian.hilbert:
            raise TypeError(
                dedent(
                    f"""the variational_state has hilbert space {variational_state.hilbert}
                    (this is normally defined by the hilbert space in the sampler), but
                    the hamiltonian has hilbert space {hamiltonian.hilbert}.
                    The two should match.
                    """))

        if sr is not None:
            if preconditioner is not None:
                raise ValueError(
                    "sr is deprecated in favour of preconditioner kwarg. You should not pass both"
                )
            else:
                preconditioner = sr
                warn_deprecation((
                    "The `sr` keyword argument is deprecated in favour of `preconditioner`."
                    "Please update your code to `VMC(.., preconditioner=your_sr)`"
                ))
        if sr_restart is not None:
            if preconditioner is None:
                raise ValueError(
                    "sr_restart only makes sense if you have a preconditioner/SR."
                )
            else:
                preconditioner.solver_restart = sr_restart
                warn_deprecation((
                    "The `sr_restart` keyword argument is deprecated in favour of specifying "
                    "`solver_restart` in the constructor of the SR object."
                    "Please update your code to `VMC(.., preconditioner=nk.optimizer.SR(..., solver_restart=True/False))`"
                ))

        # move as kwarg once deprecations are removed
        if preconditioner is None:
            preconditioner = identity_preconditioner

        super().__init__(variational_state,
                         optimizer,
                         minimized_quantity_name="Energy")

        self._ham = hamiltonian.collect()  # type: AbstractOperator

        self.preconditioner = preconditioner

        self._dp = None  # type: PyTree
        self._S = None
        self._sr_info = None
Beispiel #12
0
    def __init__(
        self,
        sampler: Sampler,
        model=None,
        *,
        n_samples: int = None,
        n_samples_per_rank: Optional[int] = None,
        n_discard: Optional[int] = None,  # deprecated
        n_discard_per_chain: Optional[int] = None,
        variables: Optional[PyTree] = None,
        init_fun: NNInitFunc = None,
        apply_fun: Callable = None,
        sample_fun: Callable = None,
        seed: Optional[SeedT] = None,
        sampler_seed: Optional[SeedT] = None,
        mutable: bool = False,
        training_kwargs: Dict = {},
    ):
        """
        Constructs the MCState.

        Args:
            sampler: The sampler
            model: (Optional) The model. If not provided, you must provide init_fun and apply_fun.
            n_samples: the total number of samples across chains and processes when sampling (default=1000).
            n_samples_per_rank: the total number of samples across chains on one process when sampling. Cannot be
                specified together with n_samples (default=None).
            n_discard_per_chain: number of discarded samples at the beginning of each monte-carlo chain (default=0 for exact sampler,
                and n_samples/10 for approximate sampler).
            parameters: Optional PyTree of weights from which to start.
            seed: rng seed used to generate a set of parameters (only if parameters is not passed). Defaults to a random one.
            sampler_seed: rng seed used to initialise the sampler. Defaults to a random one.
            mutable: Dict specifing mutable arguments. Use it to specify if the model has a state that can change
                during evaluation, but that should not be optimised. See also flax.linen.module.apply documentation
                (default=False)
            init_fun: Function of the signature f(model, shape, rng_key, dtype) -> Optional_state, parameters used to
                initialise the parameters. Defaults to the standard flax initialiser. Only specify if your network has
                a non-standard init method.
            variables: Optional initial value for the variables (parameters and model state) of the model.
            apply_fun: Function of the signature f(model, variables, σ) that should evaluate the model. Defafults to
                `model.apply(variables, σ)`. specify only if your network has a non-standard apply method.
            sample_fun: Optional function used to sample the state, if it is not the same as `apply_fun`.
            training_kwargs: a dict containing the optionaal keyword arguments to be passed to the apply_fun during training.
                Useful for example when you have a batchnorm layer that constructs the average/mean only during training.
            n_discard: DEPRECATED. Please use `n_discard_per_chain` which has the same behaviour.
        """
        super().__init__(sampler.hilbert)

        # Init type 1: pass in a model
        if model is not None:
            # extract init and apply functions
            # Wrap it in an HashablePartial because if two instances of the same model are provided,
            # model.apply and model2.apply will be different methods forcing recompilation, but
            # model and model2 will have the same hash.
            _, model = maybe_wrap_module(model)

            self._model = model

            self._init_fun = nkjax.HashablePartial(
                lambda model, *args, **kwargs: model.init(*args, **kwargs),
                model)
            self._apply_fun = nkjax.HashablePartial(
                lambda model, *args, **kwargs: model.apply(*args, **kwargs),
                model)

        elif apply_fun is not None:
            self._apply_fun = apply_fun

            if init_fun is not None:
                self._init_fun = init_fun
            elif variables is None:
                raise ValueError(
                    "If you don't provide variables, you must pass a valid init_fun."
                )

            self._model = wrap_afun(apply_fun)

        else:
            raise ValueError(
                "Must either pass the model or apply_fun, otherwise how do you think we"
                "gonna evaluate the model?")

        # default argument for n_samples/n_samples_per_rank
        if n_samples is None and n_samples_per_rank is None:
            n_samples = 1000
        elif n_samples is not None and n_samples_per_rank is not None:
            raise ValueError(
                "Only one argument between `n_samples` and `n_samples_per_rank`"
                "can be specified at the same time.")

        if n_discard is not None and n_discard_per_chain is not None:
            raise ValueError(
                "`n_discard` has been renamed to `n_discard_per_chain` and deprecated."
                "Specify only `n_discard_per_chain`.")
        elif n_discard is not None:
            warn_deprecation(
                "`n_discard` has been renamed to `n_discard_per_chain` and deprecated."
                "Please update your code to use `n_discard_per_chain`.")
            n_discard_per_chain = n_discard

        if sample_fun is not None:
            self._sample_fun = sample_fun
        else:
            self._sample_fun = self._apply_fun

        self.mutable = mutable
        self.training_kwargs = flax.core.freeze(training_kwargs)

        if variables is not None:
            self.variables = variables
        else:
            self.init(seed, dtype=sampler.dtype)

        if sampler_seed is None and seed is not None:
            key, key2 = jax.random.split(nkjax.PRNGKey(seed), 2)
            sampler_seed = key2

        self._sampler_seed = sampler_seed
        self.sampler = sampler

        if n_samples is not None:
            self.n_samples = n_samples
        else:
            self.n_samples_per_rank = n_samples_per_rank

        self.n_discard_per_chain = n_discard_per_chain
Beispiel #13
0
    def __init__(
        self,
        lindbladian,
        optimizer,
        *args,
        variational_state=None,
        preconditioner=None,
        sr=None,
        sr_restart=None,
        **kwargs,
    ):
        """
        Initializes the driver class.

        Args:
            lindbladian: The Lindbladian of the system.
            optimizer: Determines how optimization steps are performed given the
                bare energy gradient.
            preconditioner: Determines which preconditioner to use for the loss gradient.
                This must be a tuple of `(object, solver)` as documented in the section
                `preconditioners` in the documentation. The standard preconditioner
                included with NetKet is Stochastic Reconfiguration. By default, no preconditioner
                is used and the bare gradient is passed to the optimizer.
        """
        if variational_state is None:
            variational_state = MCMixedState(*args, **kwargs)

        if not isinstance(lindbladian, AbstractSuperOperator):
            raise TypeError("The first argument must be a super-operator")

        if sr is not None:
            if preconditioner is not None:
                raise ValueError(
                    "sr is deprecated in favour of preconditioner kwarg. You should not pass both"
                )
            else:
                preconditioner = sr
                warn_deprecation(
                    (
                        "The `sr` keyword argument is deprecated in favour of `preconditioner`."
                        "Please update your code to `SteadyState(.., precondioner=your_sr)`"
                    )
                )

        if sr_restart is not None:
            if preconditioner is None:
                raise ValueError(
                    "sr_restart only makes sense if you have a preconditioner/SR."
                )
            else:
                preconditioner.solver_restart = sr_restart
                warn_deprecation(
                    (
                        "The `sr_restart` keyword argument is deprecated in favour of specifiying "
                        "`solver_restart` in the constructor of the SR object."
                        "Please update your code to `SteadyState(.., preconditioner=nk.optimizer.SR(..., solver_restart=True/False))`"
                    )
                )

        # move as kwarg once deprecations are removed
        if preconditioner is None:
            preconditioner = identity_preconditioner

        super().__init__(variational_state, optimizer, minimized_quantity_name="LdagL")

        self._lind = lindbladian
        self._ldag_l = Squared(lindbladian)

        self.preconditioner = preconditioner

        self._dp = None
        self._S = None
        self._sr_info = None
Beispiel #14
0
def DenseEquivariant(
    symmetries,
    features: int = None,
    mode="auto",
    shape=None,
    point_group=None,
    in_features=None,
    **kwargs,
):
    r"""A group convolution operation that is equivariant over a symmetry group.

    Acts on a feature map of symmetry poses of shape [num_samples, in_features, num_symm]
    and returns a feature  map of poses of shape [num_samples, features, num_symm]

    G-convolutions are described in ` Cohen et. {\it al} <http://proceedings.mlr.press/v48/cohenc16.pdf>`_
    and applied to quantum many-body problems in ` Roth et. {\it al} <https://arxiv.org/pdf/2104.05085.pdf>`_

    The G-convolution generalizes the convolution to non-commuting groups:

    .. math ::

        C^i_g = \sum_h {\bf W}_{g^{-1} h} \cdot {\bf f}_h

    Group elements that differ by the same symmetry operation (i.e. :math:`g = xh`
    and :math:`g' = xh'`) are connected by the same filter.

    This layer maps an input of shape `(..., in_features, n_sites)` to an
    output of shape `(..., features, num_symm)`.

    Args:
        symmetries: A specification of the symmetry group. Can be given by a
            nk.graph.Graph, an nk.utils.PermuationGroup, a list of irreducible
            representations or a product table.
        point_group: The point group, from which the space group is built.
            If symmetries is a graph the default point group is overwritten.
        mode: string "fft, irreps, matrix, auto" specifying whether to use a fast
            fourier transform over the translation group, a fourier transform using
            the irreducible representations or by constructing the full kernel matrix.
        shape: A tuple specifying the dimensions of the translation group.
        features: The number of output features. The full output shape
            is [n_batch,features,n_symm].
        use_bias: A bool specifying whether to add a bias to the output (default: True).
        mask: An optional array of shape [n_sites] consisting of ones and zeros
            that can be used to give the kernel a particular shape.
        dtype: The datatype of the weights. Defaults to a 64bit float.
        precision: Optional argument specifying numerical precision of the computation.
            see `jax.lax.Precision`for details.
        kernel_init: Optional kernel initialization function. Defaults to variance scaling.
        bias_init: Optional bias initialization function. Defaults to zero initialization.
    """
    # deprecate in_features
    if in_features is not None:
        warn_deprecation((
            "`in_features` is now automatically detected from the input and deprecated."
            "Please remove it when calling `DenseEquivariant`."))
    if "out_features" in kwargs:
        warn_deprecation(
            "`out_features` has been renamed to `features` and the old name is "
            "now deprecated. Please update your code.")
        if features is not None:
            raise ValueError(
                "You must only specify `features`. `out_features` is deprecated."
            )
        features = kwargs.pop("out_features")

    if features is None:
        raise ValueError(
            "`features` not specified (the number of output features).")

    kwargs["features"] = features

    if isinstance(symmetries,
                  Lattice) and (point_group is not None
                                or symmetries._point_group is not None):
        shape = tuple(symmetries.extent)
        # With graph try to find point group, otherwise default to automorphisms
        sg = symmetries.space_group(point_group)
        if mode == "auto":
            mode = "fft"
    elif isinstance(symmetries, Graph):
        sg = symmetries.automorphisms()
        if mode == "auto":
            mode = "irreps"
        elif mode == "fft":
            raise ValueError(
                "When requesting 'mode=fft' a valid point group must be specified"
                "in order to construct the space group")
    elif isinstance(symmetries, PermutationGroup):
        # If we get a group and default to irrep projection
        if mode == "auto":
            mode = "irreps"
        sg = symmetries

    elif isinstance(symmetries, Sequence):
        if mode not in ["irreps", "auto"]:
            raise ValueError(
                "Specification of symmetries incompatible with mode")
        return DenseEquivariantIrrep(symmetries, **kwargs)
    else:
        if symmetries.ndim == 2 and symmetries.shape[0] == symmetries.shape[1]:
            if mode == "irreps":
                raise ValueError(
                    "Specification of symmetries incompatible with mode")
            elif mode == "matrix":
                return DenseEquivariantMatrix(symmetries, **kwargs)
            else:
                if shape is None:
                    raise TypeError(
                        "When requesting `mode=fft`, the shape of the translation group must be specified. "
                        "Either supply the `shape` keyword argument or pass a `netket.graph.Graph` object to "
                        "the symmetries keyword argument.")
                else:
                    return DenseEquivariantFFT(symmetries,
                                               shape=shape,
                                               **kwargs)
        return ValueError("Invalid Specification of Symmetries")

    if mode == "fft":
        if shape is None:
            raise TypeError(
                "When requesting `mode=fft`, the shape of the translation group must be specified. "
                "Either supply the `shape` keyword argument or pass a `netket.graph.Graph` object to "
                "the symmetries keyword argument.")
        else:
            return DenseEquivariantFFT(HashableArray(sg.product_table),
                                       shape=shape,
                                       **kwargs)
    elif mode in ["irreps", "auto"]:
        irreps = tuple(HashableArray(irrep) for irrep in sg.irrep_matrices())
        return DenseEquivariantIrrep(irreps, **kwargs)
    elif mode == "matrix":
        return DenseEquivariantMatrix(HashableArray(sg.product_table),
                                      **kwargs)
    else:
        raise ValueError(
            f"Unknown mode={mode}. Valid modes are 'fft', 'matrix', 'irreps' or 'auto'."
        )