def __call__(self, σr, σc): U_S = nknn.Dense( name="Symm", features=int(self.alpha * σr.shape[-1]), dtype=self.dtype, use_bias=False, kernel_init=self.kernel_init, precision=self.precision, ) U_A = nknn.Dense( name="ASymm", features=int(self.alpha * σr.shape[-1]), dtype=self.dtype, use_bias=False, kernel_init=self.kernel_init, precision=self.precision, ) y = U_S(0.5 * (σr + σc)) + 1j * U_A(0.5 * (σr - σc)) if self.use_bias: bias = self.param( "bias", self.bias_init, (int(self.alpha * σr.shape[-1]),), nkjax.dtype_real(self.dtype), ) y = y + bias y = self.activation(y) return y.sum(axis=-1)
def test_deprecated_layers(): with pytest.warns(FutureWarning): module = nknn.Dense(features=3, dtype=complex) with pytest.raises(KeyError): nknn.Dense(features=3, param_dtype=complex) module2 = nn.Dense(features=3, param_dtype=complex) assert module == module2
def __call__(self, σr, σc, symmetric=True): W = nknn.Dense( name="Dense", features=int(self.alpha * σr.shape[-1]), dtype=self.dtype, use_bias=self.use_hidden_bias, kernel_init=self.kernel_init, bias_init=self.hidden_bias_init, precision=self.precision, ) xr = self.activation(W(σr)).sum(axis=-1) xc = self.activation(W(σc)).sum(axis=-1) if symmetric: y = xr + xc else: y = xr - xc if self.use_visible_bias: v_bias = self.param( "visible_bias", self.visible_bias_init, (σr.shape[-1],), self.dtype ) if symmetric: out_bias = jnp.dot(σr + σc, v_bias) else: out_bias = jnp.dot(σr - σc, v_bias) y = y + out_bias return 0.5 * y
def __call__(self, x): re = nknn.Dense( features=int(self.alpha * x.shape[-1]), dtype=self.dtype, use_bias=self.use_bias, kernel_init=self.kernel_init, bias_init=self.bias_init, )(x) re = self.activation(re) re = jnp.sum(re, axis=-1) im = nknn.Dense( features=int(self.alpha * x.shape[-1]), dtype=self.dtype, use_bias=self.use_bias, kernel_init=self.kernel_init, bias_init=self.bias_init, )(x) im = self.activation(im) im = jnp.sum(im, axis=-1) return re + 1j * im
def __call__(self, input): x = nknn.Dense( name="Dense", features=int(self.alpha * input.shape[-1]), dtype=self.dtype, use_bias=self.use_bias, kernel_init=self.kernel_init, bias_init=self.bias_init, )(input) x = self.activation(x) x = jnp.sum(x, axis=-1) if self.use_visible_bias: v_bias = self.param("visible_bias", self.visible_bias_init, (input.shape[-1], ), self.dtype) out_bias = jnp.dot(input, v_bias) return x + out_bias else: return x