def update_site(self, inputs: Array, index: int) -> Array: """ Adds an input site into the cache, and applies the masked convolution to the cache. Args: inputs: an input site to be added into the cache with dimensions (batch, features). index: the index of the output site. The index of the input site should be `index - self.exclusive`. Returns: The next output site with dimensions (batch, features). """ dtype = jnp.promote_types(inputs.dtype, self.dtype) inputs = jnp.asarray(inputs, dtype) kernel_size = self.kernel_size - self.exclusive dilation = self.kernel_dilation is_single_input = False if inputs.ndim == 1: is_single_input = True inputs = jnp.expand_dims(inputs, axis=0) batch, in_features = inputs.shape assert in_features % self.feature_group_count == 0 cache_size = kernel_size * dilation - (not self.exclusive) * ( dilation - 1) # Initialize the cache with zeros, and the RNG key is None # `cache.dtype` must be the same as `inputs.dtype` (no promotion) _cache = self.variable( "cache", "inputs", zeros, None, (batch, cache_size, in_features), inputs.dtype, ) initializing = self.is_mutable_collection("params") if not initializing: # Add the input site into the cache # To write the cache, use `_cache.value` as the left value of the assignment _cache.value = lax.cond( index - self.exclusive >= 0, lambda _: jnp.concatenate( [_cache.value[:, 1:, :], jnp.expand_dims(inputs, axis=1)], axis=1), lambda _: _cache.value, None, ) cache = _cache.value cache = jnp.asarray(cache, dtype) kernel_shape = ( kernel_size, in_features // self.feature_group_count, self.features, ) kernel = self.param("kernel", self.kernel_init, kernel_shape, self.dtype) kernel = jnp.asarray(kernel, dtype) if self.exclusive and dilation > 1: cache = cache[:, :-(dilation - 1), :] dimension_numbers = flax.linen.linear._conv_dimension_numbers( cache.shape) y_i = lax.conv_general_dilated( cache, kernel, window_strides=(1, ), padding="VALID", lhs_dilation=(1, ), rhs_dilation=(dilation, ), dimension_numbers=dimension_numbers, feature_group_count=self.feature_group_count, precision=self.precision, ) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features, ), self.dtype) bias = jnp.asarray(bias, dtype) y_i = y_i + bias y_i = y_i.squeeze(axis=1) if is_single_input: y_i = y_i.squeeze(axis=0) return y_i
_make_harness("clamp", "", lax.clamp, [ RandArg((3, 4, 5), _f32), RandArg((3, 4, 5), _f32), RandArg((3, 4, 5), _f32) ], poly_axes=[0, 0, 0]), _make_harness("conv_general_dilated", "", lambda lhs, rhs: lax.conv_general_dilated( lhs, rhs, window_strides=(2, 3), padding=((0, 0), (0, 0)), lhs_dilation=(1, 1), rhs_dilation=(1, 2), dimension_numbers=("NCHW", "OIHW", "NCHW"), feature_group_count=1, batch_group_count=1, precision=None), [RandArg((7, 3, 9, 10), _f32), RandArg((3, 3, 4, 5), _f32)], poly_axes=[0, None]), _make_harness("cummax", "", lambda x: lax_control_flow.cummax(x, axis=1, reverse=False), [RandArg((3, 4, 5), _f32)], poly_axes=[0]), _make_harness( "dot_general",
def _extract_image_patches( lhs: np.ndarray, filter_shape: Sequence[int], window_strides: Sequence[int], padding: str, lhs_dilation: Sequence[int] = None, rhs_dilation: Sequence[int] = None, dimension_numbers: lax.ConvDimensionNumbers = None, precision: lax.Precision = None, ) -> np.ndarray: """Extract patches subject to the receptive field of a general convolution. Runs the input through a convolution that packs input spatial and channel entries into output channel `"C"` entries. The order of dimensions packed is `"C" + ''.join(c for c in rhs_spec if c not in 'OI')`, where `rhs_spec == dimension_numbers[1]`. Docstring below adapted from `jax.lax.conv_general_dilated`. See Also: https://www.tensorflow.org/xla/operation_semantics#conv_convolution Args: lhs: a rank `n+2` dimensional input array. filter_shape: a sequence of `n` integers, representing the receptive window spatial shape in the order as specified in `rhs_spec = dimension_numbers[1]`. window_strides: a sequence of `n` integers, representing the inter-window strides. padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of `n` `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. lhs_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of `lhs`. LHS dilation is also known as transposed convolution. rhs_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of `rhs`. RHS dilation is also known as atrous convolution. dimension_numbers: either `None`, a `ConvDimensionNumbers` object, or a 3-tuple `(lhs_spec, rhs_spec, out_spec)`, where each element is a string of length `n+2`. precision: Optional. Either ``None``, which means the default precision for the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``). Returns: An array containing the image patches flattened inside the `"C"` output dimension. The size of this dimension is `C_input * onp.prod(filter_shape)`. In the string case of `dimension_numbers`, each character identifies by position: - the batch dimensions in `lhs`, `rhs`, and the output with the character 'N', - the feature dimensions in `lhs` and the output with the character 'C', - the input and output feature dimensions in rhs with the characters 'I' and 'O' respectively, and - spatial dimension correspondences between lhs, rhs, and the output using any distinct characters. For example, to indicate dimension numbers consistent with the `conv` function with two spatial dimensions, one could use `('NCHW', 'OIHW', 'NCHW')`. As another example, to indicate dimension numbers consistent with the TensorFlow Conv2D operation, one could use `('NHWC', 'HWIO', 'NHWC')`. When using the latter form of convolution dimension specification, window strides are associated with spatial dimension character labels according to the order in which the labels appear in the `rhs_spec` string, so that `window_strides[0]` is matched with the dimension corresponding to the first character appearing in rhs_spec that is not `'I'` or `'O'`. If `dimension_numbers` is `None`, the default is `('NCHW', 'OIHW', 'NCHW')` (for a 2D convolution). """ lhs_spec, rhs_spec, out_spec = dimension_numbers filter_shape = tuple(filter_shape) spatial_size = onp.prod(filter_shape) n_channels = lhs.shape[lhs_spec.index('C')] # Move separate `lhs` spatial locations into separate `rhs` channels. rhs = np.eye(spatial_size, dtype=lhs.dtype).reshape(filter_shape * 2) rhs = rhs.reshape((spatial_size, 1) + filter_shape) rhs = np.tile(rhs, (n_channels, ) + (1, ) * (rhs.ndim - 1)) rhs = np.moveaxis(rhs, (0, 1), (rhs_spec.index('O'), rhs_spec.index('I'))) out = lax.conv_general_dilated(lhs=lhs, rhs=rhs, window_strides=window_strides, padding=padding, lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers, precision=precision, feature_group_count=n_channels) return out
def convNd( input, filter, strides=1, padding="VALID", input_format=None, filter_format=None, output_format=None, input_dilation=None, filter_dilation=None, ): """General n-dimensional convolution operator, with input/filter dilation. Wraps Jax's conv_general_dilated functin, and thus also the XLA's `Conv <https://www.tensorflow.org/xla/operation_semantics#conv_convolution>`_ operator. Args: input (Tensor): a rank `n+2` dimensional input array. filter (Tensor): a rank `n+2` dimensional array of kernel weights. strides (int, sequence of int, optional): a (sequence) of `n` integers, representing the inter-window strides. If a scalar is given, it is used `n` times. Defaults to `1`. padding (sequence of couple, `'SAME'`, `'VALID'`, optional): a sequence of `n` `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. For `'VALID'`, those are `0`. For `'SAME'`, they are the `input length - filter length + 1` for each dim. Defaults to `'Valid'`. input_format (`None` or str, optional): a string of same length as the number of dimensions in `input` which specify their role (see below). Defaults to `'NCW'` for 1d conv, `'NCHW'` for 2d conv, and `'NDCHW'` for 3d conv. input_dilation (`None`, int or sequence of int, optional): giving the dilation factor to apply in each spatial dimension of `input`. Inumpy.t dilation is also known as transposed convolution as it allows to increase the output spatial dimension by inserting in the input any number of `0`s between each spatial value. filter_dilation (`None`, int or sequence of int): giving the dilation factor to apply in each spatial dimension of `filter`. Filter dilation is also known as atrous convolution as it corresponds to inserting any number of `0`s in between the filter values, similar to performing the non-dilated filter convolution with a subsample version of the input across the spatial dimensions. Returns: Tensor: An array containing the convolution result. Format of `input`, `filter` and `output`: For example, to indicate dimension numbers consistent with the `conv` function with two spatial dimensions, one could use `('NCHW', 'OIHW', 'NCHW')`. As another example, to indicate dimension numbers consistent with the TensorFlow Conv2D operation, one could use `('NHWC', 'HWIO', 'NHWC')`. When using the latter form of convolution dimension specification, window strides are associated with spatial dimension character labels according to the order in which the labels appear in the `rhs_spec` string, so that `window_strides[0]` is matched with the dimension corresponding to the first character appearing in rhs_spec that is not `'I'` or `'O'`. :param filter_format: :param output_format: """ # setting up the strides if numpy.isscalar(strides): strides = (strides, ) * (input.ndim - 2) elif len(strides) != (input.ndim - 2): msg = "given strides: {} should match the number".format( strides) + "of spatial dim. in input: {}".format(input.ndim - 2) raise ValueError(msg) # setting up the padding if type(padding) != str: strides = (strides, ) * (input.ndim - 2) if len(padding) != (input.ndim - 2): msg = "given padding: {} should match the ".format( padding) + "number of spatial dim. in input: {}".format( input.ndim - 2) raise ValueError(msg) # setting up the filter_format if filter_format is None: if filter.ndim == 3: filter_format = "OIW" elif filter.ndim == 4: filter_format = "OIHW" elif filter.ndim == 5: filter_format = "OIDHW" else: msg = "filter_format should be given for >5 dimensions." raise ValueError(msg) elif len(filter_format) != filter.ndim: msg = "given filter_format: {} should".format( len(filter_format) ) + "match the number of dimension in filter: {}".format(filter.ndim) raise ValueError(msg) # setting up the input format if input_format is None: if len(filter.shape) == 3: input_format = "NCW" elif len(filter.shape) == 4: input_format = "NCHW" elif len(filter.shape) == 5: input_format = "NCDHW" else: msg = "input_format should be given for >5 dimensions." raise ValueError(msg) elif len(input_format) != input.ndim: msg = "given input_format: {} should".format( len(input_format) ) + "match the number of dimension in input: {}".format(input.ndim) raise ValueError(msg) # setting up the output format if output_format is None: if len(filter.shape) == 3: output_format = "NCW" elif len(filter.shape) == 4: output_format = "NCHW" elif len(filter.shape) == 5: output_format = "NCDHW" else: msg = "output_format should be given for >5 dimensions." raise ValueError(msg) elif len(output_format) != input.ndim: msg = "given output_format: {} should".format( len(output_format) ) + "match the number of dimension in output: {}".format(input.ndim) raise ValueError(msg) # setting up dilations if numpy.isscalar(input_dilation): input_dilation = (input_dilation, ) * 2 if numpy.isscalar(filter_dilation): filter_dilation = (filter_dilation, ) * 2 specs = (input_format, filter_format, output_format) return jla.conv_general_dilated( lhs=input, rhs=filter, window_strides=strides, padding=padding, lhs_dilation=input_dilation, rhs_dilation=filter_dilation, dimension_numbers=specs, precision=None, )
def __call__( self, inputs: jnp.ndarray, *, precision: Optional[lax.Precision] = None, ) -> jnp.ndarray: """Connects ``ConvND`` layer. Args: inputs: An array of shape ``[spatial_dims, C]`` and rank-N+1 if unbatched, or an array of shape ``[N, spatial_dims, C]`` and rank-N+2 if batched. precision: Optional :class:`jax.lax.Precision` to pass to :func:`jax.lax.conv_general_dilated`. Returns: An array of shape ``[spatial_dims, output_channels]`` and rank-N+1 if unbatched, or an array of shape ``[N, spatial_dims, output_channels]`` and rank-N+2 if batched. """ unbatched_rank = self.num_spatial_dims + 1 allowed_ranks = [unbatched_rank, unbatched_rank + 1] if inputs.ndim not in allowed_ranks: raise ValueError( f"Input to ConvND needs to have rank in {allowed_ranks}," f" but input has shape {inputs.shape}.") unbatched = inputs.ndim == unbatched_rank if unbatched: inputs = jnp.expand_dims(inputs, axis=0) if inputs.shape[self.channel_index] % self.feature_group_count != 0: raise ValueError( f"Inputs channels {inputs.shape[self.channel_index]} " f"should be a multiple of feature_group_count " f"{self.feature_group_count}") w_shape = self.kernel_shape + (inputs.shape[self.channel_index] // self.feature_group_count, self.output_channels) if self.mask is not None and self.mask.shape != w_shape: raise ValueError("Mask needs to have the same shape as weights. " f"Shapes are: {self.mask.shape}, {w_shape}") w_init = self.w_init if w_init is None: fan_in_shape = np.prod(w_shape[:-1]) stddev = 1. / np.sqrt(fan_in_shape) w_init = hk.initializers.TruncatedNormal(stddev=stddev) w = hk.get_parameter("w", w_shape, inputs.dtype, init=w_init) if self.mask is not None: w *= self.mask out = lax.conv_general_dilated( inputs, w, window_strides=self.stride, padding=self.padding, lhs_dilation=self.lhs_dilation, rhs_dilation=self.kernel_dilation, dimension_numbers=self.dimension_numbers, feature_group_count=self.feature_group_count, precision=precision) if self.with_bias: if self.channel_index == -1: bias_shape = (self.output_channels, ) else: bias_shape = ( self.output_channels, ) + (1, ) * self.num_spatial_dims b = hk.get_parameter("b", bias_shape, inputs.dtype, init=self.b_init) b = jnp.broadcast_to(b, out.shape) out = out + b if unbatched: out = jnp.squeeze(out, axis=0) return out
def f(params, x): one = (1, 1) dimension_numbers = ('HNWC', 'HWIO', 'HWNC') y = lax.conv_general_dilated( x, params, one, 'SAME', one, one, dimension_numbers) return y
""" # convert to jax array X_image_jax = jnp.array(X_images_scaled, dtype=jnp.float32) # define the kernel kernel = jnp.ones(shape=(3, 3), dtype=jnp.float32) # better orthogonal kernel X_image_transform = conv_general_dilated( lhs=X_image_jax, # input rhs=kernel[..., None, None], # kernel window_strides=(1, 1), padding="SAME", lhs_dilation=(1, 1), rhs_dilation=(1, 1), dimension_numbers=("NHWC", "IOHW", "NHWC"), ) fig, ax = plt.subplots() plt.imshow(X_image_transform[0]) ax.set_yticks([]) ax.set_xticks([]) plt.tight_layout() plt.show() #%% """Invertible???"""
def apply_fun(params, inputs, rng=None): W, b = params return lax.conv_general_dilated(inputs, W, strides, padding, one, one, dimension_numbers) + b
def apply( self, inputs, features, kernel_size, is_first_layer=False, strides=None, padding="SAME", input_dilation=None, kernel_dilation=None, feature_group_count=1, bias=True, dtype=jnp.float32, precision=None, kernel_init=default_kernel_init, bias_init=initializers.zeros, ): """Applies a convolution to the inputs. """ inputs = jnp.asarray(inputs, dtype) assert len(kernel_size) == 1, "kernel_shape must be one dimensional" assert kernel_size[0] % 2 != 0, "kernel_shape must be odd" mask = onp.ones(kernel_size[0]) if is_first_layer: i = (kernel_size[0] - 1) // 2 else: i = (kernel_size[0] + 1) // 2 mask[i:] = 0 mask = jnp.asarray(mask[:, onp.newaxis, onp.newaxis], dtype) if strides is None: strides = (1, ) * (inputs.ndim - 2) in_features = inputs.shape[-1] assert in_features % feature_group_count == 0 kernel_shape = kernel_size + (in_features // feature_group_count, features) kernel = self.param("kernel", kernel_shape, kernel_init) kernel = jnp.asarray(kernel, dtype) kernel = kernel * mask dimension_numbers = _conv_dimension_numbers(inputs.shape) y = lax.conv_general_dilated( inputs, kernel, strides, padding, lhs_dilation=input_dilation, rhs_dilation=kernel_dilation, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, precision=precision, ) if bias: bias = self.param("bias", (features, ), bias_init) bias = jnp.asarray(bias, dtype) y = y + bias return y
def __call__(self, inputs: Array) -> Array: """ Applies a masked convolution to the inputs. For 1D convolution, there is not really a mask. We only need to apply appropriate padding. Args: inputs: input data with dimensions (batch, length, features). Returns: The convolved data. """ dtype = jnp.promote_types(inputs.dtype, self.dtype) inputs = jnp.asarray(inputs, dtype) kernel_size = self.kernel_size - self.exclusive dilation = self.kernel_dilation is_single_input = False if inputs.ndim == 2: is_single_input = True inputs = jnp.expand_dims(inputs, axis=0) in_features = inputs.shape[-1] assert in_features % self.feature_group_count == 0 kernel_shape = ( kernel_size, in_features // self.feature_group_count, self.features, ) kernel = self.param("kernel", self.kernel_init, kernel_shape, self.dtype) kernel = jnp.asarray(kernel, dtype) if self.exclusive: inputs = inputs[:, :-dilation, :] # Zero padding y = jnp.pad( inputs, ( (0, 0), ((kernel_size - (not self.exclusive)) * dilation, 0), (0, 0), ), ) dimension_numbers = flax.linen.linear._conv_dimension_numbers(inputs.shape) y = lax.conv_general_dilated( y, kernel, window_strides=(1,), padding="VALID", lhs_dilation=(1,), rhs_dilation=(dilation,), dimension_numbers=dimension_numbers, feature_group_count=self.feature_group_count, precision=self.precision, ) if is_single_input: y = y.squeeze(axis=0) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features,), self.dtype) bias = jnp.asarray(bias, dtype) y = y + bias return y
def __call__(self, inputs: Array) -> Array: """ Applies a masked convolution to the inputs. Args: inputs: input data with dimensions (batch, width, height, features). Returns: The convolved data. """ dtype = jnp.promote_types(inputs.dtype, self.dtype) inputs = jnp.asarray(inputs, dtype) kernel_h, kernel_w = self.kernel_size dilation_h, dilation_w = self.kernel_dilation ones = (1, 1) is_single_input = False if inputs.ndim == 3: is_single_input = True inputs = jnp.expand_dims(inputs, axis=0) in_features = inputs.shape[-1] assert in_features % self.feature_group_count == 0 kernel_shape = self.kernel_size + ( in_features // self.feature_group_count, self.features, ) kernel = self.param( "kernel", wrap_kernel_init(self.kernel_init, self.mask), kernel_shape, self.dtype, ) mask = jnp.asarray(self.mask, dtype) kernel = jnp.asarray(kernel, dtype) # Zero padding y = jnp.pad( inputs, ( (0, 0), ((kernel_h - 1) * dilation_h, 0), (kernel_w // 2 * dilation_w, (kernel_w - 1) // 2 * dilation_w), (0, 0), ), ) dimension_numbers = flax.linen.linear._conv_dimension_numbers(inputs.shape) y = lax.conv_general_dilated( y, mask * kernel, window_strides=ones, padding="VALID", lhs_dilation=ones, rhs_dilation=self.kernel_dilation, dimension_numbers=dimension_numbers, feature_group_count=self.feature_group_count, precision=self.precision, ) if is_single_input: y = y.squeeze(axis=0) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features,), self.dtype) bias = jnp.asarray(bias, dtype) y = y + bias return y
def conv2d(x, filters, strides, padding, data_format='NHWC', dilations=1): strides = [strides]*2 if isinstance(strides, int) else strides dilations = [dilations]*2 if isinstance(dilations, int) else dilations return _jlax.conv_general_dilated(x, filters, strides, padding, None, dilations, (data_format, 'HWIO', data_format))
def apply(self, inputs, filters, kernel_size, block_size, strides=None, padding='SAME', input_dilation=None, kernel_dilation=None, feature_group_count=1, bias=True, dtype=jnp.float32, precision=None, kernel_init=nn.linear.default_kernel_init, bias_init=nn.initializers.zeros): """Applies a convolution to the inputs. Args: inputs: input data with dimensions (batch, spatial_dims..., features). filters: number of convolution filters. kernel_size: shape of the convolutional kernel. block_size: shape of space-to-depth blocks. strides: a sequence of `n` integers, representing the inter-window strides. padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of `n` `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. input_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of `inputs`. Convolution with input dilation `d` is equivalent to transposed convolution with stride `d`. kernel_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Convolution with kernel dilation is also known as 'atrous convolution'. feature_group_count: integer, default 1. If specified divides the input features into groups. bias: whether to add a bias to the output (default: True). dtype: the dtype of the computation (default: float32). precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the convolutional kernel. bias_init: initializer for the bias. Returns: The convolved data. """ inputs = jnp.asarray(inputs, dtype) if strides is None: strides = block_size assert strides[0] % block_size[0] == 0 assert strides[1] % block_size[1] == 0 strides = tuple(s // b for s, b in zip(strides, block_size)) # create kernel as if there were no space to depth batch_size, h, w, features = inputs.shape original_input_shape = (batch_size, h * block_size[0], w * block_size[1], features // block_size[0] // block_size[1]) in_features = original_input_shape[-1] assert in_features % feature_group_count == 0 kernel_shape = kernel_size + (in_features // feature_group_count, filters) kernel = self.param('kernel', kernel_shape, kernel_init) kernel = jnp.asarray(kernel, dtype) # zero-pad kernel to multiple of block size (e.g. 7x7 --> 8x8) h_blocks, h_ragged = divmod(kernel_size[0], block_size[0]) h_blocks = h_blocks + 1 if h_ragged != 0: kernel = jnp.pad(kernel, pad_width=[[block_size[0] - h_ragged, 0], [0, 0], [0, 0], [0, 0]], mode='constant', constant_values=0.) w_blocks, w_ragged = divmod(kernel_size[1], block_size[1]) w_blocks = w_blocks + 1 if w_ragged != 0: kernel = jnp.pad(kernel, pad_width=[[0, 0], [block_size[1] - w_ragged, 0], [0, 0], [0, 0]], mode='constant', constant_values=0.) # transform kernel following space-to-depth logic: http://shortn/_9YvHW96xPJ kernel = jnp.reshape(kernel, [ h_blocks, block_size[0], w_blocks, block_size[1], in_features // feature_group_count, filters ]) kernel = jnp.transpose(kernel, [0, 2, 1, 3, 4, 5]) kernel = jnp.reshape(kernel, [h_blocks, w_blocks, features, filters]) kernel = kernel.astype(inputs.dtype) dimension_numbers = nn.linear._conv_dimension_numbers(inputs.shape) # pylint: disable=protected-access y = lax.conv_general_dilated(lhs=inputs, rhs=kernel, window_strides=strides, padding=padding, lhs_dilation=input_dilation, rhs_dilation=kernel_dilation, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, precision=precision) if bias: bias = self.param('bias', (features, ), bias_init) bias = jnp.asarray(bias, dtype) y = y + bias return y
def call(self, x, params=(), **kwargs): del kwargs w, b = params return lax.conv_general_dilated( x, w, self._strides, self._padding, self._one, self._one, self._dimension_numbers) + b
def update_site(self, inputs: Array, index: int) -> Array: """ Adds an input site into the cache, and applies the masked convolution to the cache. Args: inputs: an input site to be added into the cache with dimensions (batch, features). index: the index of the output site. The index of the input site should be `index - self.exclusive`. Returns: The next output site with dimensions (batch, features). """ dtype = jnp.promote_types(inputs.dtype, self.dtype) inputs = jnp.asarray(inputs, dtype) L = self.L index_w = index % L kernel_h, kernel_w = self.kernel_size dilation_h, dilation_w = self.kernel_dilation ones = (1, 1) is_single_input = False if inputs.ndim == 1: is_single_input = True inputs = jnp.expand_dims(inputs, axis=0) batch, in_features = inputs.shape assert in_features % self.feature_group_count == 0 recep_h = (kernel_h - 1) * dilation_h + 1 recep_w = (kernel_w - 1) * dilation_w + 1 # Initialize the cache with zeros, and the RNG key is None # `cache.dtype` must be the same as `inputs.dtype` (no promotion) _cache = self.variable( "cache", "inputs", zeros, None, (batch, recep_h, L, in_features), inputs.dtype, ) initializing = self.is_mutable_collection("params") if not initializing: # Add the input site into the cache # To write the cache, use `_cache.value` as the left value of the assignment inputs = jnp.expand_dims(inputs, axis=(1, 2)) # Index of the input site in the width direction index_w_in = (index - self.exclusive) % L def _add(cache): # return cache.at[:, -1, index_w_in, :].set(inputs) return lax.dynamic_update_slice(cache, inputs, (0, -1, index_w_in, 0)) def _shift(cache): return jnp.concatenate( [ cache[:, 1:, :, :], jnp.zeros( (batch, 1, L, in_features), dtype=inputs.dtype), ], axis=1, ) cache_new_row = lax.cond( index_w_in == 0, lambda _: _add(_shift(_cache.value)), lambda _: _shift(_add(_cache.value)), None, ) cache_new = lax.cond( index_w == 0, lambda _: cache_new_row, lambda _: _add(_cache.value), None, ) _cache.value = lax.cond( index - self.exclusive >= 0, lambda _: cache_new, lambda _: _cache.value, None, ) cache = _cache.value cache = jnp.asarray(cache, dtype) kernel_shape = self.kernel_size + ( in_features // self.feature_group_count, self.features, ) kernel = self.param( "kernel", wrap_kernel_init(self.kernel_init, self.mask), kernel_shape, self.dtype, ) kernel = jnp.asarray(kernel, dtype) # Zero padding cache = jnp.pad( cache, ( (0, 0), (0, 0), (kernel_w // 2 * dilation_w, (kernel_w - 1) // 2 * dilation_w), (0, 0), ), ) # cache = cache[:, :, index_w : index_w + recep_w, :] cache = lax.dynamic_slice(cache, (0, 0, index_w, 0), (batch, recep_h, recep_w, in_features)) dimension_numbers = flax.linen.linear._conv_dimension_numbers( cache.shape) y_i = lax.conv_general_dilated( cache, kernel, window_strides=ones, padding="VALID", lhs_dilation=ones, rhs_dilation=self.kernel_dilation, dimension_numbers=dimension_numbers, feature_group_count=self.feature_group_count, precision=self.precision, ) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features, ), self.dtype) bias = jnp.asarray(bias, dtype) y_i = y_i + bias y_i = y_i.squeeze(axis=(1, 2)) if is_single_input: y_i = y_i.squeeze(axis=0) return y_i
def energy(state_mat, jvalue): # Calculate energy logits = lax.conv_general_dilated(state_mat, jvalue*kernel, (1,1), 'SAME', (1,1), (1,1), dn) return logits
def apply(self, inputs, features, kernel_size, strides=None, padding='SAME', input_dilation=None, kernel_dilation=None, feature_group_count=1, bias=True, dtype=jnp.float32, precision=None, kernel_init=default_kernel_init, bias_init=initializers.zeros): """Applies a convolution to the inputs. Args: inputs: input data with dimensions (batch, spatial_dims..., features). features: number of convolution filters. kernel_size: shape of the convolutional kernel. strides: a sequence of `n` integers, representing the inter-window strides. padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of `n` `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. input_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of `inputs`. Convolution with input dilation `d` is equivalent to transposed convolution with stride `d`. kernel_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Convolution with kernel dilation is also known as 'atrous convolution'. feature_group_count: integer, default 1. If specified divides the input features into groups. bias: whether to add a bias to the output (default: True). dtype: the dtype of the computation (default: float32). precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the convolutional kernel. bias_init: initializer for the bias. Returns: The convolved data. """ inputs = jnp.asarray(inputs, dtype) if strides is None: strides = (1, ) * (inputs.ndim - 2) in_features = inputs.shape[-1] assert in_features % feature_group_count == 0 kernel_shape = kernel_size + (in_features // feature_group_count, features) kernel = self.param('kernel', kernel_shape, kernel_init) kernel = jnp.asarray(kernel, dtype) dimension_numbers = _conv_dimension_numbers(inputs.shape) y = lax.conv_general_dilated(inputs, kernel, strides, padding, lhs_dilation=input_dilation, rhs_dilation=kernel_dilation, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, precision=precision) if bias: bias = self.param('bias', (features, ), bias_init) bias = jnp.asarray(bias, dtype) y = y + bias return y
def __call__(self, inputs): """Applies a convolution to the inputs with optional quantization. Args: inputs: input data with dimensions (batch, spatial_dims..., features). Returns: The convolved data. """ hparams = self.hparams if hparams.weight_prec is not None and hparams.weight_prec > 8: raise NotImplementedError( 'If you want to use more than 8bits for quantization, please revisit ' 'jax.lax.Precision.DEFAULT to determine whether it is still sufficient.' ) jax_precision = jax.lax.Precision.DEFAULT if self.strides is None: strides = (1,) * (inputs.ndim - 2) else: strides = self.strides in_features = inputs.shape[-1] assert in_features % self.feature_group_count == 0 kernel_shape = self.kernel_size + (in_features // self.feature_group_count, self.features) kernel = self.param('kernel', self.kernel_init, kernel_shape) inputs = jnp.asarray(inputs, self.dtype) kernel = jnp.asarray(kernel, self.dtype) # Activation quantization if hparams.quant_act is not None: inputs = QuantOps.create_inputs_fake_quant( inputs=inputs, hparams=hparams.quant_act, get_bounds_params=get_bounds.GetBounds.Params( update_bounds=self.quant_context.update_bounds, update_stats=self.train, paxis_name=self.paxis_name)) # Weight quantization if hparams.weight_prec is not None: kernel_reduction_axis = tuple(range(kernel.ndim - 1)) expected_scale_shape = (1,) * (kernel.ndim - 1) + (self.features,) assert hparams.quant_type == QuantType.fake_quant, ( 'we only support fake_quant style of aqt for ConvAqt.') quantized_type = hparams.quant_type.to_jax_type() kernel = QuantOps.create_weights_fake_quant( kernel, weight_params=QuantOps.WeightParams( prec=hparams.weight_prec, half_shift=hparams.weight_half_shift, axis=kernel_reduction_axis, expected_scale_shape=expected_scale_shape), quantized_type=quantized_type) # Convolution dimension_numbers = flax.nn.linear._conv_dimension_numbers(inputs.shape) # pylint: disable=protected-access metadata_context = contextlib.suppress() # Use metadata context to annotate op metadata with quantization info act_prec = None if hparams.quant_act is None else hparams.quant_act.prec if flags.FLAGS.metadata_enabled: metadata_context = compute_cost_utils.ConvMetadataMonkeyPatch( weight_prec=hparams.weight_prec, act_prec=act_prec) with metadata_context: y = lax.conv_general_dilated( inputs, kernel, strides, self.padding, lhs_dilation=self.input_dilation, rhs_dilation=self.kernel_dilation, dimension_numbers=dimension_numbers, feature_group_count=self.feature_group_count, precision=jax_precision) # TODO(shivaniagrawal): create quantized conv general dilated. # bias if self.use_bias: bias = self.param('bias', self.bias_init, (self.features,)) bias = jnp.asarray(bias, self.dtype) # The inputs can have an arbitrary number of spatial dims, so we broadcast # the bias to match: (batch_size, spatial_dim,... features) # TODO(shivaniagrawal): Consider making ConvAqt rank static (e.g. 2D) # or maybe add error checking (e.g. expect inputs to have rank N, but this # may already be checked by lax.conv_general_dialated). bias = utils.broadcast_rank(bias, inputs) y = y + bias return y
def apply_fun(params, inputs, **kwargs): W = params return lax.conv_general_dilated(inputs, W, strides, padding, one, one, dimension_numbers)
def conv(lhs, rhs): return lax.conv_general_dilated( lhs, rhs, strides, padding, lhs_dilation=lhs_dilation, dimension_numbers=dimension_numbers)
def apply(self, inputs, features, kernel_size, strides=None, padding='SAME', lhs_dilation=None, rhs_dilation=None, feature_group_count=1, bias=True, dtype=jnp.float32, precision=None, kernel_init=nn.linear.default_kernel_init, bias_init=initializers.zeros, scale_init=initializers.ones, compensate_padding=True): """Applies a convolution to the inputs. Args: inputs: input data with dimensions (batch, spatial_dims..., features). features: number of convolution filters. kernel_size: shape of the convolutional kernel. strides: a sequence of `n` integers, representing the inter-window strides. padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of `n` `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. lhs_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of `lhs`. LHS dilation is also known as transposed convolution. rhs_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of `rhs`. RHS dilation is also known as atrous convolution. feature_group_count: integer, default 1. If specified divides the input features into groups. bias: whether to add a bias to the output (default: True). dtype: the dtype of the computation (default: float32). precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the convolutional kernel. bias_init: initializer for the bias. scale_init: initializer for the scale. compensate_padding: Renormalize output based on introduced zero padding. Returns: The convolved data. """ inputs = jnp.asarray(inputs, dtype) if strides is None: strides = (1,) * (inputs.ndim - 2) in_features = inputs.shape[-1] assert in_features % feature_group_count == 0 kernel_shape = kernel_size + (in_features // feature_group_count, features) kernel_unnorm = self.param('kernel', kernel_shape, kernel_init) kernel_unnorm = jnp.asarray(kernel_unnorm, dtype) kernel_unnorm = jnp.reshape( kernel_unnorm, (-1, features), ) kernel = kernel_unnorm / ( jnp.linalg.norm(kernel_unnorm, axis=0, keepdims=True) + 1e-5) scale = self.param('scale', (features,), scale_init) kernel *= scale.reshape((-1, features)) kernel = jnp.reshape(kernel, kernel_shape) # pylint: disable=protected-access dimension_numbers = nn.linear._conv_dimension_numbers(inputs.shape) # pylint: enable=protected-access y = lax.conv_general_dilated( inputs, kernel, strides, padding, lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, precision=precision) if bias: bias = self.param('bias', (features,), bias_init) bias = jnp.asarray(bias, dtype) y = y + bias if compensate_padding: y = padding_compensate(inputs, kernel_size, lhs_dilation, padding, precision, rhs_dilation, strides, y) return y
def high_precision_conv(*args, **kwargs): kwargs.pop('precision') kwargs.pop('lhs_shape') kwargs.pop('rhs_shape') return lax.conv_general_dilated(*args, precision=lax.Precision.HIGH, **kwargs)