def __call__(self, x: jnp.ndarray, *, is_training: bool) -> jnp.ndarray:
        n_upsamples = max(self._n_time_upsamples, self._n_space_upsamples)

        time_stride = 2
        space_stride = 2

        for i in range(n_upsamples):
            if i >= self._n_time_upsamples:
                time_stride = 1
            if i >= self._n_space_upsamples:
                space_stride = 1

            channels = self._n_outputs * pow(2, n_upsamples - 1 - i)

            x = hk.Conv3DTranspose(
                output_channels=channels,
                stride=[time_stride, space_stride, space_stride],
                kernel_shape=[4, 4, 4],
                name=f'conv3d_transpose_{i}')(x)
            if i != n_upsamples - 1:
                x = jax.nn.relu(x)

        return x
Exemple #2
0
        shape=(BATCH_SIZE, 2, 2)),
    ModuleDescriptor(
        name="Conv2D",
        create=lambda: hk.Conv2D(3, 3),
        shape=(BATCH_SIZE, 2, 2, 2)),
    ModuleDescriptor(
        name="Conv2DTranspose",
        create=lambda: hk.Conv2DTranspose(3, 3),
        shape=(BATCH_SIZE, 2, 2, 2)),
    ModuleDescriptor(
        name="Conv3D",
        create=lambda: hk.Conv3D(3, 3),
        shape=(BATCH_SIZE, 2, 2, 2, 2)),
    ModuleDescriptor(
        name="Conv3DTranspose",
        create=lambda: hk.Conv3DTranspose(3, 3),
        shape=(BATCH_SIZE, 2, 2, 2, 2)),
    ModuleDescriptor(
        name="DepthwiseConv2D",
        create=lambda: hk.DepthwiseConv2D(1, 3),
        shape=(BATCH_SIZE, 2, 2, 2)),
)


class DummyCore(hk.RNNCore):

  def initial_state(self, batch_size):
    return jnp.ones([batch_size, 128, 1])

  def __call__(self, inputs, state):
    return inputs, state