def onnx_avgpool( x, kernel_shape, pads=None, strides=None, auto_pad='NOTSET', ceil_mode=0, count_include_pad=0, ): if ceil_mode != 0: raise NotImplemented('ceil_mode != 0') dims = (1, ) * (x.ndim - len(kernel_shape)) + tuple(kernel_shape) strides = (((1, ) * (x.ndim - len(strides)) + tuple(strides)) if strides else (1, ) * x.ndim) if auto_pad == "NOTSET": pads = pad_helper(x.ndim, pads) if pads else 'VALID' elif auto_pad == "SAME_UPPER": pads = "SAME" elif auto_pad == "VALID": pads = "VALID" elif auto_pad == "SAME_LOWER": raise NotImplemented("AveragePool with auto_pad `SAME_LOWER`") else: raise ValueError(f"Invalid auto_pad attribute: {auto_pad}") if count_include_pad == 0: one = jnp.ones_like(x, dtype=x.dtype) wsizes = lax.reduce_window(one, 0.0, lax.add, dims, strides, pads) else: wsizes = np.prod(kernel_shape) return lax.reduce_window(x, 0.0, lax.add, dims, strides, pads, None, None) / wsizes
def avg_pool(value, window_shape, strides, padding): """Average pool. Args: value: Value to pool. window_shape: Shape of window to pool over. Same rank as value. strides: Strides for the window. Same rank as value. padding: Padding algorithm. Either "VALID" or "SAME". Returns: Pooled result. Same rank as value. Raises: ValueError: If the padding is not VALID. """ reduce_window_args = (0., lax.add, window_shape, strides, padding) pooled = lax.reduce_window(value, *reduce_window_args) if padding == "VALID": # Avoid the extra reduce_window. return pooled / jnp.prod(window_shape) else: # Count the number of valid entries at each input point, then use that for # computing average. Assumes that any two arrays of same shape will be # padded the same. # TODO(tycai): This mask is computed at runtime. Give option to bake it # in as a constant instead. window_counts = lax.reduce_window(jnp.ones_like(value), *reduce_window_args) assert pooled.shape == window_counts.shape return pooled / window_counts
def avgpool_op(self, params: Tuple, inputs: DeviceArray): out = lax.reduce_window(inputs, 0.0, lax.add, self.pool_size, self.strides, self.padding) ones = jnp.ones((1, inputs.shape[1], inputs.shape[2], 1), dtype=inputs.dtype) window_sizes = lax.reduce_window(ones, 0.0, lax.add, self.pool_size, self.strides, self.padding) return lax.div(out, window_sizes)
def _call_batched(self, x): info = self.info one = np.ones(x.shape[1:-1], dtype=x.dtype) window_strides = info.strides[1:-1] window_sizes = lax.reduce_window(one, 0., lax.add, info.window_shape, window_strides, info.padding) outputs = lax.reduce_window(x, 0., lax.add, info.dims, info.strides, info.padding) return outputs / window_sizes[..., np.newaxis]
def pool(inputs, init, reduce_fn, window_shape, strides, padding): """Helper function to define pooling functions. Pooling functions are implemented using the ReduceWindow XLA op. NOTE: Be aware that pooling is not generally differentiable. That means providing a reduce_fn that is differentiable does not imply that pool is differentiable. Args: inputs: input data with dimensions (batch, window dims..., features). init: the initial value for the reduction reduce_fn: a reduce function of the form `(T, T) -> T`. window_shape: a shape tuple defining the window to reduce over. strides: a sequence of `n` integers, representing the inter-window strides. padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of `n` `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. Returns: The output of the reduction for each window slice. """ strides = strides or (1, ) * len(window_shape) strides = (1, ) + strides + (1, ) dims = (1, ) + window_shape + (1, ) return lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
def max_pool( value: jnp.ndarray, window_shape: ShapeLike, strides: ShapeLike, padding: Text, channel_axis: Optional[int] = -1, ) -> jnp.ndarray: """Max pool. Args: value: Value to pool. window_shape: Shape of the pooling window, an int or same rank as value. strides: Strides of the pooling window, an int or same rank as value. padding: Padding algorithm. Either "VALID" or "SAME". channel_axis: Axis of the spatial channels for which pooling is skipped, used to infer `window_shape` or `strides` if they are an integer. Returns: Pooled result. Same rank as value. """ if padding not in ["SAME", "VALID"]: raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.") window_shape = _infer_shape(value, window_shape, channel_axis) strides = _infer_shape(value, strides, channel_axis) return lax.reduce_window(value, -jnp.inf, lax.max, window_shape, strides, padding)
def onnx_maxpool(x, kernel_shape, pads=None, strides=None): """Numpy-backed implementation of ONNX MaxPool op.""" prefix = (1,) * (x.ndim - len(kernel_shape)) dims = prefix + tuple(kernel_shape) pads = tuple(pads) if pads else [0] * len(kernel_shape) strides = (prefix + tuple(strides)) if strides else [1] * len(kernel_shape) return [lax.reduce_window(x, -jnp.inf, lax.max, dims, strides, 'VALID')]
def max_pool_2d( x: JaxArray, size: Union[Tuple[int, int], int] = 2, strides: Optional[Union[Tuple[int, int], int]] = None, padding: Union[ConvPadding, str, ConvPaddingInt] = ConvPadding.VALID ) -> JaxArray: """Applies max pooling using a square 2D filter. Args: x: input tensor of shape (N, C, H, W). size: size of pooling filter. strides: stride step, use size when stride is none (default). padding: padding of the input tensor, either Padding.SAME or Padding.VALID or numerical values. Returns: output tensor of shape (N, C, H, W). """ size = to_tuple(size, 2) strides = to_tuple(strides, 2) if strides else size padding = to_padding(padding, 2) if isinstance(padding, tuple): padding = ((0, 0), (0, 0)) + padding return lax.reduce_window(x, -jn.inf, lax.max, (1, 1) + size, (1, 1) + strides, padding=padding)
def pool(inputs, init, reduce_fn, window_shape, strides, padding): """Helper function to define pooling functions. Pooling functions are implemented using the ReduceWindow XLA op. NOTE: Be aware that pooling is not generally differentiable. That means providing a reduce_fn that is differentiable does not imply that pool is differentiable. Args: inputs: input data with dimensions (batch, window dims..., features). init: the initial value for the reduction reduce_fn: a reduce function of the form `(T, T) -> T`. window_shape: a shape tuple defining the window to reduce over. strides: a sequence of `n` integers, representing the inter-window strides. padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of `n` `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. Returns: The output of the reduction for each window slice. """ strides = strides or (1, ) * len(window_shape) strides = (1, ) + strides + (1, ) dims = (1, ) + window_shape + (1, ) if not isinstance(padding, str): padding = tuple(map(tuple, padding)) assert (len(padding) == len(window_shape)), ( f"padding {padding} must specify pads for same number of dims as " f"window_shape {window_shape}") assert (all([len(x) == 2 for x in padding ])), (f"each entry in padding {padding} must be length 2") padding = ((0, 0), ) + padding + ((0, 0), ) return lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
def max_pool( value: jnp.ndarray, window_shape: Union[int, Sequence[int]], strides: Union[int, Sequence[int]], padding: str, channel_axis: Optional[int] = -1, ) -> jnp.ndarray: """Max pool. Args: value: Value to pool. window_shape: Shape of the pooling window, an int or same rank as value. strides: Strides of the pooling window, an int or same rank as value. padding: Padding algorithm. Either ``VALID`` or ``SAME``. channel_axis: Axis of the spatial channels for which pooling is skipped, used to infer ``window_shape`` or ``strides`` if they are an integer. Returns: Pooled result. Same rank as value. """ if padding not in ("SAME", "VALID"): raise ValueError( f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.") _warn_if_unsafe(window_shape, strides) window_shape = _infer_shape(value, window_shape, channel_axis) strides = _infer_shape(value, strides, channel_axis) return lax.reduce_window(value, -jnp.inf, lax.max, window_shape, strides, padding)
def onnx_maxpool( x, kernel_shape, pads=None, strides=None, dilations=None, auto_pad='NOTSET', ceil_mode=0, storage_order=0, ): dims = (1, ) * (x.ndim - len(kernel_shape)) + tuple(kernel_shape) strides = (((1, ) * (x.ndim - len(strides)) + tuple(strides)) if strides else (1, ) * x.ndim) dilations = (((1, ) * (x.ndim - len(dilations)) + tuple(dilations)) if dilations else (1, ) * x.ndim) if auto_pad == "NOTSET": pads = pad_helper(x.ndim, pads) if pads else 'VALID' elif auto_pad == "SAME_UPPER": pads = "SAME" elif auto_pad == "VALID": pads = "VALID" elif auto_pad == "SAME_LOWER": raise NotImplemented("MaxPool with auto_pad `SAME_LOWER`") else: raise ValueError(f"Invalid auto_pad attribute: {auto_pad}") return lax.reduce_window(x, -jnp.inf, lax.max, dims, strides, pads, None, dilations)
def f(params, x): one = (1, 1) dimension_numbers = ('NHWC', 'HWIO', 'NHWC') y = lax.conv_general_dilated( x, params, one, 'SAME', one, one, dimension_numbers) y = lax.reduce_window( y, 0.0, lax.add, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME') return y
def avg_pool( value: jnp.ndarray, window_shape: ShapeLike, strides: ShapeLike, padding: str, channel_axis: Optional[int] = -1, ) -> jnp.ndarray: """Average pool. Args: value: Value to pool. window_shape: Shape of the pooling window, an int or same rank as value. strides: Strides of the pooling window, an int or same rank as value. padding: Padding algorithm. Either "VALID" or "SAME". channel_axis: Axis of the spatial channels for which pooling is skipped, used to infer `window_shape` or `strides` if they are an integer. Returns: Pooled result. Same rank as value. Raises: ValueError: If the padding is not VALID. """ if padding not in ["SAME", "VALID"]: raise ValueError( f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.") window_shape = _infer_shape(value, window_shape, channel_axis) strides = _infer_shape(value, strides, channel_axis) reduce_window_args = (0., lax.add, window_shape, strides, padding) pooled = lax.reduce_window(value, *reduce_window_args) if padding == "VALID": # Avoid the extra reduce_window. return pooled / jnp.prod(window_shape) else: # Count the number of valid entries at each input point, then use that for # computing average. Assumes that any two arrays of same shape will be # padded the same. # TODO(tycai): This mask is computed at runtime. Give option to bake it # in as a constant instead. window_counts = lax.reduce_window(jnp.ones_like(value), *reduce_window_args) assert pooled.shape == window_counts.shape return pooled / window_counts
def _pooling_general(inputs, reducer, init_val, rescaler=None, pool_size=(2, 2), strides=None, padding='VALID'): """Helper: general pooling computation used in pooling layers later.""" spatial_strides = strides or (1,) * len(pool_size) rescale = rescaler(pool_size, spatial_strides, padding) if rescaler else None dims = (1,) + pool_size + (1,) # NHWC strides = (1,) + spatial_strides + (1,) out = lax.reduce_window(inputs, init_val, reducer, dims, strides, padding) return rescale(out, inputs) if rescale else out # pylint: disable=not-callable
def avg_pool( value: jnp.ndarray, window_shape: Union[int, Sequence[int]], strides: Union[int, Sequence[int]], padding: str, channel_axis: Optional[int] = -1, ) -> jnp.ndarray: """Average pool. Args: value: Value to pool. window_shape: Shape of the pooling window, an int or same rank as value. strides: Strides of the pooling window, an int or same rank as value. padding: Padding algorithm. Either ``VALID`` or ``SAME``. channel_axis: Axis of the spatial channels for which pooling is skipped, used to infer ``window_shape`` or ``strides`` if they are an integer. Returns: Pooled result. Same rank as value. Raises: ValueError: If the padding is not valid. """ if padding not in ("SAME", "VALID"): raise ValueError( f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.") _warn_if_unsafe(window_shape, strides) window_shape = _infer_shape(value, window_shape, channel_axis) strides = _infer_shape(value, strides, channel_axis) reduce_window_args = (0.0, lax.add, window_shape, strides, padding) pooled = lax.reduce_window(value, *reduce_window_args) if padding == "VALID": # Avoid the extra reduce_window. return pooled / jnp.prod(window_shape) else: # Count the number of valid entries at each input point, then use that for # computing average. Assumes that any two arrays of same shape will be # padded the same. window_counts = lax.reduce_window(jnp.ones_like(value), *reduce_window_args) assert pooled.shape == window_counts.shape return pooled / window_counts
def _average_pool_nngp_6d(nngp, window_shape, strides, padding): """Get covariances of average pooling outputs given inputs covariances `nngp`. Args: nngp: a 6D `np.ndarray` containing sample-sample-pixel-pixel covariances. Has shape `[batch_size_1, batch_size_2, height, height, width, width]`. window_shape: tuple of two positive integers, the pooling spatial shape (e.g. `(3, 3)`). strides: tuple of two positive integers, the pooling strides, e.g. `(1, 1)`. padding: a `Padding` enum, e.g. `Padding.CIRCULAR`. Returns: a 6D `np.ndarray` containing sample-sample-pixel-pixel covariances of the average pooling outputs. Has shape `[batch_size_1, new_height, new_width, batch_size_2, new_height, new_width]`. """ if not _is_array(nngp): return nngp if padding == Padding.CIRCULAR: pixel_axes = tuple(range(2, nngp.ndim)) nngp = _same_pad_for_filter_shape(nngp, _double_tuple(window_shape), _double_tuple(strides), pixel_axes, 'wrap') padding = Padding.VALID window_shape = _double_tuple((1, ) + window_shape) strides = _double_tuple((1, ) + strides) nngp_out = lax.reduce_window(nngp, 0., lax.add, window_shape, strides, padding.value) if padding == Padding.SAME: # `SAME` padding in `jax.experimental.stax.AvgPool` normalizes by actual # window size, which is smaller at the edges. one = np.ones(nngp.shape, nngp.dtype) window_sizes = lax.reduce_window(one, 0., lax.add, window_shape, strides, padding.value) nngp_out /= window_sizes else: nngp_out /= np.prod(window_shape) return nngp_out
def rescale(outputs, inputs, spec): non_spatial_axes = spec.index('N'), spec.index('C') spatial_shape = tuple(inputs.shape[i] for i in range(inputs.ndim) if i not in non_spatial_axes) one = np.ones(spatial_shape, dtype=inputs.dtype) window_sizes = lax.reduce_window(one, 0., lax.add, dims, strides, padding) for i in sorted(non_spatial_axes): window_sizes = np.expand_dims(window_sizes, i) return outputs / window_sizes
def poolNd( input, window_shape, reducer="MAX", strides=None, padding="VALID", init_val=None, rescalor=None, ): # set up the reducer if reducer == "MAX": reducer = jla.max rescalor = numpy.float32(1.0) init_val = -numpy.inf elif reducer == "SUM" or reducer == "AVG": reducer = jla.add if reducer == "AVG": rescalor = numpy.float32(1.0 / numpy.prod(window_shape)) else: rescalor = numpy.float32(1.0) if init_val is None: init_val = 0.0 # set up the window_shape if numpy.isscalar(window_shape): window_shape = (window_shape, ) * input.ndim elif len(window_shape) != input.ndim: msg = "Given window_shape {} not the same length ".format( strides) + "as input shape {}".format(input.ndim) raise ValueError(msg) # set up the strides if strides is None: strides = window_shape elif numpy.isscalar(strides): strides = (strides, ) * len(window_shape) elif len(strides) != len(window_shape): msg = "Given strides {} not the same length ".format( strides) + "as window_shape {}".format(window_shape) raise ValueError(msg) out = jla.reduce_window( operand=input * rescalor, init_value=init_val, computation=reducer, window_dimensions=window_shape, window_strides=strides, padding=padding, ) return out
def max_pool(value, window_shape, strides, padding): """Max pool. Args: value: Value to pool. window_shape: Shape of window to pool over. Same rank as value. strides: Strides for the window. Same rank as value. padding: Padding algorithm. Either "VALID" or "SAME". Returns: Pooled result. Same rank as value. """ return lax.reduce_window(value, -jnp.inf, lax.max, window_shape, strides, padding)
def normpool(x): norms = jnp.linalg.norm(x, axis=-1) idxs = jnp.arange(x.shape[0]) def g(a, b): an, ai = a bn, bi = b which = an >= bn return (jnp.where(which, an, bn), jnp.where(which, ai, bi)) _, idxs = lax.reduce_window((norms, idxs), (-np.inf, -1), g, window_dimensions=(2, ), window_strides=(2, ), padding=((0, 0), )) return x[idxs]
def average_pool_2d(x: JaxArray, size: Union[Tuple[int, int], int] = 2, strides: Optional[Union[Tuple[int, int], int]] = None, padding: ConvPadding = ConvPadding.VALID) -> JaxArray: """Applies average pooling using a square 2D filter. Args: x: input tensor of shape (N, C, H, W). size: size of pooling filter. strides: stride step, use size when stride is none (default). padding: type of padding used in pooling operation. Returns: output tensor of shape (N, C, H, W). """ size = to_tuple(size, 2) strides = to_tuple(strides, 2) if strides else size return lax.reduce_window( x, 0, lax.add, (1, 1) + size, (1, 1) + strides, padding=padding.value) / np.prod(size)
def max_pool_2d(x: JaxArray, size: Union[Tuple[int, int], int] = 2, strides: Union[Tuple[int, int], int] = 2, padding: ConvPadding = ConvPadding.VALID) -> JaxArray: """Applies max pooling using a square 2D filter. Args: x: input tensor of shape (N, C, H, W). size: size of pooling filter. strides: stride step. padding: type of padding used in pooling operation. Returns: output tensor of shape (N, C, H, W). """ size = to_tuple(size, 2) strides = to_tuple(strides, 2) return lax.reduce_window(x, -jn.inf, lax.max, (1, 1) + size, (1, 1) + strides, padding=padding.value)
def apply_fun(params, inputs, rng=None): out = lax.reduce_window(inputs, init_val, reducer, dims, strides, padding) return rescale(out, inputs) if rescale else out
def apply_fun(params, inputs, **kwargs): out = lax.reduce_window(inputs, init_val, reducer, window_shape, strides, padding) return rescale(out, inputs, spec) if rescale else out
def fun(operand): return lax.reduce_window(operand, init_val, op, dims, strides, padding)
def _call_batched(self, x): info = self.info return lax.reduce_window(x, -np.inf, lax.max, info.dims, info.strides, info.padding)
def rescale(outputs, inputs): one = jnp.ones(inputs.shape[1:-1], dtype=inputs.dtype) window_sizes = lax.reduce_window( one, 0., lax.add, dims, spatial_strides, padding) return outputs / window_sizes[..., jnp.newaxis]
def pool( input, window_shape, reducer="MAX", strides=None, padding="VALID", base_dilation=None, window_dilation=None, ): """apply arbitrary pooling on a Tensor. Parameters ---------- input: Tensor the input to pool over window_shape: tuple of int the shape of the pooling operator, it must have same length than the number of dimension in `input` reducer: str the type of pooling to apply, must be one of `MAX`, `AVG`, `SUM` strides: tuple of int the stride of the pooling operator padding: str the type of padding, must be `VALID`, `SAME` or `FULL` base_dilation: tuple (optional) dilations of the input. This corresponds to the number of `0` inserted in between the values in the input window_dilation: tuple (optional) dilations of the window, this corresponds to the number of `0`s inserted in the window Returns ------- pooled Tensor """ # set up the reducer if reducer == "MAX": reduce = jla.max init_val = -jnp.inf elif reducer == "SUM" or reducer == "AVG": reduce = jla.add init_val = 0.0 # set up the window_shape if len(window_shape) != input.ndim: msg = "Given window_shape {} not the same length ".format( strides ) + "as input shape {}".format(input.ndim) raise ValueError(msg) # set up the strides if strides is None: strides = window_shape elif len(strides) != len(window_shape): msg = "Given strides {} not the same length ".format( strides ) + "as window_shape {}".format(window_shape) raise ValueError(msg) if reducer == "AVG": input = input / numpy.prod(window_shape) out = jla.reduce_window( operand=input, init_value=init_val, computation=reduce, window_dimensions=window_shape, window_strides=strides, padding=padding, base_dilation=base_dilation, window_dilation=window_dilation, ) return out
def fun(operand): return lax.reduce_window(operand, init_val, op, dims, strides, padding, base_dilation, window_dilation)
def pool(inputs): out = lax.reduce_window(inputs, init_val, reducer, dims, strides, padding) return rescale(out, inputs) if rescale else out