Exemplo n.º 1
0
  def __call__(self, inputs: Array) -> Array:
    """Applies a transposed convolution to the inputs. Behaviour mirrors of
    `jax.lax.conv_transpose`.

    Args:
      inputs: input data with dimensions (batch, spatial_dims..., features).

    Returns:
      The convolved data.
    """
    inputs = jnp.asarray(inputs, self.dtype)
    strides = self.strides or (1,) * (inputs.ndim - 2)

    in_features = inputs.shape[-1]
    kernel_shape = self.kernel_size + (in_features, self.features)
    kernel = self.param('kernel', self.kernel_init, kernel_shape)
    kernel = jnp.asarray(kernel, self.dtype)

    y = lax.conv_transpose(inputs,
                           kernel,
                           strides,
                           self.padding,
                           rhs_dilation=self.kernel_dilation,
                           precision=self.precision)

    if self.use_bias:
      bias = self.param('bias', self.bias_init, (self.features,))
      bias = jnp.asarray(bias, self.dtype)
      y = y + bias
    return y
Exemplo n.º 2
0
 def apply_fun(params, inputs, **kwargs):
     W = params
     return lax.conv_transpose(inputs,
                               W,
                               strides,
                               padding,
                               dimension_numbers=dimension_numbers)
Exemplo n.º 3
0
    def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
        """Computes the transposed convolution of the input.

    Args:
      inputs: An array of shape ``[spatial_dims, C]`` and rank-N+1 if unbatched,
        or an array of shape ``[N, spatial_dims, C]`` and rank-N+2 if batched.

    Returns:
      An array of shape ``[spatial_dims, output_channels]`` and rank-N+1 if
        unbatched, or an array of shape ``[N, spatial_dims, output_channels]``
        and rank-N+2 if batched.
    """
        unbatched_rank = self.num_spatial_dims + 1
        allowed_ranks = [unbatched_rank, unbatched_rank + 1]
        if inputs.ndim not in allowed_ranks:
            raise ValueError(
                f"Input to ConvNDTranspose needs to have rank in "
                f"{allowed_ranks}, but input has shape {inputs.shape}.")

        unbatched = inputs.ndim == unbatched_rank
        if unbatched:
            inputs = jnp.expand_dims(inputs, axis=0)

        input_channels = inputs.shape[self.channel_index]
        w_shape = self.kernel_shape + (self.output_channels, input_channels)

        if self.mask is not None and self.mask.shape != w_shape:
            raise ValueError("Mask needs to have the same shape as weights. "
                             f"Shapes are: {self.mask.shape}, {w_shape}")

        w_init = self.w_init
        if w_init is None:
            fan_in_shape = self.kernel_shape + (input_channels, )
            stddev = 1. / np.sqrt(np.prod(fan_in_shape))
            w_init = hk.initializers.TruncatedNormal(stddev=stddev)
        w = hk.get_parameter("w", w_shape, inputs.dtype, init=w_init)

        if self.mask is not None:
            w = w * self.mask

        out = lax.conv_transpose(inputs,
                                 w,
                                 strides=self.stride,
                                 padding=self.padding,
                                 dimension_numbers=self.dimension_numbers)

        if self.with_bias:
            if self.channel_index == -1:
                bias_shape = (self.output_channels, )
            else:
                bias_shape = (
                    self.output_channels, ) + (1, ) * self.num_spatial_dims
            b = hk.get_parameter("b", bias_shape, init=self.b_init)
            b = jnp.broadcast_to(b, out.shape)
            out = out + b

        if unbatched:
            out = jnp.squeeze(out, axis=0)
        return out
Exemplo n.º 4
0
 def __call__(self, x: JaxArray) -> JaxArray:
     """Returns the results of applying the transposed convolution to input x."""
     y = lax.conv_transpose(x, self.w.value, self.strides, self.padding,
                            rhs_dilation=self.dilations,
                            dimension_numbers=('NCHW', 'HWIO', 'NCHW'), transpose_kernel=True)
     if self.b:
         y += self.b.value
     return y
Exemplo n.º 5
0
def conv_transpose(scope,
                   inputs,
                   features,
                   kernel_size,
                   strides=None,
                   padding='SAME',
                   kernel_dilation=None,
                   bias=True,
                   dtype=jnp.float32,
                   precision=None,
                   kernel_init=default_kernel_init,
                   bias_init=initializers.zeros):
    """Applies a transposed convolution to the inputs. Behaviour mirrors that of
  `jax.lax.conv_transpose`.

  Args:
    scope: functional scope.
    inputs: input data with dimensions (batch, spatial_dims..., features).
    features: number of convolution filters.
    kernel_size: shape of the convolutional kernel.
    strides: a sequence of `n` integers, representing the inter-window
      strides.
    padding: either the string `'SAME'`, the string `'VALID'`, or a sequence
      of `n` `(low, high)` integer pairs that give the padding to apply before
      and after each spatial dimension.
    kernel_dilation: `None`, or a sequence of `n` integers, giving the
      dilation factor to apply in each spatial dimension of the convolution
      kernel. Convolution with kernel dilation is also known as 'atrous
      convolution'.
    bias: whether to add a bias to the output (default: True).
    dtype: the dtype of the computation (default: float32).
    precision: numerical precision of the computation see `jax.lax.Precision`
      for details.
    kernel_init: initializer for the convolutional kernel.
    bias_init: initializer for the bias.
  Returns:
    The convolved data.
  """
    inputs = jnp.asarray(inputs, dtype)
    strides = strides or (1, ) * (inputs.ndim - 2)

    in_features = inputs.shape[-1]
    kernel_shape = kernel_size + (in_features, features)
    kernel = scope.param('kernel', kernel_init, kernel_shape)
    kernel = jnp.asarray(kernel, dtype)

    y = lax.conv_transpose(inputs,
                           kernel,
                           strides,
                           padding,
                           rhs_dilation=kernel_dilation,
                           precision=precision)

    if bias:
        bias = scope.param('bias', bias_init, (features, ))
        bias = jnp.asarray(bias, dtype)
        y = y + bias
    return y
Exemplo n.º 6
0
 def _call_batched(self, x):
     params, info = self.params, self.info
     result = lax.conv_transpose(x,
                                 params.kernel,
                                 info.strides,
                                 info.padding,
                                 dimension_numbers=DIMENSION_NUMBERS)
     if info.use_bias:
         result += params.bias
     return result
Exemplo n.º 7
0
    def convolution_transpose_op(self, params, inputs, **kwargs):
        output = lax.conv_transpose(
            inputs, params[0], self.strides, self.padding, dimension_numbers=self.dn
        )
        if self.use_bias:
            output = jnp.add(output, params[1])

        if self.activation:
            output = self.activation(output)
        return output
Exemplo n.º 8
0
Arquivo: conv.py Projeto: chjort/elegy
    def call(self, inputs: np.ndarray) -> np.ndarray:
        """
        Computes the transposed convolution of the input.

        Args:
            inputs: A rank-N+2 array with shape ``[N, spatial_dims, C]``.

        Returns:
            A rank-N+2 array with shape ``[N, spatial_dims, output_channels]``.
        """
        required_rank = self.num_spatial_dims + 2
        if inputs.ndim != required_rank:
            raise ValueError(
                f"Input to ConvND needs to have rank {required_rank}, "
                f"but input has shape {inputs.shape}."
            )

        input_channels = inputs.shape[self.channel_index]
        w_shape = self.kernel_shape + (self.output_channels, input_channels)

        if self.mask is not None and self.mask.shape != w_shape:
            raise ValueError(
                "Mask needs to have the same shape as weights. "
                f"Shapes are: {self.mask.shape}, {w_shape}"
            )

        w_init = self.w_init
        if w_init is None:
            fan_in_shape = self.kernel_shape + (input_channels,)
            stddev = 1.0 / np.sqrt(np.prod(fan_in_shape))
            w_init = initializers.TruncatedNormal(stddev=stddev)
        w = hooks.get_parameter("w", w_shape, inputs.dtype, initializer=w_init)

        if self.mask is not None:
            w = w * self.mask

        out = lax.conv_transpose(
            inputs,
            w,
            strides=self.stride,
            padding=self.padding,
            dimension_numbers=self.dimension_numbers,
        )

        if self.with_bias:
            if self.channel_index == -1:
                bias_shape = (self.output_channels,)
            else:
                bias_shape = (self.output_channels,) + (1,) * self.num_spatial_dims
            b = hooks.get_parameter("b", bias_shape, initializer=self.b_init)
            b = jnp.broadcast_to(b, out.shape)
            out = out + b

        return out
Exemplo n.º 9
0
 def forward(self, x):
     w, b = self.weights
     x_shape = list(x.shape)
     if len(x_shape) > 4:
         self._check_nhwc()
         new_batch_dim = functools.reduce(operator.mul, x.shape[:-3])
         x = jnp.reshape(x, [new_batch_dim] + list(x.shape[-3:]))
     res = lax.conv_transpose(x, w, self._strides, self._padding,
                              self._rhs_dilation,
                              self._dimension_numbers) + b
     if len(x_shape) > 4:
         res = jnp.reshape(res, x_shape[:-3] + list(res.shape[-3:]))
     return res
Exemplo n.º 10
0
    def conv_transpose(inputs):
        filter_shape_iter = iter(filter_shape)

        kernel_shape = [out_chan if c == 'O' else
                        inputs.shape[lhs_spec.index('C')] if c == 'I' else
                        next(filter_shape_iter) for c in rhs_spec]

        bias_shape = tuple(
            itertools.dropwhile(lambda x: x == 1, [out_chan if c == 'C' else 1 for c in out_spec]))

        kernel = parameter(kernel_shape, kernel_init, 'kernel')
        bias = parameter(bias_shape, bias_init, 'bias')
        return lax.conv_transpose(inputs, kernel, strides, padding,
                                  dimension_numbers=dimension_numbers) + bias
Exemplo n.º 11
0
def _upsample_nearest_neighbour(inputs_nchw):
    # nearest neighbour upsampling on NCHW input
    _n, input_c, h, w = inputs_nchw.shape
    flat_inputs_shape = (-1, h, w, 1)
    flat_inputs = jnp.reshape(inputs_nchw, flat_inputs_shape)
    resize_kernel = jnp.ones((2, 2, 1, 1))
    strides = (2, 2)
    flat_outputs = conv_transpose(flat_inputs,
                                  resize_kernel,
                                  strides,
                                  padding="SAME")
    outputs_nchw_shape = (-1, input_c, 2 * h, 2 * w)
    outputs_nchw = jnp.reshape(flat_outputs, outputs_nchw_shape)
    return outputs_nchw
Exemplo n.º 12
0
    def __call__(self, inputs: Array) -> Array:
        """Applies a transposed convolution to the inputs. Behaviour mirrors of
        `jax.lax.conv_transpose`.
        Args:
          inputs: input data with dimensions (batch, spatial_dims..., features).
        Returns:
          The convolved data.
        """
        dtype = nkjax.maybe_promote_to_complex(self.dtype, inputs.dtype)

        inputs = jnp.asarray(inputs, dtype)

        if isinstance(self.kernel_size, int):
            kernel_size = (self.kernel_size, )
        else:
            kernel_size = self.kernel_size

        is_single_input = False
        if inputs.ndim == len(kernel_size) + 1:
            is_single_input = True
            inputs = jnp.expand_dims(inputs, axis=0)

        strides = self.strides or (1, ) * (inputs.ndim - 2)

        in_features = inputs.shape[-1]
        kernel_shape = kernel_size + (in_features, self.features)
        kernel = self.param("kernel", self.kernel_init, kernel_shape,
                            self.dtype)
        kernel = jnp.asarray(kernel, dtype)

        y = lax.conv_transpose(
            inputs,
            kernel,
            strides,
            self.padding,
            rhs_dilation=self.kernel_dilation,
            precision=self.precision,
        )

        if is_single_input:
            y = jnp.squeeze(y, axis=0)
        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features, ),
                              self.dtype)
            bias = jnp.asarray(bias, dtype)
            bias = jnp.asarray(bias, dtype)
            y = y + bias
        return y
Exemplo n.º 13
0
    def __call__(self, inputs):
        """Connects Conv2DTranspose layer.

    Args:
      inputs: A rank-N+2 array with shape [N, spatial_dims, C].

    Returns:
      A rank-N+2 array with shape [N, spatial_dims, output_channels].
    """
        if len(inputs.shape) != self._num_spatial_dims + 2:
            raise ValueError(
                "Input to ConvND needs to have rank {}, but input "
                "has shape {}.".format(self._num_spatial_dims + 2,
                                       inputs.shape))
        weight_shape = self._kernel_shape + (inputs.shape[self._channel_index],
                                             self._output_channels)

        fan_in_shape = np.sqrt(np.prod(weight_shape[:-1]))
        stddev = 1. / fan_in_shape
        w_init = self._w_init or initializers.TruncatedNormal(stddev=stddev)
        w = base.get_parameter("w", weight_shape, inputs.dtype, init=w_init)

        if self._mask is not None:
            if self._mask.shape != w.shape:
                raise ValueError(
                    "Mask needs to have the same shape as weights. "
                    "Shapes are: {}, {}".format(self._mask.shape, w.shape))
            w *= self._mask

        result = lax.conv_transpose(inputs,
                                    w,
                                    self._stride,
                                    self._padding,
                                    dimension_numbers=self._dn)
        if self._with_bias:
            if self._channel_index == -1:
                bias_shape = (self._output_channels, )
            else:
                bias_shape = (
                    self._output_channels, ) + (1, ) * self._num_spatial_dims
            b = base.get_parameter("b", bias_shape, init=self._b_init)
            result = result + b
        return result
Exemplo n.º 14
0
def batch_convolve_transpose(
    input,
    filter,
    strides=1,
    padding="VALID",
    input_format=None,
    filter_format=None,
    output_format=None,
    input_dilation=None,
    filter_dilation=None,
    transpose_kernel=False,
):
    """General n-dimensional convolution operator, with optional dilation.

    Wraps Jax's conv_general_dilated functin, and thus also the XLA's `Conv
    <https://www.tensorflow.org/xla/operation_semantics#conv_convolution>`_
    operator.

    Args:
        input (Tensor): a rank `n+2` dimensional input array.
        filter (Tensor): a rank `n+2` dimensional array of kernel weights.
        strides (int, sequence of int, optional): a (sequence) of `n` integers,
            representing the inter-window strides. If a scalar is given, it is
            used `n` times. Defaults to `1`.
        padding (sequence of couple, `'SAME'`, `'VALID'`, optional): a sequence of
            `n` `(low, high)` integer pairs that give the padding to apply
            before and after each spatial dimension. For  `'VALID'`, those are
            `0`. For `'SAME'`, they are the `input length - filter length + 1`
            for each dim. Defaults to `'Valid'`.
        input_format (`None` or str, optional): a string of same length as the
            number of dimensions in `input` which specify their role
            (see below). Defaults to `'NCW'` for 1d conv, `'NCHW'` for 2d conv,
             and `'NDCHW'` for 3d conv.
        input_dilation (`None`, int or sequence of int, optional): giving the
            dilation factor to apply in each spatial dimension of `input`.
            Inumpy.t dilation is also known as transposed convolution as it allows
            to increase the output spatial dimension by inserting in the input
            any number of `0`s between each spatial value.
        filter_dilation (`None`, int or sequence of int): giving the dilation
            factor to apply in each spatial dimension of `filter`. Filter
            dilation is also known as atrous convolution as it corresponds to
            inserting any number of `0`s in between the filter values, similar
            to performing the non-dilated filter convolution with a subsample
            version of the input across the spatial dimensions.

    Returns:
        Tensor: An array containing the convolution result.

    Format of `input`, `filter` and `output`:
    For example, to indicate dimension numbers consistent with the `conv` function
    with two spatial dimensions, one could use `('NCHW', 'OIHW', 'NCHW')`. As
    another example, to indicate dimension numbers consistent with the TensorFlow
    Conv2D operation, one could use `('NHWC', 'HWIO', 'NHWC')`. When using the
    latter form of convolution dimension specification, window strides are
    associated with spatial dimension character labels according to the order in
    which the labels appear in the `rhs_spec` string, so that `window_strides[0]`
    is matched with the dimension corresponding to the first character
    appearing in rhs_spec that is not `'I'` or `'O'`.
    :param filter_format:
    :param output_format:
    :param transpose_kernel:
    """

    # setting up the strides
    if numpy.isscalar(strides):
        strides = (strides,) * (input.ndim - 2)
    elif len(strides) != (input.ndim - 2):
        msg = "given strides: {} should match the number".format(
            strides
        ) + "of spatial dim. in input: {}".format(input.ndim - 2)
        raise ValueError(msg)

    # setting up the padding
    if type(padding) != str:
        strides = (strides,) * (input.ndim - 2)
        if len(padding) != (input.ndim - 2):
            msg = "given padding: {} should match the ".format(
                padding
            ) + "number of spatial dim. in input: {}".format(input.ndim - 2)
            raise ValueError(msg)

    # setting up the filter_format
    if filter_format is None:
        if filter.ndim == 3:
            filter_format = "OIW"
        elif filter.ndim == 4:
            filter_format = "OIHW"
        elif filter.ndim == 5:
            filter_format = "OIDHW"
        else:
            msg = "filter_format should be given for >5 dimensions."
            raise ValueError(msg)
    elif len(filter_format) != filter.ndim:
        msg = "given filter_format: {} should".format(
            len(filter_format)
        ) + "match the number of dimension in filter: {}".format(filter.ndim)
        raise ValueError(msg)

    # setting up the input format
    if input_format is None:
        if len(filter.shape) == 3:
            input_format = "NCW"
        elif len(filter.shape) == 4:
            input_format = "NCHW"
        elif len(filter.shape) == 5:
            input_format = "NCDHW"
        else:
            msg = "input_format should be given for >5 dimensions."
            raise ValueError(msg)
    elif len(input_format) != input.ndim:
        msg = "given input_format: {} should".format(
            len(input_format)
        ) + "match the number of dimension in input: {}".format(input.ndim)
        raise ValueError(msg)

    # setting up the output format
    if output_format is None:
        if len(filter.shape) == 3:
            output_format = "NCW"
        elif len(filter.shape) == 4:
            output_format = "NCHW"
        elif len(filter.shape) == 5:
            output_format = "NCDHW"
        else:
            msg = "output_format should be given for >5 dimensions."
            raise ValueError(msg)
    elif len(output_format) != input.ndim:
        msg = "given output_format: {} should".format(
            len(output_format)
        ) + "match the number of dimension in output: {}".format(input.ndim)
        raise ValueError(msg)

    # setting up dilations
    if numpy.isscalar(input_dilation):
        input_dilation = (input_dilation,) * 2
    if numpy.isscalar(filter_dilation):
        filter_dilation = (filter_dilation,) * 2

    specs = (input_format, filter_format, output_format)
    return jla.conv_transpose(
        lhs=input,
        rhs=filter,
        strides=strides,
        padding=padding,
        rhs_dilation=filter_dilation,
        dimension_numbers=specs,
        precision=None,
        transpose_kernel=transpose_kernel,
    )
Exemplo n.º 15
0
def onnx_conv_transpose(x,
                        w,
                        b=None,
                        auto_pad='NOTSET',
                        dilations=None,
                        group=1,
                        kernel_shape=None,
                        output_padding=None,
                        output_shape=None,
                        pads=None,
                        strides=None,
                        **kwargs):

    kernel_shape = kernel_shape or w.shape
    spatial_size = w.ndim - 2
    strides = strides or [1] * spatial_size
    rhs_dilation = dilations or [1] * (w.ndim - 2)

    # pad
    if auto_pad == "NOTSET":
        if pads is None:
            pad_mode = 'VALID'
        elif pads == 'VALID':
            pad_mode = 'VALID'
        elif pads == [0, 0] * spatial_size:
            pad_mode = pads
        else:
            pad_mode = []
            pad_pairs = len(pads) // 2
            for idx in range(pad_pairs):
                pad_mode.append((pads[idx], pads[idx + pad_pairs]))
    elif auto_pad == "SAME_UPPER":
        pad_mode = "SAME"
    elif auto_pad == "VALID":
        pad_mode = "VALID"
    elif auto_pad == "SAME_LOWER":
        raise NotImplemented("Conv with auto_pad `SAME_LOWER`")
    else:
        raise ValueError("Invalid auto_pad attribute: {}".format(auto_pad))

    if b is not None:
        b = b.reshape([1, w.shape[0]] + [1] * spatial_size)
    else:
        b = 0

    res = lax.conv_transpose(
        lhs=x,
        rhs=w,
        strides=strides,
        padding=pad_mode,
        rhs_dilation=rhs_dilation,
        dimension_numbers=('NCHW', 'OIHW', 'NCHW'),
        transpose_kernel=True,
        precision=None,
    )

    # change output_padding order
    # TODO
    output_padding = ([0, 0, 0, 0] if output_padding is None else
                      [0, 0, output_padding[0], output_padding[1]])
    if output_shape is not None:
        need_append_output_pad = True
        for spatial_idx in range(spatial_size):
            total_pad = (output_padding[spatial_idx] +
                         output_padding[spatial_size + spatial_idx])
            shape_diff = (output_shape[spatial_idx] -
                          res.shape[spatial_idx + 2] - total_pad)
            if shape_diff == 0:
                need_append_output_pad = False
            else:
                need_append_output_pad = True

        if need_append_output_pad:
            for spatial_idx in range(spatial_size):
                shape_diff = output_shape[spatial_idx] - res.shape[spatial_idx
                                                                   + 2]
                if shape_diff < 0:
                    raise Exception(
                        'output_sahpe can not samller than lax.conv_transpose output shape'
                    )
                else:
                    output_padding[spatial_idx + spatial_size] += shape_diff

    if output_padding != [0, 0, 0, 0]:
        res = pad_helper(res, output_padding)

    return [res + b]