Example #1
0
def _decode(channels: int) -> Callable[[jnp.ndarray], jnp.ndarray]:
    return hk.Sequential([
        hk.GroupNorm(8),
        _upsample,
        _conv(channels),
        nn.leaky_relu,
        hk.GroupNorm(8),
        _upsample,
        _conv(channels),
        nn.leaky_relu,
        hk.GroupNorm(8),
        _upsample,
        _conv(channels),
        nn.leaky_relu,
    ])
Example #2
0
     shape=(BATCH_SIZE, 2, 2, 3)),
 ModuleDescriptor(
     name="Bias",
     create=lambda: hk.Bias(),
     shape=(BATCH_SIZE, 3, 3, 3)),
 ModuleDescriptor(
     name="Flatten",
     create=lambda: hk.Flatten(),
     shape=(BATCH_SIZE, 3, 3, 3)),
 ModuleDescriptor(
     name="InstanceNorm",
     create=lambda: hk.InstanceNorm(True, True),
     shape=(BATCH_SIZE, 3, 2)),
 ModuleDescriptor(
     name="GroupNorm",
     create=lambda: hk.GroupNorm(5),
     shape=(BATCH_SIZE, 4, 4, 10)),
 ModuleDescriptor(
     name="LayerNorm",
     create=lambda: hk.LayerNorm(1, True, True),
     shape=(BATCH_SIZE, 3, 2)),
 ModuleDescriptor(
     name="MultiHeadAttention",
     create=lambda: MultiInput(  # pylint: disable=g-long-lambda
         hk.MultiHeadAttention(num_heads=8, key_size=64, w_init_scale=1.0),
         num_inputs=3),
     shape=(BATCH_SIZE, 3, 2)),
 ModuleDescriptor(
     name="RMSNorm",
     create=lambda: hk.RMSNorm(1),
     shape=(BATCH_SIZE, 3, 2)),