def test_compile_custom_op(self): my_op = jt.compile_custom_op( """ struct MyOp : Op { Var* output; MyOp(NanoVector shape, NanoString dtype=ns_float32); const char* name() const override { return "my"; } DECLARE_jit_run; }; """, """ #ifndef JIT MyOp::MyOp(NanoVector shape, NanoString dtype) { output = create_output(shape, dtype); } void MyOp::jit_prepare(JK& jk) { add_jit_define(jk, "T", output->dtype()); } #else // JIT void MyOp::jit_run() { index_t num = output->num; auto* __restrict__ x = output->ptr<T>(); for (index_t i=0; i<num; i++) x[i] = (T)-i; } #endif // JIT """, "my") a = my_op([3, 4, 5], 'float') na = a.data assert a.shape == [3, 4, 5] and a.dtype == 'float' assert (-na.flatten() == range(3 * 4 * 5)).all(), na
def test_no_cuda_op(self): no_cuda_op = jt.compile_custom_op( """ struct NoCudaOp : Op { Var* output; NoCudaOp(NanoVector shape, string dtype="float"); const char* name() const override { return "my_cuda"; } DECLARE_jit_run; }; """, """ #ifndef JIT NoCudaOp::NoCudaOp(NanoVector shape, string dtype) { flags.set(NodeFlags::_cpu); output = create_output(shape, dtype); } void NoCudaOp::jit_prepare(JK& jk) { add_jit_define(jk, "T", output->dtype()); } #else // JIT void NoCudaOp::jit_run() {} #endif // JIT """, "no_cuda") # force use cuda a = no_cuda_op([3, 4, 5], 'float') expect_error(lambda: a())
def test_cuda_custom_op(self): my_op = jt.compile_custom_op( """ struct MyCudaOp : Op { Var* output; MyCudaOp(NanoVector shape, string dtype="float"); const char* name() const override { return "my_cuda"; } DECLARE_jit_run; }; """, """ #ifndef JIT MyCudaOp::MyCudaOp(NanoVector shape, string dtype) { flags.set(NodeFlags::_cuda); output = create_output(shape, dtype); } void MyCudaOp::jit_prepare() { add_jit_define("T", output->dtype()); } #else // JIT #ifdef JIT_cuda __global__ void kernel(index_t n, T *x) { int index = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; for (int i = index; i < n; i += stride) x[i] = (T)-i; } void MyCudaOp::jit_run() { index_t num = output->num; auto* __restrict__ x = output->ptr<T>(); int blockSize = 256; int numBlocks = (num + blockSize - 1) / blockSize; kernel<<<numBlocks, blockSize>>>(num, x); } #endif // JIT_cuda #endif // JIT """, "my_cuda") with jt.var_scope(use_cuda=1): a = my_op([3, 4, 5], 'float') na = a.data assert a.shape == [3, 4, 5] and a.dtype == 'float' assert (-na.flatten() == range(3 * 4 * 5)).all(), na