def _threefry2x32_gpu_lowering(threefry2x32_lowering, ctx, k1, k2, x1, x2): aval_out, _ = ctx.avals_out k1_aval, k2_aval, x1_aval, x2_aval = ctx.avals_in rank = len(aval_out.shape) if 0 in aval_out.shape: zeros = mlir.full_like_aval(0, aval_out) return [zeros, zeros] def _broadcast(x, aval): return mhlo.BroadcastInDimOp( mlir.aval_to_ir_type(aval_out), x, mlir.dense_int_elements(range(rank - len(aval.shape), rank))).result return threefry2x32_lowering( (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)), (_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)))
def _reduce_window_lower(reduce_op, init_value, ctx, operand, *, window_dimensions, window_strides, padding, base_dilation, window_dilation): aval_out, = ctx.avals_out operand_aval, = ctx.avals_in scalar_aval = operand_aval.update(shape=()) scalar_type = mlir.aval_to_ir_type(scalar_aval) rw = mhlo.ReduceWindowOp( mlir.aval_to_ir_types(aval_out), [operand], [mlir.full_like_aval(init_value(scalar_aval.dtype), scalar_aval)], mlir.dense_int_elements(window_dimensions), mlir.dense_int_elements(window_strides), mlir.dense_int_elements(base_dilation), mlir.dense_int_elements(window_dilation), ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64))) reducer = rw.regions[0].blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(reducer): mhlo.ReturnOp(reduce_op(*reducer.arguments)) return rw.results