Esempio n. 1
0
    def __init__(
        self,
        num_spatial_dims: int,
        output_channels: int,
        kernel_shape: Union[int, Sequence[int]],
        stride: Union[int, Sequence[int]] = 1,
        padding: Union[str, Sequence[Tuple[int, int]]] = "SAME",
        with_bias: bool = True,
        w_init: Optional[hk.initializers.Initializer] = None,
        b_init: Optional[hk.initializers.Initializer] = None,
        data_format: str = "channels_last",
        mask: Optional[jnp.ndarray] = None,
        name: str = None,
    ):
        """Initializes the module.

    Args:
      num_spatial_dims: The number of spatial dimensions of the input.
      output_channels: Number of output channels.
      kernel_shape: The shape of the kernel. Either an integer or a sequence of
        length ``num_spatial_dims``.
      stride: Optional stride for the kernel. Either an integer or a sequence of
        length ``num_spatial_dims``. Defaults to 1.
      padding: Optional padding algorithm. Either "VALID" or "SAME".
        Defaults to "SAME". See:
        https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
      with_bias: Whether to add a bias. By default, true.
      w_init: Optional weight initialization. By default, truncated normal.
      b_init: Optional bias initialization. By default, zeros.
      data_format: The data format of the input. Can be either
        ``channels_first``, ``channels_last``, ``N...C`` or ``NC...``. By
        default, ``channels_last``.
      mask: Optional mask of the weights.
      name: The name of the module.
    """
        super().__init__(name=name)

        if num_spatial_dims <= 0:
            raise ValueError(
                "We only support convolution operations for `num_spatial_dims` "
                f"greater than 0, received num_spatial_dims={num_spatial_dims}."
            )

        self.num_spatial_dims = num_spatial_dims
        self.output_channels = output_channels
        self.kernel_shape = (utils.replicate(kernel_shape, num_spatial_dims,
                                             "kernel_shape"))
        self.with_bias = with_bias
        self.stride = utils.replicate(stride, num_spatial_dims, "strides")
        self.w_init = w_init
        self.b_init = b_init or jnp.zeros
        self.mask = mask
        # TODO(tomhennigan) Make use of hk.pad.create here?
        self.padding = padding
        self.data_format = data_format
        self.channel_index = utils.get_channel_index(data_format)
        self.dimension_numbers = to_dimension_numbers(
            num_spatial_dims,
            channels_last=(self.channel_index == -1),
            transpose=True)
Esempio n. 2
0
    def __call__(self, inputs):

        channel_index = utils.get_channel_index(self._data_format)
        weight_shape = self._kernel_shape + (1, self._channel_multiplier *
                                             inputs.shape[channel_index])
        fan_in_shape = np.prod(weight_shape[:-1])
        stddev = 1. / np.sqrt(fan_in_shape)
        w_init = self._w_init or initializers.TruncatedNormal(stddev=stddev)
        w = base.get_parameter("w", weight_shape, inputs.dtype, init=w_init)
        if self._channel_index == -1:
            dn = DIMENSION_NUMBERS[self._num_spatial_dims]
        else:
            dn = DIMENSION_NUMBERS_NCSPATIAL[self._num_spatial_dims]
        result = lax.conv_general_dilated(
            inputs,
            w,
            self._stride,
            self._padding,
            self._lhs_dilation,
            self._rhs_dilation,
            dn,
            feature_group_count=inputs.shape[channel_index])
        if self._with_bias:
            if channel_index == -1:
                bias_shape = (self._channel_multiplier *
                              inputs.shape[channel_index], )
            else:
                bias_shape = (self._channel_multiplier *
                              inputs.shape[channel_index], 1, 1)
            b = base.get_parameter("b", bias_shape, init=self._b_init)
            result = result + b
        return result
Esempio n. 3
0
    def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
        channel_index = utils.get_channel_index(self.data_format)
        w_shape = self.kernel_shape + (1, self.channel_multiplier *
                                       inputs.shape[channel_index])

        w_init = self.w_init
        if w_init is None:
            fan_in_shape = np.prod(w_shape[:-1])
            stddev = 1. / np.sqrt(fan_in_shape)
            w_init = hk.initializers.TruncatedNormal(stddev=stddev)
        w = hk.get_parameter("w", w_shape, inputs.dtype, init=w_init)

        out = lax.conv_general_dilated(
            inputs,
            w,
            window_strides=self.stride,
            padding=self.padding,
            lhs_dilation=self.lhs_dilation,
            rhs_dilation=self.rhs_dilation,
            dimension_numbers=self.dn,
            feature_group_count=inputs.shape[channel_index])

        if self.with_bias:
            if channel_index == -1:
                b_shape = (self.channel_multiplier *
                           inputs.shape[channel_index], )
            else:
                b_shape = (self.channel_multiplier *
                           inputs.shape[channel_index], 1, 1)
            b = hk.get_parameter("b", b_shape, init=self.b_init)
            b = jnp.broadcast_to(b, out.shape)
            out = out + b

        return out
Esempio n. 4
0
    def __init__(self,
                 create_scale,
                 create_offset,
                 decay_rate,
                 eps=1e-5,
                 scale_init=None,
                 offset_init=None,
                 axis=None,
                 cross_replica_axis=None,
                 data_format="channels_last",
                 name=None):
        """Constructs a BatchNorm module.

    Args:
      create_scale: Whether to include a trainable scaling factor.
      create_offset: Whether to include a trainable offset.
      decay_rate: Decay rate for EMA.
      eps: Small epsilon to avoid division by zero variance. Defaults 1e-5, as
        in the paper and Sonnet.
      scale_init: Optional initializer for gain (aka scale). Can only be set
        if `create_scale=True`. By default, one.
      offset_init: Optional initializer for bias (aka offset). Can only be set
        if `create_offset=True`. By default, zero.
      axis: Which axes to reduce over. The default (None)
        signifies that all but the channel axis should be normalized. Otherwise
        this is a list of axis indices which will have normalization
        statistics calculated.
      cross_replica_axis: If not None, it should be a string representing
        the axis name over which this module is being run within a jax.pmap.
        Supplying this argument means that batch statistics are calculated
        across all replicas on that axis.
      data_format: The data format of the input. Can be either
        `channels_first`, `channels_last`, `N...C` or `NC...`. By
        default it is `channels_last`.
      name: The module name.
    """
        super(BatchNorm, self).__init__(name=name)
        self._create_scale = create_scale
        self._create_offset = create_offset
        if not self._create_scale and scale_init is not None:
            raise ValueError("Cannot set `scale_init` if `create_scale=False`")
        self._scale_init = scale_init or jnp.ones
        if not self._create_offset and offset_init is not None:
            raise ValueError(
                "Cannot set `offset_init` if `create_offset=False`")
        self._offset_init = offset_init or jnp.zeros
        self._eps = eps

        self._cross_replica_axis = cross_replica_axis
        self._data_format = data_format
        self._channel_index = utils.get_channel_index(data_format)
        self._axis = axis

        self._mean_ema = moving_averages.ExponentialMovingAverage(
            decay_rate, name="mean_ema")
        self._var_ema = moving_averages.ExponentialMovingAverage(
            decay_rate, name="var_ema")
Esempio n. 5
0
  def __init__(self,
               num_spatial_dims,
               output_channels,
               kernel_shape,
               stride=1,
               padding="SAME",
               with_bias=True,
               w_init=None,
               b_init=None,
               data_format="channels_last",
               mask=None,
               name=None):
    """Initializes a Conv2DTranspose module.

    Args:
      num_spatial_dims: The number of spatial dimensions of the input.
      output_channels: Number of output channels.
      kernel_shape: The shape of the kernel. Either an integer or a sequence of
        length `num_spatial_dims`.
      stride: Optional stride for the kernel. Either an integer or a sequence of
        length `num_spatial_dims`. Defaults to 1.
      padding: Optional padding algorithm. Either "VALID" or "SAME".
        Defaults to "SAME". See:
        https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
      with_bias: Whether to add a bias. By default, true.
      w_init: Optional weight initialization. By default, truncated normal.
      b_init: Optional bias initialization. By default, zeros.
      data_format: The data format of the input. Can be either
        `channels_first`, `channels_last`, `N...C` or `NC...`. By default,
        `channels_last`.
      mask: Optional mask of the weights.
      name: The name of the module.
    """
    super(ConvNDTranspose, self).__init__(name=name)
    if not 1 <= num_spatial_dims <= 3:
      raise ValueError(
          "We only support convolution operations for num_spatial_dims=1, 2 or "
          "3, received num_spatial_dims={}.".format(num_spatial_dims))
    self._num_spatial_dims = num_spatial_dims
    self._output_channels = output_channels
    self._kernel_shape = utils.replicate(kernel_shape, num_spatial_dims,
                                         "kernel_shape")
    self._with_bias = with_bias
    self._stride = utils.replicate(stride, num_spatial_dims, "strides")
    self._w_init = w_init
    self._b_init = b_init or jnp.zeros
    self._mask = mask
    self._padding = padding

    self._data_format = data_format
    self._channel_index = utils.get_channel_index(data_format)
    if self._channel_index == -1:
      self._dn = DIMENSION_NUMBERS[self._num_spatial_dims]
    else:
      self._dn = DIMENSION_NUMBERS_NCSPATIAL[self._num_spatial_dims]
Esempio n. 6
0
    def __init__(
        self,
        create_scale: bool,
        create_offset: bool,
        decay_rate: float,
        eps: float = 1e-5,
        scale_init: Optional[hk.Initializer] = None,
        offset_init: Optional[hk.Initializer] = None,
        axis: Optional[Sequence[int]] = None,
        cross_replica_axis: Optional[str] = None,
        data_format: str = "channels_last",
        name: Optional[str] = None,
    ):
        """Constructs a BatchNorm module.

    Args:
      create_scale: Whether to include a trainable scaling factor.
      create_offset: Whether to include a trainable offset.
      decay_rate: Decay rate for EMA.
      eps: Small epsilon to avoid division by zero variance. Defaults ``1e-5``,
        as in the paper and Sonnet.
      scale_init: Optional initializer for gain (aka scale). Can only be set
        if ``create_scale=True``. By default, ``1``.
      offset_init: Optional initializer for bias (aka offset). Can only be set
        if ``create_offset=True``. By default, ``0``.
      axis: Which axes to reduce over. The default (``None``) signifies that all
        but the channel axis should be normalized. Otherwise this is a list of
        axis indices which will have normalization statistics calculated.
      cross_replica_axis: If not ``None``, it should be a string representing
        the axis name over which this module is being run within a ``jax.pmap``.
        Supplying this argument means that batch statistics are calculated
        across all replicas on that axis.
      data_format: The data format of the input. Can be either
        ``channels_first``, ``channels_last``, ``N...C`` or ``NC...``. By
        default it is ``channels_last``.
      name: The module name.
    """
        super().__init__(name=name)
        if not create_scale and scale_init is not None:
            raise ValueError("Cannot set `scale_init` if `create_scale=False`")
        if not create_offset and offset_init is not None:
            raise ValueError(
                "Cannot set `offset_init` if `create_offset=False`")

        self.create_scale = create_scale
        self.create_offset = create_offset
        self.eps = eps
        self.scale_init = scale_init or jnp.ones
        self.offset_init = offset_init or jnp.zeros
        self.axis = axis
        self.cross_replica_axis = cross_replica_axis
        self.channel_index = utils.get_channel_index(data_format)
        self.mean_ema = hk.ExponentialMovingAverage(decay_rate,
                                                    name="mean_ema")
        self.var_ema = hk.ExponentialMovingAverage(decay_rate, name="var_ema")
Esempio n. 7
0
    def __init__(
        self,
        channel_multiplier: int,
        kernel_shape: Union[int, Sequence[int]],
        stride: Union[int, Sequence[int]] = 1,
        padding: Union[str, Sequence[Tuple[int, int]]] = "SAME",
        with_bias: bool = True,
        w_init: Optional[hk.initializers.Initializer] = None,
        b_init: Optional[hk.initializers.Initializer] = None,
        data_format: str = "NHWC",
        name: Optional[str] = None,
    ):
        """Construct a 2D Depthwise Convolution.

    Args:
      channel_multiplier: Multiplicity of output channels. To keep the number of
        output channels the same as the number of input channels, set 1.
      kernel_shape: The shape of the kernel. Either an integer or a sequence of
        length ``num_spatial_dims``.
      stride: Optional stride for the kernel. Either an integer or a sequence of
        length ``num_spatial_dims``. Defaults to 1.
      padding: Optional padding algorithm. Either ``VALID``, ``SAME`` or a
        sequence of ``before, after`` pairs. Defaults to ``SAME``. See:
        https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
      with_bias: Whether to add a bias. By default, true.
      w_init: Optional weight initialization. By default, truncated normal.
      b_init: Optional bias initialization. By default, zeros.
      data_format: The data format of the input.  Can be either
        ``channels_first``, ``channels_last``, ``N...C`` or ``NC...``. By
        default, ``channels_last``.
      name: The name of the module.
    """
        super().__init__(name=name)
        self.kernel_shape = utils.replicate(kernel_shape, 2, "kernel_shape")
        self.lhs_dilation = (1, ) * len(self.kernel_shape)
        self.rhs_dilation = (1, ) * len(self.kernel_shape)
        self.channel_multiplier = channel_multiplier
        self.padding = padding
        self.stride = utils.replicate(stride, 2, "strides")
        self.data_format = data_format
        self.channel_index = utils.get_channel_index(data_format)
        self.with_bias = with_bias
        self.w_init = w_init
        self.b_init = b_init or jnp.zeros
        self.num_spatial_dims = 2
        if self.channel_index == -1:
            self.dn = DIMENSION_NUMBERS[self.num_spatial_dims]
        else:
            self.dn = DIMENSION_NUMBERS_NCSPATIAL[self.num_spatial_dims]
Esempio n. 8
0
    def __init__(self,
                 channel_multiplier,
                 kernel_shape,
                 stride=1,
                 padding="SAME",
                 with_bias=True,
                 w_init=None,
                 b_init=None,
                 data_format="NHWC",
                 name=None):
        """Construct a 2D Depthwise Convolution.

    Args:
      channel_multiplier: Multiplicity of output channels. To keep the number of
        output channels the same as the number of input channels, set 1.
      kernel_shape: The shape of the kernel. Either an integer or a sequence of
        length `num_spatial_dims`.
      stride: Optional stride for the kernel. Either an integer or a sequence of
        length `num_spatial_dims`. Defaults to 1.
      padding: Optional padding algorithm. Either "VALID" or "SAME" or
        a callable or sequence of callables of size `num_spatial_dims`. Any
        callables must take a single integer argument equal to the effective
        kernel size and return a list of two integers representing the padding
        before and after. See haiku.pad.* for more details and example
        functions. Defaults to "SAME". See:
        https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
      with_bias: Whether to add a bias. By default, true.
      w_init: Optional weight initialization. By default, truncated normal.
      b_init: Optional bias initialization. By default, zeros.
      data_format: The data format of the input.  Can be either
        `channels_first`, `channels_last`, `N...C` or `NC...`. By default,
        `channels_last`.
      name: The name of the module.
    """
        super(DepthwiseConv2D, self).__init__(name=name)
        self._kernel_shape = utils.replicate(kernel_shape, 2, "kernel_shape")
        self._lhs_dilation = (1, ) * len(self._kernel_shape)
        self._rhs_dilation = (1, ) * len(self._kernel_shape)
        self._channel_multiplier = channel_multiplier
        self._padding = padding
        self._stride = utils.replicate(stride, 2, "strides")
        self._data_format = data_format
        self._channel_index = utils.get_channel_index(data_format)
        self._with_bias = with_bias
        self._w_init = w_init
        self._b_init = b_init or jnp.zeros
        self._num_spatial_dims = 2
Esempio n. 9
0
    def __init__(
        self,
        create_scale: bool = True,
        create_offset: bool = True,
        eps: float = 1e-5,
        scale_init: Optional[initializers.Initializer] = None,
        offset_init: Optional[initializers.Initializer] = None,
        data_format: str = "channels_last",
        **kwargs
    ):
        """Constructs an `InstanceNormalization` module.

        This method creates a module which normalizes over the spatial dimensions.

        Args:
            create_scale: ``bool`` representing whether to create a trainable scale
                per channel applied after the normalization.
            create_offset: ``bool`` representing whether to create a trainable offset
                per channel applied after normalization and scaling.
            eps: Small epsilon to avoid division by zero variance. Defaults to
                ``1e-5``.
            scale_init: Optional initializer for the scale variable. Can only be set
                if ``create_scale=True``. By default scale is initialized to ``1``.
            offset_init: Optional initializer for the offset variable. Can only be set
                if ``create_offset=True``. By default offset is initialized to ``0``.
            data_format: The data format of the input. Can be either
                ``channels_first``, ``channels_last``, ``N...C`` or ``NC...``. By
                default it is ``channels_last``.
            kwargs: Additional keyword arguments passed to Module.
        """
        if hk_utils.get_channel_index(data_format) == 1:
            axis = slice(2, None)
        else:  # channel_index = -1
            axis = slice(1, -1)
        super().__init__(
            axis=axis,
            create_scale=create_scale,
            create_offset=create_offset,
            eps=eps,
            scale_init=scale_init,
            offset_init=offset_init,
            **kwargs
        )
Esempio n. 10
0
  def __init__(self,
               create_scale,
               create_offset,
               eps=1e-5,
               scale_init=None,
               offset_init=None,
               data_format="channels_last",
               name=None):
    """Constructs an `InstanceNorm` module.

    This method creates a module which normalizes over the spatial dimensions.

    Args:
      create_scale: `bool` representing whether to create a trainable scale
        per channel applied after the normalization.
      create_offset: `bool` representing whether to create a trainable offset
        per channel applied after normalization and scaling.
      eps: Small epsilon to avoid division by zero variance. Defaults to
        `1e-5`.
      scale_init: Optional initializer for the scale variable. Can only be set
        if `create_scale=True`. By default scale is initialized to `1`.
      offset_init: Optional initializer for the offset variable. Can only be set
        if `create_offset=True`. By default offset is initialized to `0`.
      data_format: The data format of the input. Can be either
        `channels_first`, `channels_last`, `N...C` or `NC...`. By
        default it is `channels_last`.
      name: Name of the module.
    """
    if utils.get_channel_index(data_format) == 1:
      axis = slice(2, None)
    else:  # channel_index = -1
      axis = slice(1, -1)
    super(InstanceNorm, self).__init__(
        axis=axis,
        create_scale=create_scale,
        create_offset=create_offset,
        eps=eps,
        scale_init=scale_init,
        offset_init=offset_init,
        name=name)
Esempio n. 11
0
 def test_invalid_strings(self, data_format):
     with self.assertRaisesRegex(
             ValueError,
             "Unable to extract channel information from '{}'.".format(
                 data_format)):
         utils.get_channel_index(data_format)
Esempio n. 12
0
 def test_returns_index_channels_last(self, data_format):
     self.assertEqual(utils.get_channel_index(data_format), -1)
Esempio n. 13
0
  def __init__(self,
               num_spatial_dims,
               output_channels,
               kernel_shape,
               stride=1,
               rate=1,
               padding="SAME",
               with_bias=True,
               w_init=None,
               b_init=None,
               data_format="channels_last",
               mask=None,
               name=None):
    """Constructs a `ConvND` module.

    Args:
      num_spatial_dims: The number of spatial dimensions of the input.
      output_channels: Number of output channels.
      kernel_shape: The shape of the kernel. Either an integer or a sequence of
        length `num_spatial_dims`.
      stride: Optional stride for the kernel. Either an integer or a sequence of
        length `num_spatial_dims`. Defaults to 1.
      rate: Optional kernel dilation rate. Either an integer or a sequence of
        length `num_spatial_dims`. 1 corresponds to standard ND convolution,
        `rate > 1` corresponds to dilated convolution. Defaults to 1.
      padding: Optional padding algorithm. Either "VALID" or "SAME" or
        a callable or sequence of callables of size `num_spatial_dims`. Any
        callables must take a single integer argument equal to the effective
        kernel size and return a list of two integers representing the padding
        before and after. See haiku.pad.* for more details and example
        functions. Defaults to "SAME". See:
        https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
      with_bias: Whether to add a bias. By default, true.
      w_init: Optional weight initialization. By default, truncated normal.
      b_init: Optional bias initialization. By default, zeros.
      data_format: The data format of the input.  Can be either
        `channels_first`, `channels_last`, `N...C` or `NC...`. By default,
        `channels_last`.
      mask: Optional mask of the weights.
      name: The name of the module.
    """
    super(ConvND, self).__init__(name=name)

    if not 1 <= num_spatial_dims <= 3:
      raise ValueError(
          "We only support convolution operations for num_spatial_dims=1, 2 or "
          "3, received num_spatial_dims={}.".format(num_spatial_dims))
    self._num_spatial_dims = num_spatial_dims
    self._output_channels = output_channels
    self._kernel_shape = utils.replicate(kernel_shape, num_spatial_dims,
                                         "kernel_shape")
    self._with_bias = with_bias
    self._stride = utils.replicate(stride, num_spatial_dims, "strides")
    self._w_init = w_init
    self._b_init = b_init or jnp.zeros
    self._mask = mask
    self._lhs_dilation = utils.replicate(1, num_spatial_dims, "lhs_dilation")
    self._kernel_dilation = utils.replicate(rate, num_spatial_dims,
                                            "kernel_dilation")
    self._data_format = data_format
    self._channel_index = utils.get_channel_index(data_format)
    if self._channel_index == -1:
      self._dn = DIMENSION_NUMBERS[self._num_spatial_dims]
    else:
      self._dn = DIMENSION_NUMBERS_NCSPATIAL[self._num_spatial_dims]

    if isinstance(padding, str):
      self._padding = padding.upper()
    else:
      self._padding = pad.create(
          padding=padding,
          kernel=self._kernel_shape,
          rate=self._kernel_dilation,
          n=self._num_spatial_dims)
Esempio n. 14
0
    def __init__(
        self,
        num_spatial_dims: int,
        output_channels: int,
        kernel_shape: tp.Union[int, tp.Sequence[int]],
        stride: tp.Union[int, tp.Sequence[int]] = 1,
        rate: tp.Union[int, tp.Sequence[int]] = 1,
        padding: tp.Union[str, tp.Sequence[tp.Tuple[int, int]],
                          types.PadFnOrFns] = "SAME",
        with_bias: bool = True,
        w_init: tp.Optional[types.Initializer] = None,
        b_init: tp.Optional[types.Initializer] = None,
        data_format: str = "channels_last",
        mask: tp.Optional[np.ndarray] = None,
        groups: int = 1,
        **kwargs,
    ):
        """
        Initializes the module.

        Args:
            num_spatial_dims: The number of spatial dimensions of the input.
            output_channels: Number of output channels.
            kernel_shape: The shape of the kernel. Either an integer or a sequence of
                length ``num_spatial_dims``.
            stride: tp.Optional stride for the kernel. Either an integer or a sequence of
                length ``num_spatial_dims``. Defaults to 1.
            rate: tp.Optional kernel dilation rate. Either an integer or a sequence of
                length ``num_spatial_dims``. 1 corresponds to standard ND convolution,
                ``rate > 1`` corresponds to dilated convolution. Defaults to 1.
            padding: tp.Optional padding algorithm. Either ``VALID`` or ``SAME`` or a
                sequence of n ``(low, high)`` integer pairs that give the padding to
                apply before and after each spatial dimension. or a callable or sequence
                of callables of size ``num_spatial_dims``. Any callables must take a
                single integer argument equal to the effective kernel size and return a
                sequence of two integers representing the padding before and after. See
                ``haiku.pad.*`` for more details and example functions. Defaults to
                ``SAME``. See:
                https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
            with_bias: Whether to add a bias. By default, true.
            w_init: tp.Optional weight initialization. By default, truncated normal.
            b_init: tp.Optional bias initialization. By default, zeros.
            data_format: The data format of the input.  Can be either
                ``channels_first``, ``channels_last``, ``N...C`` or ``NC...``. By
                default, ``channels_last``.
            mask: tp.Optional mask of the weights.
            groups: A positive integer specifying the number of groups in which the
                input is split along the channel axis. Each group is convolved separately
                with filters / groups filters. The output is the concatenation of all the
                groups results along the channel axis. Input channels and filters must both
                be divisible by groups.
            kwargs: Additional keyword arguments passed to Module.
        """
        super().__init__(**kwargs)
        if num_spatial_dims <= 0:
            raise ValueError(
                "We only support convolution operations for `num_spatial_dims` "
                f"greater than 0, received num_spatial_dims={num_spatial_dims}."
            )

        self.num_spatial_dims = num_spatial_dims
        self.output_channels = output_channels
        self.kernel_shape = hk_utils.replicate(kernel_shape, num_spatial_dims,
                                               "kernel_shape")
        self.with_bias = with_bias
        self.stride = hk_utils.replicate(stride, num_spatial_dims, "strides")
        self.w_init = w_init
        self.b_init = b_init or jnp.zeros
        self.mask = mask
        self.lhs_dilation = hk_utils.replicate(1, num_spatial_dims,
                                               "lhs_dilation")
        self.kernel_dilation = hk_utils.replicate(rate, num_spatial_dims,
                                                  "kernel_dilation")
        self.data_format = data_format
        self.channel_index = hk_utils.get_channel_index(data_format)
        self.dimension_numbers = to_dimension_numbers(
            num_spatial_dims,
            channels_last=(self.channel_index == -1),
            transpose=False)
        self.groups = groups

        if isinstance(padding, str):
            self.padding = padding.upper()
        else:
            self.padding = hk.pad.create(
                padding=padding,
                kernel=self.kernel_shape,
                rate=self.kernel_dilation,
                n=self.num_spatial_dims,
            )
Esempio n. 15
0
    def __init__(
        self,
        output_channels: int,
        uniform_init_minval: float,
        uniform_init_maxval: float,
        kernel_shape: Union[int, Sequence[int]],
        num_spatial_dims: int = 2,
        stride: Union[int, Sequence[int]] = 1,
        rate: Union[int, Sequence[int]] = 1,
        padding: Union[str, Sequence[Tuple[int, int]], hk.pad.PadFn,
                       Sequence[hk.pad.PadFn]] = "SAME",
        with_bias: bool = True,
        w_init: Union[Optional[hk.initializers.Initializer], str] = "uniform",
        b_init: Union[Optional[hk.initializers.Initializer], str] = "uniform",
        data_format: str = "channels_last",
        mask: Optional[jnp.ndarray] = None,
        feature_group_count: int = 1,
        name: Optional[str] = None,
        stochastic_parameters: bool = False,
    ):
        """Initializes the module.

    Args:
      output_channels: Number of output channels.
      uniform_init_minval: TODO(nband).
      uniform_init_maxval: TODO(nband).
      kernel_shape: The shape of the kernel. Either an integer or a sequence
        of length ``num_spatial_dims``.
      num_spatial_dims: The number of spatial dimensions of the input.
      stride: Optional stride for the kernel. Either an integer or a sequence
        of length ``num_spatial_dims``. Defaults to 1.
      rate: Optional kernel dilation rate. Either an integer or a sequence of
        length ``num_spatial_dims``. 1 corresponds to standard ND convolution,
        ``rate > 1`` corresponds to dilated convolution. Defaults to 1.
      padding: Optional padding algorithm. Either ``VALID`` or ``SAME`` or a
        sequence of n ``(low, high)`` integer pairs that give the padding to
        apply before and after each spatial dimension. or a callable or
        sequence of callables of size ``num_spatial_dims``. Any callables must
        take a single integer argument equal to the effective kernel size and
        return a sequence of two integers representing the padding before and
        after. See ``haiku.pad.*`` for more details and example functions.
        Defaults to
          ``SAME``. See:
          https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
      with_bias: Whether to add a bias. By default, true.
      w_init: Optional weight initialization. By default, truncated normal.
      b_init: Optional bias initialization. By default, zeros.
      data_format: The data format of the input.  Can be either
        ``channels_first``, ``channels_last``, ``N...C`` or ``NC...``. By
        default, ``channels_last``.
      mask: Optional mask of the weights.
      feature_group_count: Optional number of groups in group convolution.
        Default value of 1 corresponds to normal dense convolution. If a
        higher value is used, convolutions are applied separately to that many
        groups, then stacked together. This reduces the number of parameters
          and possibly the compute for a given ``output_channels``. See:
          https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
      name: The name of the module.
      stochastic_parameters: TODO(nband).
    """
        super().__init__(name=name)
        if num_spatial_dims <= 0:
            raise ValueError(
                "We only support convolution operations for `num_spatial_dims` "
                f"greater than 0, received num_spatial_dims={num_spatial_dims}."
            )

        self.num_spatial_dims = num_spatial_dims
        self.output_channels = output_channels
        self.kernel_shape = utils.replicate(kernel_shape, num_spatial_dims,
                                            "kernel_shape")
        self.with_bias = with_bias
        self.stride = utils.replicate(stride, num_spatial_dims, "strides")
        self.w_init = w_init
        self.b_init = b_init
        self.uniform_init_minval = uniform_init_minval
        self.uniform_init_maxval = uniform_init_maxval
        self.mask = mask
        self.feature_group_count = feature_group_count
        self.lhs_dilation = utils.replicate(1, num_spatial_dims,
                                            "lhs_dilation")
        self.kernel_dilation = utils.replicate(rate, num_spatial_dims,
                                               "kernel_dilation")
        self.data_format = data_format
        self.channel_index = utils.get_channel_index(data_format)
        self.dimension_numbers = to_dimension_numbers(
            num_spatial_dims,
            channels_last=(self.channel_index == -1),
            transpose=False)
        self.stochastic_parameters = stochastic_parameters

        if isinstance(padding, str):
            self.padding = padding.upper()
        else:
            self.padding = hk.pad.create(
                padding=padding,
                kernel=self.kernel_shape,
                rate=self.kernel_dilation,
                n=self.num_spatial_dims,
            )
Esempio n. 16
0
    def __init__(
        self,
        groups: int,
        axis: Union[int, slice, Sequence[int]] = slice(1, None),
        create_scale: bool = True,
        create_offset: bool = True,
        eps: float = 1e-5,
        scale_init: Optional[hk.initializers.Initializer] = None,
        offset_init: Optional[hk.initializers.Initializer] = None,
        data_format: str = "channels_last",
        name: Optional[str] = None,
    ):
        """Constructs a ``GroupNorm`` module.

    Args:
      groups: number of groups to divide the channels by. The number of channels
        must be divisible by this.
      axis: ``int``, ``slice`` or sequence of ints representing the axes which
        should be normalized across. By default this is all but the first
        dimension. For time series data use `slice(2, None)` to average over the
        none Batch and Time data.
      create_scale: whether to create a trainable scale per channel applied
        after the normalization.
      create_offset: whether to create a trainable offset per channel applied
        after normalization and scaling.
      eps: Small epsilon to add to the variance to avoid division by zero.
        Defaults to ``1e-5``.
      scale_init: Optional initializer for the scale parameter. Can only be set
        if ``create_scale=True``. By default scale is initialized to ``1``.
      offset_init: Optional initializer for the offset parameter. Can only be
        set if ``create_offset=True``. By default offset is initialized to
        ``0``.
      data_format: The data format of the input. Can be either
        ``channels_first``, ``channels_last``, ``N...C`` or ``NC...``. By
        default it is ``channels_last``.
      name: Name of the module.
    """
        super().__init__(name=name)

        if isinstance(axis, slice):
            self.axis = axis
        elif isinstance(axis, int):
            self.axis = (axis, )
        elif (isinstance(axis, collections.abc.Iterable)
              and all(isinstance(ax, int) for ax in axis)):
            self.axis = axis
        else:
            raise ValueError(
                "`axis` should be an int, slice or iterable of ints.")

        self.groups = groups
        self.eps = eps
        self.data_format = data_format
        self.channel_index = utils.get_channel_index(data_format)
        self.create_scale = create_scale
        self.create_offset = create_offset
        self.rank = None

        if self.create_scale:
            if scale_init is None:
                scale_init = jnp.ones
            self.scale_init = scale_init
        elif scale_init is not None:
            raise ValueError(
                "Cannot set `scale_init` if `create_scale=False`.")

        if self.create_offset:
            if offset_init is None:
                offset_init = jnp.zeros
            self.offset_init = offset_init
        elif offset_init is not None:
            raise ValueError(
                "Cannot set `offset_init` if `create_offset=False`.")