def __init__(self, nin: int, nout: int, k: Union[Tuple[int, int], int], strides: Union[Tuple[int, int], int] = 1, dilations: Union[Tuple[int, int], int] = 1, groups: int = 1, padding: Union[ConvPadding, str, ConvPaddingInt] = ConvPadding.SAME, use_bias: bool = True, w_init: Callable = kaiming_normal): """Creates a Conv2D module instance. Args: nin: number of channels of the input tensor. nout: number of channels of the output tensor. k: size of the convolution kernel, either tuple (height, width) or single number if they're the same. strides: convolution strides, either tuple (stride_y, stride_x) or single number if they're the same. dilations: spacing between kernel points (also known as astrous convolution), either tuple (dilation_y, dilation_x) or single number if they're the same. groups: number of input and output channels group. When groups > 1 convolution operation is applied individually for each group. nin and nout must both be divisible by groups. padding: padding of the input tensor, either Padding.SAME, Padding.VALID or numerical values. use_bias: if True then convolution will have bias term. w_init: initializer for convolution kernel (a function that takes in a HWIO shape and returns a 4D matrix). """ super().__init__() assert nin % groups == 0, 'nin should be divisible by groups' assert nout % groups == 0, 'nout should be divisible by groups' self.b = TrainVar(jn.zeros((nout, 1, 1))) if use_bias else None self.w = TrainVar(w_init((*util.to_tuple(k, 2), nin // groups, nout))) # HWIO self.padding = util.to_padding(padding, 2) self.strides = util.to_tuple(strides, 2) self.dilations = util.to_tuple(dilations, 2) self.groups = groups
def max_pool_2d( x: JaxArray, size: Union[Tuple[int, int], int] = 2, strides: Optional[Union[Tuple[int, int], int]] = None, padding: Union[ConvPadding, str, ConvPaddingInt] = ConvPadding.VALID ) -> JaxArray: """Applies max pooling using a square 2D filter. Args: x: input tensor of shape (N, C, H, W). size: size of pooling filter. strides: stride step, use size when stride is none (default). padding: padding of the input tensor, either Padding.SAME or Padding.VALID or numerical values. Returns: output tensor of shape (N, C, H, W). """ size = to_tuple(size, 2) strides = to_tuple(strides, 2) if strides else size padding = to_padding(padding, 2) if isinstance(padding, tuple): padding = ((0, 0), (0, 0)) + padding return lax.reduce_window(x, -jn.inf, lax.max, (1, 1) + size, (1, 1) + strides, padding=padding)
def __init__(self, nin: int, nout: int, k: Union[Tuple[int, int], int], strides: Union[Tuple[int, int], int] = 1, dilations: Union[Tuple[int, int], int] = 1, padding: Union[ConvPadding, str, ConvPaddingInt] = ConvPadding.SAME, use_bias: bool = True, w_init: Callable = kaiming_normal): """Creates a ConvTranspose2D module instance. Args: nin: number of channels of the input tensor. nout: number of channels of the output tensor. k: size of the convolution kernel, either tuple (height, width) or single number if they're the same. strides: convolution strides, either tuple (stride_y, stride_x) or single number if they're the same. dilations: spacing between kernel points (also known as astrous convolution), either tuple (dilation_y, dilation_x) or single number if they're the same. padding: padding of the input tensor, either Padding.SAME or Padding.VALID. use_bias: if True then convolution will have bias term. w_init: initializer for convolution kernel (a function that takes in a HWIO shape and returns a 4D matrix). """ super().__init__(nin=nout, nout=nin, k=k, strides=strides, padding=padding, use_bias=False, w_init=w_init) self.b = TrainVar(jn.zeros((nout, 1, 1))) if use_bias else None self.dilations = util.to_tuple(dilations, 2)
def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[Tuple[int, int], int], stride: Union[Tuple[int, int], int] = 1, padding: Union[str, Tuple[int, int], int] = 0, dilation: Union[Tuple[int, int], int] = 1, groups: int = 1, bias: bool = False, kernel_init: Callable = kaiming_normal, bias_init: Callable = jnp.zeros, ): """Creates a Conv2D module instance. Args: in_channels: number of channels of the input tensor. out_channels: number of channels of the output tensor. kernel_size: size of the convolution kernel, either tuple (height, width) or single number if they're the same. stride: convolution strides, either tuple (stride_y, stride_x) or single number if they're the same. dilation: spacing between kernel points (also known as astrous convolution), either tuple (dilation_y, dilation_x) or single number if they're the same. groups: number of input and output channels group. When groups > 1 convolution operation is applied individually for each group. nin and nout must both be divisible by groups. padding: padding of the input tensor, either Padding.SAME or Padding.VALID. bias: if True then convolution will have bias term. kernel_init: initializer for convolution kernel (a function that takes in a HWIO shape and returns a 4D matrix). """ super().__init__() assert in_channels % groups == 0, 'in_chs should be divisible by groups' assert out_channels % groups == 0, 'out_chs should be divisible by groups' kernel_size = util.to_tuple(kernel_size, 2) self.weight = TrainVar(kernel_init((out_channels, in_channels // groups, *kernel_size))) # OIHW self.bias = TrainVar(bias_init((out_channels,))) if bias else None self.strides = util.to_tuple(stride, 2) self.dilations = util.to_tuple(dilation, 2) if isinstance(padding, str): if padding == 'LIKE': padding = ( get_like_padding(kernel_size[0], self.strides[0], self.dilations[0]), get_like_padding(kernel_size[1], self.strides[1], self.dilations[1])) padding = [padding, padding] else: padding = util.to_tuple(padding, 2) padding = [padding, padding] self.padding = padding self.groups = groups
def average_pool_2d(x: JaxArray, size: Union[Tuple[int, int], int] = 2, strides: Optional[Union[Tuple[int, int], int]] = None, padding: ConvPadding = ConvPadding.VALID) -> JaxArray: """Applies average pooling using a square 2D filter. Args: x: input tensor of shape (N, C, H, W). size: size of pooling filter. strides: stride step, use size when stride is none (default). padding: type of padding used in pooling operation. Returns: output tensor of shape (N, C, H, W). """ size = to_tuple(size, 2) strides = to_tuple(strides, 2) if strides else size return lax.reduce_window( x, 0, lax.add, (1, 1) + size, (1, 1) + strides, padding=padding.value) / np.prod(size)
def max_pool_2d(x: JaxArray, size: Union[Tuple[int, int], int] = 2, strides: Union[Tuple[int, int], int] = 2, padding: ConvPadding = ConvPadding.VALID) -> JaxArray: """Applies max pooling using a square 2D filter. Args: x: input tensor of shape (N, C, H, W). size: size of pooling filter. strides: stride step. padding: type of padding used in pooling operation. Returns: output tensor of shape (N, C, H, W). """ size = to_tuple(size, 2) strides = to_tuple(strides, 2) return lax.reduce_window(x, -jn.inf, lax.max, (1, 1) + size, (1, 1) + strides, padding=padding.value)
def channel_to_space2d(x: JaxArray, size: Union[Tuple[int, int], int] = 2) -> JaxArray: """Transfer channel dimension C into spatial dimensions (H, W). Args: x: input tensor of shape (N, C, H, W). size: size of spatial area. Returns: output tensor of shape (N, C // (size[0] * size[1]), H * size[0], W * size[1]). """ size = to_tuple(size, 2) s = x.shape y = x.reshape((s[0], -1, size[0], size[1], s[2], s[3])) y = y.transpose((0, 1, 4, 2, 5, 3)) return y.reshape( (s[0], s[1] // (size[0] * size[1]), s[2] * size[0], s[3] * size[1]))
def space_to_channel2d(x: JaxArray, size: Union[Tuple[int, int], int] = 2) -> JaxArray: """Transfer spatial dimensions (H, W) into channel dimension C. Args: x: input tensor of shape (N, C, H, W). size: size of spatial area. Returns: output tensor of shape (N, C * size[0] * size[1]), H // size[0], W // size[1]). """ size = to_tuple(size, 2) s = x.shape y = x.reshape( (s[0], s[1], s[2] // size[0], size[0], s[3] // size[1], size[1])) y = y.transpose((0, 1, 3, 5, 2, 4)) return y.reshape( (s[0], s[1] * size[0] * size[1], s[2] // size[0], s[3] // size[1]))
def upsample_2d(x: JaxArray, scale: Union[Tuple[int, int], int], method: Union[Interpolate, str] = Interpolate.BILINEAR) -> JaxArray: """Function to upscale 2D images. Args: x: input tensor. scale: int or tuple scaling factor method: str or UpSample interpolation methods e.g. ['bilinear', 'nearest']. Returns: upscaled 2d image tensor """ s = x.shape assert len(s) == 4, f'{s} must have 4 dimensions to be upsampled, or you can try interpolate function.' scale = util.to_tuple(scale, 2) y = jax.image.resize(x.transpose([0, 2, 3, 1]), shape=(s[0], s[2] * scale[0], s[3] * scale[1], s[1]), method=util.to_interpolate(method)) return y.transpose([0, 3, 1, 2])