Ejemplo n.º 1
0
Archivo: prng.py Proyecto: cloudhan/jax
def _threefry2x32_gpu_lowering(threefry2x32_lowering, ctx, k1, k2, x1, x2):
  aval_out, _ = ctx.avals_out
  k1_aval, k2_aval, x1_aval, x2_aval = ctx.avals_in
  rank = len(aval_out.shape)
  if 0 in aval_out.shape:
    zeros = mlir.full_like_aval(0, aval_out)
    return [zeros, zeros]
  def _broadcast(x, aval):
    return mhlo.BroadcastInDimOp(
        mlir.aval_to_ir_type(aval_out), x,
        mlir.dense_int_elements(range(rank - len(aval.shape), rank))).result
  return threefry2x32_lowering(
          (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
          (_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)))
Ejemplo n.º 2
0
def _reduce_window_lower(reduce_op, init_value, ctx, operand, *,
                         window_dimensions, window_strides, padding,
                         base_dilation, window_dilation):
    aval_out, = ctx.avals_out
    operand_aval, = ctx.avals_in
    scalar_aval = operand_aval.update(shape=())
    scalar_type = mlir.aval_to_ir_type(scalar_aval)
    rw = mhlo.ReduceWindowOp(
        mlir.aval_to_ir_types(aval_out), [operand],
        [mlir.full_like_aval(init_value(scalar_aval.dtype), scalar_aval)],
        mlir.dense_int_elements(window_dimensions),
        mlir.dense_int_elements(window_strides),
        mlir.dense_int_elements(base_dilation),
        mlir.dense_int_elements(window_dilation),
        ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
    reducer = rw.regions[0].blocks.append(scalar_type, scalar_type)
    with ir.InsertionPoint(reducer):
        mhlo.ReturnOp(reduce_op(*reducer.arguments))
    return rw.results