예제 #1
0
파일: mlir.py 프로젝트: rsepassi/jax
def _ndarray_constant_handler(val: np.ndarray,
                              canonicalize_types) -> Sequence[ir.Value]:
    """Constant handler for ndarray literals, handling zero-size strides.

  In most cases this function calls _numpy_array_constant(val) except it has
  special handling of arrays with any strides of size zero: for those, it
  generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose
  to avoid staging in large literals that might arise from np.zeros or np.ones
  or the output of lax.broadcast (which uses np.broadcast_to which in turn
  uses size-zero strides).

  Args:
    val: an ndarray.

  Returns:
    An XLA ComputationDataHandle / XlaOp representing the constant ndarray
    staged into the XLA Computation.
  """
    if dtypes.result_type(val) == dtypes.float0:
        return _numpy_array_constant(np.zeros(val.shape, dtype=np.bool_),
                                     canonicalize_types=False)
    elif np.any(np.equal(0, val.strides)) and val.size > 0:
        zero_stride_axes, = np.where(np.equal(0, val.strides))
        other_axes, = np.where(np.not_equal(0, val.strides))
        collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None)
                                  for ax in range(val.ndim))]
        out = mhlo.BroadcastInDimOp(
            aval_to_ir_type(xla.abstractify(val)),
            _numpy_array_constant(collapsed_val, canonicalize_types)[0],
            dense_int_elements(other_axes)).result
        return (out, )
    else:
        return _numpy_array_constant(val, canonicalize_types)
예제 #2
0
파일: prng.py 프로젝트: xueeinstein/jax
 def _broadcast(x, aval):
     return mhlo.BroadcastInDimOp(
         mlir.aval_to_ir_type(aval_out), x,
         mlir.dense_int_elements(range(rank - len(aval.shape),
                                       rank))).result