def _conv_general_vjp_lhs_padding(in_shape, window_dimensions, window_strides, out_shape, padding, lhs_dilation, rhs_dilation) -> List[Tuple[int, int]]: lhs_dilated_shape = lax._dilate_shape(in_shape, lhs_dilation) rhs_dilated_shape = lax._dilate_shape(window_dimensions, rhs_dilation) out_dilated_shape = lax._dilate_shape(out_shape, window_strides) pad_before = np.subtract(rhs_dilated_shape, [lo for lo, _ in padding]) - 1 pad_after = (np.add(lhs_dilated_shape, rhs_dilated_shape) - 1 - out_dilated_shape - pad_before) return safe_zip(pad_before, pad_after)
def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides, padding, base_dilation=None, window_dilation=None): if base_dilation is not None: operand_shape = lax._dilate_shape(operand_shape, base_dilation) if window_dilation is not None: window_dimensions = lax._dilate_shape(window_dimensions, window_dilation) pads_lo, pads_hi = zip(*padding) operand_padded = core.sum_shapes(operand_shape, pads_lo, pads_hi) return core.stride_shape(operand_padded, window_dimensions, window_strides)
def _conv_general_vjp_rhs_padding(in_shape, window_dimensions, window_strides, out_shape, padding, lhs_dilation, rhs_dilation): if len(in_shape) == 0: # 0D conv return [] lhs_dilated_shape = lax._dilate_shape(in_shape, lhs_dilation) rhs_dilated_shape = lax._dilate_shape(window_dimensions, rhs_dilation) out_dilated_shape = lax._dilate_shape(out_shape, window_strides) pads_lo, _ = zip(*padding) pads_from_lhs = core.diff_shape(out_dilated_shape, lhs_dilated_shape) pads_from_rhs = core.diff_shape( core.diff_shape(rhs_dilated_shape, pads_lo), (1, ) * len(pads_lo)) pads_hi = core.sum_shapes(pads_from_lhs, pads_from_rhs) return list(zip(pads_lo, pads_hi))
def reduce_window(operand, init_value, computation: Callable, window_dimensions: core.Shape, window_strides: Sequence[int], padding: Union[str, Sequence[Tuple[int, int]]], base_dilation: Optional[Sequence[int]] = None, window_dilation: Optional[Sequence[int]] = None) -> Array: """Wraps XLA's `ReduceWindowWithGeneralPadding <https://www.tensorflow.org/xla/operation_semantics#reducewindow>`_ operator. """ flat_operands, operand_tree = tree_util.tree_flatten(operand) flat_init_values, init_value_tree = tree_util.tree_flatten(init_value) if operand_tree != init_value_tree: raise ValueError('Operands must have the same tree structure as ' f'init_values: {operand_tree} vs. {init_value_tree}') if len(flat_operands) == 0: raise ValueError('reduce_window must have at least one operand.') if len(flat_operands) != len(flat_init_values): raise ValueError( 'Must have same total number of operands as init_values: ' f' {len(flat_operands)} vs. {len(flat_init_values)}') if isinstance(padding, str): dilated_window_dims = (window_dimensions if window_dilation is None else lax._dilate_shape( window_dimensions, window_dilation)) padding = tuple( lax.padtype_to_pads(flat_operands[0].shape, dilated_window_dims, window_strides, padding)) else: padding = tuple(padding) if base_dilation is None: base_dilation = (1, ) * len(window_dimensions) if window_dilation is None: window_dilation = (1, ) * len(window_dimensions) monoid_reducer = _get_monoid_window_reducer(computation, flat_init_values) if monoid_reducer: return monoid_reducer(operand, window_dimensions, window_strides, padding, base_dilation, window_dilation) else: flat_init_avals = map(lax._abstractify, flat_init_values) jaxpr, consts, out_tree = lax._variadic_reduction_jaxpr( computation, tuple(flat_init_avals), init_value_tree) if operand_tree != out_tree: raise ValueError( 'reduce_window output must have the same tree structure as the operands' f' {operand_tree} vs. {out_tree}') out_flat = reduce_window_p.bind( *(flat_operands + flat_init_values), jaxpr=jaxpr, consts=consts, window_dimensions=tuple(window_dimensions), window_strides=tuple(window_strides), padding=padding, base_dilation=tuple(base_dilation), window_dilation=tuple(window_dilation)) return tree_util.tree_unflatten(out_tree, out_flat)
def _conv_general_dilated_shape_rule(lhs: core.ShapedArray, rhs: core.ShapedArray, *, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, **unused_kwargs) -> Tuple[int, ...]: assert type(dimension_numbers) is ConvDimensionNumbers if len(lhs.shape) != len(rhs.shape): msg = ("conv_general_dilated lhs and rhs must have the same number of " "dimensions, but got {} and {}.") raise ValueError(msg.format(lhs.shape, rhs.shape)) if not feature_group_count > 0: msg = ("conv_general_dilated feature_group_count " "must be a positive integer, got {}.") raise ValueError(msg.format(feature_group_count)) lhs_feature_count = lhs.shape[dimension_numbers.lhs_spec[1]] quot, rem = divmod(lhs_feature_count, feature_group_count) if rem: msg = ( "conv_general_dilated feature_group_count must divide lhs feature " "dimension size, but {} does not divide {}.") raise ValueError(msg.format(feature_group_count, lhs_feature_count)) if not core.symbolic_equal_dim(quot, rhs.shape[dimension_numbers.rhs_spec[1]]): msg = ( "conv_general_dilated lhs feature dimension size divided by " "feature_group_count must equal the rhs input feature dimension " "size, but {} // {} != {}.") raise ValueError( msg.format(lhs_feature_count, feature_group_count, rhs.shape[dimension_numbers.rhs_spec[1]])) if rhs.shape[dimension_numbers.rhs_spec[0]] % feature_group_count: msg = ( "conv_general_dilated rhs output feature dimension size must be a " "multiple of feature_group_count, but {} is not a multiple of {}.") raise ValueError( msg.format(rhs.shape[dimension_numbers.rhs_spec[0]], feature_group_count)) if not batch_group_count > 0: msg = ("conv_general_dilated batch_group_count " "must be a positive integer, got {}.") raise ValueError(msg.format(batch_group_count)) lhs_batch_count = lhs.shape[dimension_numbers.lhs_spec[0]] if batch_group_count > 1 and lhs_batch_count % batch_group_count != 0: msg = ("conv_general_dilated batch_group_count must divide lhs batch " "dimension size, but {} does not divide {}.") raise ValueError(msg.format(batch_group_count, lhs_batch_count)) if rhs.shape[dimension_numbers.rhs_spec[0]] % batch_group_count: msg = ( "conv_general_dilated rhs output feature dimension size must be a " "multiple of batch_group_count, but {} is not a multiple of {}.") raise ValueError( msg.format(rhs.shape[dimension_numbers.rhs_spec[0]], batch_group_count)) if batch_group_count > 1 and feature_group_count > 1: msg = ( "At most one of batch_group_count and feature_group_count may be > " "1, got batch_group_count={} and feature_group_count={}") raise ValueError(msg.format(batch_group_count, feature_group_count)) if len(_conv_sdims(dimension_numbers.rhs_spec)) != len(window_strides): msg = ("conv_general_dilated window and window_strides must have " "the same number of dimensions, but got {} and {}") raise ValueError( msg.format(len(_conv_sdims(dimension_numbers.rhs_spec)), len(window_strides))) lhs_perm, rhs_perm, out_perm = dimension_numbers lhs_trans = lax._dilate_shape(np.take(lhs.shape, lhs_perm), lhs_dilation) rhs_trans = lax._dilate_shape(np.take(rhs.shape, rhs_perm), rhs_dilation) out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding, batch_group_count) return tuple(np.take(out_trans, np.argsort(out_perm))) # type: ignore[arg-type]