def __call__(self, x_in): x = x_in if x.ndim < 3: x = jnp.expand_dims(x, -2) x = nknn.DenseSymm( name="Dense", mode="matrix", symmetries=self.symmetries, features=self.features, dtype=self.dtype, use_bias=self.use_hidden_bias, kernel_init=self.kernel_init, bias_init=self.hidden_bias_init, precision=self.precision, )(x) x = self.activation(x) x = x.reshape(-1, self.features * self.n_symm) x = jnp.sum(x, axis=-1) if self.use_visible_bias: v_bias = self.param("visible_bias", self.visible_bias_init, (1, ), self.dtype) out_bias = v_bias[0] * jnp.sum(x_in, axis=-1) return x + out_bias else: return x
def setup(self): self.n_symm = np.asarray(self.symmetries).shape[0] if self.flattened_product_table is None and not isinstance( self.symmetries, SymmGroup ): raise AttributeError( "product table must be specified if symmetries are given as an array" ) if self.flattened_product_table is None: flat_pt = HashableArray(self.symmetries.product_table().ravel()) else: flat_pt = self.flattened_product_table if not np.asarray(flat_pt).shape[0] == np.square(self.n_symm): raise ValueError("Flattened product table must have shape [n_symm*n_symm]") if isinstance(self.features, int): feature_dim = [self.features for layer in range(self.layers)] else: if not len(self.features) == self.layers: raise ValueError( """Length of vector specifying feature dimensions must be the same as the number of layers""" ) else: feature_dim = tuple(self.features) self.dense_symm = nknn.DenseSymm( symmetries=self.symmetries, features=feature_dim[0], dtype=self.dtype, use_bias=self.use_bias, kernel_init=self.kernel_init, bias_init=self.bias_init, precision=self.precision, ) self.equivariant_layers = [ nknn.DenseEquivariant( symmetry_info=flat_pt, in_features=feature_dim[layer], out_features=feature_dim[layer + 1], use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init, ) for layer in range(self.layers - 1) ]
def test_deprecated_dtype_layers(): g = nk.graph.Square(3) with pytest.warns(FutureWarning): module = nknn.DenseSymm(g, features=2, dtype=complex) with pytest.warns(FutureWarning): assert module.dtype == module.param_dtype with pytest.warns(FutureWarning): module = nknn.DenseEquivariant(g, features=2, dtype=complex) with pytest.warns(FutureWarning): assert module.dtype == module.param_dtype
def __call__(self, x_in): x = nknn.DenseSymm( name="Dense", symmetries=self.symmetries, features=self.features, dtype=self.dtype, use_bias=self.use_hidden_bias, kernel_init=self.kernel_init, bias_init=self.hidden_bias_init, precision=self.precision, )(x_in) x = self.activation(x) x = jnp.sum(x, axis=-1) if self.use_visible_bias: v_bias = self.param("visible_bias", self.visible_bias_init, (1, ), self.dtype) out_bias = v_bias[0] * jnp.sum(x_in, axis=-1) return x + out_bias else: return x