Esempio n. 1
0
def test_DenseEquivariant(symmetries, use_bias, lattice):
    g, hi, perms = _setup_symm(symmetries, N=3, lattice=lattice)

    pt = perms.product_table
    n_symm = np.asarray(perms).shape[0]

    ma = nk.nn.DenseEquivariant(
        symmetry_info=pt.ravel(),
        in_features=1,
        out_features=1,
        use_bias=use_bias,
        bias_init=nk.nn.initializers.uniform(),
    )

    pars = ma.init(nk.jax.PRNGKey(), np.random.normal(0, 1, [1, n_symm]))

    # inv_pt computes chosen_op = gh^-1 instead of g^-1h
    chosen_op = np.random.randint(n_symm)
    inverse = PermutationGroup([perms.elems[i] for i in perms.inverse],
                               degree=g.n_nodes)
    inv_pt = inverse.product_table
    sym_op = np.where(inv_pt == chosen_op, 1.0, 0.0)

    v = random.normal(random.PRNGKey(0), [3, n_symm])
    v_trans = dot(v, sym_op)

    out = ma.apply(pars, v)
    out_trans = ma.apply(pars, v_trans)

    # output should be involution
    assert jnp.allclose(dot(out, sym_op.transpose(0, 1)), out_trans)
Esempio n. 2
0
    def _translations_along_axis(self, axis: int) -> PermutationGroup:
        """
        The group of valid translations along an axis as a `PermutationGroup`
        acting on the sites of `self.lattice.`
        """
        if self.lattice._pbc[axis]:
            trans_list = [Identity()]
            # note that we need the preimages in the permutation
            trans_perm = self.lattice.id_from_position(
                self.lattice.positions - self.lattice.basis_vectors[axis])
            vector = np.zeros(self.lattice.ndim, dtype=int)
            vector[axis] = 1
            trans_by_one = Translation(trans_perm, vector)

            for _ in range(1, self.lattice.extent[axis]):
                trans_list.append(trans_list[-1] @ trans_by_one)

            return PermutationGroup(trans_list, degree=self.lattice.n_nodes)
        else:
            return PermutationGroup([Identity()], degree=self.lattice.n_nodes)
Esempio n. 3
0
def test_DenseEquivariant(symmetries, use_bias, lattice, mode, mask):
    rng = nk.jax.PRNGSeq(0)

    g, hi, perms = _setup_symm(symmetries, N=3, lattice=lattice)

    pt = perms.product_table
    n_symm = np.asarray(perms).shape[0]

    if mask:
        mask = np.zeros(n_symm)
        mask[np.random.choice(n_symm, n_symm // 2, replace=False)] = 1
    else:
        mask = np.ones([n_symm])

    if mode == "irreps":
        ma = nk.nn.DenseEquivariant(
            symmetries=perms,
            mode=mode,
            features=1,
            mask=mask,
            use_bias=use_bias,
            bias_init=uniform(),
        )
    else:
        ma = nk.nn.DenseEquivariant(
            symmetries=pt,
            shape=tuple(g.extent),
            mode=mode,
            features=1,
            mask=mask,
            use_bias=use_bias,
            bias_init=uniform(),
        )

    dum_input = jax.random.normal(rng.next(), (1, 1, n_symm))
    pars = ma.init(rng.next(), dum_input)

    # inv_pt computes chosen_op = gh^-1 instead of g^-1h
    chosen_op = np.random.randint(n_symm)
    inverse = PermutationGroup(
        [perms.elems[i] for i in perms.inverse], degree=g.n_nodes
    )
    inv_pt = inverse.product_table
    sym_op = np.where(inv_pt == chosen_op, 1.0, 0.0)

    v = random.normal(rng.next(), [3, 1, n_symm])
    v_trans = jnp.matmul(v, sym_op)

    out = ma.apply(pars, v)
    out_trans = ma.apply(pars, v_trans)

    # output should be involution
    assert jnp.allclose(jnp.matmul(out, sym_op), out_trans)
Esempio n. 4
0
 def rotation_group(self) -> PermutationGroup:
     """The group of rotations (i.e. point group symmetries with determinant +1)
     as a `PermutationGroup` acting on the sites of `self.lattice`."""
     perms = []
     for p in self.point_group_.rotation_group():
         if isinstance(p, Identity):
             perms.append(Identity())
         else:
             # note that we need the preimages in the permutation
             perm = self.lattice.id_from_position(
                 p.preimage(self.lattice.positions))
             perms.append(Permutation(perm, name=str(p)))
     return PermutationGroup(perms, degree=self.lattice.n_nodes)
Esempio n. 5
0
    def _compute_automorphisms(self):
        """
        Compute the graph autmorphisms of this graph.
        """
        colors = self.edge_colors
        result = self._igraph.get_isomorphisms_vf2(edge_color1=colors,
                                                   edge_color2=colors)

        # sort them s.t. the identity comes first
        result = np.unique(result, axis=0).tolist()
        result = PermutationGroup([Permutation(i) for i in result],
                                  self.n_nodes)
        return result