def test_DenseEquivariant(symmetries, use_bias, lattice): g, hi, perms = _setup_symm(symmetries, N=3, lattice=lattice) pt = perms.product_table n_symm = np.asarray(perms).shape[0] ma = nk.nn.DenseEquivariant( symmetry_info=pt.ravel(), in_features=1, out_features=1, use_bias=use_bias, bias_init=nk.nn.initializers.uniform(), ) pars = ma.init(nk.jax.PRNGKey(), np.random.normal(0, 1, [1, n_symm])) # inv_pt computes chosen_op = gh^-1 instead of g^-1h chosen_op = np.random.randint(n_symm) inverse = PermutationGroup([perms.elems[i] for i in perms.inverse], degree=g.n_nodes) inv_pt = inverse.product_table sym_op = np.where(inv_pt == chosen_op, 1.0, 0.0) v = random.normal(random.PRNGKey(0), [3, n_symm]) v_trans = dot(v, sym_op) out = ma.apply(pars, v) out_trans = ma.apply(pars, v_trans) # output should be involution assert jnp.allclose(dot(out, sym_op.transpose(0, 1)), out_trans)
def _translations_along_axis(self, axis: int) -> PermutationGroup: """ The group of valid translations along an axis as a `PermutationGroup` acting on the sites of `self.lattice.` """ if self.lattice._pbc[axis]: trans_list = [Identity()] # note that we need the preimages in the permutation trans_perm = self.lattice.id_from_position( self.lattice.positions - self.lattice.basis_vectors[axis]) vector = np.zeros(self.lattice.ndim, dtype=int) vector[axis] = 1 trans_by_one = Translation(trans_perm, vector) for _ in range(1, self.lattice.extent[axis]): trans_list.append(trans_list[-1] @ trans_by_one) return PermutationGroup(trans_list, degree=self.lattice.n_nodes) else: return PermutationGroup([Identity()], degree=self.lattice.n_nodes)
def test_DenseEquivariant(symmetries, use_bias, lattice, mode, mask): rng = nk.jax.PRNGSeq(0) g, hi, perms = _setup_symm(symmetries, N=3, lattice=lattice) pt = perms.product_table n_symm = np.asarray(perms).shape[0] if mask: mask = np.zeros(n_symm) mask[np.random.choice(n_symm, n_symm // 2, replace=False)] = 1 else: mask = np.ones([n_symm]) if mode == "irreps": ma = nk.nn.DenseEquivariant( symmetries=perms, mode=mode, features=1, mask=mask, use_bias=use_bias, bias_init=uniform(), ) else: ma = nk.nn.DenseEquivariant( symmetries=pt, shape=tuple(g.extent), mode=mode, features=1, mask=mask, use_bias=use_bias, bias_init=uniform(), ) dum_input = jax.random.normal(rng.next(), (1, 1, n_symm)) pars = ma.init(rng.next(), dum_input) # inv_pt computes chosen_op = gh^-1 instead of g^-1h chosen_op = np.random.randint(n_symm) inverse = PermutationGroup( [perms.elems[i] for i in perms.inverse], degree=g.n_nodes ) inv_pt = inverse.product_table sym_op = np.where(inv_pt == chosen_op, 1.0, 0.0) v = random.normal(rng.next(), [3, 1, n_symm]) v_trans = jnp.matmul(v, sym_op) out = ma.apply(pars, v) out_trans = ma.apply(pars, v_trans) # output should be involution assert jnp.allclose(jnp.matmul(out, sym_op), out_trans)
def rotation_group(self) -> PermutationGroup: """The group of rotations (i.e. point group symmetries with determinant +1) as a `PermutationGroup` acting on the sites of `self.lattice`.""" perms = [] for p in self.point_group_.rotation_group(): if isinstance(p, Identity): perms.append(Identity()) else: # note that we need the preimages in the permutation perm = self.lattice.id_from_position( p.preimage(self.lattice.positions)) perms.append(Permutation(perm, name=str(p))) return PermutationGroup(perms, degree=self.lattice.n_nodes)
def _compute_automorphisms(self): """ Compute the graph autmorphisms of this graph. """ colors = self.edge_colors result = self._igraph.get_isomorphisms_vf2(edge_color1=colors, edge_color2=colors) # sort them s.t. the identity comes first result = np.unique(result, axis=0).tolist() result = PermutationGroup([Permutation(i) for i in result], self.n_nodes) return result