def testGluValue(self): val = nn.glu(jnp.array([1.0, 0.0])) self.assertAllClose(val, jnp.array([0.5]))
def testGluValue(self): val = nn.glu(np.array([1.0, 0.0])) self.assertAllClose(val, np.array([0.5]), check_dtypes=True)