Beispiel #1
0
def _select_and_scatter_add_translation(ctx, avals_in, avals_out, source,
                                        operand, *, select_prim,
                                        window_dimensions, window_strides,
                                        padding, expand_padding):
    source_aval, operand_aval = avals_in
    c = ctx.builder
    dtype = operand_aval.dtype
    scalar = ShapedArray((), dtype)
    select = xla.primitive_subcomputation(ctx.platform, ctx.axis_env,
                                          select_prim, scalar, scalar)
    scatter = xla.primitive_subcomputation(
        ctx.platform, ctx.axis_env,
        lax.or_p if dtype == np.bool_ else lax.add_p, scalar, scalar)
    zero = xla.pyval_to_ir_constant(c, np.array(0, dtype))
    # TODO(b/161704903): remove this workaround when XLA:CPU bug is fixed.
    expand_padding = (expand_padding
                      and not all(lo == 0 and hi == 0 for (lo, hi) in padding))
    if expand_padding:
        original_padding = padding
        identity = (lax._get_max_identity
                    if select_prim is lax.ge_p else lax._get_min_identity)
        pads = [(lo, hi, 0) for (lo, hi) in padding]
        operand = xops.Pad(operand,
                           xla.pyval_to_ir_constant(c, identity(dtype)),
                           xc.make_padding_config(pads))
        padding = [(0, 0) for _ in padding]
    output = xops.SelectAndScatterWithGeneralPadding(operand, select,
                                                     window_dimensions,
                                                     window_strides, padding,
                                                     source, zero, scatter)
    if expand_padding:
        start_indices = [lo for (lo, hi) in original_padding]
        stop_indices = [
            lo + d
            for ((lo, hi), d) in zip(original_padding, operand_aval.shape)
        ]
        output = xops.Slice(output, start_indices, stop_indices,
                            [1] * len(start_indices))
    return [output]
Beispiel #2
0
def _reduce_window_sum_translation_rule(ctx, avals_in, avals_out, operand, *,
                                        window_dimensions, window_strides,
                                        padding, base_dilation,
                                        window_dilation):
    operand_aval, = avals_in
    scalar = ShapedArray((), operand_aval.dtype)
    return [
        xops.ReduceWindowWithGeneralPadding(
            operand,
            xla.pyval_to_ir_constant(ctx.builder,
                                     np.array(0, operand_aval.dtype)),
            xla.primitive_subcomputation(ctx.platform, ctx.axis_env, lax.add_p,
                                         scalar, scalar), window_dimensions,
            window_strides, base_dilation, window_dilation, padding)
    ]
Beispiel #3
0
def sparse_array_constant_handler(c, val, canonicalize_dtypes):
  return (
    xla.pyval_to_ir_constant(val.data, canonicalize_dtypes),
    xla.pyval_to_ir_constant(val.indices, canonicalize_dtypes)
  )