예제 #1
0
def onnx_avgpool(
    x,
    kernel_shape,
    pads=None,
    strides=None,
    auto_pad='NOTSET',
    ceil_mode=0,
    count_include_pad=0,
):
    if ceil_mode != 0:
        raise NotImplemented('ceil_mode != 0')

    dims = (1, ) * (x.ndim - len(kernel_shape)) + tuple(kernel_shape)
    strides = (((1, ) * (x.ndim - len(strides)) +
                tuple(strides)) if strides else (1, ) * x.ndim)

    if auto_pad == "NOTSET":
        pads = pad_helper(x.ndim, pads) if pads else 'VALID'
    elif auto_pad == "SAME_UPPER":
        pads = "SAME"
    elif auto_pad == "VALID":
        pads = "VALID"
    elif auto_pad == "SAME_LOWER":
        raise NotImplemented("AveragePool with auto_pad `SAME_LOWER`")
    else:
        raise ValueError(f"Invalid auto_pad attribute: {auto_pad}")

    if count_include_pad == 0:
        one = jnp.ones_like(x, dtype=x.dtype)
        wsizes = lax.reduce_window(one, 0.0, lax.add, dims, strides, pads)
    else:
        wsizes = np.prod(kernel_shape)

    return lax.reduce_window(x, 0.0, lax.add, dims, strides, pads, None,
                             None) / wsizes
예제 #2
0
def avg_pool(value, window_shape, strides, padding):
    """Average pool.

  Args:
    value: Value to pool.
    window_shape: Shape of window to pool over. Same rank as value.
    strides: Strides for the window. Same rank as value.
    padding: Padding algorithm. Either "VALID" or "SAME".

  Returns:
    Pooled result. Same rank as value.

  Raises:
    ValueError: If the padding is not VALID.
  """
    reduce_window_args = (0., lax.add, window_shape, strides, padding)
    pooled = lax.reduce_window(value, *reduce_window_args)
    if padding == "VALID":
        # Avoid the extra reduce_window.
        return pooled / jnp.prod(window_shape)
    else:
        # Count the number of valid entries at each input point, then use that for
        # computing average. Assumes that any two arrays of same shape will be
        # padded the same.
        # TODO(tycai): This mask is computed at runtime. Give option to bake it
        # in as a constant instead.
        window_counts = lax.reduce_window(jnp.ones_like(value),
                                          *reduce_window_args)
        assert pooled.shape == window_counts.shape
        return pooled / window_counts
예제 #3
0
 def avgpool_op(self, params: Tuple, inputs: DeviceArray):
     out = lax.reduce_window(inputs, 0.0, lax.add, self.pool_size,
                             self.strides, self.padding)
     ones = jnp.ones((1, inputs.shape[1], inputs.shape[2], 1),
                     dtype=inputs.dtype)
     window_sizes = lax.reduce_window(ones, 0.0, lax.add, self.pool_size,
                                      self.strides, self.padding)
     return lax.div(out, window_sizes)
예제 #4
0
 def _call_batched(self, x):
     info = self.info
     one = np.ones(x.shape[1:-1], dtype=x.dtype)
     window_strides = info.strides[1:-1]
     window_sizes = lax.reduce_window(one, 0., lax.add, info.window_shape,
                                      window_strides, info.padding)
     outputs = lax.reduce_window(x, 0., lax.add, info.dims, info.strides,
                                 info.padding)
     return outputs / window_sizes[..., np.newaxis]
예제 #5
0
파일: pooling.py 프로젝트: vballoli/flax
def pool(inputs, init, reduce_fn, window_shape, strides, padding):
    """Helper function to define pooling functions.

  Pooling functions are implemented using the ReduceWindow XLA op.
  NOTE: Be aware that pooling is not generally differentiable.
  That means providing a reduce_fn that is differentiable does not imply
  that pool is differentiable.

  Args:
    inputs: input data with dimensions (batch, window dims..., features).
    init: the initial value for the reduction
    reduce_fn: a reduce function of the form `(T, T) -> T`.
    window_shape: a shape tuple defining the window to reduce over.
    strides: a sequence of `n` integers, representing the inter-window
        strides.
    padding: either the string `'SAME'`, the string `'VALID'`, or a sequence
      of `n` `(low, high)` integer pairs that give the padding to apply before
      and after each spatial dimension.
  Returns:
    The output of the reduction for each window slice.
  """
    strides = strides or (1, ) * len(window_shape)
    strides = (1, ) + strides + (1, )
    dims = (1, ) + window_shape + (1, )
    return lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
예제 #6
0
def max_pool(
    value: jnp.ndarray,
    window_shape: ShapeLike,
    strides: ShapeLike,
    padding: Text,
    channel_axis: Optional[int] = -1,
) -> jnp.ndarray:
  """Max pool.

  Args:
    value: Value to pool.
    window_shape: Shape of the pooling window, an int or same rank as value.
    strides: Strides of the pooling window, an int or same rank as value.
    padding: Padding algorithm. Either "VALID" or "SAME".
    channel_axis: Axis of the spatial channels for which pooling is skipped,
      used to infer `window_shape` or `strides` if they are an integer.

  Returns:
    Pooled result. Same rank as value.
  """
  if padding not in ["SAME", "VALID"]:
    raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")

  window_shape = _infer_shape(value, window_shape, channel_axis)
  strides = _infer_shape(value, strides, channel_axis)

  return lax.reduce_window(value, -jnp.inf, lax.max, window_shape, strides,
                           padding)
예제 #7
0
파일: onnx2xla.py 프로젝트: yashk2810/jax
def onnx_maxpool(x, kernel_shape, pads=None, strides=None):
  """Numpy-backed implementation of ONNX MaxPool op."""
  prefix = (1,) * (x.ndim - len(kernel_shape))
  dims = prefix + tuple(kernel_shape)
  pads = tuple(pads) if pads else [0] * len(kernel_shape)
  strides = (prefix + tuple(strides)) if strides else [1] * len(kernel_shape)
  return [lax.reduce_window(x, -jnp.inf, lax.max, dims, strides, 'VALID')]
예제 #8
0
def max_pool_2d(
    x: JaxArray,
    size: Union[Tuple[int, int], int] = 2,
    strides: Optional[Union[Tuple[int, int], int]] = None,
    padding: Union[ConvPadding, str, ConvPaddingInt] = ConvPadding.VALID
) -> JaxArray:
    """Applies max pooling using a square 2D filter.

    Args:
        x: input tensor of shape (N, C, H, W).
        size: size of pooling filter.
        strides: stride step, use size when stride is none (default).
        padding: padding of the input tensor, either Padding.SAME or Padding.VALID or numerical values.

    Returns:
        output tensor of shape (N, C, H, W).
    """
    size = to_tuple(size, 2)
    strides = to_tuple(strides, 2) if strides else size
    padding = to_padding(padding, 2)
    if isinstance(padding, tuple):
        padding = ((0, 0), (0, 0)) + padding
    return lax.reduce_window(x,
                             -jn.inf,
                             lax.max, (1, 1) + size, (1, 1) + strides,
                             padding=padding)
예제 #9
0
파일: pooling.py 프로젝트: ykumards/flax
def pool(inputs, init, reduce_fn, window_shape, strides, padding):
    """Helper function to define pooling functions.

  Pooling functions are implemented using the ReduceWindow XLA op.
  NOTE: Be aware that pooling is not generally differentiable.
  That means providing a reduce_fn that is differentiable does not imply
  that pool is differentiable.

  Args:
    inputs: input data with dimensions (batch, window dims..., features).
    init: the initial value for the reduction
    reduce_fn: a reduce function of the form `(T, T) -> T`.
    window_shape: a shape tuple defining the window to reduce over.
    strides: a sequence of `n` integers, representing the inter-window
        strides.
    padding: either the string `'SAME'`, the string `'VALID'`, or a sequence
      of `n` `(low, high)` integer pairs that give the padding to apply before
      and after each spatial dimension.
  Returns:
    The output of the reduction for each window slice.
  """
    strides = strides or (1, ) * len(window_shape)
    strides = (1, ) + strides + (1, )
    dims = (1, ) + window_shape + (1, )
    if not isinstance(padding, str):
        padding = tuple(map(tuple, padding))
        assert (len(padding) == len(window_shape)), (
            f"padding {padding} must specify pads for same number of dims as "
            f"window_shape {window_shape}")
        assert (all([len(x) == 2 for x in padding
                     ])), (f"each entry in padding {padding} must be length 2")
        padding = ((0, 0), ) + padding + ((0, 0), )
    return lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
예제 #10
0
파일: pool.py 프로젝트: sooheon/elegy
def max_pool(
    value: jnp.ndarray,
    window_shape: Union[int, Sequence[int]],
    strides: Union[int, Sequence[int]],
    padding: str,
    channel_axis: Optional[int] = -1,
) -> jnp.ndarray:
    """Max pool.

    Args:
      value: Value to pool.
      window_shape: Shape of the pooling window, an int or same rank as value.
      strides: Strides of the pooling window, an int or same rank as value.
      padding: Padding algorithm. Either ``VALID`` or ``SAME``.
      channel_axis: Axis of the spatial channels for which pooling is skipped,
        used to infer ``window_shape`` or ``strides`` if they are an integer.

    Returns:
      Pooled result. Same rank as value.
    """
    if padding not in ("SAME", "VALID"):
        raise ValueError(
            f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")

    _warn_if_unsafe(window_shape, strides)
    window_shape = _infer_shape(value, window_shape, channel_axis)
    strides = _infer_shape(value, strides, channel_axis)

    return lax.reduce_window(value, -jnp.inf, lax.max, window_shape, strides,
                             padding)
예제 #11
0
def onnx_maxpool(
    x,
    kernel_shape,
    pads=None,
    strides=None,
    dilations=None,
    auto_pad='NOTSET',
    ceil_mode=0,
    storage_order=0,
):
    dims = (1, ) * (x.ndim - len(kernel_shape)) + tuple(kernel_shape)
    strides = (((1, ) * (x.ndim - len(strides)) +
                tuple(strides)) if strides else (1, ) * x.ndim)
    dilations = (((1, ) * (x.ndim - len(dilations)) +
                  tuple(dilations)) if dilations else (1, ) * x.ndim)

    if auto_pad == "NOTSET":
        pads = pad_helper(x.ndim, pads) if pads else 'VALID'
    elif auto_pad == "SAME_UPPER":
        pads = "SAME"
    elif auto_pad == "VALID":
        pads = "VALID"
    elif auto_pad == "SAME_LOWER":
        raise NotImplemented("MaxPool with auto_pad `SAME_LOWER`")
    else:
        raise ValueError(f"Invalid auto_pad attribute: {auto_pad}")

    return lax.reduce_window(x, -jnp.inf, lax.max, dims, strides, pads, None,
                             dilations)
예제 #12
0
 def f(params, x):
   one = (1, 1)
   dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
   y = lax.conv_general_dilated(
       x, params, one, 'SAME', one, one, dimension_numbers)
   y = lax.reduce_window(
       y, 0.0, lax.add, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME')
   return y
예제 #13
0
파일: pool.py 프로젝트: vballoli/dm-haiku
def avg_pool(
    value: jnp.ndarray,
    window_shape: ShapeLike,
    strides: ShapeLike,
    padding: str,
    channel_axis: Optional[int] = -1,
) -> jnp.ndarray:
    """Average pool.

  Args:
    value: Value to pool.
    window_shape: Shape of the pooling window, an int or same rank as value.
    strides: Strides of the pooling window, an int or same rank as value.
    padding: Padding algorithm. Either "VALID" or "SAME".
    channel_axis: Axis of the spatial channels for which pooling is skipped,
      used to infer `window_shape` or `strides` if they are an integer.

  Returns:
    Pooled result. Same rank as value.

  Raises:
    ValueError: If the padding is not VALID.
  """
    if padding not in ["SAME", "VALID"]:
        raise ValueError(
            f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")

    window_shape = _infer_shape(value, window_shape, channel_axis)
    strides = _infer_shape(value, strides, channel_axis)

    reduce_window_args = (0., lax.add, window_shape, strides, padding)
    pooled = lax.reduce_window(value, *reduce_window_args)
    if padding == "VALID":
        # Avoid the extra reduce_window.
        return pooled / jnp.prod(window_shape)
    else:
        # Count the number of valid entries at each input point, then use that for
        # computing average. Assumes that any two arrays of same shape will be
        # padded the same.
        # TODO(tycai): This mask is computed at runtime. Give option to bake it
        # in as a constant instead.
        window_counts = lax.reduce_window(jnp.ones_like(value),
                                          *reduce_window_args)
        assert pooled.shape == window_counts.shape
        return pooled / window_counts
예제 #14
0
파일: jax.py 프로젝트: ulrikSebastienR/trax
def _pooling_general(inputs, reducer, init_val, rescaler=None,
                     pool_size=(2, 2), strides=None, padding='VALID'):
  """Helper: general pooling computation used in pooling layers later."""
  spatial_strides = strides or (1,) * len(pool_size)
  rescale = rescaler(pool_size, spatial_strides, padding) if rescaler else None
  dims = (1,) + pool_size + (1,)  # NHWC
  strides = (1,) + spatial_strides + (1,)
  out = lax.reduce_window(inputs, init_val, reducer, dims, strides, padding)
  return rescale(out, inputs) if rescale else out  # pylint: disable=not-callable
예제 #15
0
파일: pool.py 프로젝트: sooheon/elegy
def avg_pool(
    value: jnp.ndarray,
    window_shape: Union[int, Sequence[int]],
    strides: Union[int, Sequence[int]],
    padding: str,
    channel_axis: Optional[int] = -1,
) -> jnp.ndarray:
    """Average pool.

    Args:
      value: Value to pool.
      window_shape: Shape of the pooling window, an int or same rank as value.
      strides: Strides of the pooling window, an int or same rank as value.
      padding: Padding algorithm. Either ``VALID`` or ``SAME``.
      channel_axis: Axis of the spatial channels for which pooling is skipped,
        used to infer ``window_shape`` or ``strides`` if they are an integer.

    Returns:
      Pooled result. Same rank as value.

    Raises:
      ValueError: If the padding is not valid.
    """
    if padding not in ("SAME", "VALID"):
        raise ValueError(
            f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")

    _warn_if_unsafe(window_shape, strides)
    window_shape = _infer_shape(value, window_shape, channel_axis)
    strides = _infer_shape(value, strides, channel_axis)

    reduce_window_args = (0.0, lax.add, window_shape, strides, padding)
    pooled = lax.reduce_window(value, *reduce_window_args)
    if padding == "VALID":
        # Avoid the extra reduce_window.
        return pooled / jnp.prod(window_shape)
    else:
        # Count the number of valid entries at each input point, then use that for
        # computing average. Assumes that any two arrays of same shape will be
        # padded the same.
        window_counts = lax.reduce_window(jnp.ones_like(value),
                                          *reduce_window_args)
        assert pooled.shape == window_counts.shape
        return pooled / window_counts
예제 #16
0
def _average_pool_nngp_6d(nngp, window_shape, strides, padding):
    """Get covariances of average pooling outputs given inputs covariances `nngp`.

  Args:
    nngp: a 6D `np.ndarray` containing sample-sample-pixel-pixel covariances.
      Has shape `[batch_size_1, batch_size_2, height, height, width, width]`.
    window_shape: tuple of two positive integers, the pooling spatial shape
      (e.g. `(3, 3)`).
    strides: tuple of two positive integers, the pooling strides, e.g. `(1, 1)`.
    padding: a `Padding` enum, e.g. `Padding.CIRCULAR`.

  Returns:
    a 6D `np.ndarray` containing sample-sample-pixel-pixel covariances of the
      average pooling outputs. Has shape `[batch_size_1, new_height, new_width,
                                          batch_size_2, new_height, new_width]`.
  """
    if not _is_array(nngp):
        return nngp

    if padding == Padding.CIRCULAR:
        pixel_axes = tuple(range(2, nngp.ndim))
        nngp = _same_pad_for_filter_shape(nngp, _double_tuple(window_shape),
                                          _double_tuple(strides), pixel_axes,
                                          'wrap')
        padding = Padding.VALID

    window_shape = _double_tuple((1, ) + window_shape)
    strides = _double_tuple((1, ) + strides)

    nngp_out = lax.reduce_window(nngp, 0., lax.add, window_shape, strides,
                                 padding.value)

    if padding == Padding.SAME:
        # `SAME` padding in `jax.experimental.stax.AvgPool` normalizes by actual
        # window size, which is smaller at the edges.
        one = np.ones(nngp.shape, nngp.dtype)
        window_sizes = lax.reduce_window(one, 0., lax.add, window_shape,
                                         strides, padding.value)
        nngp_out /= window_sizes
    else:
        nngp_out /= np.prod(window_shape)

    return nngp_out
예제 #17
0
파일: stax.py 프로젝트: devopsotrator/jax
    def rescale(outputs, inputs, spec):
        non_spatial_axes = spec.index('N'), spec.index('C')
        spatial_shape = tuple(inputs.shape[i] for i in range(inputs.ndim)
                              if i not in non_spatial_axes)
        one = np.ones(spatial_shape, dtype=inputs.dtype)
        window_sizes = lax.reduce_window(one, 0., lax.add, dims, strides,
                                         padding)
        for i in sorted(non_spatial_axes):
            window_sizes = np.expand_dims(window_sizes, i)

        return outputs / window_sizes
예제 #18
0
def poolNd(
    input,
    window_shape,
    reducer="MAX",
    strides=None,
    padding="VALID",
    init_val=None,
    rescalor=None,
):

    # set up the reducer
    if reducer == "MAX":
        reducer = jla.max
        rescalor = numpy.float32(1.0)
        init_val = -numpy.inf
    elif reducer == "SUM" or reducer == "AVG":
        reducer = jla.add
        if reducer == "AVG":
            rescalor = numpy.float32(1.0 / numpy.prod(window_shape))
        else:
            rescalor = numpy.float32(1.0)
        if init_val is None:
            init_val = 0.0

    # set up the window_shape
    if numpy.isscalar(window_shape):
        window_shape = (window_shape, ) * input.ndim
    elif len(window_shape) != input.ndim:
        msg = "Given window_shape {} not the same length ".format(
            strides) + "as input shape {}".format(input.ndim)
        raise ValueError(msg)

    # set up the strides
    if strides is None:
        strides = window_shape
    elif numpy.isscalar(strides):
        strides = (strides, ) * len(window_shape)
    elif len(strides) != len(window_shape):
        msg = "Given strides {} not the same length ".format(
            strides) + "as window_shape {}".format(window_shape)
        raise ValueError(msg)

    out = jla.reduce_window(
        operand=input * rescalor,
        init_value=init_val,
        computation=reducer,
        window_dimensions=window_shape,
        window_strides=strides,
        padding=padding,
    )
    return out
예제 #19
0
def max_pool(value, window_shape, strides, padding):
    """Max pool.

  Args:
    value: Value to pool.
    window_shape: Shape of window to pool over. Same rank as value.
    strides: Strides for the window. Same rank as value.
    padding: Padding algorithm. Either "VALID" or "SAME".

  Returns:
    Pooled result. Same rank as value.
  """
    return lax.reduce_window(value, -jnp.inf, lax.max, window_shape, strides,
                             padding)
예제 #20
0
        def normpool(x):
            norms = jnp.linalg.norm(x, axis=-1)
            idxs = jnp.arange(x.shape[0])

            def g(a, b):
                an, ai = a
                bn, bi = b
                which = an >= bn
                return (jnp.where(which, an, bn), jnp.where(which, ai, bi))

            _, idxs = lax.reduce_window((norms, idxs), (-np.inf, -1),
                                        g,
                                        window_dimensions=(2, ),
                                        window_strides=(2, ),
                                        padding=((0, 0), ))
            return x[idxs]
예제 #21
0
파일: pooling.py 프로젝트: spacexcorp/objax
def average_pool_2d(x: JaxArray,
                    size: Union[Tuple[int, int], int] = 2,
                    strides: Optional[Union[Tuple[int, int], int]] = None,
                    padding: ConvPadding = ConvPadding.VALID) -> JaxArray:
    """Applies average pooling using a square 2D filter.

    Args:
        x: input tensor of shape (N, C, H, W).
        size: size of pooling filter.
        strides: stride step, use size when stride is none (default).
        padding: type of padding used in pooling operation.

    Returns:
        output tensor of shape (N, C, H, W).
    """
    size = to_tuple(size, 2)
    strides = to_tuple(strides, 2) if strides else size
    return lax.reduce_window(
        x, 0, lax.add, (1, 1) + size,
        (1, 1) + strides, padding=padding.value) / np.prod(size)
예제 #22
0
파일: pooling.py 프로젝트: rwightman/objax
def max_pool_2d(x: JaxArray,
                size: Union[Tuple[int, int], int] = 2,
                strides: Union[Tuple[int, int], int] = 2,
                padding: ConvPadding = ConvPadding.VALID) -> JaxArray:
    """Applies max pooling using a square 2D filter.

    Args:
        x: input tensor of shape (N, C, H, W).
        size: size of pooling filter.
        strides: stride step.
        padding: type of padding used in pooling operation.

    Returns:
        output tensor of shape (N, C, H, W).
    """
    size = to_tuple(size, 2)
    strides = to_tuple(strides, 2)
    return lax.reduce_window(x,
                             -jn.inf,
                             lax.max, (1, 1) + size, (1, 1) + strides,
                             padding=padding.value)
예제 #23
0
 def apply_fun(params, inputs, rng=None):
     out = lax.reduce_window(inputs, init_val, reducer, dims, strides,
                             padding)
     return rescale(out, inputs) if rescale else out
예제 #24
0
 def apply_fun(params, inputs, **kwargs):
     out = lax.reduce_window(inputs, init_val, reducer, window_shape,
                             strides, padding)
     return rescale(out, inputs, spec) if rescale else out
예제 #25
0
 def fun(operand):
   return lax.reduce_window(operand, init_val, op, dims, strides, padding)
예제 #26
0
 def _call_batched(self, x):
     info = self.info
     return lax.reduce_window(x, -np.inf, lax.max, info.dims, info.strides,
                              info.padding)
예제 #27
0
파일: jax.py 프로젝트: ulrikSebastienR/trax
 def rescale(outputs, inputs):
   one = jnp.ones(inputs.shape[1:-1], dtype=inputs.dtype)
   window_sizes = lax.reduce_window(
       one, 0., lax.add, dims, spatial_strides, padding)
   return outputs / window_sizes[..., jnp.newaxis]
예제 #28
0
def pool(
    input,
    window_shape,
    reducer="MAX",
    strides=None,
    padding="VALID",
    base_dilation=None,
    window_dilation=None,
):
    """apply arbitrary pooling on a Tensor.

    Parameters
    ----------

    input: Tensor
        the input to pool over

    window_shape: tuple of int
        the shape of the pooling operator, it must have same length than the
        number of dimension in `input`

    reducer: str
        the type of pooling to apply, must be one of `MAX`, `AVG`, `SUM`

    strides: tuple of int
        the stride of the pooling operator

    padding: str
        the type of padding, must be `VALID`, `SAME` or `FULL`

    base_dilation: tuple (optional)
        dilations of the input. This corresponds to the number of `0` inserted
        in between the values in the input

    window_dilation: tuple (optional)
        dilations of the window, this corresponds to the number
        of `0`s inserted in the window

    Returns
    -------

    pooled Tensor

    """

    # set up the reducer
    if reducer == "MAX":
        reduce = jla.max
        init_val = -jnp.inf
    elif reducer == "SUM" or reducer == "AVG":
        reduce = jla.add
        init_val = 0.0
    # set up the window_shape
    if len(window_shape) != input.ndim:
        msg = "Given window_shape {} not the same length ".format(
            strides
        ) + "as input shape {}".format(input.ndim)
        raise ValueError(msg)

    # set up the strides
    if strides is None:
        strides = window_shape
    elif len(strides) != len(window_shape):
        msg = "Given strides {} not the same length ".format(
            strides
        ) + "as window_shape {}".format(window_shape)
        raise ValueError(msg)

    if reducer == "AVG":
        input = input / numpy.prod(window_shape)

    out = jla.reduce_window(
        operand=input,
        init_value=init_val,
        computation=reduce,
        window_dimensions=window_shape,
        window_strides=strides,
        padding=padding,
        base_dilation=base_dilation,
        window_dilation=window_dilation,
    )

    return out
예제 #29
0
파일: lax_vmap_test.py 프로젝트: x1489/jax
 def fun(operand):
     return lax.reduce_window(operand, init_val, op, dims, strides,
                              padding, base_dilation, window_dilation)
예제 #30
0
 def pool(inputs):
     out = lax.reduce_window(inputs, init_val, reducer, dims, strides, padding)
     return rescale(out, inputs) if rescale else out