def __call__(self, inputs: Array) -> Array: """Applies a transposed convolution to the inputs. Behaviour mirrors of `jax.lax.conv_transpose`. Args: inputs: input data with dimensions (batch, spatial_dims..., features). Returns: The convolved data. """ inputs = jnp.asarray(inputs, self.dtype) strides = self.strides or (1,) * (inputs.ndim - 2) in_features = inputs.shape[-1] kernel_shape = self.kernel_size + (in_features, self.features) kernel = self.param('kernel', self.kernel_init, kernel_shape) kernel = jnp.asarray(kernel, self.dtype) y = lax.conv_transpose(inputs, kernel, strides, self.padding, rhs_dilation=self.kernel_dilation, precision=self.precision) if self.use_bias: bias = self.param('bias', self.bias_init, (self.features,)) bias = jnp.asarray(bias, self.dtype) y = y + bias return y
def apply_fun(params, inputs, **kwargs): W = params return lax.conv_transpose(inputs, W, strides, padding, dimension_numbers=dimension_numbers)
def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: """Computes the transposed convolution of the input. Args: inputs: An array of shape ``[spatial_dims, C]`` and rank-N+1 if unbatched, or an array of shape ``[N, spatial_dims, C]`` and rank-N+2 if batched. Returns: An array of shape ``[spatial_dims, output_channels]`` and rank-N+1 if unbatched, or an array of shape ``[N, spatial_dims, output_channels]`` and rank-N+2 if batched. """ unbatched_rank = self.num_spatial_dims + 1 allowed_ranks = [unbatched_rank, unbatched_rank + 1] if inputs.ndim not in allowed_ranks: raise ValueError( f"Input to ConvNDTranspose needs to have rank in " f"{allowed_ranks}, but input has shape {inputs.shape}.") unbatched = inputs.ndim == unbatched_rank if unbatched: inputs = jnp.expand_dims(inputs, axis=0) input_channels = inputs.shape[self.channel_index] w_shape = self.kernel_shape + (self.output_channels, input_channels) if self.mask is not None and self.mask.shape != w_shape: raise ValueError("Mask needs to have the same shape as weights. " f"Shapes are: {self.mask.shape}, {w_shape}") w_init = self.w_init if w_init is None: fan_in_shape = self.kernel_shape + (input_channels, ) stddev = 1. / np.sqrt(np.prod(fan_in_shape)) w_init = hk.initializers.TruncatedNormal(stddev=stddev) w = hk.get_parameter("w", w_shape, inputs.dtype, init=w_init) if self.mask is not None: w = w * self.mask out = lax.conv_transpose(inputs, w, strides=self.stride, padding=self.padding, dimension_numbers=self.dimension_numbers) if self.with_bias: if self.channel_index == -1: bias_shape = (self.output_channels, ) else: bias_shape = ( self.output_channels, ) + (1, ) * self.num_spatial_dims b = hk.get_parameter("b", bias_shape, init=self.b_init) b = jnp.broadcast_to(b, out.shape) out = out + b if unbatched: out = jnp.squeeze(out, axis=0) return out
def __call__(self, x: JaxArray) -> JaxArray: """Returns the results of applying the transposed convolution to input x.""" y = lax.conv_transpose(x, self.w.value, self.strides, self.padding, rhs_dilation=self.dilations, dimension_numbers=('NCHW', 'HWIO', 'NCHW'), transpose_kernel=True) if self.b: y += self.b.value return y
def conv_transpose(scope, inputs, features, kernel_size, strides=None, padding='SAME', kernel_dilation=None, bias=True, dtype=jnp.float32, precision=None, kernel_init=default_kernel_init, bias_init=initializers.zeros): """Applies a transposed convolution to the inputs. Behaviour mirrors that of `jax.lax.conv_transpose`. Args: scope: functional scope. inputs: input data with dimensions (batch, spatial_dims..., features). features: number of convolution filters. kernel_size: shape of the convolutional kernel. strides: a sequence of `n` integers, representing the inter-window strides. padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of `n` `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. kernel_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Convolution with kernel dilation is also known as 'atrous convolution'. bias: whether to add a bias to the output (default: True). dtype: the dtype of the computation (default: float32). precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the convolutional kernel. bias_init: initializer for the bias. Returns: The convolved data. """ inputs = jnp.asarray(inputs, dtype) strides = strides or (1, ) * (inputs.ndim - 2) in_features = inputs.shape[-1] kernel_shape = kernel_size + (in_features, features) kernel = scope.param('kernel', kernel_init, kernel_shape) kernel = jnp.asarray(kernel, dtype) y = lax.conv_transpose(inputs, kernel, strides, padding, rhs_dilation=kernel_dilation, precision=precision) if bias: bias = scope.param('bias', bias_init, (features, )) bias = jnp.asarray(bias, dtype) y = y + bias return y
def _call_batched(self, x): params, info = self.params, self.info result = lax.conv_transpose(x, params.kernel, info.strides, info.padding, dimension_numbers=DIMENSION_NUMBERS) if info.use_bias: result += params.bias return result
def convolution_transpose_op(self, params, inputs, **kwargs): output = lax.conv_transpose( inputs, params[0], self.strides, self.padding, dimension_numbers=self.dn ) if self.use_bias: output = jnp.add(output, params[1]) if self.activation: output = self.activation(output) return output
def call(self, inputs: np.ndarray) -> np.ndarray: """ Computes the transposed convolution of the input. Args: inputs: A rank-N+2 array with shape ``[N, spatial_dims, C]``. Returns: A rank-N+2 array with shape ``[N, spatial_dims, output_channels]``. """ required_rank = self.num_spatial_dims + 2 if inputs.ndim != required_rank: raise ValueError( f"Input to ConvND needs to have rank {required_rank}, " f"but input has shape {inputs.shape}." ) input_channels = inputs.shape[self.channel_index] w_shape = self.kernel_shape + (self.output_channels, input_channels) if self.mask is not None and self.mask.shape != w_shape: raise ValueError( "Mask needs to have the same shape as weights. " f"Shapes are: {self.mask.shape}, {w_shape}" ) w_init = self.w_init if w_init is None: fan_in_shape = self.kernel_shape + (input_channels,) stddev = 1.0 / np.sqrt(np.prod(fan_in_shape)) w_init = initializers.TruncatedNormal(stddev=stddev) w = hooks.get_parameter("w", w_shape, inputs.dtype, initializer=w_init) if self.mask is not None: w = w * self.mask out = lax.conv_transpose( inputs, w, strides=self.stride, padding=self.padding, dimension_numbers=self.dimension_numbers, ) if self.with_bias: if self.channel_index == -1: bias_shape = (self.output_channels,) else: bias_shape = (self.output_channels,) + (1,) * self.num_spatial_dims b = hooks.get_parameter("b", bias_shape, initializer=self.b_init) b = jnp.broadcast_to(b, out.shape) out = out + b return out
def forward(self, x): w, b = self.weights x_shape = list(x.shape) if len(x_shape) > 4: self._check_nhwc() new_batch_dim = functools.reduce(operator.mul, x.shape[:-3]) x = jnp.reshape(x, [new_batch_dim] + list(x.shape[-3:])) res = lax.conv_transpose(x, w, self._strides, self._padding, self._rhs_dilation, self._dimension_numbers) + b if len(x_shape) > 4: res = jnp.reshape(res, x_shape[:-3] + list(res.shape[-3:])) return res
def conv_transpose(inputs): filter_shape_iter = iter(filter_shape) kernel_shape = [out_chan if c == 'O' else inputs.shape[lhs_spec.index('C')] if c == 'I' else next(filter_shape_iter) for c in rhs_spec] bias_shape = tuple( itertools.dropwhile(lambda x: x == 1, [out_chan if c == 'C' else 1 for c in out_spec])) kernel = parameter(kernel_shape, kernel_init, 'kernel') bias = parameter(bias_shape, bias_init, 'bias') return lax.conv_transpose(inputs, kernel, strides, padding, dimension_numbers=dimension_numbers) + bias
def _upsample_nearest_neighbour(inputs_nchw): # nearest neighbour upsampling on NCHW input _n, input_c, h, w = inputs_nchw.shape flat_inputs_shape = (-1, h, w, 1) flat_inputs = jnp.reshape(inputs_nchw, flat_inputs_shape) resize_kernel = jnp.ones((2, 2, 1, 1)) strides = (2, 2) flat_outputs = conv_transpose(flat_inputs, resize_kernel, strides, padding="SAME") outputs_nchw_shape = (-1, input_c, 2 * h, 2 * w) outputs_nchw = jnp.reshape(flat_outputs, outputs_nchw_shape) return outputs_nchw
def __call__(self, inputs: Array) -> Array: """Applies a transposed convolution to the inputs. Behaviour mirrors of `jax.lax.conv_transpose`. Args: inputs: input data with dimensions (batch, spatial_dims..., features). Returns: The convolved data. """ dtype = nkjax.maybe_promote_to_complex(self.dtype, inputs.dtype) inputs = jnp.asarray(inputs, dtype) if isinstance(self.kernel_size, int): kernel_size = (self.kernel_size, ) else: kernel_size = self.kernel_size is_single_input = False if inputs.ndim == len(kernel_size) + 1: is_single_input = True inputs = jnp.expand_dims(inputs, axis=0) strides = self.strides or (1, ) * (inputs.ndim - 2) in_features = inputs.shape[-1] kernel_shape = kernel_size + (in_features, self.features) kernel = self.param("kernel", self.kernel_init, kernel_shape, self.dtype) kernel = jnp.asarray(kernel, dtype) y = lax.conv_transpose( inputs, kernel, strides, self.padding, rhs_dilation=self.kernel_dilation, precision=self.precision, ) if is_single_input: y = jnp.squeeze(y, axis=0) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features, ), self.dtype) bias = jnp.asarray(bias, dtype) bias = jnp.asarray(bias, dtype) y = y + bias return y
def __call__(self, inputs): """Connects Conv2DTranspose layer. Args: inputs: A rank-N+2 array with shape [N, spatial_dims, C]. Returns: A rank-N+2 array with shape [N, spatial_dims, output_channels]. """ if len(inputs.shape) != self._num_spatial_dims + 2: raise ValueError( "Input to ConvND needs to have rank {}, but input " "has shape {}.".format(self._num_spatial_dims + 2, inputs.shape)) weight_shape = self._kernel_shape + (inputs.shape[self._channel_index], self._output_channels) fan_in_shape = np.sqrt(np.prod(weight_shape[:-1])) stddev = 1. / fan_in_shape w_init = self._w_init or initializers.TruncatedNormal(stddev=stddev) w = base.get_parameter("w", weight_shape, inputs.dtype, init=w_init) if self._mask is not None: if self._mask.shape != w.shape: raise ValueError( "Mask needs to have the same shape as weights. " "Shapes are: {}, {}".format(self._mask.shape, w.shape)) w *= self._mask result = lax.conv_transpose(inputs, w, self._stride, self._padding, dimension_numbers=self._dn) if self._with_bias: if self._channel_index == -1: bias_shape = (self._output_channels, ) else: bias_shape = ( self._output_channels, ) + (1, ) * self._num_spatial_dims b = base.get_parameter("b", bias_shape, init=self._b_init) result = result + b return result
def batch_convolve_transpose( input, filter, strides=1, padding="VALID", input_format=None, filter_format=None, output_format=None, input_dilation=None, filter_dilation=None, transpose_kernel=False, ): """General n-dimensional convolution operator, with optional dilation. Wraps Jax's conv_general_dilated functin, and thus also the XLA's `Conv <https://www.tensorflow.org/xla/operation_semantics#conv_convolution>`_ operator. Args: input (Tensor): a rank `n+2` dimensional input array. filter (Tensor): a rank `n+2` dimensional array of kernel weights. strides (int, sequence of int, optional): a (sequence) of `n` integers, representing the inter-window strides. If a scalar is given, it is used `n` times. Defaults to `1`. padding (sequence of couple, `'SAME'`, `'VALID'`, optional): a sequence of `n` `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. For `'VALID'`, those are `0`. For `'SAME'`, they are the `input length - filter length + 1` for each dim. Defaults to `'Valid'`. input_format (`None` or str, optional): a string of same length as the number of dimensions in `input` which specify their role (see below). Defaults to `'NCW'` for 1d conv, `'NCHW'` for 2d conv, and `'NDCHW'` for 3d conv. input_dilation (`None`, int or sequence of int, optional): giving the dilation factor to apply in each spatial dimension of `input`. Inumpy.t dilation is also known as transposed convolution as it allows to increase the output spatial dimension by inserting in the input any number of `0`s between each spatial value. filter_dilation (`None`, int or sequence of int): giving the dilation factor to apply in each spatial dimension of `filter`. Filter dilation is also known as atrous convolution as it corresponds to inserting any number of `0`s in between the filter values, similar to performing the non-dilated filter convolution with a subsample version of the input across the spatial dimensions. Returns: Tensor: An array containing the convolution result. Format of `input`, `filter` and `output`: For example, to indicate dimension numbers consistent with the `conv` function with two spatial dimensions, one could use `('NCHW', 'OIHW', 'NCHW')`. As another example, to indicate dimension numbers consistent with the TensorFlow Conv2D operation, one could use `('NHWC', 'HWIO', 'NHWC')`. When using the latter form of convolution dimension specification, window strides are associated with spatial dimension character labels according to the order in which the labels appear in the `rhs_spec` string, so that `window_strides[0]` is matched with the dimension corresponding to the first character appearing in rhs_spec that is not `'I'` or `'O'`. :param filter_format: :param output_format: :param transpose_kernel: """ # setting up the strides if numpy.isscalar(strides): strides = (strides,) * (input.ndim - 2) elif len(strides) != (input.ndim - 2): msg = "given strides: {} should match the number".format( strides ) + "of spatial dim. in input: {}".format(input.ndim - 2) raise ValueError(msg) # setting up the padding if type(padding) != str: strides = (strides,) * (input.ndim - 2) if len(padding) != (input.ndim - 2): msg = "given padding: {} should match the ".format( padding ) + "number of spatial dim. in input: {}".format(input.ndim - 2) raise ValueError(msg) # setting up the filter_format if filter_format is None: if filter.ndim == 3: filter_format = "OIW" elif filter.ndim == 4: filter_format = "OIHW" elif filter.ndim == 5: filter_format = "OIDHW" else: msg = "filter_format should be given for >5 dimensions." raise ValueError(msg) elif len(filter_format) != filter.ndim: msg = "given filter_format: {} should".format( len(filter_format) ) + "match the number of dimension in filter: {}".format(filter.ndim) raise ValueError(msg) # setting up the input format if input_format is None: if len(filter.shape) == 3: input_format = "NCW" elif len(filter.shape) == 4: input_format = "NCHW" elif len(filter.shape) == 5: input_format = "NCDHW" else: msg = "input_format should be given for >5 dimensions." raise ValueError(msg) elif len(input_format) != input.ndim: msg = "given input_format: {} should".format( len(input_format) ) + "match the number of dimension in input: {}".format(input.ndim) raise ValueError(msg) # setting up the output format if output_format is None: if len(filter.shape) == 3: output_format = "NCW" elif len(filter.shape) == 4: output_format = "NCHW" elif len(filter.shape) == 5: output_format = "NCDHW" else: msg = "output_format should be given for >5 dimensions." raise ValueError(msg) elif len(output_format) != input.ndim: msg = "given output_format: {} should".format( len(output_format) ) + "match the number of dimension in output: {}".format(input.ndim) raise ValueError(msg) # setting up dilations if numpy.isscalar(input_dilation): input_dilation = (input_dilation,) * 2 if numpy.isscalar(filter_dilation): filter_dilation = (filter_dilation,) * 2 specs = (input_format, filter_format, output_format) return jla.conv_transpose( lhs=input, rhs=filter, strides=strides, padding=padding, rhs_dilation=filter_dilation, dimension_numbers=specs, precision=None, transpose_kernel=transpose_kernel, )
def onnx_conv_transpose(x, w, b=None, auto_pad='NOTSET', dilations=None, group=1, kernel_shape=None, output_padding=None, output_shape=None, pads=None, strides=None, **kwargs): kernel_shape = kernel_shape or w.shape spatial_size = w.ndim - 2 strides = strides or [1] * spatial_size rhs_dilation = dilations or [1] * (w.ndim - 2) # pad if auto_pad == "NOTSET": if pads is None: pad_mode = 'VALID' elif pads == 'VALID': pad_mode = 'VALID' elif pads == [0, 0] * spatial_size: pad_mode = pads else: pad_mode = [] pad_pairs = len(pads) // 2 for idx in range(pad_pairs): pad_mode.append((pads[idx], pads[idx + pad_pairs])) elif auto_pad == "SAME_UPPER": pad_mode = "SAME" elif auto_pad == "VALID": pad_mode = "VALID" elif auto_pad == "SAME_LOWER": raise NotImplemented("Conv with auto_pad `SAME_LOWER`") else: raise ValueError("Invalid auto_pad attribute: {}".format(auto_pad)) if b is not None: b = b.reshape([1, w.shape[0]] + [1] * spatial_size) else: b = 0 res = lax.conv_transpose( lhs=x, rhs=w, strides=strides, padding=pad_mode, rhs_dilation=rhs_dilation, dimension_numbers=('NCHW', 'OIHW', 'NCHW'), transpose_kernel=True, precision=None, ) # change output_padding order # TODO output_padding = ([0, 0, 0, 0] if output_padding is None else [0, 0, output_padding[0], output_padding[1]]) if output_shape is not None: need_append_output_pad = True for spatial_idx in range(spatial_size): total_pad = (output_padding[spatial_idx] + output_padding[spatial_size + spatial_idx]) shape_diff = (output_shape[spatial_idx] - res.shape[spatial_idx + 2] - total_pad) if shape_diff == 0: need_append_output_pad = False else: need_append_output_pad = True if need_append_output_pad: for spatial_idx in range(spatial_size): shape_diff = output_shape[spatial_idx] - res.shape[spatial_idx + 2] if shape_diff < 0: raise Exception( 'output_sahpe can not samller than lax.conv_transpose output shape' ) else: output_padding[spatial_idx + spatial_size] += shape_diff if output_padding != [0, 0, 0, 0]: res = pad_helper(res, output_padding) return [res + b]