def _decode(channels: int) -> Callable[[jnp.ndarray], jnp.ndarray]: return hk.Sequential([ hk.GroupNorm(8), _upsample, _conv(channels), nn.leaky_relu, hk.GroupNorm(8), _upsample, _conv(channels), nn.leaky_relu, hk.GroupNorm(8), _upsample, _conv(channels), nn.leaky_relu, ])
shape=(BATCH_SIZE, 2, 2, 3)), ModuleDescriptor( name="Bias", create=lambda: hk.Bias(), shape=(BATCH_SIZE, 3, 3, 3)), ModuleDescriptor( name="Flatten", create=lambda: hk.Flatten(), shape=(BATCH_SIZE, 3, 3, 3)), ModuleDescriptor( name="InstanceNorm", create=lambda: hk.InstanceNorm(True, True), shape=(BATCH_SIZE, 3, 2)), ModuleDescriptor( name="GroupNorm", create=lambda: hk.GroupNorm(5), shape=(BATCH_SIZE, 4, 4, 10)), ModuleDescriptor( name="LayerNorm", create=lambda: hk.LayerNorm(1, True, True), shape=(BATCH_SIZE, 3, 2)), ModuleDescriptor( name="MultiHeadAttention", create=lambda: MultiInput( # pylint: disable=g-long-lambda hk.MultiHeadAttention(num_heads=8, key_size=64, w_init_scale=1.0), num_inputs=3), shape=(BATCH_SIZE, 3, 2)), ModuleDescriptor( name="RMSNorm", create=lambda: hk.RMSNorm(1), shape=(BATCH_SIZE, 3, 2)),