def __init__( self, num_spatial_dims: int, output_channels: int, kernel_shape: Union[int, Sequence[int]], stride: Union[int, Sequence[int]] = 1, padding: Union[str, Sequence[Tuple[int, int]]] = "SAME", with_bias: bool = True, w_init: Optional[hk.initializers.Initializer] = None, b_init: Optional[hk.initializers.Initializer] = None, data_format: str = "channels_last", mask: Optional[jnp.ndarray] = None, name: str = None, ): """Initializes the module. Args: num_spatial_dims: The number of spatial dimensions of the input. output_channels: Number of output channels. kernel_shape: The shape of the kernel. Either an integer or a sequence of length ``num_spatial_dims``. stride: Optional stride for the kernel. Either an integer or a sequence of length ``num_spatial_dims``. Defaults to 1. padding: Optional padding algorithm. Either "VALID" or "SAME". Defaults to "SAME". See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution. with_bias: Whether to add a bias. By default, true. w_init: Optional weight initialization. By default, truncated normal. b_init: Optional bias initialization. By default, zeros. data_format: The data format of the input. Can be either ``channels_first``, ``channels_last``, ``N...C`` or ``NC...``. By default, ``channels_last``. mask: Optional mask of the weights. name: The name of the module. """ super().__init__(name=name) if num_spatial_dims <= 0: raise ValueError( "We only support convolution operations for `num_spatial_dims` " f"greater than 0, received num_spatial_dims={num_spatial_dims}." ) self.num_spatial_dims = num_spatial_dims self.output_channels = output_channels self.kernel_shape = (utils.replicate(kernel_shape, num_spatial_dims, "kernel_shape")) self.with_bias = with_bias self.stride = utils.replicate(stride, num_spatial_dims, "strides") self.w_init = w_init self.b_init = b_init or jnp.zeros self.mask = mask # TODO(tomhennigan) Make use of hk.pad.create here? self.padding = padding self.data_format = data_format self.channel_index = utils.get_channel_index(data_format) self.dimension_numbers = to_dimension_numbers( num_spatial_dims, channels_last=(self.channel_index == -1), transpose=True)
def __call__(self, inputs): channel_index = utils.get_channel_index(self._data_format) weight_shape = self._kernel_shape + (1, self._channel_multiplier * inputs.shape[channel_index]) fan_in_shape = np.prod(weight_shape[:-1]) stddev = 1. / np.sqrt(fan_in_shape) w_init = self._w_init or initializers.TruncatedNormal(stddev=stddev) w = base.get_parameter("w", weight_shape, inputs.dtype, init=w_init) if self._channel_index == -1: dn = DIMENSION_NUMBERS[self._num_spatial_dims] else: dn = DIMENSION_NUMBERS_NCSPATIAL[self._num_spatial_dims] result = lax.conv_general_dilated( inputs, w, self._stride, self._padding, self._lhs_dilation, self._rhs_dilation, dn, feature_group_count=inputs.shape[channel_index]) if self._with_bias: if channel_index == -1: bias_shape = (self._channel_multiplier * inputs.shape[channel_index], ) else: bias_shape = (self._channel_multiplier * inputs.shape[channel_index], 1, 1) b = base.get_parameter("b", bias_shape, init=self._b_init) result = result + b return result
def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: channel_index = utils.get_channel_index(self.data_format) w_shape = self.kernel_shape + (1, self.channel_multiplier * inputs.shape[channel_index]) w_init = self.w_init if w_init is None: fan_in_shape = np.prod(w_shape[:-1]) stddev = 1. / np.sqrt(fan_in_shape) w_init = hk.initializers.TruncatedNormal(stddev=stddev) w = hk.get_parameter("w", w_shape, inputs.dtype, init=w_init) out = lax.conv_general_dilated( inputs, w, window_strides=self.stride, padding=self.padding, lhs_dilation=self.lhs_dilation, rhs_dilation=self.rhs_dilation, dimension_numbers=self.dn, feature_group_count=inputs.shape[channel_index]) if self.with_bias: if channel_index == -1: b_shape = (self.channel_multiplier * inputs.shape[channel_index], ) else: b_shape = (self.channel_multiplier * inputs.shape[channel_index], 1, 1) b = hk.get_parameter("b", b_shape, init=self.b_init) b = jnp.broadcast_to(b, out.shape) out = out + b return out
def __init__(self, create_scale, create_offset, decay_rate, eps=1e-5, scale_init=None, offset_init=None, axis=None, cross_replica_axis=None, data_format="channels_last", name=None): """Constructs a BatchNorm module. Args: create_scale: Whether to include a trainable scaling factor. create_offset: Whether to include a trainable offset. decay_rate: Decay rate for EMA. eps: Small epsilon to avoid division by zero variance. Defaults 1e-5, as in the paper and Sonnet. scale_init: Optional initializer for gain (aka scale). Can only be set if `create_scale=True`. By default, one. offset_init: Optional initializer for bias (aka offset). Can only be set if `create_offset=True`. By default, zero. axis: Which axes to reduce over. The default (None) signifies that all but the channel axis should be normalized. Otherwise this is a list of axis indices which will have normalization statistics calculated. cross_replica_axis: If not None, it should be a string representing the axis name over which this module is being run within a jax.pmap. Supplying this argument means that batch statistics are calculated across all replicas on that axis. data_format: The data format of the input. Can be either `channels_first`, `channels_last`, `N...C` or `NC...`. By default it is `channels_last`. name: The module name. """ super(BatchNorm, self).__init__(name=name) self._create_scale = create_scale self._create_offset = create_offset if not self._create_scale and scale_init is not None: raise ValueError("Cannot set `scale_init` if `create_scale=False`") self._scale_init = scale_init or jnp.ones if not self._create_offset and offset_init is not None: raise ValueError( "Cannot set `offset_init` if `create_offset=False`") self._offset_init = offset_init or jnp.zeros self._eps = eps self._cross_replica_axis = cross_replica_axis self._data_format = data_format self._channel_index = utils.get_channel_index(data_format) self._axis = axis self._mean_ema = moving_averages.ExponentialMovingAverage( decay_rate, name="mean_ema") self._var_ema = moving_averages.ExponentialMovingAverage( decay_rate, name="var_ema")
def __init__(self, num_spatial_dims, output_channels, kernel_shape, stride=1, padding="SAME", with_bias=True, w_init=None, b_init=None, data_format="channels_last", mask=None, name=None): """Initializes a Conv2DTranspose module. Args: num_spatial_dims: The number of spatial dimensions of the input. output_channels: Number of output channels. kernel_shape: The shape of the kernel. Either an integer or a sequence of length `num_spatial_dims`. stride: Optional stride for the kernel. Either an integer or a sequence of length `num_spatial_dims`. Defaults to 1. padding: Optional padding algorithm. Either "VALID" or "SAME". Defaults to "SAME". See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution. with_bias: Whether to add a bias. By default, true. w_init: Optional weight initialization. By default, truncated normal. b_init: Optional bias initialization. By default, zeros. data_format: The data format of the input. Can be either `channels_first`, `channels_last`, `N...C` or `NC...`. By default, `channels_last`. mask: Optional mask of the weights. name: The name of the module. """ super(ConvNDTranspose, self).__init__(name=name) if not 1 <= num_spatial_dims <= 3: raise ValueError( "We only support convolution operations for num_spatial_dims=1, 2 or " "3, received num_spatial_dims={}.".format(num_spatial_dims)) self._num_spatial_dims = num_spatial_dims self._output_channels = output_channels self._kernel_shape = utils.replicate(kernel_shape, num_spatial_dims, "kernel_shape") self._with_bias = with_bias self._stride = utils.replicate(stride, num_spatial_dims, "strides") self._w_init = w_init self._b_init = b_init or jnp.zeros self._mask = mask self._padding = padding self._data_format = data_format self._channel_index = utils.get_channel_index(data_format) if self._channel_index == -1: self._dn = DIMENSION_NUMBERS[self._num_spatial_dims] else: self._dn = DIMENSION_NUMBERS_NCSPATIAL[self._num_spatial_dims]
def __init__( self, create_scale: bool, create_offset: bool, decay_rate: float, eps: float = 1e-5, scale_init: Optional[hk.Initializer] = None, offset_init: Optional[hk.Initializer] = None, axis: Optional[Sequence[int]] = None, cross_replica_axis: Optional[str] = None, data_format: str = "channels_last", name: Optional[str] = None, ): """Constructs a BatchNorm module. Args: create_scale: Whether to include a trainable scaling factor. create_offset: Whether to include a trainable offset. decay_rate: Decay rate for EMA. eps: Small epsilon to avoid division by zero variance. Defaults ``1e-5``, as in the paper and Sonnet. scale_init: Optional initializer for gain (aka scale). Can only be set if ``create_scale=True``. By default, ``1``. offset_init: Optional initializer for bias (aka offset). Can only be set if ``create_offset=True``. By default, ``0``. axis: Which axes to reduce over. The default (``None``) signifies that all but the channel axis should be normalized. Otherwise this is a list of axis indices which will have normalization statistics calculated. cross_replica_axis: If not ``None``, it should be a string representing the axis name over which this module is being run within a ``jax.pmap``. Supplying this argument means that batch statistics are calculated across all replicas on that axis. data_format: The data format of the input. Can be either ``channels_first``, ``channels_last``, ``N...C`` or ``NC...``. By default it is ``channels_last``. name: The module name. """ super().__init__(name=name) if not create_scale and scale_init is not None: raise ValueError("Cannot set `scale_init` if `create_scale=False`") if not create_offset and offset_init is not None: raise ValueError( "Cannot set `offset_init` if `create_offset=False`") self.create_scale = create_scale self.create_offset = create_offset self.eps = eps self.scale_init = scale_init or jnp.ones self.offset_init = offset_init or jnp.zeros self.axis = axis self.cross_replica_axis = cross_replica_axis self.channel_index = utils.get_channel_index(data_format) self.mean_ema = hk.ExponentialMovingAverage(decay_rate, name="mean_ema") self.var_ema = hk.ExponentialMovingAverage(decay_rate, name="var_ema")
def __init__( self, channel_multiplier: int, kernel_shape: Union[int, Sequence[int]], stride: Union[int, Sequence[int]] = 1, padding: Union[str, Sequence[Tuple[int, int]]] = "SAME", with_bias: bool = True, w_init: Optional[hk.initializers.Initializer] = None, b_init: Optional[hk.initializers.Initializer] = None, data_format: str = "NHWC", name: Optional[str] = None, ): """Construct a 2D Depthwise Convolution. Args: channel_multiplier: Multiplicity of output channels. To keep the number of output channels the same as the number of input channels, set 1. kernel_shape: The shape of the kernel. Either an integer or a sequence of length ``num_spatial_dims``. stride: Optional stride for the kernel. Either an integer or a sequence of length ``num_spatial_dims``. Defaults to 1. padding: Optional padding algorithm. Either ``VALID``, ``SAME`` or a sequence of ``before, after`` pairs. Defaults to ``SAME``. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution. with_bias: Whether to add a bias. By default, true. w_init: Optional weight initialization. By default, truncated normal. b_init: Optional bias initialization. By default, zeros. data_format: The data format of the input. Can be either ``channels_first``, ``channels_last``, ``N...C`` or ``NC...``. By default, ``channels_last``. name: The name of the module. """ super().__init__(name=name) self.kernel_shape = utils.replicate(kernel_shape, 2, "kernel_shape") self.lhs_dilation = (1, ) * len(self.kernel_shape) self.rhs_dilation = (1, ) * len(self.kernel_shape) self.channel_multiplier = channel_multiplier self.padding = padding self.stride = utils.replicate(stride, 2, "strides") self.data_format = data_format self.channel_index = utils.get_channel_index(data_format) self.with_bias = with_bias self.w_init = w_init self.b_init = b_init or jnp.zeros self.num_spatial_dims = 2 if self.channel_index == -1: self.dn = DIMENSION_NUMBERS[self.num_spatial_dims] else: self.dn = DIMENSION_NUMBERS_NCSPATIAL[self.num_spatial_dims]
def __init__(self, channel_multiplier, kernel_shape, stride=1, padding="SAME", with_bias=True, w_init=None, b_init=None, data_format="NHWC", name=None): """Construct a 2D Depthwise Convolution. Args: channel_multiplier: Multiplicity of output channels. To keep the number of output channels the same as the number of input channels, set 1. kernel_shape: The shape of the kernel. Either an integer or a sequence of length `num_spatial_dims`. stride: Optional stride for the kernel. Either an integer or a sequence of length `num_spatial_dims`. Defaults to 1. padding: Optional padding algorithm. Either "VALID" or "SAME" or a callable or sequence of callables of size `num_spatial_dims`. Any callables must take a single integer argument equal to the effective kernel size and return a list of two integers representing the padding before and after. See haiku.pad.* for more details and example functions. Defaults to "SAME". See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution. with_bias: Whether to add a bias. By default, true. w_init: Optional weight initialization. By default, truncated normal. b_init: Optional bias initialization. By default, zeros. data_format: The data format of the input. Can be either `channels_first`, `channels_last`, `N...C` or `NC...`. By default, `channels_last`. name: The name of the module. """ super(DepthwiseConv2D, self).__init__(name=name) self._kernel_shape = utils.replicate(kernel_shape, 2, "kernel_shape") self._lhs_dilation = (1, ) * len(self._kernel_shape) self._rhs_dilation = (1, ) * len(self._kernel_shape) self._channel_multiplier = channel_multiplier self._padding = padding self._stride = utils.replicate(stride, 2, "strides") self._data_format = data_format self._channel_index = utils.get_channel_index(data_format) self._with_bias = with_bias self._w_init = w_init self._b_init = b_init or jnp.zeros self._num_spatial_dims = 2
def __init__( self, create_scale: bool = True, create_offset: bool = True, eps: float = 1e-5, scale_init: Optional[initializers.Initializer] = None, offset_init: Optional[initializers.Initializer] = None, data_format: str = "channels_last", **kwargs ): """Constructs an `InstanceNormalization` module. This method creates a module which normalizes over the spatial dimensions. Args: create_scale: ``bool`` representing whether to create a trainable scale per channel applied after the normalization. create_offset: ``bool`` representing whether to create a trainable offset per channel applied after normalization and scaling. eps: Small epsilon to avoid division by zero variance. Defaults to ``1e-5``. scale_init: Optional initializer for the scale variable. Can only be set if ``create_scale=True``. By default scale is initialized to ``1``. offset_init: Optional initializer for the offset variable. Can only be set if ``create_offset=True``. By default offset is initialized to ``0``. data_format: The data format of the input. Can be either ``channels_first``, ``channels_last``, ``N...C`` or ``NC...``. By default it is ``channels_last``. kwargs: Additional keyword arguments passed to Module. """ if hk_utils.get_channel_index(data_format) == 1: axis = slice(2, None) else: # channel_index = -1 axis = slice(1, -1) super().__init__( axis=axis, create_scale=create_scale, create_offset=create_offset, eps=eps, scale_init=scale_init, offset_init=offset_init, **kwargs )
def __init__(self, create_scale, create_offset, eps=1e-5, scale_init=None, offset_init=None, data_format="channels_last", name=None): """Constructs an `InstanceNorm` module. This method creates a module which normalizes over the spatial dimensions. Args: create_scale: `bool` representing whether to create a trainable scale per channel applied after the normalization. create_offset: `bool` representing whether to create a trainable offset per channel applied after normalization and scaling. eps: Small epsilon to avoid division by zero variance. Defaults to `1e-5`. scale_init: Optional initializer for the scale variable. Can only be set if `create_scale=True`. By default scale is initialized to `1`. offset_init: Optional initializer for the offset variable. Can only be set if `create_offset=True`. By default offset is initialized to `0`. data_format: The data format of the input. Can be either `channels_first`, `channels_last`, `N...C` or `NC...`. By default it is `channels_last`. name: Name of the module. """ if utils.get_channel_index(data_format) == 1: axis = slice(2, None) else: # channel_index = -1 axis = slice(1, -1) super(InstanceNorm, self).__init__( axis=axis, create_scale=create_scale, create_offset=create_offset, eps=eps, scale_init=scale_init, offset_init=offset_init, name=name)
def test_invalid_strings(self, data_format): with self.assertRaisesRegex( ValueError, "Unable to extract channel information from '{}'.".format( data_format)): utils.get_channel_index(data_format)
def test_returns_index_channels_last(self, data_format): self.assertEqual(utils.get_channel_index(data_format), -1)
def __init__(self, num_spatial_dims, output_channels, kernel_shape, stride=1, rate=1, padding="SAME", with_bias=True, w_init=None, b_init=None, data_format="channels_last", mask=None, name=None): """Constructs a `ConvND` module. Args: num_spatial_dims: The number of spatial dimensions of the input. output_channels: Number of output channels. kernel_shape: The shape of the kernel. Either an integer or a sequence of length `num_spatial_dims`. stride: Optional stride for the kernel. Either an integer or a sequence of length `num_spatial_dims`. Defaults to 1. rate: Optional kernel dilation rate. Either an integer or a sequence of length `num_spatial_dims`. 1 corresponds to standard ND convolution, `rate > 1` corresponds to dilated convolution. Defaults to 1. padding: Optional padding algorithm. Either "VALID" or "SAME" or a callable or sequence of callables of size `num_spatial_dims`. Any callables must take a single integer argument equal to the effective kernel size and return a list of two integers representing the padding before and after. See haiku.pad.* for more details and example functions. Defaults to "SAME". See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution. with_bias: Whether to add a bias. By default, true. w_init: Optional weight initialization. By default, truncated normal. b_init: Optional bias initialization. By default, zeros. data_format: The data format of the input. Can be either `channels_first`, `channels_last`, `N...C` or `NC...`. By default, `channels_last`. mask: Optional mask of the weights. name: The name of the module. """ super(ConvND, self).__init__(name=name) if not 1 <= num_spatial_dims <= 3: raise ValueError( "We only support convolution operations for num_spatial_dims=1, 2 or " "3, received num_spatial_dims={}.".format(num_spatial_dims)) self._num_spatial_dims = num_spatial_dims self._output_channels = output_channels self._kernel_shape = utils.replicate(kernel_shape, num_spatial_dims, "kernel_shape") self._with_bias = with_bias self._stride = utils.replicate(stride, num_spatial_dims, "strides") self._w_init = w_init self._b_init = b_init or jnp.zeros self._mask = mask self._lhs_dilation = utils.replicate(1, num_spatial_dims, "lhs_dilation") self._kernel_dilation = utils.replicate(rate, num_spatial_dims, "kernel_dilation") self._data_format = data_format self._channel_index = utils.get_channel_index(data_format) if self._channel_index == -1: self._dn = DIMENSION_NUMBERS[self._num_spatial_dims] else: self._dn = DIMENSION_NUMBERS_NCSPATIAL[self._num_spatial_dims] if isinstance(padding, str): self._padding = padding.upper() else: self._padding = pad.create( padding=padding, kernel=self._kernel_shape, rate=self._kernel_dilation, n=self._num_spatial_dims)
def __init__( self, num_spatial_dims: int, output_channels: int, kernel_shape: tp.Union[int, tp.Sequence[int]], stride: tp.Union[int, tp.Sequence[int]] = 1, rate: tp.Union[int, tp.Sequence[int]] = 1, padding: tp.Union[str, tp.Sequence[tp.Tuple[int, int]], types.PadFnOrFns] = "SAME", with_bias: bool = True, w_init: tp.Optional[types.Initializer] = None, b_init: tp.Optional[types.Initializer] = None, data_format: str = "channels_last", mask: tp.Optional[np.ndarray] = None, groups: int = 1, **kwargs, ): """ Initializes the module. Args: num_spatial_dims: The number of spatial dimensions of the input. output_channels: Number of output channels. kernel_shape: The shape of the kernel. Either an integer or a sequence of length ``num_spatial_dims``. stride: tp.Optional stride for the kernel. Either an integer or a sequence of length ``num_spatial_dims``. Defaults to 1. rate: tp.Optional kernel dilation rate. Either an integer or a sequence of length ``num_spatial_dims``. 1 corresponds to standard ND convolution, ``rate > 1`` corresponds to dilated convolution. Defaults to 1. padding: tp.Optional padding algorithm. Either ``VALID`` or ``SAME`` or a sequence of n ``(low, high)`` integer pairs that give the padding to apply before and after each spatial dimension. or a callable or sequence of callables of size ``num_spatial_dims``. Any callables must take a single integer argument equal to the effective kernel size and return a sequence of two integers representing the padding before and after. See ``haiku.pad.*`` for more details and example functions. Defaults to ``SAME``. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution. with_bias: Whether to add a bias. By default, true. w_init: tp.Optional weight initialization. By default, truncated normal. b_init: tp.Optional bias initialization. By default, zeros. data_format: The data format of the input. Can be either ``channels_first``, ``channels_last``, ``N...C`` or ``NC...``. By default, ``channels_last``. mask: tp.Optional mask of the weights. groups: A positive integer specifying the number of groups in which the input is split along the channel axis. Each group is convolved separately with filters / groups filters. The output is the concatenation of all the groups results along the channel axis. Input channels and filters must both be divisible by groups. kwargs: Additional keyword arguments passed to Module. """ super().__init__(**kwargs) if num_spatial_dims <= 0: raise ValueError( "We only support convolution operations for `num_spatial_dims` " f"greater than 0, received num_spatial_dims={num_spatial_dims}." ) self.num_spatial_dims = num_spatial_dims self.output_channels = output_channels self.kernel_shape = hk_utils.replicate(kernel_shape, num_spatial_dims, "kernel_shape") self.with_bias = with_bias self.stride = hk_utils.replicate(stride, num_spatial_dims, "strides") self.w_init = w_init self.b_init = b_init or jnp.zeros self.mask = mask self.lhs_dilation = hk_utils.replicate(1, num_spatial_dims, "lhs_dilation") self.kernel_dilation = hk_utils.replicate(rate, num_spatial_dims, "kernel_dilation") self.data_format = data_format self.channel_index = hk_utils.get_channel_index(data_format) self.dimension_numbers = to_dimension_numbers( num_spatial_dims, channels_last=(self.channel_index == -1), transpose=False) self.groups = groups if isinstance(padding, str): self.padding = padding.upper() else: self.padding = hk.pad.create( padding=padding, kernel=self.kernel_shape, rate=self.kernel_dilation, n=self.num_spatial_dims, )
def __init__( self, output_channels: int, uniform_init_minval: float, uniform_init_maxval: float, kernel_shape: Union[int, Sequence[int]], num_spatial_dims: int = 2, stride: Union[int, Sequence[int]] = 1, rate: Union[int, Sequence[int]] = 1, padding: Union[str, Sequence[Tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn]] = "SAME", with_bias: bool = True, w_init: Union[Optional[hk.initializers.Initializer], str] = "uniform", b_init: Union[Optional[hk.initializers.Initializer], str] = "uniform", data_format: str = "channels_last", mask: Optional[jnp.ndarray] = None, feature_group_count: int = 1, name: Optional[str] = None, stochastic_parameters: bool = False, ): """Initializes the module. Args: output_channels: Number of output channels. uniform_init_minval: TODO(nband). uniform_init_maxval: TODO(nband). kernel_shape: The shape of the kernel. Either an integer or a sequence of length ``num_spatial_dims``. num_spatial_dims: The number of spatial dimensions of the input. stride: Optional stride for the kernel. Either an integer or a sequence of length ``num_spatial_dims``. Defaults to 1. rate: Optional kernel dilation rate. Either an integer or a sequence of length ``num_spatial_dims``. 1 corresponds to standard ND convolution, ``rate > 1`` corresponds to dilated convolution. Defaults to 1. padding: Optional padding algorithm. Either ``VALID`` or ``SAME`` or a sequence of n ``(low, high)`` integer pairs that give the padding to apply before and after each spatial dimension. or a callable or sequence of callables of size ``num_spatial_dims``. Any callables must take a single integer argument equal to the effective kernel size and return a sequence of two integers representing the padding before and after. See ``haiku.pad.*`` for more details and example functions. Defaults to ``SAME``. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution. with_bias: Whether to add a bias. By default, true. w_init: Optional weight initialization. By default, truncated normal. b_init: Optional bias initialization. By default, zeros. data_format: The data format of the input. Can be either ``channels_first``, ``channels_last``, ``N...C`` or ``NC...``. By default, ``channels_last``. mask: Optional mask of the weights. feature_group_count: Optional number of groups in group convolution. Default value of 1 corresponds to normal dense convolution. If a higher value is used, convolutions are applied separately to that many groups, then stacked together. This reduces the number of parameters and possibly the compute for a given ``output_channels``. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution. name: The name of the module. stochastic_parameters: TODO(nband). """ super().__init__(name=name) if num_spatial_dims <= 0: raise ValueError( "We only support convolution operations for `num_spatial_dims` " f"greater than 0, received num_spatial_dims={num_spatial_dims}." ) self.num_spatial_dims = num_spatial_dims self.output_channels = output_channels self.kernel_shape = utils.replicate(kernel_shape, num_spatial_dims, "kernel_shape") self.with_bias = with_bias self.stride = utils.replicate(stride, num_spatial_dims, "strides") self.w_init = w_init self.b_init = b_init self.uniform_init_minval = uniform_init_minval self.uniform_init_maxval = uniform_init_maxval self.mask = mask self.feature_group_count = feature_group_count self.lhs_dilation = utils.replicate(1, num_spatial_dims, "lhs_dilation") self.kernel_dilation = utils.replicate(rate, num_spatial_dims, "kernel_dilation") self.data_format = data_format self.channel_index = utils.get_channel_index(data_format) self.dimension_numbers = to_dimension_numbers( num_spatial_dims, channels_last=(self.channel_index == -1), transpose=False) self.stochastic_parameters = stochastic_parameters if isinstance(padding, str): self.padding = padding.upper() else: self.padding = hk.pad.create( padding=padding, kernel=self.kernel_shape, rate=self.kernel_dilation, n=self.num_spatial_dims, )
def __init__( self, groups: int, axis: Union[int, slice, Sequence[int]] = slice(1, None), create_scale: bool = True, create_offset: bool = True, eps: float = 1e-5, scale_init: Optional[hk.initializers.Initializer] = None, offset_init: Optional[hk.initializers.Initializer] = None, data_format: str = "channels_last", name: Optional[str] = None, ): """Constructs a ``GroupNorm`` module. Args: groups: number of groups to divide the channels by. The number of channels must be divisible by this. axis: ``int``, ``slice`` or sequence of ints representing the axes which should be normalized across. By default this is all but the first dimension. For time series data use `slice(2, None)` to average over the none Batch and Time data. create_scale: whether to create a trainable scale per channel applied after the normalization. create_offset: whether to create a trainable offset per channel applied after normalization and scaling. eps: Small epsilon to add to the variance to avoid division by zero. Defaults to ``1e-5``. scale_init: Optional initializer for the scale parameter. Can only be set if ``create_scale=True``. By default scale is initialized to ``1``. offset_init: Optional initializer for the offset parameter. Can only be set if ``create_offset=True``. By default offset is initialized to ``0``. data_format: The data format of the input. Can be either ``channels_first``, ``channels_last``, ``N...C`` or ``NC...``. By default it is ``channels_last``. name: Name of the module. """ super().__init__(name=name) if isinstance(axis, slice): self.axis = axis elif isinstance(axis, int): self.axis = (axis, ) elif (isinstance(axis, collections.abc.Iterable) and all(isinstance(ax, int) for ax in axis)): self.axis = axis else: raise ValueError( "`axis` should be an int, slice or iterable of ints.") self.groups = groups self.eps = eps self.data_format = data_format self.channel_index = utils.get_channel_index(data_format) self.create_scale = create_scale self.create_offset = create_offset self.rank = None if self.create_scale: if scale_init is None: scale_init = jnp.ones self.scale_init = scale_init elif scale_init is not None: raise ValueError( "Cannot set `scale_init` if `create_scale=False`.") if self.create_offset: if offset_init is None: offset_init = jnp.zeros self.offset_init = offset_init elif offset_init is not None: raise ValueError( "Cannot set `offset_init` if `create_offset=False`.")