def product_table(self) -> Array: """ Returns a table of indices corresponding to :math:`g^{-1} h` over the group. That is, if :code:`g = self[idx_g]', :code:`h = self[idx_h]`, and :code:`idx_u = self.product_table[idx_g, idx_h]`, then :code:`self[idx_u]` corresponds to :math:`u = g^{-1} h`. """ perms = self.to_array() inverse = perms[self.inverse].squeeze() n_symm = len(perms) product_table = np.zeros([n_symm, n_symm], dtype=int) inv_t = inverse.transpose() perms_t = perms.transpose() inv_elements = perms_t[inv_t].reshape(-1, n_symm * n_symm).transpose() perms = [HashableArray(element) for element in perms] inv_perms = [HashableArray(element) for element in inv_elements] inverse_index_mapping = {element: index for index, element in enumerate(perms)} inds = [ (index, inverse_index_mapping[element]) for index, element in enumerate(inv_perms) if element in inverse_index_mapping ] inds = np.asarray(inds) product_table[inds[:, 0] // n_symm, inds[:, 0] % n_symm] = inds[:, 1] return product_table
def get_true_edges( basis_vectors: PositionT, sites: Sequence[LatticeSite], inside: Sequence[bool], basis_coord_to_site, extent, distance_atol, ): positions = _np.array([p.position for p in sites]) naive_edges = get_edges( positions, _np.linalg.norm(basis_vectors, axis=1).max() + distance_atol ) true_edges = set() for node1, node2 in naive_edges: site1, inside1 = sites[node1], inside[node1] site2, inside2 = sites[node2], inside[node2] if inside1 and inside2: true_edges.add((node1, node2)) elif inside1 or inside2: cell1 = site1.basis_coord cell2 = site2.basis_coord cell1[:-1] = cell1[:-1] % extent cell2[:-1] = cell2[:-1] % extent node1 = basis_coord_to_site[HashableArray(cell1)] node2 = basis_coord_to_site[HashableArray(cell2)] edge = (node1, node2) if edge not in true_edges and (node2, node1) not in true_edges: true_edges.add(edge) return list(true_edges)
def __eq__(self, other): if isinstance(other, PGSymmetry): return HashableArray(comparable(self._affine)) == HashableArray( comparable(other._affine) ) else: return False
def _get_id_from_dict(dict: Dict[HashableArray, int], key: Array) -> Union[int, Array]: if key.ndim == 1: return dict.get(HashableArray(key), None) elif key.ndim == 2: return _np.array([dict.get(HashableArray(k), None) for k in key]) else: raise ValueError("Input needs to be rank 1 or rank 2 array")
def setup(self): if isinstance(self.symmetry_info, SymmGroup): self.symmetry_info = HashableArray( self.symmetry_info.product_table().ravel() ) if not np.asarray(self.symmetry_info).ndim == 1: raise ValueError("Product table should be flattened") self.n_symm = int(np.sqrt(np.asarray(self.symmetry_info).shape[0]))
def _get_id_from_dict(dict: Dict[HashableArray, int], key: Array) -> Union[int, Array]: try: if key.ndim == 1: return dict[HashableArray(key)] elif key.ndim == 2: return _np.array([dict[HashableArray(k)] for k in key]) else: raise ValueError("Input needs to be rank 1 or rank 2 array") except KeyError as e: raise InvalidSiteError( "Some coordinates do not correspond to a valid lattice site" ) from e
def product_table(self) -> Array: try: perms = self.to_array() inverse = perms[self.inverse].squeeze() n_symm = len(perms) product_table = np.zeros([n_symm, n_symm], dtype=int) inv_t = inverse.transpose() perms_t = perms.transpose() inv_elements = perms_t[inv_t].reshape(-1, n_symm * n_symm).transpose() inv_perms = [HashableArray(element) for element in inv_elements] lookup = self._canonical_lookup() inds = [(index, lookup[element]) for index, element in enumerate(inv_perms)] inds = np.asarray(inds) product_table[inds[:, 0] // n_symm, inds[:, 0] % n_symm] = inds[:, 1] return product_table except KeyError as err: raise RuntimeError( "PermutationGroup is not closed under multiplication") from err
def create_sites( basis_vectors, extent, apositions, pbc, order ) -> Tuple[Sequence[LatticeSite], Sequence[bool], Dict[HashableArray, int]]: # note: by modifying these, the number of shells can be tuned. shell_vec = _np.where(pbc, 2 * order, 0) shift_vec = _np.where(pbc, order, 0) shell_min = 0 - shift_vec shell_max = _np.asarray(extent) + shell_vec - shift_vec # cell coordinates ranges = [slice(lo, hi) for lo, hi in zip(shell_min, shell_max)] # site coordinate within unit cell ranges += [slice(0, len(apositions))] basis_coords = _np.vstack([_np.ravel(x) for x in _np.mgrid[ranges]]).T site_coords = ( basis_coords[:, :-1] + _np.tile(apositions.T, reps=len(basis_coords) // len(apositions)).T) positions = site_coords @ basis_vectors sites = [] coord_to_site = {} for idx, (coord, pos) in enumerate(zip(basis_coords, positions)): sites.append( LatticeSite( id=None, # to be set later, after sorting all sites basis_coord=coord, position=pos, ), ) coord_to_site[HashableArray(coord)] = idx is_inside = ~(_np.any(basis_coords[:, :-1] < 0, axis=1) | _np.any(basis_coords[:, :-1] > (extent - 1), axis=1)) return sites, is_inside, coord_to_site
def create_sites( basis_vectors, extent, apositions, pbc) -> Tuple[Tuple[LatticeSite, bool], Dict[HashableArray, int]]: shell_vec = _np.zeros(extent.size, dtype=int) shift_vec = _np.zeros(extent.size, dtype=int) # note: by modifying these, the number of shells can be tuned. shell_vec[pbc] = 2 shift_vec[pbc] = 1 ranges = tuple([list(range(ex)) for ex in extent + shell_vec]) sites = [] cell_coord_to_site = {} for s_cell in itertools.product(*ranges): s_coord_cell = _np.asarray(s_cell) - shift_vec inside = not (_np.any(s_coord_cell < 0) or _np.any(s_coord_cell > (extent - 1))) atom_count = len(sites) for i, atom_coord in enumerate(apositions): s_coord_site = s_coord_cell + atom_coord r_coord_site = _np.matmul(basis_vectors.T, s_coord_site) cell_coord_site = _np.array((*s_coord_cell, i), dtype=int) sites.append( ( LatticeSite( id=None, # to be set later, after sorting all sites position=r_coord_site, cell_coord=cell_coord_site, ), inside, ), ) cell_coord_to_site[HashableArray(cell_coord_site)] = atom_count + i return sites, cell_coord_to_site
def product_table(self) -> Array: r""" A table of indices corresponding to :math:`g^{-1} h` over the group. Assuming the definitions .. code:: g = self[idx_g] h = self[idx_h] idx_u = self.product_table[idx_g, idx_h] :code:`self[idx_u]` corresponds to :math:`u = g^{-1} h` . """ n_symm = len(self) product_table = np.zeros([n_symm, n_symm], dtype=int) lookup = self._canonical_lookup() for i, e1 in enumerate(self.elems[self.inverse]): for j, e2 in enumerate(self.elems): prod = e1 @ e2 product_table[i, j] = lookup[HashableArray(self._canonical(prod))] return product_table
def setup(self): # pylint: disable=attribute-defined-outside-init if isinstance(self.symmetry_info, PermutationGroup): self.symmetry_info = HashableArray(self.symmetry_info.product_table.ravel()) if not np.asarray(self.symmetry_info).ndim == 1: raise ValueError("Product table should be flattened") self.n_symm = int(np.sqrt(np.asarray(self.symmetry_info).shape[0]))
def _canonical_lookup(self) -> dict: r""" Creates a lookup table from canonical forms to index in `self.elems` """ return { HashableArray(self._canonical(element)): index for index, element in enumerate(self.elems) }
def test_HashableArray(numpy): a = numpy.asarray(np.random.rand(256, 128)) b = 2 * a wa = HashableArray(a) wa2 = HashableArray(a.copy()) wb = HashableArray(b) assert hash(wa) == hash(wa2) assert wa == wa2 assert hash(wb) == hash(wb) assert wb == wb assert wa != wb assert_equal(wa.wrapped, np.asarray(wa)) assert wa.wrapped is not wa
def setup(self): self.n_symm = np.asarray(self.symmetries).shape[0] if self.flattened_product_table is None and not isinstance( self.symmetries, SymmGroup ): raise AttributeError( "product table must be specified if symmetries are given as an array" ) if self.flattened_product_table is None: flat_pt = HashableArray(self.symmetries.product_table().ravel()) else: flat_pt = self.flattened_product_table if not np.asarray(flat_pt).shape[0] == np.square(self.n_symm): raise ValueError("Flattened product table must have shape [n_symm*n_symm]") if isinstance(self.features, int): feature_dim = [self.features for layer in range(self.layers)] else: if not len(self.features) == self.layers: raise ValueError( """Length of vector specifying feature dimensions must be the same as the number of layers""" ) else: feature_dim = tuple(self.features) self.dense_symm = nknn.DenseSymm( symmetries=self.symmetries, features=feature_dim[0], dtype=self.dtype, use_bias=self.use_bias, kernel_init=self.kernel_init, bias_init=self.bias_init, precision=self.precision, ) self.equivariant_layers = [ nknn.DenseEquivariant( symmetry_info=flat_pt, in_features=feature_dim[layer], out_features=feature_dim[layer + 1], use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init, ) for layer in range(self.layers - 1) ]
def get_true_edges( basis_vectors: PositionT, sites: Sequence[LatticeSite], inside: Sequence[bool], basis_coord_to_site, extent, distance_atol, order, ): positions = _np.array([p.position for p in sites]) naive_edges_by_order = get_edges( positions, order * _np.linalg.norm(basis_vectors, axis=1).max() + distance_atol, order, ) true_edges_by_order = [] for k, naive_edges in enumerate(naive_edges_by_order): true_edges = set() for node1, node2 in naive_edges: site1, inside1 = sites[node1], inside[node1] site2, inside2 = sites[node2], inside[node2] if inside1 and inside2: true_edges.add((node1, node2)) elif inside1 or inside2: cell1 = site1.basis_coord cell2 = site2.basis_coord cell1[:-1] = cell1[:-1] % extent cell2[:-1] = cell2[:-1] % extent node1 = basis_coord_to_site[HashableArray(cell1)] node2 = basis_coord_to_site[HashableArray(cell2)] edge = (node1, node2) if edge not in true_edges and (node2, node1) not in true_edges: if node1 == node2: raise RuntimeError( f"Lattice contains self-referential edge {(node1, node2)} of order {k}" ) true_edges.add(edge) true_edges_by_order.append(list(true_edges)) return true_edges_by_order
def inverse(self) -> Array: try: lookup = self._canonical_lookup() inverses = [] for perm in self.to_array(): invperm = np.argsort(perm) inverses.append(lookup[HashableArray(invperm)]) return np.asarray(inverses, dtype=int) except KeyError as err: raise RuntimeError( "PermutationGroup does not contain the inverse of all elements" ) from err
def __init__(self, permutation: Array, name: Optional[str] = None): r""" Creates a `Permutation` from an array of preimages of :code:`range(N)` Arguments: permutation: 1D array listing :math:`g^{-1}(x)` for all :math:`0\le x < N` (i.e., `V[permutation]` permutes the elements of `V` as desired) name: optional, custom name for the permutation Returns: a `Permutation` object encoding the same permutation """ self.permutation = HashableArray(np.asarray(permutation)) self.__name = name
def inverse(self) -> Array: try: lookup = self._canonical_lookup() inverses = [] for perm in self.to_array(): # `np.argsort` changes int32 to int64 on Windows, # and we need to change it back invperm = np.argsort(perm).astype(perm.dtype) inverses.append(lookup[HashableArray(invperm)]) return np.asarray(inverses, dtype=int) except KeyError as err: raise RuntimeError( "PermutationGroup does not contain the inverse of all elements" ) from err
def inverse(self) -> Array: try: lookup = self._canonical_lookup() affine_matrices = self.to_array() inverse = np.zeros(len(self.elems), dtype=int) for index in range(len(self)): inverse_matrix = np.linalg.inv(affine_matrices[index]) inverse[index] = lookup[HashableArray( self._canonical_from_affine_matrix(inverse_matrix))] return inverse except KeyError as err: raise RuntimeError( "PointGroup does not contain the inverse of all elements" ) from err
def product_table(self) -> Array: try: # again, we calculate the product table of transformation matrices directly affine_matrices = self.to_array() product_matrices = np.einsum( "iab, jbc -> ijac", affine_matrices, affine_matrices) # this is a table of M_g M_h lookup = self._canonical_lookup() n_symm = len(self) product_table = np.zeros((n_symm, n_symm), dtype=int) for i in range(n_symm): for j in range(n_symm): product_table[i, j] = lookup[HashableArray( self._canonical_from_affine_matrix( product_matrices[i, j]))] return product_table[self.inverse] # reshuffle rows to match specs except KeyError as err: raise RuntimeError( "PointGroup is not closed under multiplication") from err
def DenseEquivariant( symmetries, features: int = None, mode="auto", shape=None, point_group=None, in_features=None, **kwargs, ): r"""A group convolution operation that is equivariant over a symmetry group. Acts on a feature map of symmetry poses of shape [num_samples, in_features, num_symm] and returns a feature map of poses of shape [num_samples, features, num_symm] G-convolutions are described in ` Cohen et. {\it al} <http://proceedings.mlr.press/v48/cohenc16.pdf>`_ and applied to quantum many-body problems in ` Roth et. {\it al} <https://arxiv.org/pdf/2104.05085.pdf>`_ The G-convolution generalizes the convolution to non-commuting groups: .. math :: C^i_g = \sum_h {\bf W}_{g^{-1} h} \cdot {\bf f}_h Group elements that differ by the same symmetry operation (i.e. :math:`g = xh` and :math:`g' = xh'`) are connected by the same filter. This layer maps an input of shape `(..., in_features, n_sites)` to an output of shape `(..., features, num_symm)`. Args: symmetries: A specification of the symmetry group. Can be given by a nk.graph.Graph, an nk.utils.PermuationGroup, a list of irreducible representations or a product table. point_group: The point group, from which the space group is built. If symmetries is a graph the default point group is overwritten. mode: string "fft, irreps, matrix, auto" specifying whether to use a fast fourier transform over the translation group, a fourier transform using the irreducible representations or by constructing the full kernel matrix. shape: A tuple specifying the dimensions of the translation group. features: The number of output features. The full output shape is [n_batch,features,n_symm]. use_bias: A bool specifying whether to add a bias to the output (default: True). mask: An optional array of shape [n_sites] consisting of ones and zeros that can be used to give the kernel a particular shape. dtype: The datatype of the weights. Defaults to a 64bit float. precision: Optional argument specifying numerical precision of the computation. see `jax.lax.Precision`for details. kernel_init: Optional kernel initialization function. Defaults to variance scaling. bias_init: Optional bias initialization function. Defaults to zero initialization. """ # deprecate in_features if in_features is not None: warn_deprecation(( "`in_features` is now automatically detected from the input and deprecated." "Please remove it when calling `DenseEquivariant`.")) if "out_features" in kwargs: warn_deprecation( "`out_features` has been renamed to `features` and the old name is " "now deprecated. Please update your code.") if features is not None: raise ValueError( "You must only specify `features`. `out_features` is deprecated." ) features = kwargs.pop("out_features") if features is None: raise ValueError( "`features` not specified (the number of output features).") kwargs["features"] = features if isinstance(symmetries, Lattice) and (point_group is not None or symmetries._point_group is not None): shape = tuple(symmetries.extent) # With graph try to find point group, otherwise default to automorphisms sg = symmetries.space_group(point_group) if mode == "auto": mode = "fft" elif isinstance(symmetries, Graph): sg = symmetries.automorphisms() if mode == "auto": mode = "irreps" elif mode == "fft": raise ValueError( "When requesting 'mode=fft' a valid point group must be specified" "in order to construct the space group") elif isinstance(symmetries, PermutationGroup): # If we get a group and default to irrep projection if mode == "auto": mode = "irreps" sg = symmetries elif isinstance(symmetries, Sequence): if mode not in ["irreps", "auto"]: raise ValueError( "Specification of symmetries incompatible with mode") return DenseEquivariantIrrep(symmetries, **kwargs) else: if symmetries.ndim == 2 and symmetries.shape[0] == symmetries.shape[1]: if mode == "irreps": raise ValueError( "Specification of symmetries incompatible with mode") elif mode == "matrix": return DenseEquivariantMatrix(symmetries, **kwargs) else: if shape is None: raise TypeError( "When requesting `mode=fft`, the shape of the translation group must be specified. " "Either supply the `shape` keyword argument or pass a `netket.graph.Graph` object to " "the symmetries keyword argument.") else: return DenseEquivariantFFT(symmetries, shape=shape, **kwargs) return ValueError("Invalid Specification of Symmetries") if mode == "fft": if shape is None: raise TypeError( "When requesting `mode=fft`, the shape of the translation group must be specified. " "Either supply the `shape` keyword argument or pass a `netket.graph.Graph` object to " "the symmetries keyword argument.") else: return DenseEquivariantFFT(HashableArray(sg.product_table), shape=shape, **kwargs) elif mode in ["irreps", "auto"]: irreps = tuple(HashableArray(irrep) for irrep in sg.irrep_matrices()) return DenseEquivariantIrrep(irreps, **kwargs) elif mode == "matrix": return DenseEquivariantMatrix(HashableArray(sg.product_table), **kwargs) else: raise ValueError( f"Unknown mode={mode}. Valid modes are 'fft', 'matrix', 'irreps' or 'auto'." )
def __init__( self, basis_vectors: _np.ndarray, extent: _np.ndarray, *, pbc: Union[bool, Sequence[bool]] = True, site_offsets: Optional[_np.ndarray] = None, atoms_coord: Optional[_np.ndarray] = None, distance_atol: float = 1e-5, point_group: Optional[PointGroup] = None, max_neighbor_order: Optional[int] = None, custom_edges: Optional[Sequence[CustomEdgeT]] = None, ): """ Constructs a new ``Lattice`` given its side length and the features of the unit cell. Args: basis_vectors: The basis vectors of the lattice. Should be an array of shape `(ndim, ndim)` where each `row` is a basis vector. extent: The number of copies of the unit cell; needs to be an array of length `ndim`. pbc: If ``True`` then the constructed lattice will have periodic boundary conditions, otherwise open boundary conditions are imposed. Can also be an boolean sequence of length `ndim`, indicating either open or closed boundary conditions separately for each direction. site_offsets: The position offsets of sites in the unit cell (one site at the origin by default). distance_atol: Distance below which spatial points are considered equal for the purpose of identifying nearest neighbors. point_group: Default `PointGroup` object for constructing space groups max_neighbor_order: For :code:`max_neighbor_order == k`, edges between up to :math:`k`-nearest neighbor sites (measured by their Euclidean distance) are included in the graph. The edges can be distiguished by their color, which is set to :math:`k - 1` (so nearest-neighbor edges have color 0). By default, nearest neighbours (:code:`max_neighbor_order=1`) are autogenerated unless :code:`custom_edges` is passed. custom_edges: (Optional) Lists all edges starting in one unit cell, which are repeated in every unit cell of the constructed lattice. Should be a list of tuples; each tuple should contain the following: * index of the starting point in the unit cell * index of the endpoint in the unit cell * vector pointing from the former to the latter * color of the edge (optional) If colors are not supplied, they are assigned sequentially starting from 0. Cannot be used together with `max_neighbor_order`. Examples: Constructs a Kagome lattice with 3 × 3 unit cells: >>> import numpy as np >>> from netket.graph import Lattice >>> # Hexagonal lattice basis >>> sqrt3 = np.sqrt(3.0) >>> basis = np.array([ ... [1.0, 0.0], ... [0.5, sqrt3 / 2.0], ... ]) >>> # Kagome unit cell >>> cell = np.array([ ... basis[0] / 2.0, ... basis[1] / 2.0, ... (basis[0]+basis[1])/2.0 ... ]) >>> g = Lattice(basis_vectors=basis, site_offsets=cell, extent=[3, 3]) >>> print(g.n_nodes) 27 >>> print(g.basis_coords[:6]) [[0 0 0] [0 0 1] [0 0 2] [0 1 0] [0 1 1] [0 1 2]] >>> print(g.positions[:6]) [[0.5 0. ] [0.25 0.4330127 ] [0.75 0.4330127 ] [1. 0.8660254 ] [0.75 1.29903811] [1.25 1.29903811]] Constructs a rectangular lattice with distinct horizontal and vertical edges: >>> import numpy as np >>> from netket.graph import Lattice >>> basis = np.array([ ... [1.0,0.0], ... [0.0,0.5], ... ]) >>> custom_edges = [ ... (0, 0, [1.0,0.0], 0), ... (0, 0, [0.0,0.5], 1), ... ] >>> g = Lattice(basis_vectors=basis, pbc=False, extent=[4,6], ... custom_edges=custom_edges) >>> print(g.n_nodes) 24 >>> print(len(g.edges(filter_color=0))) 18 >>> print(len(g.edges(filter_color=1))) 20 """ # Clean input parameters self._basis_vectors = self._clean_basis(basis_vectors) self._ndim = self._basis_vectors.shape[1] self._site_offsets, site_pos_fractional = self._clean_site_offsets( site_offsets, atoms_coord, self._basis_vectors, ) self._pbc = self._clean_pbc(pbc, self._ndim) self._extent = _np.asarray(extent, dtype=int) self._lattice_dims = _np.expand_dims(self._extent, 1) * self.basis_vectors self._inv_dims = _np.linalg.inv(self._lattice_dims) self._point_group = point_group # Generate sites self._sites, self._basis_coords, self._positions = _create_sites( self._basis_vectors, self._extent, self._site_offsets, ) self._basis_coord_to_site = { HashableArray(p.basis_coord): p.id for p in self._sites } int_positions = self._to_integer_position(self._positions) self._int_position_to_site = { HashableArray(pos): index for index, pos in enumerate(int_positions) } # Generate edges if custom_edges is not None: if max_neighbor_order is not None: raise ValueError( "custom_edges and max_neighbor_order cannot be specified at the same time" ) colored_edges = get_custom_edges( self._basis_vectors, self._extent, self._site_offsets, self._pbc, distance_atol, custom_edges, ) else: if max_neighbor_order is None: max_neighbor_order = 1 colored_edges = get_nn_edges( self._basis_vectors, self._extent, self._site_offsets, self._pbc, distance_atol, max_neighbor_order, ) super().__init__(colored_edges, len(self._sites))
def __hash__(self): return hash(HashableArray(comparable(self._affine)))
def GCNN( symmetries=None, product_table=None, irreps=None, point_group=None, mode="auto", shape=None, layers=None, features=None, characters=None, parity=None, **kwargs, ): r"""Implements a Group Convolutional Neural Network (G-CNN) that outputs a wavefunction that is invariant over a specified symmetry group. The G-CNN is described in ` Cohen et. *al* <http://proceedings.mlr.press/v48/cohenc16.pdf>`_ and applied to quantum many-body problems in ` Roth et. *al* <https://arxiv.org/pdf/2104.05085.pdf>`_. The G-CNN alternates convolution operations with pointwise non-linearities. The first layer is symmetrized linear transform given by DenseSymm, while the other layers are G-convolutions given by DenseEquivariant. The hidden layers of the G-CNN are related by the following equation: .. math :: {\bf f}^{i+1}_h = \Gamma( \sum_h W_{g^{-1} h} {\bf f}^i_h). Args: symmetries: A specification of the symmetry group. Can be given by a nk.graph.Graph, a nk.utils.PermuationGroup, or an array [n_symm, n_sites] specifying the permutations corresponding to symmetry transformations of the lattice. product_table: Product table describing the algebra of the symmetry group. Only needs to be specified if mode='fft' and symmetries is specified as an array. irreps: List of 3D tensors that project onto irreducible representations of the symmetry group. Only needs to be specified if mode='irreps' and symmetries is specified as an array. point_group: The point group, from which the space group is built. If symmetries is a graph the default point group is overwritten. mode: string "fft, irreps, matrix, auto" specifying whether to use a fast fourier transform over the translation group, a fourier transform using the irreducible representations or by constructing the full kernel matrix. shape: A tuple specifying the dimensions of the translation group. layers: Number of layers (not including sum layer over output). features: Number of features in each layer starting from the input. If a single number is given, all layers will have the same number of features. characters: Array specifying the characters of the desired symmetry representation parity: Optional argument with value +/-1 that specifies the eigenvalue with respect to parity (only use on two level systems). dtype: The dtype of the weights. activation: The nonlinear activation function between hidden layers. Defaults to :ref:`nk.nn.activation.reim_selu`. output_activation: The nonlinear activation before the output. equal_amplitudes: If True forces all basis states to have equal amplitude by setting Re[psi] = 0. use_bias: If True uses a bias in all layers. precision: Numerical precision of the computation see `jax.lax.Precision`for details. kernel_init: Initializer for the kernels of all layers. Defaults to `lecun_normal(in_axis=1, out_axis=0)` which guarantees the correct variance of the output. bias_init: Initializer for the biases of all layers. """ if isinstance(symmetries, Lattice) and ( point_group is not None or symmetries._point_group is not None ): # With graph try to find point group, otherwise default to automorphisms shape = tuple(symmetries.extent) sg = symmetries.space_group(point_group) if mode == "auto": mode = "fft" elif isinstance(symmetries, Graph): sg = symmetries.automorphisms() if mode == "auto": mode = "irreps" if mode == "fft": raise ValueError( "When requesting 'mode=fft' a valid point group must be specified" "in order to construct the space group" ) elif isinstance(symmetries, PermutationGroup): # If we get a group and default to irrep projection if mode == "auto": mode = "irreps" sg = symmetries else: if irreps is not None and (mode == "irreps" or mode == "auto"): mode = "irreps" sg = symmetries irreps = tuple(HashableArray(irrep) for irrep in irreps) elif product_table is not None and (mode == "fft" or mode == "auto"): mode = "fft" sg = symmetries product_table = HashableArray(product_table) else: raise ValueError( "Specification of symmetries is wrong or incompatible with selected mode" ) if mode == "fft": if shape is None: raise TypeError( "When requesting `mode=fft`, the shape of the translation group must be specified. " "Either supply the `shape` keyword argument or pass a `netket.graph.Graph` object to " "the symmetries keyword argument." ) else: shape = tuple(shape) if isinstance(features, int): features = (features,) * layers if characters is None: characters = HashableArray(np.ones(len(np.asarray(sg)))) else: characters = HashableArray(characters) if mode == "fft": sym = HashableArray(np.asarray(sg)) if product_table is None: product_table = HashableArray(sg.product_table) if parity: return GCNN_Parity_FFT( symmetries=sym, product_table=product_table, layers=layers, features=features, characters=characters, shape=shape, parity=parity, **kwargs, ) else: return GCNN_FFT( symmetries=sym, product_table=product_table, layers=layers, features=features, characters=characters, shape=shape, **kwargs, ) elif mode in ["irreps", "auto"]: sym = HashableArray(np.asarray(sg)) if irreps is None: irreps = tuple(HashableArray(irrep) for irrep in sg.irrep_matrices()) if parity: return GCNN_Parity_Irrep( symmetries=sym, irreps=irreps, layers=layers, features=features, characters=characters, parity=parity, **kwargs, ) else: return GCNN_Irrep( symmetries=sym, irreps=irreps, layers=layers, features=features, characters=characters, **kwargs, ) else: raise ValueError( f"Unknown mode={mode}. Valid modes are 'fft',irreps' or 'auto'." )
class DenseEquivariant(Module): """Implements a G-convolution that acts on a feature map of symmetry poses of shape [batch_size,n_symm*in_features] and returns a feature map of poses of shape [batch_size,n_symm*out_features] G-convolutions are described in ` Cohen et. {\it al} <http://proceedings.mlr.press/v48/cohenc16.pdf>`_ and applied to quantum many-body problems in ` Roth et. {\it al} <https://arxiv.org/pdf/2104.05085.pdf>`_ The G-convolution generalizes the convolution to non-commuting groups: .. math :: C^i_g = \sum_h {\bf W}_{g^{-1} h} \cdot {\bf f}_h Symmetry poses that are linked by the same symmetry element are connected by the same filter. The output symmetry group is an involution over the input symmetry group, i.e. the symmetry group is inverted by G-convolution .. math :: {\bf C}*(g) = C(g^{-1}) """ symmetry_info: Union[HashableArray, SymmGroup] """Flattened product table generated by SymmGroup.produt_table().ravel() that specifies the product of the group with its involution, or the SymmGroup object itself""" in_features: int """The number of symmetry-reduced input features. The full input size is n_symm*in_features.""" out_features: int """The number of symmetry-reduced output features. The full output size is n_symm*out_features.""" use_bias: bool = True """Whether to add a bias to the output (default: True).""" dtype: Any = jnp.float64 """The dtype of the weights.""" precision: Any = None """numerical precision of the computation see `jax.lax.Precision`for details.""" kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init """Initializer for the Dense layer matrix.""" bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros """Initializer for the bias.""" def setup(self): if isinstance(self.symmetry_info, SymmGroup): self.symmetry_info = HashableArray( self.symmetry_info.product_table().ravel() ) if not np.asarray(self.symmetry_info).ndim == 1: raise ValueError("Product table should be flattened") self.n_symm = int(np.sqrt(np.asarray(self.symmetry_info).shape[0])) def full_kernel(self, kernel): """ Converts the symmetry-reduced kernel of shape (n_sites, features) to the full Dense kernel of shape (n_sites, features * n_symm). """ result = jnp.take(kernel, self.symmetry_info, 0) result = result.reshape( self.n_symm, self.n_symm, self.in_features, self.out_features ) result = result.transpose(2, 0, 3, 1).reshape( self.n_symm * self.in_features, -1 ) return result def full_bias(self, bias): """ Convert symmetry-reduced bias of shape (features,) to the full bias of shape (n_symm * features,). """ return jnp.repeat(bias, self.n_symm) @compact def __call__(self, inputs: Array) -> Array: """Applies the equivariant transform to the inputs along the last dimension. Args: inputs: The nd-array to be transformed. Returns: The transformed input. """ dtype = jnp.promote_types(inputs.dtype, self.dtype) inputs = jnp.asarray(inputs, dtype) kernel = self.param( "kernel", self.kernel_init, (inputs.shape[-1], self.in_features, self.out_features), self.dtype, ) kernel = self.full_kernel(kernel) kernel = jnp.asarray(kernel, dtype) y = lax.dot_general( inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision, ) if self.use_bias: bias = self.param("bias", self.bias_init, (self.out_features,), self.dtype) bias = jnp.asarray(self.full_bias(bias), dtype) y += bias return y
def __init__( self, basis_vectors: _np.ndarray, extent: _np.ndarray, *, pbc: Union[bool, Sequence[bool]] = True, site_offsets: Optional[_np.ndarray] = None, atoms_coord: Optional[_np.ndarray] = None, distance_atol: float = 1e-5, point_group: Optional[PointGroup] = None, ): """ Constructs a new ``Lattice`` given its side length and the features of the unit cell. Args: basis_vectors: The basis vectors of the lattice. Should be an array of shape `(ndim, ndim)` where each `row` is a basis vector. extent: The number of copies of the unit cell; needs to be an array of length `ndim`. pbc: If ``True`` then the constructed lattice will have periodic boundary conditions, otherwise open boundary conditions are imposed. Can also be an boolean sequence of length `ndim`, indicating either open or closed boundary conditions separately for each direction. site_offsets: The position offsets of sites in the unit cell (one site at the origin by default). distance_atol: Distance below which spatial points are considered equal for the purpose of identifying nearest neighbors. point_group: Default `PointGroup` object for constructing space groups Examples: Constructs a Kagome lattice with 3 × 3 unit cells: >>> import numpy as np >>> from netket.graph import Lattice >>> # Hexagonal lattice basis >>> sqrt3 = np.sqrt(3.0) >>> basis = np.array([ ... [1.0, 0.0], ... [0.5, sqrt3 / 2.0], ... ]) >>> # Kagome unit cell >>> cell = np.array([ ... basis[0] / 2.0, ... basis[1] / 2.0, ... (basis[0]+basis[1])/2.0 ... ]) >>> g = Lattice(basis_vectors=basis, site_offsets=cell, extent=[3, 3]) >>> print(g.n_nodes) 27 >>> print(g.basis_coords[:6]) [[0 0 0] [0 0 1] [0 0 2] [0 1 0] [0 1 1] [0 1 2]] >>> print(g.positions[:6]) [[0.5 0. ] [0.25 0.4330127 ] [0.75 0.4330127 ] [1. 0.8660254 ] [0.75 1.29903811] [1.25 1.29903811]] """ self._basis_vectors = self._clean_basis(basis_vectors) self._ndim = self._basis_vectors.shape[1] self._site_offsets, site_pos_fractional = self._clean_site_offsets( site_offsets, atoms_coord, self._basis_vectors, ) self._pbc = self._clean_pbc(pbc, self._ndim) self._extent = _np.asarray(extent, dtype=int) self._point_group = point_group sites, inside, self._basis_coord_to_site = create_sites( self._basis_vectors, self._extent, site_pos_fractional, self._pbc ) edges = get_true_edges( self._basis_vectors, sites, inside, self._basis_coord_to_site, self._extent, distance_atol, ) old_nodes = sorted(set(node for edge in edges for node in edge)) new_nodes = {old_node: new_node for new_node, old_node in enumerate(old_nodes)} graph = igraph.Graph() graph.add_vertices(len(old_nodes)) graph.add_edges([(new_nodes[edge[0]], new_nodes[edge[1]]) for edge in edges]) graph.simplify() self._sites = [] for i, site in enumerate(sites[old_node] for old_node in old_nodes): site.id = i self._sites.append(site) self._basis_coord_to_site = { HashableArray(p.basis_coord): p.id for p in self._sites } self._positions = _np.array([p.position for p in self._sites]) self._basis_coords = _np.array([p.basis_coord for p in self._sites]) self._lattice_dims = _np.expand_dims(self._extent, 1) * self.basis_vectors self._inv_dims = _np.linalg.inv(self._lattice_dims) int_positions = self._to_integer_position(self._positions) self._int_position_to_site = { HashableArray(pos): index for index, pos in enumerate(int_positions) } super().__init__(list(graph.get_edgelist()), graph.vcount())
def DenseSymm(symmetries, point_group=None, mode="auto", shape=None, **kwargs): r""" Implements a projection onto a symmetry group. The output will be equivariant with respect to the symmetry operations in the group and can be averaged to produce an invariant model. This layer maps an input of shape `(..., in_features, n_sites)` to an output of shape `(..., features, num_symm)`. Note: The output shape has changed to seperate the feature and symmetry dimensions. The previous shape was [num_samples, num_symm*features] and the new shape is [num_samples, features, num_symm] Args: symmetries: A specification of the symmetry group. Can be given by a :ref:`netket.graph.Graph`, a :ref:`netket.utils.group.PermutationGroup`, or an array of shape :code:`(n_symm, n_sites)`. A :ref:`netket.utils.HashableArray` may also be passed. specifying the permutations corresponding to symmetry transformations of the lattice. point_group: The point group, from which the space group is built. If symmetries is a graph the default point group is overwritten. mode: string "fft, matrix, auto" specifying whether to use a fast Fourier transform, matrix multiplication, or to choose a sensible default based on the symmetry group. shape: A tuple specifying the dimensions of the translation group. features: The number of output features. The full output shape is [n_batch,features,n_symm]. use_bias: A bool specifying whether to add a bias to the output (default: True). mask: An optional array of shape [n_sites] consisting of ones and zeros that can be used to give the kernel a particular shape. dtype: The datatype of the weights. Defaults to a 64bit float. precision: Optional argument specifying numerical precision of the computation. see `jax.lax.Precision`for details. kernel_init: Optional kernel initialization function. Defaults to variance scaling. bias_init: Optional bias initialization function. Defaults to zero initialization. """ if isinstance(symmetries, Lattice) and (point_group is not None or symmetries._point_group is not None): shape = tuple(symmetries.extent) sym = HashableArray(np.asarray(symmetries.space_group(point_group))) if mode == "auto": mode = "fft" elif isinstance(symmetries, Graph): if mode == "fft": raise ValueError( "When requesting 'mode=fft' a valid point group must be specified" "in order to construct the space group") sym = HashableArray(np.asarray(symmetries.automorphisms())) elif isinstance(symmetries, HashableArray): sym = symmetries else: sym = HashableArray(np.asarray(symmetries)) if mode == "fft": if shape is None: raise TypeError( "When requesting `mode=fft`, the shape of the translation group must be specified. " "Either supply the `shape` keyword argument or pass a `netket.graph.Graph` object to " "the symmetries keyword argument.") else: return DenseSymmFFT(sym, shape=shape, **kwargs) elif mode in ["matrix", "auto"]: return DenseSymmMatrix(sym, **kwargs) else: raise ValueError( f"Unknown mode={mode}. Valid modes are 'fft', 'matrix', or 'auto'." )
def __init__(self, permutation: Array): self.permutation = HashableArray(np.asarray(permutation))