Exemple #1
0
    def __call__(self, σr, σc):
        U_S = nknn.Dense(
            name="Symm",
            features=int(self.alpha * σr.shape[-1]),
            dtype=self.dtype,
            use_bias=False,
            kernel_init=self.kernel_init,
            precision=self.precision,
        )
        U_A = nknn.Dense(
            name="ASymm",
            features=int(self.alpha * σr.shape[-1]),
            dtype=self.dtype,
            use_bias=False,
            kernel_init=self.kernel_init,
            precision=self.precision,
        )
        y = U_S(0.5 * (σr + σc)) + 1j * U_A(0.5 * (σr - σc))

        if self.use_bias:
            bias = self.param(
                "bias",
                self.bias_init,
                (int(self.alpha * σr.shape[-1]),),
                nkjax.dtype_real(self.dtype),
            )
            y = y + bias

        y = self.activation(y)
        return y.sum(axis=-1)
Exemple #2
0
def test_deprecated_layers():
    with pytest.warns(FutureWarning):
        module = nknn.Dense(features=3, dtype=complex)

    with pytest.raises(KeyError):
        nknn.Dense(features=3, param_dtype=complex)

    module2 = nn.Dense(features=3, param_dtype=complex)

    assert module == module2
Exemple #3
0
    def __call__(self, σr, σc, symmetric=True):
        W = nknn.Dense(
            name="Dense",
            features=int(self.alpha * σr.shape[-1]),
            dtype=self.dtype,
            use_bias=self.use_hidden_bias,
            kernel_init=self.kernel_init,
            bias_init=self.hidden_bias_init,
            precision=self.precision,
        )
        xr = self.activation(W(σr)).sum(axis=-1)
        xc = self.activation(W(σc)).sum(axis=-1)

        if symmetric:
            y = xr + xc
        else:
            y = xr - xc

        if self.use_visible_bias:
            v_bias = self.param(
                "visible_bias", self.visible_bias_init, (σr.shape[-1],), self.dtype
            )
            if symmetric:
                out_bias = jnp.dot(σr + σc, v_bias)
            else:
                out_bias = jnp.dot(σr - σc, v_bias)

            y = y + out_bias

        return 0.5 * y
Exemple #4
0
    def __call__(self, x):
        re = nknn.Dense(
            features=int(self.alpha * x.shape[-1]),
            dtype=self.dtype,
            use_bias=self.use_bias,
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
        )(x)
        re = self.activation(re)
        re = jnp.sum(re, axis=-1)

        im = nknn.Dense(
            features=int(self.alpha * x.shape[-1]),
            dtype=self.dtype,
            use_bias=self.use_bias,
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
        )(x)
        im = self.activation(im)
        im = jnp.sum(im, axis=-1)

        return re + 1j * im
Exemple #5
0
    def __call__(self, input):
        x = nknn.Dense(
            name="Dense",
            features=int(self.alpha * input.shape[-1]),
            dtype=self.dtype,
            use_bias=self.use_bias,
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
        )(input)
        x = self.activation(x)
        x = jnp.sum(x, axis=-1)

        if self.use_visible_bias:
            v_bias = self.param("visible_bias", self.visible_bias_init,
                                (input.shape[-1], ), self.dtype)
            out_bias = jnp.dot(input, v_bias)
            return x + out_bias
        else:
            return x