Esempio n. 1
0
    def __init__(self,
                 nin: int,
                 nout: int,
                 k: Union[Tuple[int, int], int],
                 strides: Union[Tuple[int, int], int] = 1,
                 dilations: Union[Tuple[int, int], int] = 1,
                 groups: int = 1,
                 padding: Union[ConvPadding, str, ConvPaddingInt] = ConvPadding.SAME,
                 use_bias: bool = True,
                 w_init: Callable = kaiming_normal):
        """Creates a Conv2D module instance.

        Args:
            nin: number of channels of the input tensor.
            nout: number of channels of the output tensor.
            k: size of the convolution kernel, either tuple (height, width) or single number if they're the same.
            strides: convolution strides, either tuple (stride_y, stride_x) or single number if they're the same.
            dilations: spacing between kernel points (also known as astrous convolution),
                       either tuple (dilation_y, dilation_x) or single number if they're the same.
            groups: number of input and output channels group. When groups > 1 convolution operation is applied
                    individually for each group. nin and nout must both be divisible by groups.
            padding: padding of the input tensor, either Padding.SAME, Padding.VALID or numerical values.
            use_bias: if True then convolution will have bias term.
            w_init: initializer for convolution kernel (a function that takes in a HWIO shape and returns a 4D matrix).
        """
        super().__init__()
        assert nin % groups == 0, 'nin should be divisible by groups'
        assert nout % groups == 0, 'nout should be divisible by groups'
        self.b = TrainVar(jn.zeros((nout, 1, 1))) if use_bias else None
        self.w = TrainVar(w_init((*util.to_tuple(k, 2), nin // groups, nout)))  # HWIO
        self.padding = util.to_padding(padding, 2)
        self.strides = util.to_tuple(strides, 2)
        self.dilations = util.to_tuple(dilations, 2)
        self.groups = groups
Esempio n. 2
0
def max_pool_2d(
    x: JaxArray,
    size: Union[Tuple[int, int], int] = 2,
    strides: Optional[Union[Tuple[int, int], int]] = None,
    padding: Union[ConvPadding, str, ConvPaddingInt] = ConvPadding.VALID
) -> JaxArray:
    """Applies max pooling using a square 2D filter.

    Args:
        x: input tensor of shape (N, C, H, W).
        size: size of pooling filter.
        strides: stride step, use size when stride is none (default).
        padding: padding of the input tensor, either Padding.SAME or Padding.VALID or numerical values.

    Returns:
        output tensor of shape (N, C, H, W).
    """
    size = to_tuple(size, 2)
    strides = to_tuple(strides, 2) if strides else size
    padding = to_padding(padding, 2)
    if isinstance(padding, tuple):
        padding = ((0, 0), (0, 0)) + padding
    return lax.reduce_window(x,
                             -jn.inf,
                             lax.max, (1, 1) + size, (1, 1) + strides,
                             padding=padding)
Esempio n. 3
0
    def __init__(self,
                 nin: int,
                 nout: int,
                 k: Union[Tuple[int, int], int],
                 strides: Union[Tuple[int, int], int] = 1,
                 dilations: Union[Tuple[int, int], int] = 1,
                 padding: Union[ConvPadding, str, ConvPaddingInt] = ConvPadding.SAME,
                 use_bias: bool = True,
                 w_init: Callable = kaiming_normal):
        """Creates a ConvTranspose2D module instance.

        Args:
            nin: number of channels of the input tensor.
            nout: number of channels of the output tensor.
            k: size of the convolution kernel, either tuple (height, width) or single number if they're the same.
            strides: convolution strides, either tuple (stride_y, stride_x) or single number if they're the same.
            dilations: spacing between kernel points (also known as astrous convolution),
                       either tuple (dilation_y, dilation_x) or single number if they're the same.
            padding: padding of the input tensor, either Padding.SAME or Padding.VALID.
            use_bias: if True then convolution will have bias term.
            w_init: initializer for convolution kernel (a function that takes in a HWIO shape and returns a 4D matrix).
        """
        super().__init__(nin=nout, nout=nin, k=k, strides=strides, padding=padding, use_bias=False, w_init=w_init)
        self.b = TrainVar(jn.zeros((nout, 1, 1))) if use_bias else None
        self.dilations = util.to_tuple(dilations, 2)
Esempio n. 4
0
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: Union[Tuple[int, int], int],
                 stride: Union[Tuple[int, int], int] = 1,
                 padding: Union[str, Tuple[int, int], int] = 0,
                 dilation: Union[Tuple[int, int], int] = 1,
                 groups: int = 1,
                 bias: bool = False,
                 kernel_init: Callable = kaiming_normal,
                 bias_init: Callable = jnp.zeros,
                 ):
        """Creates a Conv2D module instance.

        Args:
            in_channels: number of channels of the input tensor.
            out_channels: number of channels of the output tensor.
            kernel_size: size of the convolution kernel, either tuple (height, width) or single number if they're the same.
            stride: convolution strides, either tuple (stride_y, stride_x) or single number if they're the same.
            dilation: spacing between kernel points (also known as astrous convolution),
                       either tuple (dilation_y, dilation_x) or single number if they're the same.
            groups: number of input and output channels group. When groups > 1 convolution operation is applied
                    individually for each group. nin and nout must both be divisible by groups.
            padding: padding of the input tensor, either Padding.SAME or Padding.VALID.
            bias: if True then convolution will have bias term.
            kernel_init: initializer for convolution kernel (a function that takes in a HWIO shape and returns a 4D matrix).
        """
        super().__init__()
        assert in_channels % groups == 0, 'in_chs should be divisible by groups'
        assert out_channels % groups == 0, 'out_chs should be divisible by groups'
        kernel_size = util.to_tuple(kernel_size, 2)
        self.weight = TrainVar(kernel_init((out_channels, in_channels // groups, *kernel_size)))  # OIHW
        self.bias = TrainVar(bias_init((out_channels,))) if bias else None
        self.strides = util.to_tuple(stride, 2)
        self.dilations = util.to_tuple(dilation, 2)
        if isinstance(padding, str):
            if padding == 'LIKE':
                padding = (
                    get_like_padding(kernel_size[0], self.strides[0], self.dilations[0]),
                    get_like_padding(kernel_size[1], self.strides[1], self.dilations[1]))
                padding = [padding, padding]
        else:
            padding = util.to_tuple(padding, 2)
            padding = [padding, padding]
        self.padding = padding
        self.groups = groups
Esempio n. 5
0
def average_pool_2d(x: JaxArray,
                    size: Union[Tuple[int, int], int] = 2,
                    strides: Optional[Union[Tuple[int, int], int]] = None,
                    padding: ConvPadding = ConvPadding.VALID) -> JaxArray:
    """Applies average pooling using a square 2D filter.

    Args:
        x: input tensor of shape (N, C, H, W).
        size: size of pooling filter.
        strides: stride step, use size when stride is none (default).
        padding: type of padding used in pooling operation.

    Returns:
        output tensor of shape (N, C, H, W).
    """
    size = to_tuple(size, 2)
    strides = to_tuple(strides, 2) if strides else size
    return lax.reduce_window(
        x, 0, lax.add, (1, 1) + size,
        (1, 1) + strides, padding=padding.value) / np.prod(size)
Esempio n. 6
0
def max_pool_2d(x: JaxArray,
                size: Union[Tuple[int, int], int] = 2,
                strides: Union[Tuple[int, int], int] = 2,
                padding: ConvPadding = ConvPadding.VALID) -> JaxArray:
    """Applies max pooling using a square 2D filter.

    Args:
        x: input tensor of shape (N, C, H, W).
        size: size of pooling filter.
        strides: stride step.
        padding: type of padding used in pooling operation.

    Returns:
        output tensor of shape (N, C, H, W).
    """
    size = to_tuple(size, 2)
    strides = to_tuple(strides, 2)
    return lax.reduce_window(x,
                             -jn.inf,
                             lax.max, (1, 1) + size, (1, 1) + strides,
                             padding=padding.value)
Esempio n. 7
0
def channel_to_space2d(x: JaxArray,
                       size: Union[Tuple[int, int], int] = 2) -> JaxArray:
    """Transfer channel dimension C into spatial dimensions (H, W).

    Args:
        x: input tensor of shape (N, C, H, W).
        size: size of spatial area.

    Returns:
        output tensor of shape (N, C // (size[0] * size[1]), H * size[0], W * size[1]).
    """
    size = to_tuple(size, 2)
    s = x.shape
    y = x.reshape((s[0], -1, size[0], size[1], s[2], s[3]))
    y = y.transpose((0, 1, 4, 2, 5, 3))
    return y.reshape(
        (s[0], s[1] // (size[0] * size[1]), s[2] * size[0], s[3] * size[1]))
Esempio n. 8
0
def space_to_channel2d(x: JaxArray,
                       size: Union[Tuple[int, int], int] = 2) -> JaxArray:
    """Transfer spatial dimensions (H, W) into channel dimension C.

    Args:
        x: input tensor of shape (N, C, H, W).
        size: size of spatial area.

    Returns:
        output tensor of shape (N, C * size[0] * size[1]), H // size[0], W // size[1]).
    """
    size = to_tuple(size, 2)
    s = x.shape
    y = x.reshape(
        (s[0], s[1], s[2] // size[0], size[0], s[3] // size[1], size[1]))
    y = y.transpose((0, 1, 3, 5, 2, 4))
    return y.reshape(
        (s[0], s[1] * size[0] * size[1], s[2] // size[0], s[3] // size[1]))
Esempio n. 9
0
def upsample_2d(x: JaxArray,
                scale: Union[Tuple[int, int], int],
                method: Union[Interpolate, str] = Interpolate.BILINEAR) -> JaxArray:
    """Function to upscale 2D images.

    Args:
        x: input tensor.
        scale: int or tuple scaling factor
        method: str or UpSample interpolation methods e.g. ['bilinear', 'nearest'].

    Returns:
        upscaled 2d image tensor
    """
    s = x.shape
    assert len(s) == 4, f'{s} must have 4 dimensions to be upsampled, or you can try interpolate function.'
    scale = util.to_tuple(scale, 2)
    y = jax.image.resize(x.transpose([0, 2, 3, 1]),
                         shape=(s[0], s[2] * scale[0], s[3] * scale[1], s[1]),
                         method=util.to_interpolate(method))
    return y.transpose([0, 3, 1, 2])