コード例 #1
0
ファイル: rbm.py プロジェクト: netket/netket
    def __call__(self, x_in):
        x = x_in
        if x.ndim < 3:
            x = jnp.expand_dims(x, -2)
        x = nknn.DenseSymm(
            name="Dense",
            mode="matrix",
            symmetries=self.symmetries,
            features=self.features,
            dtype=self.dtype,
            use_bias=self.use_hidden_bias,
            kernel_init=self.kernel_init,
            bias_init=self.hidden_bias_init,
            precision=self.precision,
        )(x)
        x = self.activation(x)

        x = x.reshape(-1, self.features * self.n_symm)
        x = jnp.sum(x, axis=-1)

        if self.use_visible_bias:
            v_bias = self.param("visible_bias", self.visible_bias_init, (1, ),
                                self.dtype)
            out_bias = v_bias[0] * jnp.sum(x_in, axis=-1)
            return x + out_bias
        else:
            return x
コード例 #2
0
    def setup(self):

        self.n_symm = np.asarray(self.symmetries).shape[0]

        if self.flattened_product_table is None and not isinstance(
            self.symmetries, SymmGroup
        ):
            raise AttributeError(
                "product table must be specified if symmetries are given as an array"
            )

        if self.flattened_product_table is None:
            flat_pt = HashableArray(self.symmetries.product_table().ravel())
        else:
            flat_pt = self.flattened_product_table

        if not np.asarray(flat_pt).shape[0] == np.square(self.n_symm):
            raise ValueError("Flattened product table must have shape [n_symm*n_symm]")

        if isinstance(self.features, int):
            feature_dim = [self.features for layer in range(self.layers)]
        else:
            if not len(self.features) == self.layers:
                raise ValueError(
                    """Length of vector specifying feature dimensions must be the same as the number of layers"""
                )
            else:
                feature_dim = tuple(self.features)

        self.dense_symm = nknn.DenseSymm(
            symmetries=self.symmetries,
            features=feature_dim[0],
            dtype=self.dtype,
            use_bias=self.use_bias,
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
            precision=self.precision,
        )

        self.equivariant_layers = [
            nknn.DenseEquivariant(
                symmetry_info=flat_pt,
                in_features=feature_dim[layer],
                out_features=feature_dim[layer + 1],
                use_bias=self.use_bias,
                dtype=self.dtype,
                precision=self.precision,
                kernel_init=self.kernel_init,
                bias_init=self.bias_init,
            )
            for layer in range(self.layers - 1)
        ]
コード例 #3
0
ファイル: test_deprecated.py プロジェクト: yannra/netket
def test_deprecated_dtype_layers():
    g = nk.graph.Square(3)
    with pytest.warns(FutureWarning):
        module = nknn.DenseSymm(g, features=2, dtype=complex)

    with pytest.warns(FutureWarning):
        assert module.dtype == module.param_dtype

    with pytest.warns(FutureWarning):
        module = nknn.DenseEquivariant(g, features=2, dtype=complex)

    with pytest.warns(FutureWarning):
        assert module.dtype == module.param_dtype
コード例 #4
0
    def __call__(self, x_in):
        x = nknn.DenseSymm(
            name="Dense",
            symmetries=self.symmetries,
            features=self.features,
            dtype=self.dtype,
            use_bias=self.use_hidden_bias,
            kernel_init=self.kernel_init,
            bias_init=self.hidden_bias_init,
            precision=self.precision,
        )(x_in)
        x = self.activation(x)
        x = jnp.sum(x, axis=-1)

        if self.use_visible_bias:
            v_bias = self.param("visible_bias", self.visible_bias_init, (1, ),
                                self.dtype)
            out_bias = v_bias[0] * jnp.sum(x_in, axis=-1)
            return x + out_bias
        else:
            return x