Exemplo n.º 1
0
def pads_to_padtype(in_shape, window_shape, window_strides, padding) -> str:
    for pad_str in ["VALID", "SAME"]:
        pads = lax.padtype_to_pads(in_shape, window_shape, window_strides,
                                   pad_str)
        if list(pads) == list(padding):
            return pad_str
    return "EXPLICIT"
Exemplo n.º 2
0
 def init_fun(rng, input_shape):
   padding_vals = lax.padtype_to_pads(input_shape, window_shape,
                                      strides, padding)
   ones = (1,) * len(window_shape)
   out_shape = lax.reduce_window_shape_tuple(
     input_shape, window_shape, strides, padding_vals, ones, ones)
   return out_shape, ()
Exemplo n.º 3
0
def _reduce_window(jax_f,
                   reducer,
                   init_val,
                   operand,
                   window_dimensions,
                   window_strides,
                   padding,
                   input_shape=None):
    """TensorFlow implementation of reduce_window_{sum,min,max}."""
    del input_shape
    # TODO(tomhennigan): tf2xla should have a shape inference function.
    out_shape = _reduce_window_shape(jax_f, operand, window_dimensions,
                                     window_strides, padding)
    padding = lax.padtype_to_pads(_get_shape_from_tensor_or_array(operand),
                                  window_dimensions, window_strides, padding)
    a = tf.constant(0, operand.dtype)
    reducer_fn = reducer.get_concrete_function(a, a)
    out = tfxla.reduce_window(operand,
                              tf.constant(init_val, operand.dtype),
                              reducer_fn,
                              window_dimensions,
                              window_strides,
                              padding=padding)
    out.set_shape(out_shape)
    return out
Exemplo n.º 4
0
def onnx_conv(x, w, b=0, group=1, kernel_shape=None, pads=None, strides=None,
              dilations=None, auto_pad=None):
  """Numpy-backed implementation of ONNX Conv op."""
  assert group == 1
  kernel_shape = kernel_shape or w.shape
  strides = strides or [1] * (w.ndim - 2)
  if auto_pad:
    auto_pad = 'SAME' if auto_pad.startswith(b'SAME') else 'VALID'
    pads = lax.padtype_to_pads(x.shape[2:], w.shape[2:], strides, auto_pad)
  else:
    pads = pads or [0] * (w.ndim - 2)
  lhs_dilation = [1] * (w.ndim - 2)
  rhs_dilation = dilations or [1] * (w.ndim - 2)
  return [lax.conv_with_general_padding(x, w, strides, pads,
                                        lhs_dilation, rhs_dilation) + b]
Exemplo n.º 5
0
 def spec(cls, in_spec, window_shape, strides=None, padding='VALID'):
     in_shape = in_spec.shape
     if len(in_shape) > 3:
         raise ValueError('Need to `jax.vmap` in order to batch')
     in_shape = (1, ) + in_shape
     dims = (1, ) + window_shape + (1, )  # NHWC or NHC
     non_spatial_axes = 0, len(window_shape) + 1
     strides = strides or (1, ) * len(window_shape)
     for i in sorted(non_spatial_axes):
         window_shape = window_shape[:i] + (1, ) + window_shape[i:]
         strides = strides[:i] + (1, ) + strides[i:]
     padding = lax.padtype_to_pads(in_shape, window_shape, strides, padding)
     out_shape = lax.reduce_window_shape_tuple(in_shape, dims, strides,
                                               padding)
     out_shape = out_shape[1:]
     return state.Shape(out_shape, dtype=in_spec.dtype)
Exemplo n.º 6
0
  def testSelectAndScatterAdd(self, dtype, padding, shape, dims, strides):
    rng = jtu.rand_small(self.rng())

    pads = lax.padtype_to_pads(shape, dims, strides, padding)

    def fun(operand, cotangents):
      return lax._select_and_scatter_add(operand, cotangents, lax.ge_p, dims,
                                         strides, pads)
    ones = (1,) * len(shape)
    cotangent_shape = api.eval_shape(
      lambda x: lax._select_and_gather_add(x, x, lax.ge_p, dims, strides,
                                           pads, ones, ones),
      np.ones(shape, dtype)).shape

    for bdims in all_bdims(cotangent_shape, shape):
      self._CheckBatching(fun, 3, bdims, (cotangent_shape, shape),
                          (dtype, dtype), rng)
Exemplo n.º 7
0
    def compute_output_shape(self):
        # lax.reduce_window_shape_tuple() does not accept batch size with None
        # so it's replaced with '1' only in this function
        input_shape = (1, *self._input_shape[1:])
        padding_vals = lax.padtype_to_pads(input_shape, self.pool_size,
                                           self.strides, self.padding)

        num_dims = tuple(1 for _ in range(self.dims))
        base_dilation = (1, *num_dims, 1)
        window_dilation = (1, *num_dims, 1)

        out_shape = lax.reduce_window_shape_tuple(
            operand_shape=input_shape,
            window_dimensions=self.pool_size,
            window_strides=self.strides,
            padding=padding_vals,
            base_dilation=base_dilation,
            window_dilation=window_dilation,
        )
        return out_shape
Exemplo n.º 8
0
def _same_pad_for_filter_shape(x, filter_shape, strides, axes, mode):
    """Pad an array to imitate `SAME` padding with `VALID`.

  See `Returns` section for details. This method is usually needed to implement
    `CIRCULAR` padding using `VALID` padding.

  Args:
    x: `np.ndarray` to pad, e.g. a 4D `NHWC` image.
    filter_shape: tuple of positive integers, the convolutional filters spatial
      shape (e.g. `(3, 3)` for a 2D convolution).
    strides: tuple of positive integers, the convolutional spatial strides, e.g.
      e.g. `(1, 1)` for a 2D convolution.
    axes: tuple of non-negative integers, the spatial axes to apply
      convolution over (e.g. `(1, 2)` for an `NHWC` image).
    mode: a string, padding mode, for all options see
      https://docs.scipy.org/doc/numpy/reference/generated/numpy.pad.html.

  Returns:
    A `np.ndarray` of the same dimensionality as `x` padded to a potentially
      larger shape such that a `VALID` convolution with `filter_shape` applied
      to `x` over `axes` outputs an array of the same shape as `x`.
  """
    if not _is_array(x):
        return x

    axes_shape = tuple(np.size(x, axis) for axis in axes)
    axes_pads = lax.padtype_to_pads(axes_shape, filter_shape, strides,
                                    Padding.SAME.value)

    pads = [
        (0, 0),
    ] * x.ndim
    for i, axis in enumerate(axes):
        pads[axis] = axes_pads[i]

    x = np.pad(x, pads, mode)
    return x
Exemplo n.º 9
0
 def fun(operand, tangents):
     pads = lax.padtype_to_pads(operand.shape, dims, strides, padding)
     ones = (1, ) * len(operand.shape)
     return lax._select_and_gather_add(operand, tangents, lax.ge_p,
                                       dims, strides, pads, ones, ones)
Exemplo n.º 10
0
           StaticArg(window_dimensions), StaticArg(window_strides),
           StaticArg(padding), StaticArg(base_dilation),
           StaticArg(window_dilation)],
          shape=shape,
          dtype=dtype,
          window_dimensions=window_dimensions,
          window_strides=window_strides,
          padding=padding,
          base_dilation=base_dilation,
          window_dilation=window_dilation)
  for dtype in jtu.dtypes.all_floating
  for shape in [(4, 6)]
  for select_prim in [lax.le_p, lax.ge_p]
  for window_dimensions in [(2, 1), (1, 2)]
  for window_strides in [(1, 1), (2, 1), (1, 2)]
  for padding in tuple(set([tuple(lax.padtype_to_pads(shape, window_dimensions,
                                                      window_strides, p))
                            for p in ['VALID', 'SAME']] +
                           [((0, 3), (1, 2))]))
  for base_dilation in [(1, 1)]
  for window_dilation in [(1, 1)]
) + tuple(
  # Tests with 4d shapes (see tests.lax_autodiff_test.testReduceWindowGrad)
  Harness(f"4d_shape={jtu.format_shape_dtype_string(shape, dtype)}_selectprim={select_prim}_windowdimensions={window_dimensions}_windowstrides={window_strides}_padding={padding}_basedilation={base_dilation}_windowdilation={window_dilation}",
          lax._select_and_gather_add,
          [RandArg(shape, dtype), RandArg(shape, dtype), StaticArg(select_prim),
           StaticArg(window_dimensions), StaticArg(window_strides),
           StaticArg(padding), StaticArg(base_dilation),
           StaticArg(window_dilation)],
          shape=shape,
          dtype=dtype,
          window_dimensions=window_dimensions,
Exemplo n.º 11
0
            StaticArg(base_dilation),
            StaticArg(window_dilation)
        ],
        shape=shape,
        dtype=dtype,
        window_dimensions=window_dimensions,
        window_strides=window_strides,
        padding=padding,
        base_dilation=base_dilation,
        window_dilation=window_dilation) for dtype in jtu.dtypes.all_floating
    for shape in [(4, 6)] for select_prim in [lax.le_p, lax.ge_p]
    for window_dimensions in [(2, 1), (1, 2)]
    for window_strides in [(1, 1), (2, 1), (1, 2)] for padding in tuple(
        set([
            tuple(
                lax.padtype_to_pads(shape, window_dimensions, window_strides,
                                    p)) for p in ['VALID', 'SAME']
        ] + [((0, 3), (1, 2))])) for base_dilation in [(1, 1)]
    for window_dilation in [(1, 1)]
) + tuple(
    # Tests with 4d shapes (see tests.lax_autodiff_test.testReduceWindowGrad)
    Harness(
        f"4d_shape={jtu.format_shape_dtype_string(shape, dtype)}_selectprim={select_prim}_windowdimensions={window_dimensions}_windowstrides={window_strides}_padding={padding}_basedilation={base_dilation}_windowdilation={window_dilation}",
        lax._select_and_gather_add, [
            RandArg(shape, dtype),
            RandArg(shape, dtype),
            StaticArg(select_prim),
            StaticArg(window_dimensions),
            StaticArg(window_strides),
            StaticArg(padding),
            StaticArg(base_dilation),
            StaticArg(window_dilation)