Ejemplo n.º 1
0
    def update_site(self, inputs: Array, index: int) -> Array:
        """
        Adds an input site into the cache, and applies the masked convolution to the cache.

        Args:
          inputs: an input site to be added into the cache with dimensions (batch, features).
          index: the index of the output site. The index of the input site should be `index - self.exclusive`.

        Returns:
          The next output site with dimensions (batch, features).
        """
        dtype = jnp.promote_types(inputs.dtype, self.dtype)

        inputs = jnp.asarray(inputs, dtype)

        kernel_size = self.kernel_size - self.exclusive
        dilation = self.kernel_dilation

        is_single_input = False
        if inputs.ndim == 1:
            is_single_input = True
            inputs = jnp.expand_dims(inputs, axis=0)

        batch, in_features = inputs.shape
        assert in_features % self.feature_group_count == 0
        cache_size = kernel_size * dilation - (not self.exclusive) * (
            dilation - 1)

        # Initialize the cache with zeros, and the RNG key is None
        # `cache.dtype` must be the same as `inputs.dtype` (no promotion)
        _cache = self.variable(
            "cache",
            "inputs",
            zeros,
            None,
            (batch, cache_size, in_features),
            inputs.dtype,
        )

        initializing = self.is_mutable_collection("params")
        if not initializing:
            # Add the input site into the cache
            # To write the cache, use `_cache.value` as the left value of the assignment
            _cache.value = lax.cond(
                index - self.exclusive >= 0,
                lambda _: jnp.concatenate(
                    [_cache.value[:, 1:, :],
                     jnp.expand_dims(inputs, axis=1)],
                    axis=1),
                lambda _: _cache.value,
                None,
            )

        cache = _cache.value
        cache = jnp.asarray(cache, dtype)

        kernel_shape = (
            kernel_size,
            in_features // self.feature_group_count,
            self.features,
        )
        kernel = self.param("kernel", self.kernel_init, kernel_shape,
                            self.dtype)
        kernel = jnp.asarray(kernel, dtype)

        if self.exclusive and dilation > 1:
            cache = cache[:, :-(dilation - 1), :]

        dimension_numbers = flax.linen.linear._conv_dimension_numbers(
            cache.shape)
        y_i = lax.conv_general_dilated(
            cache,
            kernel,
            window_strides=(1, ),
            padding="VALID",
            lhs_dilation=(1, ),
            rhs_dilation=(dilation, ),
            dimension_numbers=dimension_numbers,
            feature_group_count=self.feature_group_count,
            precision=self.precision,
        )

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features, ),
                              self.dtype)
            bias = jnp.asarray(bias, dtype)
            y_i = y_i + bias

        y_i = y_i.squeeze(axis=1)

        if is_single_input:
            y_i = y_i.squeeze(axis=0)

        return y_i
Ejemplo n.º 2
0
 _make_harness("clamp",
               "",
               lax.clamp, [
                   RandArg((3, 4, 5), _f32),
                   RandArg((3, 4, 5), _f32),
                   RandArg((3, 4, 5), _f32)
               ],
               poly_axes=[0, 0, 0]),
 _make_harness("conv_general_dilated",
               "",
               lambda lhs, rhs: lax.conv_general_dilated(
                   lhs,
                   rhs,
                   window_strides=(2, 3),
                   padding=((0, 0), (0, 0)),
                   lhs_dilation=(1, 1),
                   rhs_dilation=(1, 2),
                   dimension_numbers=("NCHW", "OIHW", "NCHW"),
                   feature_group_count=1,
                   batch_group_count=1,
                   precision=None),
               [RandArg((7, 3, 9, 10), _f32),
                RandArg((3, 3, 4, 5), _f32)],
               poly_axes=[0, None]),
 _make_harness("cummax",
               "",
               lambda x: lax_control_flow.cummax(x, axis=1, reverse=False),
               [RandArg((3, 4, 5), _f32)],
               poly_axes=[0]),
 _make_harness(
     "dot_general",
Ejemplo n.º 3
0
def _extract_image_patches(
    lhs: np.ndarray,
    filter_shape: Sequence[int],
    window_strides: Sequence[int],
    padding: str,
    lhs_dilation: Sequence[int] = None,
    rhs_dilation: Sequence[int] = None,
    dimension_numbers: lax.ConvDimensionNumbers = None,
    precision: lax.Precision = None,
) -> np.ndarray:
    """Extract patches subject to the receptive field of a general convolution.

  Runs the input through a convolution that packs input spatial and channel
  entries into output channel `"C"` entries. The order of dimensions packed is
  `"C" + ''.join(c for c in rhs_spec if c not in 'OI')`, where
  `rhs_spec == dimension_numbers[1]`.

  Docstring below adapted from `jax.lax.conv_general_dilated`.

  See Also:
    https://www.tensorflow.org/xla/operation_semantics#conv_convolution

  Args:
    lhs: a rank `n+2` dimensional input array.
    filter_shape: a sequence of `n` integers, representing the receptive window
      spatial shape in the order as specified in
      `rhs_spec = dimension_numbers[1]`.
    window_strides: a sequence of `n` integers, representing the inter-window
      strides.
    padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of
      `n` `(low, high)` integer pairs that give the padding to apply before and
      after each spatial dimension.
    lhs_dilation: `None`, or a sequence of `n` integers, giving the
      dilation factor to apply in each spatial dimension of `lhs`. LHS dilation
      is also known as transposed convolution.
    rhs_dilation: `None`, or a sequence of `n` integers, giving the
      dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
      is also known as atrous convolution.
    dimension_numbers: either `None`, a `ConvDimensionNumbers` object, or
      a 3-tuple `(lhs_spec, rhs_spec, out_spec)`, where each element is a string
      of length `n+2`.
    precision: Optional. Either ``None``, which means the default precision for
      the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
      ``Precision.HIGH`` or ``Precision.HIGHEST``).

  Returns:
    An array containing the image patches flattened inside the `"C"` output
    dimension. The size of this dimension is `C_input * onp.prod(filter_shape)`.

  In the string case of `dimension_numbers`, each character identifies by
  position:

  - the batch dimensions in `lhs`, `rhs`, and the output with the character
    'N',
  - the feature dimensions in `lhs` and the output with the character 'C',
  - the input and output feature dimensions in rhs with the characters 'I'
    and 'O' respectively, and
  - spatial dimension correspondences between lhs, rhs, and the output using
    any distinct characters.

  For example, to indicate dimension numbers consistent with the `conv` function
  with two spatial dimensions, one could use `('NCHW', 'OIHW', 'NCHW')`. As
  another example, to indicate dimension numbers consistent with the TensorFlow
  Conv2D operation, one could use `('NHWC', 'HWIO', 'NHWC')`. When using the
  latter form of convolution dimension specification, window strides are
  associated with spatial dimension character labels according to the order in
  which the labels appear in the `rhs_spec` string, so that `window_strides[0]`
  is matched with the dimension corresponding to the first character
  appearing in rhs_spec that is not `'I'` or `'O'`.

  If `dimension_numbers` is `None`, the default is `('NCHW', 'OIHW', 'NCHW')`
  (for a 2D convolution).
  """
    lhs_spec, rhs_spec, out_spec = dimension_numbers

    filter_shape = tuple(filter_shape)
    spatial_size = onp.prod(filter_shape)
    n_channels = lhs.shape[lhs_spec.index('C')]

    # Move separate `lhs` spatial locations into separate `rhs` channels.
    rhs = np.eye(spatial_size, dtype=lhs.dtype).reshape(filter_shape * 2)

    rhs = rhs.reshape((spatial_size, 1) + filter_shape)
    rhs = np.tile(rhs, (n_channels, ) + (1, ) * (rhs.ndim - 1))
    rhs = np.moveaxis(rhs, (0, 1), (rhs_spec.index('O'), rhs_spec.index('I')))

    out = lax.conv_general_dilated(lhs=lhs,
                                   rhs=rhs,
                                   window_strides=window_strides,
                                   padding=padding,
                                   lhs_dilation=lhs_dilation,
                                   rhs_dilation=rhs_dilation,
                                   dimension_numbers=dimension_numbers,
                                   precision=precision,
                                   feature_group_count=n_channels)
    return out
Ejemplo n.º 4
0
def convNd(
    input,
    filter,
    strides=1,
    padding="VALID",
    input_format=None,
    filter_format=None,
    output_format=None,
    input_dilation=None,
    filter_dilation=None,
):
    """General n-dimensional convolution operator, with input/filter dilation.

    Wraps Jax's conv_general_dilated functin, and thus also the XLA's `Conv
    <https://www.tensorflow.org/xla/operation_semantics#conv_convolution>`_
    operator.

    Args:
        input (Tensor): a rank `n+2` dimensional input array.
        filter (Tensor): a rank `n+2` dimensional array of kernel weights.
        strides (int, sequence of int, optional): a (sequence) of `n` integers,
            representing the inter-window strides. If a scalar is given, it is
            used `n` times. Defaults to `1`.
        padding (sequence of couple, `'SAME'`, `'VALID'`, optional): a sequence of
            `n` `(low, high)` integer pairs that give the padding to apply
            before and after each spatial dimension. For  `'VALID'`, those are
            `0`. For `'SAME'`, they are the `input length - filter length + 1`
            for each dim. Defaults to `'Valid'`.
        input_format (`None` or str, optional): a string of same length as the
            number of dimensions in `input` which specify their role
            (see below). Defaults to `'NCW'` for 1d conv, `'NCHW'` for 2d conv,
             and `'NDCHW'` for 3d conv.
        input_dilation (`None`, int or sequence of int, optional): giving the
            dilation factor to apply in each spatial dimension of `input`.
            Inumpy.t dilation is also known as transposed convolution as it allows
            to increase the output spatial dimension by inserting in the input
            any number of `0`s between each spatial value.
        filter_dilation (`None`, int or sequence of int): giving the dilation
            factor to apply in each spatial dimension of `filter`. Filter
            dilation is also known as atrous convolution as it corresponds to
            inserting any number of `0`s in between the filter values, similar
            to performing the non-dilated filter convolution with a subsample
            version of the input across the spatial dimensions.

    Returns:
        Tensor: An array containing the convolution result.

    Format of `input`, `filter` and `output`:
    For example, to indicate dimension numbers consistent with the `conv` function
    with two spatial dimensions, one could use `('NCHW', 'OIHW', 'NCHW')`. As
    another example, to indicate dimension numbers consistent with the TensorFlow
    Conv2D operation, one could use `('NHWC', 'HWIO', 'NHWC')`. When using the
    latter form of convolution dimension specification, window strides are
    associated with spatial dimension character labels according to the order in
    which the labels appear in the `rhs_spec` string, so that `window_strides[0]`
    is matched with the dimension corresponding to the first character
    appearing in rhs_spec that is not `'I'` or `'O'`.
    :param filter_format:
    :param output_format:
    """
    # setting up the strides
    if numpy.isscalar(strides):
        strides = (strides, ) * (input.ndim - 2)
    elif len(strides) != (input.ndim - 2):
        msg = "given strides: {} should match the number".format(
            strides) + "of spatial dim. in input: {}".format(input.ndim - 2)
        raise ValueError(msg)

    # setting up the padding
    if type(padding) != str:
        strides = (strides, ) * (input.ndim - 2)
        if len(padding) != (input.ndim - 2):
            msg = "given padding: {} should match the ".format(
                padding) + "number of spatial dim. in input: {}".format(
                    input.ndim - 2)
            raise ValueError(msg)

    # setting up the filter_format
    if filter_format is None:
        if filter.ndim == 3:
            filter_format = "OIW"
        elif filter.ndim == 4:
            filter_format = "OIHW"
        elif filter.ndim == 5:
            filter_format = "OIDHW"
        else:
            msg = "filter_format should be given for >5 dimensions."
            raise ValueError(msg)
    elif len(filter_format) != filter.ndim:
        msg = "given filter_format: {} should".format(
            len(filter_format)
        ) + "match the number of dimension in filter: {}".format(filter.ndim)
        raise ValueError(msg)

    # setting up the input format
    if input_format is None:
        if len(filter.shape) == 3:
            input_format = "NCW"
        elif len(filter.shape) == 4:
            input_format = "NCHW"
        elif len(filter.shape) == 5:
            input_format = "NCDHW"
        else:
            msg = "input_format should be given for >5 dimensions."
            raise ValueError(msg)
    elif len(input_format) != input.ndim:
        msg = "given input_format: {} should".format(
            len(input_format)
        ) + "match the number of dimension in input: {}".format(input.ndim)
        raise ValueError(msg)

    # setting up the output format
    if output_format is None:
        if len(filter.shape) == 3:
            output_format = "NCW"
        elif len(filter.shape) == 4:
            output_format = "NCHW"
        elif len(filter.shape) == 5:
            output_format = "NCDHW"
        else:
            msg = "output_format should be given for >5 dimensions."
            raise ValueError(msg)
    elif len(output_format) != input.ndim:
        msg = "given output_format: {} should".format(
            len(output_format)
        ) + "match the number of dimension in output: {}".format(input.ndim)
        raise ValueError(msg)

    # setting up dilations
    if numpy.isscalar(input_dilation):
        input_dilation = (input_dilation, ) * 2
    if numpy.isscalar(filter_dilation):
        filter_dilation = (filter_dilation, ) * 2

    specs = (input_format, filter_format, output_format)
    return jla.conv_general_dilated(
        lhs=input,
        rhs=filter,
        window_strides=strides,
        padding=padding,
        lhs_dilation=input_dilation,
        rhs_dilation=filter_dilation,
        dimension_numbers=specs,
        precision=None,
    )
Ejemplo n.º 5
0
    def __call__(
        self,
        inputs: jnp.ndarray,
        *,
        precision: Optional[lax.Precision] = None,
    ) -> jnp.ndarray:
        """Connects ``ConvND`` layer.

    Args:
      inputs: An array of shape ``[spatial_dims, C]`` and rank-N+1 if unbatched,
        or an array of shape ``[N, spatial_dims, C]`` and rank-N+2 if batched.
      precision: Optional :class:`jax.lax.Precision` to pass to
        :func:`jax.lax.conv_general_dilated`.

    Returns:
      An array of shape ``[spatial_dims, output_channels]`` and rank-N+1 if
        unbatched, or an array of shape ``[N, spatial_dims, output_channels]``
        and rank-N+2 if batched.
    """
        unbatched_rank = self.num_spatial_dims + 1
        allowed_ranks = [unbatched_rank, unbatched_rank + 1]
        if inputs.ndim not in allowed_ranks:
            raise ValueError(
                f"Input to ConvND needs to have rank in {allowed_ranks},"
                f" but input has shape {inputs.shape}.")

        unbatched = inputs.ndim == unbatched_rank
        if unbatched:
            inputs = jnp.expand_dims(inputs, axis=0)

        if inputs.shape[self.channel_index] % self.feature_group_count != 0:
            raise ValueError(
                f"Inputs channels {inputs.shape[self.channel_index]} "
                f"should be a multiple of feature_group_count "
                f"{self.feature_group_count}")
        w_shape = self.kernel_shape + (inputs.shape[self.channel_index] //
                                       self.feature_group_count,
                                       self.output_channels)

        if self.mask is not None and self.mask.shape != w_shape:
            raise ValueError("Mask needs to have the same shape as weights. "
                             f"Shapes are: {self.mask.shape}, {w_shape}")

        w_init = self.w_init
        if w_init is None:
            fan_in_shape = np.prod(w_shape[:-1])
            stddev = 1. / np.sqrt(fan_in_shape)
            w_init = hk.initializers.TruncatedNormal(stddev=stddev)
        w = hk.get_parameter("w", w_shape, inputs.dtype, init=w_init)

        if self.mask is not None:
            w *= self.mask

        out = lax.conv_general_dilated(
            inputs,
            w,
            window_strides=self.stride,
            padding=self.padding,
            lhs_dilation=self.lhs_dilation,
            rhs_dilation=self.kernel_dilation,
            dimension_numbers=self.dimension_numbers,
            feature_group_count=self.feature_group_count,
            precision=precision)

        if self.with_bias:
            if self.channel_index == -1:
                bias_shape = (self.output_channels, )
            else:
                bias_shape = (
                    self.output_channels, ) + (1, ) * self.num_spatial_dims
            b = hk.get_parameter("b",
                                 bias_shape,
                                 inputs.dtype,
                                 init=self.b_init)
            b = jnp.broadcast_to(b, out.shape)
            out = out + b

        if unbatched:
            out = jnp.squeeze(out, axis=0)
        return out
Ejemplo n.º 6
0
 def f(params, x):
   one = (1, 1)
   dimension_numbers = ('HNWC', 'HWIO', 'HWNC')
   y = lax.conv_general_dilated(
       x, params, one, 'SAME', one, one, dimension_numbers)
   return y
Ejemplo n.º 7
0
"""


# convert to jax array
X_image_jax = jnp.array(X_images_scaled, dtype=jnp.float32)

# define the kernel
kernel = jnp.ones(shape=(3, 3), dtype=jnp.float32)

# better orthogonal kernel

X_image_transform = conv_general_dilated(
    lhs=X_image_jax,  # input
    rhs=kernel[..., None, None],  # kernel
    window_strides=(1, 1),
    padding="SAME",
    lhs_dilation=(1, 1),
    rhs_dilation=(1, 1),
    dimension_numbers=("NHWC", "IOHW", "NHWC"),
)

fig, ax = plt.subplots()
plt.imshow(X_image_transform[0])
ax.set_yticks([])
ax.set_xticks([])
plt.tight_layout()
plt.show()

#%%
"""Invertible???"""
Ejemplo n.º 8
0
 def apply_fun(params, inputs, rng=None):
     W, b = params
     return lax.conv_general_dilated(inputs, W, strides, padding, one, one,
                                     dimension_numbers) + b
Ejemplo n.º 9
0
    def apply(
        self,
        inputs,
        features,
        kernel_size,
        is_first_layer=False,
        strides=None,
        padding="SAME",
        input_dilation=None,
        kernel_dilation=None,
        feature_group_count=1,
        bias=True,
        dtype=jnp.float32,
        precision=None,
        kernel_init=default_kernel_init,
        bias_init=initializers.zeros,
    ):
        """Applies a convolution to the inputs.

    """

        inputs = jnp.asarray(inputs, dtype)

        assert len(kernel_size) == 1, "kernel_shape must be one dimensional"
        assert kernel_size[0] % 2 != 0, "kernel_shape must be odd"

        mask = onp.ones(kernel_size[0])
        if is_first_layer:
            i = (kernel_size[0] - 1) // 2
        else:
            i = (kernel_size[0] + 1) // 2
        mask[i:] = 0
        mask = jnp.asarray(mask[:, onp.newaxis, onp.newaxis], dtype)

        if strides is None:
            strides = (1, ) * (inputs.ndim - 2)

        in_features = inputs.shape[-1]
        assert in_features % feature_group_count == 0
        kernel_shape = kernel_size + (in_features // feature_group_count,
                                      features)
        kernel = self.param("kernel", kernel_shape, kernel_init)
        kernel = jnp.asarray(kernel, dtype)
        kernel = kernel * mask

        dimension_numbers = _conv_dimension_numbers(inputs.shape)
        y = lax.conv_general_dilated(
            inputs,
            kernel,
            strides,
            padding,
            lhs_dilation=input_dilation,
            rhs_dilation=kernel_dilation,
            dimension_numbers=dimension_numbers,
            feature_group_count=feature_group_count,
            precision=precision,
        )

        if bias:
            bias = self.param("bias", (features, ), bias_init)
            bias = jnp.asarray(bias, dtype)
            y = y + bias
        return y
Ejemplo n.º 10
0
    def __call__(self, inputs: Array) -> Array:
        """
        Applies a masked convolution to the inputs.
        For 1D convolution, there is not really a mask. We only need to apply
        appropriate padding.

        Args:
          inputs: input data with dimensions (batch, length, features).

        Returns:
          The convolved data.
        """
        dtype = jnp.promote_types(inputs.dtype, self.dtype)

        inputs = jnp.asarray(inputs, dtype)

        kernel_size = self.kernel_size - self.exclusive
        dilation = self.kernel_dilation

        is_single_input = False
        if inputs.ndim == 2:
            is_single_input = True
            inputs = jnp.expand_dims(inputs, axis=0)

        in_features = inputs.shape[-1]
        assert in_features % self.feature_group_count == 0
        kernel_shape = (
            kernel_size,
            in_features // self.feature_group_count,
            self.features,
        )

        kernel = self.param("kernel", self.kernel_init, kernel_shape, self.dtype)
        kernel = jnp.asarray(kernel, dtype)

        if self.exclusive:
            inputs = inputs[:, :-dilation, :]

        # Zero padding
        y = jnp.pad(
            inputs,
            (
                (0, 0),
                ((kernel_size - (not self.exclusive)) * dilation, 0),
                (0, 0),
            ),
        )

        dimension_numbers = flax.linen.linear._conv_dimension_numbers(inputs.shape)
        y = lax.conv_general_dilated(
            y,
            kernel,
            window_strides=(1,),
            padding="VALID",
            lhs_dilation=(1,),
            rhs_dilation=(dilation,),
            dimension_numbers=dimension_numbers,
            feature_group_count=self.feature_group_count,
            precision=self.precision,
        )

        if is_single_input:
            y = y.squeeze(axis=0)

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features,), self.dtype)
            bias = jnp.asarray(bias, dtype)
            y = y + bias

        return y
Ejemplo n.º 11
0
    def __call__(self, inputs: Array) -> Array:
        """
        Applies a masked convolution to the inputs.

        Args:
          inputs: input data with dimensions (batch, width, height, features).

        Returns:
          The convolved data.
        """
        dtype = jnp.promote_types(inputs.dtype, self.dtype)

        inputs = jnp.asarray(inputs, dtype)

        kernel_h, kernel_w = self.kernel_size
        dilation_h, dilation_w = self.kernel_dilation
        ones = (1, 1)

        is_single_input = False
        if inputs.ndim == 3:
            is_single_input = True
            inputs = jnp.expand_dims(inputs, axis=0)

        in_features = inputs.shape[-1]
        assert in_features % self.feature_group_count == 0
        kernel_shape = self.kernel_size + (
            in_features // self.feature_group_count,
            self.features,
        )

        kernel = self.param(
            "kernel",
            wrap_kernel_init(self.kernel_init, self.mask),
            kernel_shape,
            self.dtype,
        )
        mask = jnp.asarray(self.mask, dtype)
        kernel = jnp.asarray(kernel, dtype)

        # Zero padding
        y = jnp.pad(
            inputs,
            (
                (0, 0),
                ((kernel_h - 1) * dilation_h, 0),
                (kernel_w // 2 * dilation_w, (kernel_w - 1) // 2 * dilation_w),
                (0, 0),
            ),
        )

        dimension_numbers = flax.linen.linear._conv_dimension_numbers(inputs.shape)
        y = lax.conv_general_dilated(
            y,
            mask * kernel,
            window_strides=ones,
            padding="VALID",
            lhs_dilation=ones,
            rhs_dilation=self.kernel_dilation,
            dimension_numbers=dimension_numbers,
            feature_group_count=self.feature_group_count,
            precision=self.precision,
        )

        if is_single_input:
            y = y.squeeze(axis=0)

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features,), self.dtype)
            bias = jnp.asarray(bias, dtype)
            y = y + bias

        return y
Ejemplo n.º 12
0
def conv2d(x, filters, strides, padding, data_format='NHWC', dilations=1):
    strides = [strides]*2 if isinstance(strides, int) else strides
    dilations = [dilations]*2 if isinstance(dilations, int) else dilations
    return _jlax.conv_general_dilated(x, filters, strides, padding, None, dilations, (data_format, 'HWIO', data_format))
Ejemplo n.º 13
0
    def apply(self,
              inputs,
              filters,
              kernel_size,
              block_size,
              strides=None,
              padding='SAME',
              input_dilation=None,
              kernel_dilation=None,
              feature_group_count=1,
              bias=True,
              dtype=jnp.float32,
              precision=None,
              kernel_init=nn.linear.default_kernel_init,
              bias_init=nn.initializers.zeros):
        """Applies a convolution to the inputs.

    Args:
      inputs: input data with dimensions (batch, spatial_dims..., features).
      filters: number of convolution filters.
      kernel_size: shape of the convolutional kernel.
      block_size: shape of space-to-depth blocks.
      strides: a sequence of `n` integers, representing the inter-window
        strides.
      padding: either the string `'SAME'`, the string `'VALID'`, or a sequence
        of `n` `(low, high)` integer pairs that give the padding to apply before
        and after each spatial dimension.
      input_dilation: `None`, or a sequence of `n` integers, giving the
        dilation factor to apply in each spatial dimension of `inputs`.
        Convolution with input dilation `d` is equivalent to transposed
        convolution with stride `d`.
      kernel_dilation: `None`, or a sequence of `n` integers, giving the
        dilation factor to apply in each spatial dimension of the convolution
        kernel. Convolution with kernel dilation is also known as 'atrous
        convolution'.
      feature_group_count: integer, default 1. If specified divides the input
        features into groups.
      bias: whether to add a bias to the output (default: True).
      dtype: the dtype of the computation (default: float32).
      precision: numerical precision of the computation see `jax.lax.Precision`
        for details.
      kernel_init: initializer for the convolutional kernel.
      bias_init: initializer for the bias.
    Returns:
      The convolved data.
    """
        inputs = jnp.asarray(inputs, dtype)

        if strides is None:
            strides = block_size
        assert strides[0] % block_size[0] == 0
        assert strides[1] % block_size[1] == 0
        strides = tuple(s // b for s, b in zip(strides, block_size))

        # create kernel as if there were no space to depth
        batch_size, h, w, features = inputs.shape
        original_input_shape = (batch_size, h * block_size[0],
                                w * block_size[1],
                                features // block_size[0] // block_size[1])
        in_features = original_input_shape[-1]
        assert in_features % feature_group_count == 0
        kernel_shape = kernel_size + (in_features // feature_group_count,
                                      filters)
        kernel = self.param('kernel', kernel_shape, kernel_init)
        kernel = jnp.asarray(kernel, dtype)

        # zero-pad kernel to multiple of block size (e.g. 7x7 --> 8x8)
        h_blocks, h_ragged = divmod(kernel_size[0], block_size[0])
        h_blocks = h_blocks + 1
        if h_ragged != 0:
            kernel = jnp.pad(kernel,
                             pad_width=[[block_size[0] - h_ragged, 0], [0, 0],
                                        [0, 0], [0, 0]],
                             mode='constant',
                             constant_values=0.)
        w_blocks, w_ragged = divmod(kernel_size[1], block_size[1])
        w_blocks = w_blocks + 1
        if w_ragged != 0:
            kernel = jnp.pad(kernel,
                             pad_width=[[0, 0], [block_size[1] - w_ragged, 0],
                                        [0, 0], [0, 0]],
                             mode='constant',
                             constant_values=0.)

        # transform kernel following space-to-depth logic: http://shortn/_9YvHW96xPJ
        kernel = jnp.reshape(kernel, [
            h_blocks, block_size[0], w_blocks, block_size[1],
            in_features // feature_group_count, filters
        ])
        kernel = jnp.transpose(kernel, [0, 2, 1, 3, 4, 5])
        kernel = jnp.reshape(kernel, [h_blocks, w_blocks, features, filters])
        kernel = kernel.astype(inputs.dtype)

        dimension_numbers = nn.linear._conv_dimension_numbers(inputs.shape)  # pylint: disable=protected-access

        y = lax.conv_general_dilated(lhs=inputs,
                                     rhs=kernel,
                                     window_strides=strides,
                                     padding=padding,
                                     lhs_dilation=input_dilation,
                                     rhs_dilation=kernel_dilation,
                                     dimension_numbers=dimension_numbers,
                                     feature_group_count=feature_group_count,
                                     precision=precision)
        if bias:
            bias = self.param('bias', (features, ), bias_init)
            bias = jnp.asarray(bias, dtype)
            y = y + bias
        return y
Ejemplo n.º 14
0
 def call(self, x, params=(), **kwargs):
   del kwargs
   w, b = params
   return lax.conv_general_dilated(
       x, w, self._strides, self._padding, self._one, self._one,
       self._dimension_numbers) + b
Ejemplo n.º 15
0
    def update_site(self, inputs: Array, index: int) -> Array:
        """
        Adds an input site into the cache, and applies the masked convolution to the cache.

        Args:
          inputs: an input site to be added into the cache with dimensions (batch, features).
          index: the index of the output site. The index of the input site should be `index - self.exclusive`.

        Returns:
          The next output site with dimensions (batch, features).
        """
        dtype = jnp.promote_types(inputs.dtype, self.dtype)

        inputs = jnp.asarray(inputs, dtype)

        L = self.L
        index_w = index % L

        kernel_h, kernel_w = self.kernel_size
        dilation_h, dilation_w = self.kernel_dilation
        ones = (1, 1)

        is_single_input = False
        if inputs.ndim == 1:
            is_single_input = True
            inputs = jnp.expand_dims(inputs, axis=0)

        batch, in_features = inputs.shape
        assert in_features % self.feature_group_count == 0
        recep_h = (kernel_h - 1) * dilation_h + 1
        recep_w = (kernel_w - 1) * dilation_w + 1

        # Initialize the cache with zeros, and the RNG key is None
        # `cache.dtype` must be the same as `inputs.dtype` (no promotion)
        _cache = self.variable(
            "cache",
            "inputs",
            zeros,
            None,
            (batch, recep_h, L, in_features),
            inputs.dtype,
        )

        initializing = self.is_mutable_collection("params")
        if not initializing:
            # Add the input site into the cache
            # To write the cache, use `_cache.value` as the left value of the assignment

            inputs = jnp.expand_dims(inputs, axis=(1, 2))

            # Index of the input site in the width direction
            index_w_in = (index - self.exclusive) % L

            def _add(cache):
                # return cache.at[:, -1, index_w_in, :].set(inputs)
                return lax.dynamic_update_slice(cache, inputs,
                                                (0, -1, index_w_in, 0))

            def _shift(cache):
                return jnp.concatenate(
                    [
                        cache[:, 1:, :, :],
                        jnp.zeros(
                            (batch, 1, L, in_features), dtype=inputs.dtype),
                    ],
                    axis=1,
                )

            cache_new_row = lax.cond(
                index_w_in == 0,
                lambda _: _add(_shift(_cache.value)),
                lambda _: _shift(_add(_cache.value)),
                None,
            )

            cache_new = lax.cond(
                index_w == 0,
                lambda _: cache_new_row,
                lambda _: _add(_cache.value),
                None,
            )

            _cache.value = lax.cond(
                index - self.exclusive >= 0,
                lambda _: cache_new,
                lambda _: _cache.value,
                None,
            )

        cache = _cache.value
        cache = jnp.asarray(cache, dtype)

        kernel_shape = self.kernel_size + (
            in_features // self.feature_group_count,
            self.features,
        )
        kernel = self.param(
            "kernel",
            wrap_kernel_init(self.kernel_init, self.mask),
            kernel_shape,
            self.dtype,
        )
        kernel = jnp.asarray(kernel, dtype)

        # Zero padding
        cache = jnp.pad(
            cache,
            (
                (0, 0),
                (0, 0),
                (kernel_w // 2 * dilation_w, (kernel_w - 1) // 2 * dilation_w),
                (0, 0),
            ),
        )

        # cache = cache[:, :, index_w : index_w + recep_w, :]
        cache = lax.dynamic_slice(cache, (0, 0, index_w, 0),
                                  (batch, recep_h, recep_w, in_features))

        dimension_numbers = flax.linen.linear._conv_dimension_numbers(
            cache.shape)
        y_i = lax.conv_general_dilated(
            cache,
            kernel,
            window_strides=ones,
            padding="VALID",
            lhs_dilation=ones,
            rhs_dilation=self.kernel_dilation,
            dimension_numbers=dimension_numbers,
            feature_group_count=self.feature_group_count,
            precision=self.precision,
        )

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features, ),
                              self.dtype)
            bias = jnp.asarray(bias, dtype)
            y_i = y_i + bias

        y_i = y_i.squeeze(axis=(1, 2))

        if is_single_input:
            y_i = y_i.squeeze(axis=0)

        return y_i
Ejemplo n.º 16
0
def energy(state_mat, jvalue):
  # Calculate energy
  logits = lax.conv_general_dilated(state_mat, jvalue*kernel, 
                                    (1,1), 'SAME', (1,1), (1,1), dn)  
  return logits
Ejemplo n.º 17
0
    def apply(self,
              inputs,
              features,
              kernel_size,
              strides=None,
              padding='SAME',
              input_dilation=None,
              kernel_dilation=None,
              feature_group_count=1,
              bias=True,
              dtype=jnp.float32,
              precision=None,
              kernel_init=default_kernel_init,
              bias_init=initializers.zeros):
        """Applies a convolution to the inputs.

    Args:
      inputs: input data with dimensions (batch, spatial_dims..., features).
      features: number of convolution filters.
      kernel_size: shape of the convolutional kernel.
      strides: a sequence of `n` integers, representing the inter-window
        strides.
      padding: either the string `'SAME'`, the string `'VALID'`, or a sequence
        of `n` `(low, high)` integer pairs that give the padding to apply before
        and after each spatial dimension.
      input_dilation: `None`, or a sequence of `n` integers, giving the
        dilation factor to apply in each spatial dimension of `inputs`.
        Convolution with input dilation `d` is equivalent to transposed
        convolution with stride `d`.
      kernel_dilation: `None`, or a sequence of `n` integers, giving the
        dilation factor to apply in each spatial dimension of the convolution
        kernel. Convolution with kernel dilation is also known as 'atrous
        convolution'.
      feature_group_count: integer, default 1. If specified divides the input
        features into groups.
      bias: whether to add a bias to the output (default: True).
      dtype: the dtype of the computation (default: float32).
      precision: numerical precision of the computation see `jax.lax.Precision`
        for details.
      kernel_init: initializer for the convolutional kernel.
      bias_init: initializer for the bias.
    Returns:
      The convolved data.
    """

        inputs = jnp.asarray(inputs, dtype)

        if strides is None:
            strides = (1, ) * (inputs.ndim - 2)

        in_features = inputs.shape[-1]
        assert in_features % feature_group_count == 0
        kernel_shape = kernel_size + (in_features // feature_group_count,
                                      features)
        kernel = self.param('kernel', kernel_shape, kernel_init)
        kernel = jnp.asarray(kernel, dtype)

        dimension_numbers = _conv_dimension_numbers(inputs.shape)
        y = lax.conv_general_dilated(inputs,
                                     kernel,
                                     strides,
                                     padding,
                                     lhs_dilation=input_dilation,
                                     rhs_dilation=kernel_dilation,
                                     dimension_numbers=dimension_numbers,
                                     feature_group_count=feature_group_count,
                                     precision=precision)

        if bias:
            bias = self.param('bias', (features, ), bias_init)
            bias = jnp.asarray(bias, dtype)
            y = y + bias
        return y
Ejemplo n.º 18
0
  def __call__(self, inputs):
    """Applies a convolution to the inputs with optional quantization.

    Args:
      inputs: input data with dimensions (batch, spatial_dims..., features).

    Returns:
      The convolved data.
    """
    hparams = self.hparams
    if hparams.weight_prec is not None and hparams.weight_prec > 8:
      raise NotImplementedError(
          'If you want to use more than 8bits for quantization, please revisit '
          'jax.lax.Precision.DEFAULT to determine whether it is still sufficient.'
      )
    jax_precision = jax.lax.Precision.DEFAULT

    if self.strides is None:
      strides = (1,) * (inputs.ndim - 2)
    else:
      strides = self.strides

    in_features = inputs.shape[-1]
    assert in_features % self.feature_group_count == 0
    kernel_shape = self.kernel_size + (in_features // self.feature_group_count,
                                       self.features)
    kernel = self.param('kernel', self.kernel_init, kernel_shape)

    inputs = jnp.asarray(inputs, self.dtype)
    kernel = jnp.asarray(kernel, self.dtype)

    # Activation quantization
    if hparams.quant_act is not None:
      inputs = QuantOps.create_inputs_fake_quant(
          inputs=inputs,
          hparams=hparams.quant_act,
          get_bounds_params=get_bounds.GetBounds.Params(
              update_bounds=self.quant_context.update_bounds,
              update_stats=self.train,
              paxis_name=self.paxis_name))

    # Weight quantization
    if hparams.weight_prec is not None:
      kernel_reduction_axis = tuple(range(kernel.ndim - 1))
      expected_scale_shape = (1,) * (kernel.ndim - 1) + (self.features,)
      assert hparams.quant_type == QuantType.fake_quant, (
          'we only support fake_quant style of aqt for ConvAqt.')
      quantized_type = hparams.quant_type.to_jax_type()
      kernel = QuantOps.create_weights_fake_quant(
          kernel,
          weight_params=QuantOps.WeightParams(
              prec=hparams.weight_prec,
              half_shift=hparams.weight_half_shift,
              axis=kernel_reduction_axis,
              expected_scale_shape=expected_scale_shape),
          quantized_type=quantized_type)

    # Convolution
    dimension_numbers = flax.nn.linear._conv_dimension_numbers(inputs.shape)  # pylint: disable=protected-access
    metadata_context = contextlib.suppress()
    # Use metadata context to annotate op metadata with quantization info
    act_prec = None if hparams.quant_act is None else hparams.quant_act.prec

    if flags.FLAGS.metadata_enabled:
      metadata_context = compute_cost_utils.ConvMetadataMonkeyPatch(
          weight_prec=hparams.weight_prec, act_prec=act_prec)
    with metadata_context:
      y = lax.conv_general_dilated(
          inputs,
          kernel,
          strides,
          self.padding,
          lhs_dilation=self.input_dilation,
          rhs_dilation=self.kernel_dilation,
          dimension_numbers=dimension_numbers,
          feature_group_count=self.feature_group_count,
          precision=jax_precision)
    # TODO(shivaniagrawal): create quantized conv general dilated.

    # bias
    if self.use_bias:
      bias = self.param('bias', self.bias_init, (self.features,))
      bias = jnp.asarray(bias, self.dtype)
      # The inputs can have an arbitrary number of spatial dims, so we broadcast
      # the bias to match: (batch_size, spatial_dim,... features)
      # TODO(shivaniagrawal): Consider making ConvAqt rank static (e.g. 2D)
      # or maybe add error checking (e.g. expect inputs to have rank N, but this
      # may already be checked by lax.conv_general_dialated).
      bias = utils.broadcast_rank(bias, inputs)
      y = y + bias
    return y
Ejemplo n.º 19
0
 def apply_fun(params, inputs, **kwargs):
     W = params
     return lax.conv_general_dilated(inputs, W, strides, padding, one, one,
                                     dimension_numbers)
Ejemplo n.º 20
0
 def conv(lhs, rhs):
   return lax.conv_general_dilated(
     lhs, rhs, strides, padding,
     lhs_dilation=lhs_dilation, dimension_numbers=dimension_numbers)
Ejemplo n.º 21
0
  def apply(self,
            inputs,
            features,
            kernel_size,
            strides=None,
            padding='SAME',
            lhs_dilation=None,
            rhs_dilation=None,
            feature_group_count=1,
            bias=True,
            dtype=jnp.float32,
            precision=None,
            kernel_init=nn.linear.default_kernel_init,
            bias_init=initializers.zeros,
            scale_init=initializers.ones,
            compensate_padding=True):
    """Applies a convolution to the inputs.

    Args:
      inputs: input data with dimensions (batch, spatial_dims..., features).
      features: number of convolution filters.
      kernel_size: shape of the convolutional kernel.
      strides: a sequence of `n` integers, representing the inter-window
        strides.
      padding: either the string `'SAME'`, the string `'VALID'`, or a sequence
        of `n` `(low, high)` integer pairs that give the padding to apply before
        and after each spatial dimension.
      lhs_dilation: `None`, or a sequence of `n` integers, giving the dilation
        factor to apply in each spatial dimension of `lhs`. LHS dilation is also
        known as transposed convolution.
      rhs_dilation: `None`, or a sequence of `n` integers, giving the dilation
        factor to apply in each spatial dimension of `rhs`. RHS dilation is also
        known as atrous convolution.
      feature_group_count: integer, default 1. If specified divides the input
        features into groups.
      bias: whether to add a bias to the output (default: True).
      dtype: the dtype of the computation (default: float32).
      precision: numerical precision of the computation see `jax.lax.Precision`
        for details.
      kernel_init: initializer for the convolutional kernel.
      bias_init: initializer for the bias.
      scale_init: initializer for the scale.
      compensate_padding: Renormalize output based on introduced zero padding.

    Returns:
      The convolved data.
    """

    inputs = jnp.asarray(inputs, dtype)

    if strides is None:
      strides = (1,) * (inputs.ndim - 2)

    in_features = inputs.shape[-1]
    assert in_features % feature_group_count == 0
    kernel_shape = kernel_size + (in_features // feature_group_count, features)
    kernel_unnorm = self.param('kernel', kernel_shape, kernel_init)
    kernel_unnorm = jnp.asarray(kernel_unnorm, dtype)
    kernel_unnorm = jnp.reshape(
        kernel_unnorm,
        (-1, features),
    )
    kernel = kernel_unnorm / (
        jnp.linalg.norm(kernel_unnorm, axis=0, keepdims=True) + 1e-5)

    scale = self.param('scale', (features,), scale_init)
    kernel *= scale.reshape((-1, features))
    kernel = jnp.reshape(kernel, kernel_shape)

    # pylint: disable=protected-access
    dimension_numbers = nn.linear._conv_dimension_numbers(inputs.shape)
    # pylint: enable=protected-access
    y = lax.conv_general_dilated(
        inputs,
        kernel,
        strides,
        padding,
        lhs_dilation=lhs_dilation,
        rhs_dilation=rhs_dilation,
        dimension_numbers=dimension_numbers,
        feature_group_count=feature_group_count,
        precision=precision)

    if bias:
      bias = self.param('bias', (features,), bias_init)
      bias = jnp.asarray(bias, dtype)
      y = y + bias

    if compensate_padding:
      y = padding_compensate(inputs, kernel_size, lhs_dilation, padding,
                             precision, rhs_dilation, strides, y)
    return y
def high_precision_conv(*args, **kwargs):
  kwargs.pop('precision')
  kwargs.pop('lhs_shape')
  kwargs.pop('rhs_shape')
  return lax.conv_general_dilated(*args, precision=lax.Precision.HIGH, **kwargs)