def pads_to_padtype(in_shape, window_shape, window_strides, padding) -> str: for pad_str in ["VALID", "SAME"]: pads = lax.padtype_to_pads(in_shape, window_shape, window_strides, pad_str) if list(pads) == list(padding): return pad_str return "EXPLICIT"
def init_fun(rng, input_shape): padding_vals = lax.padtype_to_pads(input_shape, window_shape, strides, padding) ones = (1,) * len(window_shape) out_shape = lax.reduce_window_shape_tuple( input_shape, window_shape, strides, padding_vals, ones, ones) return out_shape, ()
def _reduce_window(jax_f, reducer, init_val, operand, window_dimensions, window_strides, padding, input_shape=None): """TensorFlow implementation of reduce_window_{sum,min,max}.""" del input_shape # TODO(tomhennigan): tf2xla should have a shape inference function. out_shape = _reduce_window_shape(jax_f, operand, window_dimensions, window_strides, padding) padding = lax.padtype_to_pads(_get_shape_from_tensor_or_array(operand), window_dimensions, window_strides, padding) a = tf.constant(0, operand.dtype) reducer_fn = reducer.get_concrete_function(a, a) out = tfxla.reduce_window(operand, tf.constant(init_val, operand.dtype), reducer_fn, window_dimensions, window_strides, padding=padding) out.set_shape(out_shape) return out
def onnx_conv(x, w, b=0, group=1, kernel_shape=None, pads=None, strides=None, dilations=None, auto_pad=None): """Numpy-backed implementation of ONNX Conv op.""" assert group == 1 kernel_shape = kernel_shape or w.shape strides = strides or [1] * (w.ndim - 2) if auto_pad: auto_pad = 'SAME' if auto_pad.startswith(b'SAME') else 'VALID' pads = lax.padtype_to_pads(x.shape[2:], w.shape[2:], strides, auto_pad) else: pads = pads or [0] * (w.ndim - 2) lhs_dilation = [1] * (w.ndim - 2) rhs_dilation = dilations or [1] * (w.ndim - 2) return [lax.conv_with_general_padding(x, w, strides, pads, lhs_dilation, rhs_dilation) + b]
def spec(cls, in_spec, window_shape, strides=None, padding='VALID'): in_shape = in_spec.shape if len(in_shape) > 3: raise ValueError('Need to `jax.vmap` in order to batch') in_shape = (1, ) + in_shape dims = (1, ) + window_shape + (1, ) # NHWC or NHC non_spatial_axes = 0, len(window_shape) + 1 strides = strides or (1, ) * len(window_shape) for i in sorted(non_spatial_axes): window_shape = window_shape[:i] + (1, ) + window_shape[i:] strides = strides[:i] + (1, ) + strides[i:] padding = lax.padtype_to_pads(in_shape, window_shape, strides, padding) out_shape = lax.reduce_window_shape_tuple(in_shape, dims, strides, padding) out_shape = out_shape[1:] return state.Shape(out_shape, dtype=in_spec.dtype)
def testSelectAndScatterAdd(self, dtype, padding, shape, dims, strides): rng = jtu.rand_small(self.rng()) pads = lax.padtype_to_pads(shape, dims, strides, padding) def fun(operand, cotangents): return lax._select_and_scatter_add(operand, cotangents, lax.ge_p, dims, strides, pads) ones = (1,) * len(shape) cotangent_shape = api.eval_shape( lambda x: lax._select_and_gather_add(x, x, lax.ge_p, dims, strides, pads, ones, ones), np.ones(shape, dtype)).shape for bdims in all_bdims(cotangent_shape, shape): self._CheckBatching(fun, 3, bdims, (cotangent_shape, shape), (dtype, dtype), rng)
def compute_output_shape(self): # lax.reduce_window_shape_tuple() does not accept batch size with None # so it's replaced with '1' only in this function input_shape = (1, *self._input_shape[1:]) padding_vals = lax.padtype_to_pads(input_shape, self.pool_size, self.strides, self.padding) num_dims = tuple(1 for _ in range(self.dims)) base_dilation = (1, *num_dims, 1) window_dilation = (1, *num_dims, 1) out_shape = lax.reduce_window_shape_tuple( operand_shape=input_shape, window_dimensions=self.pool_size, window_strides=self.strides, padding=padding_vals, base_dilation=base_dilation, window_dilation=window_dilation, ) return out_shape
def _same_pad_for_filter_shape(x, filter_shape, strides, axes, mode): """Pad an array to imitate `SAME` padding with `VALID`. See `Returns` section for details. This method is usually needed to implement `CIRCULAR` padding using `VALID` padding. Args: x: `np.ndarray` to pad, e.g. a 4D `NHWC` image. filter_shape: tuple of positive integers, the convolutional filters spatial shape (e.g. `(3, 3)` for a 2D convolution). strides: tuple of positive integers, the convolutional spatial strides, e.g. e.g. `(1, 1)` for a 2D convolution. axes: tuple of non-negative integers, the spatial axes to apply convolution over (e.g. `(1, 2)` for an `NHWC` image). mode: a string, padding mode, for all options see https://docs.scipy.org/doc/numpy/reference/generated/numpy.pad.html. Returns: A `np.ndarray` of the same dimensionality as `x` padded to a potentially larger shape such that a `VALID` convolution with `filter_shape` applied to `x` over `axes` outputs an array of the same shape as `x`. """ if not _is_array(x): return x axes_shape = tuple(np.size(x, axis) for axis in axes) axes_pads = lax.padtype_to_pads(axes_shape, filter_shape, strides, Padding.SAME.value) pads = [ (0, 0), ] * x.ndim for i, axis in enumerate(axes): pads[axis] = axes_pads[i] x = np.pad(x, pads, mode) return x
def fun(operand, tangents): pads = lax.padtype_to_pads(operand.shape, dims, strides, padding) ones = (1, ) * len(operand.shape) return lax._select_and_gather_add(operand, tangents, lax.ge_p, dims, strides, pads, ones, ones)
StaticArg(window_dimensions), StaticArg(window_strides), StaticArg(padding), StaticArg(base_dilation), StaticArg(window_dilation)], shape=shape, dtype=dtype, window_dimensions=window_dimensions, window_strides=window_strides, padding=padding, base_dilation=base_dilation, window_dilation=window_dilation) for dtype in jtu.dtypes.all_floating for shape in [(4, 6)] for select_prim in [lax.le_p, lax.ge_p] for window_dimensions in [(2, 1), (1, 2)] for window_strides in [(1, 1), (2, 1), (1, 2)] for padding in tuple(set([tuple(lax.padtype_to_pads(shape, window_dimensions, window_strides, p)) for p in ['VALID', 'SAME']] + [((0, 3), (1, 2))])) for base_dilation in [(1, 1)] for window_dilation in [(1, 1)] ) + tuple( # Tests with 4d shapes (see tests.lax_autodiff_test.testReduceWindowGrad) Harness(f"4d_shape={jtu.format_shape_dtype_string(shape, dtype)}_selectprim={select_prim}_windowdimensions={window_dimensions}_windowstrides={window_strides}_padding={padding}_basedilation={base_dilation}_windowdilation={window_dilation}", lax._select_and_gather_add, [RandArg(shape, dtype), RandArg(shape, dtype), StaticArg(select_prim), StaticArg(window_dimensions), StaticArg(window_strides), StaticArg(padding), StaticArg(base_dilation), StaticArg(window_dilation)], shape=shape, dtype=dtype, window_dimensions=window_dimensions,
StaticArg(base_dilation), StaticArg(window_dilation) ], shape=shape, dtype=dtype, window_dimensions=window_dimensions, window_strides=window_strides, padding=padding, base_dilation=base_dilation, window_dilation=window_dilation) for dtype in jtu.dtypes.all_floating for shape in [(4, 6)] for select_prim in [lax.le_p, lax.ge_p] for window_dimensions in [(2, 1), (1, 2)] for window_strides in [(1, 1), (2, 1), (1, 2)] for padding in tuple( set([ tuple( lax.padtype_to_pads(shape, window_dimensions, window_strides, p)) for p in ['VALID', 'SAME'] ] + [((0, 3), (1, 2))])) for base_dilation in [(1, 1)] for window_dilation in [(1, 1)] ) + tuple( # Tests with 4d shapes (see tests.lax_autodiff_test.testReduceWindowGrad) Harness( f"4d_shape={jtu.format_shape_dtype_string(shape, dtype)}_selectprim={select_prim}_windowdimensions={window_dimensions}_windowstrides={window_strides}_padding={padding}_basedilation={base_dilation}_windowdilation={window_dilation}", lax._select_and_gather_add, [ RandArg(shape, dtype), RandArg(shape, dtype), StaticArg(select_prim), StaticArg(window_dimensions), StaticArg(window_strides), StaticArg(padding), StaticArg(base_dilation), StaticArg(window_dilation)