def down_pass(us, uls): u, ul, us, uls = ResnetDownBlock(nr_resnet)(us.pop(), uls.pop(), us, uls) u, ul, us, uls = ResnetDownBlock(nr_resnet + 1)( DoubleDown()(u), DoubleDownRight()(ul), us, uls) u, ul, us, uls = ResnetDownBlock(nr_resnet + 1)( DoubleDown()(u), DoubleDownRight()(ul), us, uls) assert len(us) == 0 assert len(uls) == 0 return NIN(10 * nr_logistic_mix)(elu(ul))
def testEluValue(self): val = nn.elu(1e4) self.assertAllClose(val, 1e4, check_dtypes=False)
def testEluMemory(self): # see https://github.com/google/jax/pull/1640 with core.skipping_checks(): # With checks we materialize the array jax.make_jaxpr(lambda: nn.elu(jnp.ones((10**12, )))) # don't oom
def concat_elu(x, axis=-1): return elu(np.concatenate((x, -x), axis))
def nonlin(x): #return np.tanh(x) return elu(x)