示例#1
0
    def product_table(self) -> Array:
        """
        Returns a table of indices corresponding to :math:`g^{-1} h` over the group.

        That is, if :code:`g = self[idx_g]', :code:`h = self[idx_h]`, and
        :code:`idx_u = self.product_table[idx_g, idx_h]`, then :code:`self[idx_u]`
        corresponds to :math:`u = g^{-1} h`.
        """
        perms = self.to_array()
        inverse = perms[self.inverse].squeeze()
        n_symm = len(perms)
        product_table = np.zeros([n_symm, n_symm], dtype=int)

        inv_t = inverse.transpose()
        perms_t = perms.transpose()
        inv_elements = perms_t[inv_t].reshape(-1, n_symm * n_symm).transpose()

        perms = [HashableArray(element) for element in perms]
        inv_perms = [HashableArray(element) for element in inv_elements]

        inverse_index_mapping = {element: index for index, element in enumerate(perms)}

        inds = [
            (index, inverse_index_mapping[element])
            for index, element in enumerate(inv_perms)
            if element in inverse_index_mapping
        ]

        inds = np.asarray(inds)

        product_table[inds[:, 0] // n_symm, inds[:, 0] % n_symm] = inds[:, 1]

        return product_table
示例#2
0
def get_true_edges(
    basis_vectors: PositionT,
    sites: Sequence[LatticeSite],
    inside: Sequence[bool],
    basis_coord_to_site,
    extent,
    distance_atol,
):
    positions = _np.array([p.position for p in sites])
    naive_edges = get_edges(
        positions, _np.linalg.norm(basis_vectors, axis=1).max() + distance_atol
    )
    true_edges = set()
    for node1, node2 in naive_edges:
        site1, inside1 = sites[node1], inside[node1]
        site2, inside2 = sites[node2], inside[node2]
        if inside1 and inside2:
            true_edges.add((node1, node2))
        elif inside1 or inside2:
            cell1 = site1.basis_coord
            cell2 = site2.basis_coord
            cell1[:-1] = cell1[:-1] % extent
            cell2[:-1] = cell2[:-1] % extent
            node1 = basis_coord_to_site[HashableArray(cell1)]
            node2 = basis_coord_to_site[HashableArray(cell2)]
            edge = (node1, node2)
            if edge not in true_edges and (node2, node1) not in true_edges:
                true_edges.add(edge)
    return list(true_edges)
示例#3
0
 def __eq__(self, other):
     if isinstance(other, PGSymmetry):
         return HashableArray(comparable(self._affine)) == HashableArray(
             comparable(other._affine)
         )
     else:
         return False
示例#4
0
 def _get_id_from_dict(dict: Dict[HashableArray, int],
                       key: Array) -> Union[int, Array]:
     if key.ndim == 1:
         return dict.get(HashableArray(key), None)
     elif key.ndim == 2:
         return _np.array([dict.get(HashableArray(k), None) for k in key])
     else:
         raise ValueError("Input needs to be rank 1 or rank 2 array")
示例#5
0
    def setup(self):
        if isinstance(self.symmetry_info, SymmGroup):
            self.symmetry_info = HashableArray(
                self.symmetry_info.product_table().ravel()
            )
        if not np.asarray(self.symmetry_info).ndim == 1:
            raise ValueError("Product table should be flattened")

        self.n_symm = int(np.sqrt(np.asarray(self.symmetry_info).shape[0]))
示例#6
0
 def _get_id_from_dict(dict: Dict[HashableArray, int],
                       key: Array) -> Union[int, Array]:
     try:
         if key.ndim == 1:
             return dict[HashableArray(key)]
         elif key.ndim == 2:
             return _np.array([dict[HashableArray(k)] for k in key])
         else:
             raise ValueError("Input needs to be rank 1 or rank 2 array")
     except KeyError as e:
         raise InvalidSiteError(
             "Some coordinates do not correspond to a valid lattice site"
         ) from e
示例#7
0
    def product_table(self) -> Array:
        try:
            perms = self.to_array()
            inverse = perms[self.inverse].squeeze()
            n_symm = len(perms)
            product_table = np.zeros([n_symm, n_symm], dtype=int)

            inv_t = inverse.transpose()
            perms_t = perms.transpose()
            inv_elements = perms_t[inv_t].reshape(-1,
                                                  n_symm * n_symm).transpose()

            inv_perms = [HashableArray(element) for element in inv_elements]

            lookup = self._canonical_lookup()

            inds = [(index, lookup[element])
                    for index, element in enumerate(inv_perms)]

            inds = np.asarray(inds)

            product_table[inds[:, 0] // n_symm, inds[:, 0] % n_symm] = inds[:,
                                                                            1]

            return product_table
        except KeyError as err:
            raise RuntimeError(
                "PermutationGroup is not closed under multiplication") from err
示例#8
0
def create_sites(
    basis_vectors, extent, apositions, pbc, order
) -> Tuple[Sequence[LatticeSite], Sequence[bool], Dict[HashableArray, int]]:
    # note: by modifying these, the number of shells can be tuned.
    shell_vec = _np.where(pbc, 2 * order, 0)
    shift_vec = _np.where(pbc, order, 0)

    shell_min = 0 - shift_vec
    shell_max = _np.asarray(extent) + shell_vec - shift_vec
    # cell coordinates
    ranges = [slice(lo, hi) for lo, hi in zip(shell_min, shell_max)]
    # site coordinate within unit cell
    ranges += [slice(0, len(apositions))]

    basis_coords = _np.vstack([_np.ravel(x) for x in _np.mgrid[ranges]]).T
    site_coords = (
        basis_coords[:, :-1] +
        _np.tile(apositions.T, reps=len(basis_coords) // len(apositions)).T)
    positions = site_coords @ basis_vectors

    sites = []
    coord_to_site = {}
    for idx, (coord, pos) in enumerate(zip(basis_coords, positions)):
        sites.append(
            LatticeSite(
                id=None,  # to be set later, after sorting all sites
                basis_coord=coord,
                position=pos,
            ), )
        coord_to_site[HashableArray(coord)] = idx
    is_inside = ~(_np.any(basis_coords[:, :-1] < 0, axis=1)
                  | _np.any(basis_coords[:, :-1] > (extent - 1), axis=1))
    return sites, is_inside, coord_to_site
示例#9
0
def create_sites(
        basis_vectors, extent, apositions,
        pbc) -> Tuple[Tuple[LatticeSite, bool], Dict[HashableArray, int]]:
    shell_vec = _np.zeros(extent.size, dtype=int)
    shift_vec = _np.zeros(extent.size, dtype=int)
    # note: by modifying these, the number of shells can be tuned.
    shell_vec[pbc] = 2
    shift_vec[pbc] = 1
    ranges = tuple([list(range(ex)) for ex in extent + shell_vec])
    sites = []
    cell_coord_to_site = {}
    for s_cell in itertools.product(*ranges):
        s_coord_cell = _np.asarray(s_cell) - shift_vec
        inside = not (_np.any(s_coord_cell < 0)
                      or _np.any(s_coord_cell > (extent - 1)))
        atom_count = len(sites)
        for i, atom_coord in enumerate(apositions):
            s_coord_site = s_coord_cell + atom_coord
            r_coord_site = _np.matmul(basis_vectors.T, s_coord_site)
            cell_coord_site = _np.array((*s_coord_cell, i), dtype=int)
            sites.append(
                (
                    LatticeSite(
                        id=None,  # to be set later, after sorting all sites
                        position=r_coord_site,
                        cell_coord=cell_coord_site,
                    ),
                    inside,
                ), )
            cell_coord_to_site[HashableArray(cell_coord_site)] = atom_count + i
    return sites, cell_coord_to_site
示例#10
0
文件: _group.py 项目: yannra/netket
    def product_table(self) -> Array:
        r"""
        A table of indices corresponding to :math:`g^{-1} h` over the group.

        Assuming the definitions

        .. code::

            g = self[idx_g]
            h = self[idx_h]
            idx_u = self.product_table[idx_g, idx_h]

        :code:`self[idx_u]` corresponds to :math:`u = g^{-1} h` .
        """
        n_symm = len(self)
        product_table = np.zeros([n_symm, n_symm], dtype=int)

        lookup = self._canonical_lookup()

        for i, e1 in enumerate(self.elems[self.inverse]):
            for j, e2 in enumerate(self.elems):
                prod = e1 @ e2
                product_table[i,
                              j] = lookup[HashableArray(self._canonical(prod))]

        return product_table
示例#11
0
    def setup(self):
        # pylint: disable=attribute-defined-outside-init
        if isinstance(self.symmetry_info, PermutationGroup):
            self.symmetry_info = HashableArray(self.symmetry_info.product_table.ravel())
        if not np.asarray(self.symmetry_info).ndim == 1:
            raise ValueError("Product table should be flattened")

        self.n_symm = int(np.sqrt(np.asarray(self.symmetry_info).shape[0]))
示例#12
0
 def _canonical_lookup(self) -> dict:
     r"""
     Creates a lookup table from canonical forms to index in `self.elems`
     """
     return {
         HashableArray(self._canonical(element)): index
         for index, element in enumerate(self.elems)
     }
示例#13
0
def test_HashableArray(numpy):
    a = numpy.asarray(np.random.rand(256, 128))
    b = 2 * a

    wa = HashableArray(a)
    wa2 = HashableArray(a.copy())
    wb = HashableArray(b)

    assert hash(wa) == hash(wa2)
    assert wa == wa2

    assert hash(wb) == hash(wb)
    assert wb == wb

    assert wa != wb

    assert_equal(wa.wrapped, np.asarray(wa))
    assert wa.wrapped is not wa
示例#14
0
    def setup(self):

        self.n_symm = np.asarray(self.symmetries).shape[0]

        if self.flattened_product_table is None and not isinstance(
            self.symmetries, SymmGroup
        ):
            raise AttributeError(
                "product table must be specified if symmetries are given as an array"
            )

        if self.flattened_product_table is None:
            flat_pt = HashableArray(self.symmetries.product_table().ravel())
        else:
            flat_pt = self.flattened_product_table

        if not np.asarray(flat_pt).shape[0] == np.square(self.n_symm):
            raise ValueError("Flattened product table must have shape [n_symm*n_symm]")

        if isinstance(self.features, int):
            feature_dim = [self.features for layer in range(self.layers)]
        else:
            if not len(self.features) == self.layers:
                raise ValueError(
                    """Length of vector specifying feature dimensions must be the same as the number of layers"""
                )
            else:
                feature_dim = tuple(self.features)

        self.dense_symm = nknn.DenseSymm(
            symmetries=self.symmetries,
            features=feature_dim[0],
            dtype=self.dtype,
            use_bias=self.use_bias,
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
            precision=self.precision,
        )

        self.equivariant_layers = [
            nknn.DenseEquivariant(
                symmetry_info=flat_pt,
                in_features=feature_dim[layer],
                out_features=feature_dim[layer + 1],
                use_bias=self.use_bias,
                dtype=self.dtype,
                precision=self.precision,
                kernel_init=self.kernel_init,
                bias_init=self.bias_init,
            )
            for layer in range(self.layers - 1)
        ]
示例#15
0
def get_true_edges(
    basis_vectors: PositionT,
    sites: Sequence[LatticeSite],
    inside: Sequence[bool],
    basis_coord_to_site,
    extent,
    distance_atol,
    order,
):
    positions = _np.array([p.position for p in sites])
    naive_edges_by_order = get_edges(
        positions,
        order * _np.linalg.norm(basis_vectors, axis=1).max() + distance_atol,
        order,
    )
    true_edges_by_order = []
    for k, naive_edges in enumerate(naive_edges_by_order):
        true_edges = set()
        for node1, node2 in naive_edges:
            site1, inside1 = sites[node1], inside[node1]
            site2, inside2 = sites[node2], inside[node2]
            if inside1 and inside2:
                true_edges.add((node1, node2))
            elif inside1 or inside2:
                cell1 = site1.basis_coord
                cell2 = site2.basis_coord
                cell1[:-1] = cell1[:-1] % extent
                cell2[:-1] = cell2[:-1] % extent
                node1 = basis_coord_to_site[HashableArray(cell1)]
                node2 = basis_coord_to_site[HashableArray(cell2)]
                edge = (node1, node2)
                if edge not in true_edges and (node2, node1) not in true_edges:
                    if node1 == node2:
                        raise RuntimeError(
                            f"Lattice contains self-referential edge {(node1, node2)} of order {k}"
                        )
                    true_edges.add(edge)
        true_edges_by_order.append(list(true_edges))
    return true_edges_by_order
示例#16
0
    def inverse(self) -> Array:
        try:
            lookup = self._canonical_lookup()
            inverses = []
            for perm in self.to_array():
                invperm = np.argsort(perm)
                inverses.append(lookup[HashableArray(invperm)])

            return np.asarray(inverses, dtype=int)
        except KeyError as err:
            raise RuntimeError(
                "PermutationGroup does not contain the inverse of all elements"
            ) from err
示例#17
0
    def __init__(self, permutation: Array, name: Optional[str] = None):
        r"""
        Creates a `Permutation` from an array of preimages of :code:`range(N)`

        Arguments:
            permutation: 1D array listing :math:`g^{-1}(x)` for all :math:`0\le x < N`
                (i.e., `V[permutation]` permutes the elements of `V` as desired)
            name: optional, custom name for the permutation

        Returns:
            a `Permutation` object encoding the same permutation
        """
        self.permutation = HashableArray(np.asarray(permutation))
        self.__name = name
示例#18
0
    def inverse(self) -> Array:
        try:
            lookup = self._canonical_lookup()
            inverses = []
            for perm in self.to_array():
                # `np.argsort` changes int32 to int64 on Windows,
                # and we need to change it back
                invperm = np.argsort(perm).astype(perm.dtype)
                inverses.append(lookup[HashableArray(invperm)])

            return np.asarray(inverses, dtype=int)
        except KeyError as err:
            raise RuntimeError(
                "PermutationGroup does not contain the inverse of all elements"
            ) from err
示例#19
0
    def inverse(self) -> Array:
        try:
            lookup = self._canonical_lookup()
            affine_matrices = self.to_array()

            inverse = np.zeros(len(self.elems), dtype=int)

            for index in range(len(self)):
                inverse_matrix = np.linalg.inv(affine_matrices[index])
                inverse[index] = lookup[HashableArray(
                    self._canonical_from_affine_matrix(inverse_matrix))]

            return inverse
        except KeyError as err:
            raise RuntimeError(
                "PointGroup does not contain the inverse of all elements"
            ) from err
示例#20
0
    def product_table(self) -> Array:
        try:
            # again, we calculate the product table of transformation matrices directly
            affine_matrices = self.to_array()
            product_matrices = np.einsum(
                "iab, jbc -> ijac", affine_matrices,
                affine_matrices)  # this is a table of M_g M_h

            lookup = self._canonical_lookup()

            n_symm = len(self)
            product_table = np.zeros((n_symm, n_symm), dtype=int)

            for i in range(n_symm):
                for j in range(n_symm):
                    product_table[i, j] = lookup[HashableArray(
                        self._canonical_from_affine_matrix(
                            product_matrices[i, j]))]

            return product_table[self.inverse]  # reshuffle rows to match specs
        except KeyError as err:
            raise RuntimeError(
                "PointGroup is not closed under multiplication") from err
示例#21
0
def DenseEquivariant(
    symmetries,
    features: int = None,
    mode="auto",
    shape=None,
    point_group=None,
    in_features=None,
    **kwargs,
):
    r"""A group convolution operation that is equivariant over a symmetry group.

    Acts on a feature map of symmetry poses of shape [num_samples, in_features, num_symm]
    and returns a feature  map of poses of shape [num_samples, features, num_symm]

    G-convolutions are described in ` Cohen et. {\it al} <http://proceedings.mlr.press/v48/cohenc16.pdf>`_
    and applied to quantum many-body problems in ` Roth et. {\it al} <https://arxiv.org/pdf/2104.05085.pdf>`_

    The G-convolution generalizes the convolution to non-commuting groups:

    .. math ::

        C^i_g = \sum_h {\bf W}_{g^{-1} h} \cdot {\bf f}_h

    Group elements that differ by the same symmetry operation (i.e. :math:`g = xh`
    and :math:`g' = xh'`) are connected by the same filter.

    This layer maps an input of shape `(..., in_features, n_sites)` to an
    output of shape `(..., features, num_symm)`.

    Args:
        symmetries: A specification of the symmetry group. Can be given by a
            nk.graph.Graph, an nk.utils.PermuationGroup, a list of irreducible
            representations or a product table.
        point_group: The point group, from which the space group is built.
            If symmetries is a graph the default point group is overwritten.
        mode: string "fft, irreps, matrix, auto" specifying whether to use a fast
            fourier transform over the translation group, a fourier transform using
            the irreducible representations or by constructing the full kernel matrix.
        shape: A tuple specifying the dimensions of the translation group.
        features: The number of output features. The full output shape
            is [n_batch,features,n_symm].
        use_bias: A bool specifying whether to add a bias to the output (default: True).
        mask: An optional array of shape [n_sites] consisting of ones and zeros
            that can be used to give the kernel a particular shape.
        dtype: The datatype of the weights. Defaults to a 64bit float.
        precision: Optional argument specifying numerical precision of the computation.
            see `jax.lax.Precision`for details.
        kernel_init: Optional kernel initialization function. Defaults to variance scaling.
        bias_init: Optional bias initialization function. Defaults to zero initialization.
    """
    # deprecate in_features
    if in_features is not None:
        warn_deprecation((
            "`in_features` is now automatically detected from the input and deprecated."
            "Please remove it when calling `DenseEquivariant`."))
    if "out_features" in kwargs:
        warn_deprecation(
            "`out_features` has been renamed to `features` and the old name is "
            "now deprecated. Please update your code.")
        if features is not None:
            raise ValueError(
                "You must only specify `features`. `out_features` is deprecated."
            )
        features = kwargs.pop("out_features")

    if features is None:
        raise ValueError(
            "`features` not specified (the number of output features).")

    kwargs["features"] = features

    if isinstance(symmetries,
                  Lattice) and (point_group is not None
                                or symmetries._point_group is not None):
        shape = tuple(symmetries.extent)
        # With graph try to find point group, otherwise default to automorphisms
        sg = symmetries.space_group(point_group)
        if mode == "auto":
            mode = "fft"
    elif isinstance(symmetries, Graph):
        sg = symmetries.automorphisms()
        if mode == "auto":
            mode = "irreps"
        elif mode == "fft":
            raise ValueError(
                "When requesting 'mode=fft' a valid point group must be specified"
                "in order to construct the space group")
    elif isinstance(symmetries, PermutationGroup):
        # If we get a group and default to irrep projection
        if mode == "auto":
            mode = "irreps"
        sg = symmetries

    elif isinstance(symmetries, Sequence):
        if mode not in ["irreps", "auto"]:
            raise ValueError(
                "Specification of symmetries incompatible with mode")
        return DenseEquivariantIrrep(symmetries, **kwargs)
    else:
        if symmetries.ndim == 2 and symmetries.shape[0] == symmetries.shape[1]:
            if mode == "irreps":
                raise ValueError(
                    "Specification of symmetries incompatible with mode")
            elif mode == "matrix":
                return DenseEquivariantMatrix(symmetries, **kwargs)
            else:
                if shape is None:
                    raise TypeError(
                        "When requesting `mode=fft`, the shape of the translation group must be specified. "
                        "Either supply the `shape` keyword argument or pass a `netket.graph.Graph` object to "
                        "the symmetries keyword argument.")
                else:
                    return DenseEquivariantFFT(symmetries,
                                               shape=shape,
                                               **kwargs)
        return ValueError("Invalid Specification of Symmetries")

    if mode == "fft":
        if shape is None:
            raise TypeError(
                "When requesting `mode=fft`, the shape of the translation group must be specified. "
                "Either supply the `shape` keyword argument or pass a `netket.graph.Graph` object to "
                "the symmetries keyword argument.")
        else:
            return DenseEquivariantFFT(HashableArray(sg.product_table),
                                       shape=shape,
                                       **kwargs)
    elif mode in ["irreps", "auto"]:
        irreps = tuple(HashableArray(irrep) for irrep in sg.irrep_matrices())
        return DenseEquivariantIrrep(irreps, **kwargs)
    elif mode == "matrix":
        return DenseEquivariantMatrix(HashableArray(sg.product_table),
                                      **kwargs)
    else:
        raise ValueError(
            f"Unknown mode={mode}. Valid modes are 'fft', 'matrix', 'irreps' or 'auto'."
        )
示例#22
0
    def __init__(
        self,
        basis_vectors: _np.ndarray,
        extent: _np.ndarray,
        *,
        pbc: Union[bool, Sequence[bool]] = True,
        site_offsets: Optional[_np.ndarray] = None,
        atoms_coord: Optional[_np.ndarray] = None,
        distance_atol: float = 1e-5,
        point_group: Optional[PointGroup] = None,
        max_neighbor_order: Optional[int] = None,
        custom_edges: Optional[Sequence[CustomEdgeT]] = None,
    ):
        """
        Constructs a new ``Lattice`` given its side length and the features of the unit
        cell.

        Args:
            basis_vectors: The basis vectors of the lattice. Should be an array
                of shape `(ndim, ndim)` where each `row` is a basis vector.
            extent: The number of copies of the unit cell; needs to be an array
                of length `ndim`.
            pbc: If ``True`` then the constructed lattice
                will have periodic boundary conditions, otherwise
                open boundary conditions are imposed. Can also be an boolean sequence
                of length `ndim`, indicating either open or closed boundary conditions
                separately for each direction.
            site_offsets: The position offsets of sites in the unit cell (one site at
                the origin by default).
            distance_atol: Distance below which spatial points are considered equal for
                the purpose of identifying nearest neighbors.
            point_group: Default `PointGroup` object for constructing space groups
            max_neighbor_order: For :code:`max_neighbor_order == k`, edges between up
                to :math:`k`-nearest neighbor sites (measured by their Euclidean distance)
                are included in the graph. The edges can be distiguished by their color,
                which is set to :math:`k - 1` (so nearest-neighbor edges have color 0).
                By default, nearest neighbours (:code:`max_neighbor_order=1`) are autogenerated
                unless :code:`custom_edges` is passed.
            custom_edges: (Optional) Lists all edges starting in one unit cell, which
                are repeated in every unit cell of the constructed lattice.
                Should be a list of tuples; each tuple should contain the following:
                * index of the starting point in the unit cell
                * index of the endpoint in the unit cell
                * vector pointing from the former to the latter
                * color of the edge (optional)
                If colors are not supplied, they are assigned sequentially starting from 0.
                Cannot be used together with `max_neighbor_order`.

        Examples:
            Constructs a Kagome lattice with 3 × 3 unit cells:

            >>> import numpy as np
            >>> from netket.graph import Lattice
            >>> # Hexagonal lattice basis
            >>> sqrt3 = np.sqrt(3.0)
            >>> basis = np.array([
            ...     [1.0, 0.0],
            ...     [0.5, sqrt3 / 2.0],
            ... ])
            >>> # Kagome unit cell
            >>> cell = np.array([
            ...     basis[0] / 2.0,
            ...     basis[1] / 2.0,
            ...     (basis[0]+basis[1])/2.0
            ... ])
            >>> g = Lattice(basis_vectors=basis, site_offsets=cell, extent=[3, 3])
            >>> print(g.n_nodes)
            27
            >>> print(g.basis_coords[:6])
            [[0 0 0]
             [0 0 1]
             [0 0 2]
             [0 1 0]
             [0 1 1]
             [0 1 2]]
             >>> print(g.positions[:6])
             [[0.5        0.        ]
              [0.25       0.4330127 ]
              [0.75       0.4330127 ]
              [1.         0.8660254 ]
              [0.75       1.29903811]
              [1.25       1.29903811]]

            Constructs a rectangular lattice with distinct horizontal and vertical edges:

            >>> import numpy as np
            >>> from netket.graph import Lattice
            >>> basis = np.array([
            ...     [1.0,0.0],
            ...     [0.0,0.5],
            ... ])
            >>> custom_edges = [
            ...     (0, 0, [1.0,0.0], 0),
            ...     (0, 0, [0.0,0.5], 1),
            ... ]
            >>> g = Lattice(basis_vectors=basis, pbc=False, extent=[4,6],
            ...     custom_edges=custom_edges)
            >>> print(g.n_nodes)
            24
            >>> print(len(g.edges(filter_color=0)))
            18
            >>> print(len(g.edges(filter_color=1)))
            20
        """
        # Clean input parameters
        self._basis_vectors = self._clean_basis(basis_vectors)
        self._ndim = self._basis_vectors.shape[1]

        self._site_offsets, site_pos_fractional = self._clean_site_offsets(
            site_offsets,
            atoms_coord,
            self._basis_vectors,
        )
        self._pbc = self._clean_pbc(pbc, self._ndim)

        self._extent = _np.asarray(extent, dtype=int)
        self._lattice_dims = _np.expand_dims(self._extent,
                                             1) * self.basis_vectors
        self._inv_dims = _np.linalg.inv(self._lattice_dims)

        self._point_group = point_group

        # Generate sites
        self._sites, self._basis_coords, self._positions = _create_sites(
            self._basis_vectors,
            self._extent,
            self._site_offsets,
        )
        self._basis_coord_to_site = {
            HashableArray(p.basis_coord): p.id
            for p in self._sites
        }
        int_positions = self._to_integer_position(self._positions)
        self._int_position_to_site = {
            HashableArray(pos): index
            for index, pos in enumerate(int_positions)
        }

        # Generate edges
        if custom_edges is not None:
            if max_neighbor_order is not None:
                raise ValueError(
                    "custom_edges and max_neighbor_order cannot be specified at the same time"
                )
            colored_edges = get_custom_edges(
                self._basis_vectors,
                self._extent,
                self._site_offsets,
                self._pbc,
                distance_atol,
                custom_edges,
            )
        else:
            if max_neighbor_order is None:
                max_neighbor_order = 1
            colored_edges = get_nn_edges(
                self._basis_vectors,
                self._extent,
                self._site_offsets,
                self._pbc,
                distance_atol,
                max_neighbor_order,
            )

        super().__init__(colored_edges, len(self._sites))
示例#23
0
 def __hash__(self):
     return hash(HashableArray(comparable(self._affine)))
示例#24
0
def GCNN(
    symmetries=None,
    product_table=None,
    irreps=None,
    point_group=None,
    mode="auto",
    shape=None,
    layers=None,
    features=None,
    characters=None,
    parity=None,
    **kwargs,
):
    r"""Implements a Group Convolutional Neural Network (G-CNN) that outputs a wavefunction
    that is invariant over a specified symmetry group.

    The G-CNN is described in ` Cohen et. *al* <http://proceedings.mlr.press/v48/cohenc16.pdf>`_
    and applied to quantum many-body problems in ` Roth et. *al* <https://arxiv.org/pdf/2104.05085.pdf>`_.

    The G-CNN alternates convolution operations with pointwise non-linearities. The first
    layer is symmetrized linear transform given by DenseSymm, while the other layers are
    G-convolutions given by DenseEquivariant. The hidden layers of the G-CNN are related by
    the following equation:

    .. math ::

        {\bf f}^{i+1}_h = \Gamma( \sum_h W_{g^{-1} h} {\bf f}^i_h).

    Args:
        symmetries: A specification of the symmetry group. Can be given by a
            nk.graph.Graph, a nk.utils.PermuationGroup, or an array [n_symm, n_sites]
            specifying the permutations corresponding to symmetry transformations
            of the lattice.
        product_table: Product table describing the algebra of the symmetry group.
            Only needs to be specified if mode='fft' and symmetries is specified as an array.
        irreps: List of 3D tensors that project onto irreducible representations of the symmetry group.
            Only needs to be specified if mode='irreps' and symmetries is specified as an array.
        point_group: The point group, from which the space group is built.
            If symmetries is a graph the default point group is overwritten.
        mode: string "fft, irreps, matrix, auto" specifying whether to use a fast
            fourier transform over the translation group, a fourier transform using
            the irreducible representations or by constructing the full kernel matrix.
        shape: A tuple specifying the dimensions of the translation group.
        layers: Number of layers (not including sum layer over output).
        features: Number of features in each layer starting from the input. If a single
            number is given, all layers will have the same number of features.
        characters: Array specifying the characters of the desired symmetry representation
        parity: Optional argument with value +/-1 that specifies the eigenvalue
            with respect to parity (only use on two level systems).
        dtype: The dtype of the weights.
        activation: The nonlinear activation function between hidden layers. Defaults to
            :ref:`nk.nn.activation.reim_selu`.
        output_activation: The nonlinear activation before the output.
        equal_amplitudes: If True forces all basis states to have equal amplitude
            by setting Re[psi] = 0.
        use_bias: If True uses a bias in all layers.
        precision: Numerical precision of the computation see `jax.lax.Precision`for details.
        kernel_init: Initializer for the kernels of all layers. Defaults to
            `lecun_normal(in_axis=1, out_axis=0)` which guarantees the correct variance of the
            output.
        bias_init: Initializer for the biases of all layers.
    """

    if isinstance(symmetries, Lattice) and (
        point_group is not None or symmetries._point_group is not None
    ):
        # With graph try to find point group, otherwise default to automorphisms
        shape = tuple(symmetries.extent)
        sg = symmetries.space_group(point_group)
        if mode == "auto":
            mode = "fft"
    elif isinstance(symmetries, Graph):
        sg = symmetries.automorphisms()
        if mode == "auto":
            mode = "irreps"
        if mode == "fft":
            raise ValueError(
                "When requesting 'mode=fft' a valid point group must be specified"
                "in order to construct the space group"
            )
    elif isinstance(symmetries, PermutationGroup):
        # If we get a group and default to irrep projection
        if mode == "auto":
            mode = "irreps"
        sg = symmetries
    else:
        if irreps is not None and (mode == "irreps" or mode == "auto"):
            mode = "irreps"
            sg = symmetries
            irreps = tuple(HashableArray(irrep) for irrep in irreps)
        elif product_table is not None and (mode == "fft" or mode == "auto"):
            mode = "fft"
            sg = symmetries
            product_table = HashableArray(product_table)
        else:
            raise ValueError(
                "Specification of symmetries is wrong or incompatible with selected mode"
            )

    if mode == "fft":
        if shape is None:
            raise TypeError(
                "When requesting `mode=fft`, the shape of the translation group must be specified. "
                "Either supply the `shape` keyword argument or pass a `netket.graph.Graph` object to "
                "the symmetries keyword argument."
            )
        else:
            shape = tuple(shape)

    if isinstance(features, int):
        features = (features,) * layers

    if characters is None:
        characters = HashableArray(np.ones(len(np.asarray(sg))))
    else:
        characters = HashableArray(characters)

    if mode == "fft":
        sym = HashableArray(np.asarray(sg))
        if product_table is None:
            product_table = HashableArray(sg.product_table)
        if parity:
            return GCNN_Parity_FFT(
                symmetries=sym,
                product_table=product_table,
                layers=layers,
                features=features,
                characters=characters,
                shape=shape,
                parity=parity,
                **kwargs,
            )
        else:
            return GCNN_FFT(
                symmetries=sym,
                product_table=product_table,
                layers=layers,
                features=features,
                characters=characters,
                shape=shape,
                **kwargs,
            )
    elif mode in ["irreps", "auto"]:
        sym = HashableArray(np.asarray(sg))

        if irreps is None:
            irreps = tuple(HashableArray(irrep) for irrep in sg.irrep_matrices())

        if parity:
            return GCNN_Parity_Irrep(
                symmetries=sym,
                irreps=irreps,
                layers=layers,
                features=features,
                characters=characters,
                parity=parity,
                **kwargs,
            )
        else:
            return GCNN_Irrep(
                symmetries=sym,
                irreps=irreps,
                layers=layers,
                features=features,
                characters=characters,
                **kwargs,
            )
    else:
        raise ValueError(
            f"Unknown mode={mode}. Valid modes are 'fft',irreps' or 'auto'."
        )
示例#25
0
class DenseEquivariant(Module):
    """Implements a G-convolution that acts on a feature map of symmetry
    poses of shape [batch_size,n_symm*in_features] and returns a feature
    map of poses of shape [batch_size,n_symm*out_features]

    G-convolutions are described in ` Cohen et. {\it al} <http://proceedings.mlr.press/v48/cohenc16.pdf>`_
    and applied to quantum many-body problems in ` Roth et. {\it al} <https://arxiv.org/pdf/2104.05085.pdf>`_

    The G-convolution generalizes the convolution to non-commuting groups:

    .. math ::

        C^i_g = \sum_h {\bf W}_{g^{-1} h} \cdot {\bf f}_h

    Symmetry poses that are linked by the same symmetry element are connected
    by the same filter. The output symmetry group is an involution over the
    input symmetry group, i.e. the symmetry group is inverted by G-convolution

    .. math ::

        {\bf C}*(g) = C(g^{-1})

    """

    symmetry_info: Union[HashableArray, SymmGroup]
    """Flattened product table generated by SymmGroup.produt_table().ravel()
    that specifies the product of the group with its involution, or the
    SymmGroup object itself"""
    in_features: int
    """The number of symmetry-reduced input features. The full input size
    is n_symm*in_features."""
    out_features: int
    """The number of symmetry-reduced output features. The full output size
    is n_symm*out_features."""
    use_bias: bool = True
    """Whether to add a bias to the output (default: True)."""
    dtype: Any = jnp.float64
    """The dtype of the weights."""
    precision: Any = None
    """numerical precision of the computation see `jax.lax.Precision`for details."""

    kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
    """Initializer for the Dense layer matrix."""
    bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros
    """Initializer for the bias."""

    def setup(self):
        if isinstance(self.symmetry_info, SymmGroup):
            self.symmetry_info = HashableArray(
                self.symmetry_info.product_table().ravel()
            )
        if not np.asarray(self.symmetry_info).ndim == 1:
            raise ValueError("Product table should be flattened")

        self.n_symm = int(np.sqrt(np.asarray(self.symmetry_info).shape[0]))

    def full_kernel(self, kernel):
        """
        Converts the symmetry-reduced kernel of shape (n_sites, features) to
        the full Dense kernel of shape (n_sites, features * n_symm).
        """

        result = jnp.take(kernel, self.symmetry_info, 0)
        result = result.reshape(
            self.n_symm, self.n_symm, self.in_features, self.out_features
        )
        result = result.transpose(2, 0, 3, 1).reshape(
            self.n_symm * self.in_features, -1
        )

        return result

    def full_bias(self, bias):
        """
        Convert symmetry-reduced bias of shape (features,) to the full bias of
        shape (n_symm * features,).
        """
        return jnp.repeat(bias, self.n_symm)

    @compact
    def __call__(self, inputs: Array) -> Array:
        """Applies the equivariant transform to the inputs along the last dimension.
        Args:
          inputs: The nd-array to be transformed.
        Returns:
          The transformed input.
        """
        dtype = jnp.promote_types(inputs.dtype, self.dtype)
        inputs = jnp.asarray(inputs, dtype)

        kernel = self.param(
            "kernel",
            self.kernel_init,
            (inputs.shape[-1], self.in_features, self.out_features),
            self.dtype,
        )
        kernel = self.full_kernel(kernel)
        kernel = jnp.asarray(kernel, dtype)

        y = lax.dot_general(
            inputs,
            kernel,
            (((inputs.ndim - 1,), (0,)), ((), ())),
            precision=self.precision,
        )

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.out_features,), self.dtype)
            bias = jnp.asarray(self.full_bias(bias), dtype)
            y += bias

        return y
示例#26
0
    def __init__(
        self,
        basis_vectors: _np.ndarray,
        extent: _np.ndarray,
        *,
        pbc: Union[bool, Sequence[bool]] = True,
        site_offsets: Optional[_np.ndarray] = None,
        atoms_coord: Optional[_np.ndarray] = None,
        distance_atol: float = 1e-5,
        point_group: Optional[PointGroup] = None,
    ):
        """
        Constructs a new ``Lattice`` given its side length and the features of the unit
        cell.

        Args:
            basis_vectors: The basis vectors of the lattice. Should be an array
                of shape `(ndim, ndim)` where each `row` is a basis vector.
            extent: The number of copies of the unit cell; needs to be an array
                of length `ndim`.
            pbc: If ``True`` then the constructed lattice
                will have periodic boundary conditions, otherwise
                open boundary conditions are imposed. Can also be an boolean sequence
                of length `ndim`, indicating either open or closed boundary conditions
                separately for each direction.
            site_offsets: The position offsets of sites in the unit cell (one site at
                the origin by default).
            distance_atol: Distance below which spatial points are considered equal for
                the purpose of identifying nearest neighbors.
            point_group: Default `PointGroup` object for constructing space groups

        Examples:
            Constructs a Kagome lattice with 3 × 3 unit cells:

            >>> import numpy as np
            >>> from netket.graph import Lattice
            >>> # Hexagonal lattice basis
            >>> sqrt3 = np.sqrt(3.0)
            >>> basis = np.array([
            ...     [1.0, 0.0],
            ...     [0.5, sqrt3 / 2.0],
            ... ])
            >>> # Kagome unit cell
            >>> cell = np.array([
            ...     basis[0] / 2.0,
            ...     basis[1] / 2.0,
            ...     (basis[0]+basis[1])/2.0
            ... ])
            >>> g = Lattice(basis_vectors=basis, site_offsets=cell, extent=[3, 3])
            >>> print(g.n_nodes)
            27
            >>> print(g.basis_coords[:6])
            [[0 0 0]
             [0 0 1]
             [0 0 2]
             [0 1 0]
             [0 1 1]
             [0 1 2]]
             >>> print(g.positions[:6])
             [[0.5        0.        ]
              [0.25       0.4330127 ]
              [0.75       0.4330127 ]
              [1.         0.8660254 ]
              [0.75       1.29903811]
              [1.25       1.29903811]]
        """

        self._basis_vectors = self._clean_basis(basis_vectors)
        self._ndim = self._basis_vectors.shape[1]

        self._site_offsets, site_pos_fractional = self._clean_site_offsets(
            site_offsets,
            atoms_coord,
            self._basis_vectors,
        )
        self._pbc = self._clean_pbc(pbc, self._ndim)

        self._extent = _np.asarray(extent, dtype=int)

        self._point_group = point_group

        sites, inside, self._basis_coord_to_site = create_sites(
            self._basis_vectors, self._extent, site_pos_fractional, self._pbc
        )
        edges = get_true_edges(
            self._basis_vectors,
            sites,
            inside,
            self._basis_coord_to_site,
            self._extent,
            distance_atol,
        )

        old_nodes = sorted(set(node for edge in edges for node in edge))
        new_nodes = {old_node: new_node for new_node, old_node in enumerate(old_nodes)}

        graph = igraph.Graph()
        graph.add_vertices(len(old_nodes))
        graph.add_edges([(new_nodes[edge[0]], new_nodes[edge[1]]) for edge in edges])
        graph.simplify()

        self._sites = []
        for i, site in enumerate(sites[old_node] for old_node in old_nodes):
            site.id = i
            self._sites.append(site)
        self._basis_coord_to_site = {
            HashableArray(p.basis_coord): p.id for p in self._sites
        }
        self._positions = _np.array([p.position for p in self._sites])
        self._basis_coords = _np.array([p.basis_coord for p in self._sites])
        self._lattice_dims = _np.expand_dims(self._extent, 1) * self.basis_vectors
        self._inv_dims = _np.linalg.inv(self._lattice_dims)
        int_positions = self._to_integer_position(self._positions)
        self._int_position_to_site = {
            HashableArray(pos): index for index, pos in enumerate(int_positions)
        }

        super().__init__(list(graph.get_edgelist()), graph.vcount())
示例#27
0
def DenseSymm(symmetries, point_group=None, mode="auto", shape=None, **kwargs):
    r"""
    Implements a projection onto a symmetry group. The output will be
    equivariant with respect to the symmetry operations in the group and can
    be averaged to produce an invariant model.

    This layer maps an input of shape `(..., in_features, n_sites)` to an
    output of shape `(..., features, num_symm)`.

    Note: The output shape has changed to seperate the feature and symmetry
    dimensions. The previous shape was [num_samples, num_symm*features] and
    the new shape is [num_samples, features, num_symm]

    Args:
        symmetries: A specification of the symmetry group. Can be given by a
            :ref:`netket.graph.Graph`, a :ref:`netket.utils.group.PermutationGroup`, or an array
            of shape :code:`(n_symm, n_sites)`. A :ref:`netket.utils.HashableArray` may also
            be passed.
            specifying the permutations corresponding to symmetry transformations
            of the lattice.
        point_group: The point group, from which the space group is built.
            If symmetries is a graph the default point group is overwritten.
        mode: string "fft, matrix, auto" specifying whether to use a fast Fourier
            transform, matrix multiplication, or to choose a sensible default
            based on the symmetry group.
        shape: A tuple specifying the dimensions of the translation group.
        features: The number of output features. The full output shape
            is [n_batch,features,n_symm].
        use_bias: A bool specifying whether to add a bias to the output (default: True).
        mask: An optional array of shape [n_sites] consisting of ones and zeros
            that can be used to give the kernel a particular shape.
        dtype: The datatype of the weights. Defaults to a 64bit float.
        precision: Optional argument specifying numerical precision of the computation.
            see `jax.lax.Precision`for details.
        kernel_init: Optional kernel initialization function. Defaults to variance scaling.
        bias_init: Optional bias initialization function. Defaults to zero initialization.
    """
    if isinstance(symmetries,
                  Lattice) and (point_group is not None
                                or symmetries._point_group is not None):
        shape = tuple(symmetries.extent)
        sym = HashableArray(np.asarray(symmetries.space_group(point_group)))
        if mode == "auto":
            mode = "fft"
    elif isinstance(symmetries, Graph):
        if mode == "fft":
            raise ValueError(
                "When requesting 'mode=fft' a valid point group must be specified"
                "in order to construct the space group")
        sym = HashableArray(np.asarray(symmetries.automorphisms()))
    elif isinstance(symmetries, HashableArray):
        sym = symmetries
    else:
        sym = HashableArray(np.asarray(symmetries))

    if mode == "fft":
        if shape is None:
            raise TypeError(
                "When requesting `mode=fft`, the shape of the translation group must be specified. "
                "Either supply the `shape` keyword argument or pass a `netket.graph.Graph` object to "
                "the symmetries keyword argument.")
        else:
            return DenseSymmFFT(sym, shape=shape, **kwargs)
    elif mode in ["matrix", "auto"]:
        return DenseSymmMatrix(sym, **kwargs)
    else:
        raise ValueError(
            f"Unknown mode={mode}. Valid modes are 'fft', 'matrix', or 'auto'."
        )
示例#28
0
 def __init__(self, permutation: Array):
     self.permutation = HashableArray(np.asarray(permutation))