示例#1
0
        shape=(BATCH_SIZE, 2, 2)),
    ModuleDescriptor(
        name="Conv1DTranspose",
        create=lambda: hk.Conv1DTranspose(3, 3),
        shape=(BATCH_SIZE, 2, 2)),
    ModuleDescriptor(
        name="Conv2D",
        create=lambda: hk.Conv2D(3, 3),
        shape=(BATCH_SIZE, 2, 2, 2)),
    ModuleDescriptor(
        name="Conv2DTranspose",
        create=lambda: hk.Conv2DTranspose(3, 3),
        shape=(BATCH_SIZE, 2, 2, 2)),
    ModuleDescriptor(
        name="Conv3D",
        create=lambda: hk.Conv3D(3, 3),
        shape=(BATCH_SIZE, 2, 2, 2, 2)),
    ModuleDescriptor(
        name="Conv3DTranspose",
        create=lambda: hk.Conv3DTranspose(3, 3),
        shape=(BATCH_SIZE, 2, 2, 2, 2)),
    ModuleDescriptor(
        name="DepthwiseConv2D",
        create=lambda: hk.DepthwiseConv2D(1, 3),
        shape=(BATCH_SIZE, 2, 2, 2)),
)


class DummyCore(hk.RNNCore):

  def initial_state(self, batch_size):
示例#2
0
    def __init__(self,
                 output_channels: int,
                 kernel_shape: Sequence[int] = (1, 1, 1),
                 stride: Sequence[int] = (1, 1, 1),
                 with_bias: bool = False,
                 separable: bool = False,
                 normalize_fn: Optional[types.NormalizeFn] = None,
                 activation_fn: Optional[types.ActivationFn] = jax.nn.relu,
                 self_gating_fn: Optional[types.GatingFn] = None,
                 name='SUnit3D'):
        """Initializes the SUnit3D module.

    Args:
      output_channels: Number of output channels.
      kernel_shape: The shape of the kernel. A sequence of length 3.
      stride: Stride for the kernel. A sequence of length 3.
      with_bias: Whether to add a bias to the convolution.
      separable: Whether to use separable.
      normalize_fn: Function used for normalization.
      activation_fn: Function used as non-linearity.
      self_gating_fn: Function used for self-gating.
      name: The name of the module.

    Raises:
      ValueError: If `kernel_shape` or `stride` has the wrong shape.
    """
        super().__init__(name=name)

        # Check args.
        if len(kernel_shape) != 3:
            raise ValueError(
                'Given `kernel_shape` must have length 3 but has length '
                f'{len(kernel_shape)}.')
        if len(stride) != 3:
            raise ValueError(
                f'Given `stride` must have length 3 but has length {len(stride)}.'
            )

        self._normalize_fn = normalize_fn
        self._activation_fn = activation_fn
        self._self_gating_fn = self_gating_fn

        k0, k1, k2 = kernel_shape
        if separable and k1 != 1:
            spatial_kernel_shape = [1, k1, k2]
            temporal_kernel_shape = [k0, 1, 1]
            s0, s1, s2 = stride
            spatial_stride = [1, s1, s2]
            temporal_stride = [s0, 1, 1]
            self._convolutions = [
                hk.Conv3D(output_channels=output_channels,
                          kernel_shape=spatial_kernel_shape,
                          stride=spatial_stride,
                          padding='SAME',
                          with_bias=with_bias),
                hk.Conv3D(output_channels=output_channels,
                          kernel_shape=temporal_kernel_shape,
                          stride=temporal_stride,
                          padding='SAME',
                          with_bias=with_bias)
            ]

        else:
            self._convolutions = [
                hk.Conv3D(output_channels=output_channels,
                          kernel_shape=kernel_shape,
                          stride=stride,
                          padding='SAME',
                          with_bias=with_bias)
            ]