shape=(BATCH_SIZE, 2, 2)), ModuleDescriptor( name="Conv1DTranspose", create=lambda: hk.Conv1DTranspose(3, 3), shape=(BATCH_SIZE, 2, 2)), ModuleDescriptor( name="Conv2D", create=lambda: hk.Conv2D(3, 3), shape=(BATCH_SIZE, 2, 2, 2)), ModuleDescriptor( name="Conv2DTranspose", create=lambda: hk.Conv2DTranspose(3, 3), shape=(BATCH_SIZE, 2, 2, 2)), ModuleDescriptor( name="Conv3D", create=lambda: hk.Conv3D(3, 3), shape=(BATCH_SIZE, 2, 2, 2, 2)), ModuleDescriptor( name="Conv3DTranspose", create=lambda: hk.Conv3DTranspose(3, 3), shape=(BATCH_SIZE, 2, 2, 2, 2)), ModuleDescriptor( name="DepthwiseConv2D", create=lambda: hk.DepthwiseConv2D(1, 3), shape=(BATCH_SIZE, 2, 2, 2)), ) class DummyCore(hk.RNNCore): def initial_state(self, batch_size):
def __init__(self, output_channels: int, kernel_shape: Sequence[int] = (1, 1, 1), stride: Sequence[int] = (1, 1, 1), with_bias: bool = False, separable: bool = False, normalize_fn: Optional[types.NormalizeFn] = None, activation_fn: Optional[types.ActivationFn] = jax.nn.relu, self_gating_fn: Optional[types.GatingFn] = None, name='SUnit3D'): """Initializes the SUnit3D module. Args: output_channels: Number of output channels. kernel_shape: The shape of the kernel. A sequence of length 3. stride: Stride for the kernel. A sequence of length 3. with_bias: Whether to add a bias to the convolution. separable: Whether to use separable. normalize_fn: Function used for normalization. activation_fn: Function used as non-linearity. self_gating_fn: Function used for self-gating. name: The name of the module. Raises: ValueError: If `kernel_shape` or `stride` has the wrong shape. """ super().__init__(name=name) # Check args. if len(kernel_shape) != 3: raise ValueError( 'Given `kernel_shape` must have length 3 but has length ' f'{len(kernel_shape)}.') if len(stride) != 3: raise ValueError( f'Given `stride` must have length 3 but has length {len(stride)}.' ) self._normalize_fn = normalize_fn self._activation_fn = activation_fn self._self_gating_fn = self_gating_fn k0, k1, k2 = kernel_shape if separable and k1 != 1: spatial_kernel_shape = [1, k1, k2] temporal_kernel_shape = [k0, 1, 1] s0, s1, s2 = stride spatial_stride = [1, s1, s2] temporal_stride = [s0, 1, 1] self._convolutions = [ hk.Conv3D(output_channels=output_channels, kernel_shape=spatial_kernel_shape, stride=spatial_stride, padding='SAME', with_bias=with_bias), hk.Conv3D(output_channels=output_channels, kernel_shape=temporal_kernel_shape, stride=temporal_stride, padding='SAME', with_bias=with_bias) ] else: self._convolutions = [ hk.Conv3D(output_channels=output_channels, kernel_shape=kernel_shape, stride=stride, padding='SAME', with_bias=with_bias) ]