def _select_and_gather_add_transpose(t, tangents, operand, *, select_prim, window_dimensions, window_strides, padding, base_dilation, window_dilation): assert select_prim in (lax.le_p, lax.ge_p) assert (ad.is_undefined_primal(tangents) and not ad.is_undefined_primal(operand)) if any(d != 1 for d in window_dilation): msg = ( "VJP not implemented for select_and_gather (MaxPool) with window " "dilation, got window_dilation={}.") raise NotImplementedError(msg.format(window_dilation)) if type(t) is ad_util.Zero: return [ad_util.Zero(tangents.aval), None] has_base_dilation = any(d != 1 for d in base_dilation) if has_base_dilation: select_identity = (lax._get_max_identity if select_prim is lax.ge_p else lax._get_min_identity) operand = lax.pad(operand, select_identity(operand.dtype), tuple((0, 0, d - 1) for d in base_dilation)) result = _select_and_scatter_add(t, operand, select_prim, window_dimensions, window_strides, padding) if has_base_dilation: result = slicing.slice(result, (0, ) * len(result.shape), result.shape, base_dilation) return [result, None]
def _select_and_scatter_add_impl(source, operand, *, select_prim, window_dimensions, window_strides, padding, expand_padding): dtype = source.dtype select = lambda x, y: select_prim.bind(x, y) scatter = lax.bitwise_or if dtype == np.bool_ else lax.add if expand_padding: operand_shape = operand.shape original_padding = padding identity = (lax._get_max_identity if select_prim is lax.ge_p else lax._get_min_identity) pads = [(lo, hi, 0) for (lo, hi) in padding] operand = lax.pad(operand, identity(dtype), pads) padding = [(0, 0) for _ in padding] out = _select_and_scatter( operand, select, window_dimensions, window_strides, padding, source, lax._zero(operand), scatter) if expand_padding: start_indices = [lo for (lo, hi) in original_padding] stop_indices = [lo + d for ((lo, hi), d) in zip(original_padding, operand_shape)] out = slicing.slice(out, start_indices, stop_indices) return out