def _ndarray_constant_handler(val: np.ndarray, canonicalize_types) -> Sequence[ir.Value]: """Constant handler for ndarray literals, handling zero-size strides. In most cases this function calls _numpy_array_constant(val) except it has special handling of arrays with any strides of size zero: for those, it generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose to avoid staging in large literals that might arise from np.zeros or np.ones or the output of lax.broadcast (which uses np.broadcast_to which in turn uses size-zero strides). Args: val: an ndarray. Returns: An XLA ComputationDataHandle / XlaOp representing the constant ndarray staged into the XLA Computation. """ if dtypes.result_type(val) == dtypes.float0: return _numpy_array_constant(np.zeros(val.shape, dtype=np.bool_), canonicalize_types=False) elif np.any(np.equal(0, val.strides)) and val.size > 0: zero_stride_axes, = np.where(np.equal(0, val.strides)) other_axes, = np.where(np.not_equal(0, val.strides)) collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None) for ax in range(val.ndim))] out = mhlo.BroadcastInDimOp( aval_to_ir_type(xla.abstractify(val)), _numpy_array_constant(collapsed_val, canonicalize_types)[0], dense_int_elements(other_axes)).result return (out, ) else: return _numpy_array_constant(val, canonicalize_types)
def _broadcast(x, aval): return mhlo.BroadcastInDimOp( mlir.aval_to_ir_type(aval_out), x, mlir.dense_int_elements(range(rank - len(aval.shape), rank))).result