def _select_and_scatter_add_translation(ctx, avals_in, avals_out, source, operand, *, select_prim, window_dimensions, window_strides, padding, expand_padding): source_aval, operand_aval = avals_in c = ctx.builder dtype = operand_aval.dtype scalar = ShapedArray((), dtype) select = xla.primitive_subcomputation(ctx.platform, ctx.axis_env, select_prim, scalar, scalar) scatter = xla.primitive_subcomputation( ctx.platform, ctx.axis_env, lax.or_p if dtype == np.bool_ else lax.add_p, scalar, scalar) zero = xla.pyval_to_ir_constant(c, np.array(0, dtype)) # TODO(b/161704903): remove this workaround when XLA:CPU bug is fixed. expand_padding = (expand_padding and not all(lo == 0 and hi == 0 for (lo, hi) in padding)) if expand_padding: original_padding = padding identity = (lax._get_max_identity if select_prim is lax.ge_p else lax._get_min_identity) pads = [(lo, hi, 0) for (lo, hi) in padding] operand = xops.Pad(operand, xla.pyval_to_ir_constant(c, identity(dtype)), xc.make_padding_config(pads)) padding = [(0, 0) for _ in padding] output = xops.SelectAndScatterWithGeneralPadding(operand, select, window_dimensions, window_strides, padding, source, zero, scatter) if expand_padding: start_indices = [lo for (lo, hi) in original_padding] stop_indices = [ lo + d for ((lo, hi), d) in zip(original_padding, operand_aval.shape) ] output = xops.Slice(output, start_indices, stop_indices, [1] * len(start_indices)) return [output]
def _reduce_window_sum_translation_rule(ctx, avals_in, avals_out, operand, *, window_dimensions, window_strides, padding, base_dilation, window_dilation): operand_aval, = avals_in scalar = ShapedArray((), operand_aval.dtype) return [ xops.ReduceWindowWithGeneralPadding( operand, xla.pyval_to_ir_constant(ctx.builder, np.array(0, operand_aval.dtype)), xla.primitive_subcomputation(ctx.platform, ctx.axis_env, lax.add_p, scalar, scalar), window_dimensions, window_strides, base_dilation, window_dilation, padding) ]
def sparse_array_constant_handler(c, val, canonicalize_dtypes): return ( xla.pyval_to_ir_constant(val.data, canonicalize_dtypes), xla.pyval_to_ir_constant(val.indices, canonicalize_dtypes) )