def _select_and_scatter(operand: Array, select: Callable, window_dimensions: core.Shape, window_strides: Sequence[int], padding: Sequence[Tuple[int, int]], source: Array, init_value: Array, scatter: Callable) -> Array: select_jaxpr, select_consts = lax._reduction_jaxpr( select, lax._abstractify(init_value)) scatter_jaxpr, scatter_consts = lax._reduction_jaxpr( scatter, lax._abstractify(init_value)) return select_and_scatter_p.bind( operand, source, init_value, select_jaxpr=select_jaxpr, select_consts=select_consts, scatter_jaxpr=scatter_jaxpr, scatter_consts=scatter_consts, window_dimensions=tuple(window_dimensions), window_strides=tuple(window_strides), padding=tuple(padding))
def _reduce_window_prod(operand: Array, window_dimensions: core.Shape, window_strides: Sequence[int], padding: Sequence[Tuple[int, int]], base_dilation: Optional[Sequence[int]] = None, window_dilation: Optional[Sequence[int]] = None) -> Array: init_value = lax._const(operand, 1) jaxpr, consts = lax._reduction_jaxpr(lax.mul, lax._abstractify(init_value)) if base_dilation is None: base_dilation = (1,) * len(window_dimensions) if window_dilation is None: window_dilation = (1,) * len(window_dimensions) out, = reduce_window_p.bind( operand, init_value, jaxpr=jaxpr, consts=consts, window_dimensions=tuple(window_dimensions), window_strides=tuple(window_strides), padding=tuple(padding), base_dilation=tuple(base_dilation), window_dilation=tuple(window_dilation)) return out