Exemple #1
0
 def test(self):
     # An error message contains the file name `kern.cu`
     with six.assertRaisesRegex(self, compiler.CompileException, 'kern.cu'):
         compiler.compile_using_nvrtc('a')
Exemple #2
0
 def test2(self):
     with self.assertRaises(compiler.CompileException) as e:
         compiler.compile_using_nvrtc('a')
         assert "unknown type name 'a'" in e
Exemple #3
0
 def _compile(self, arch):
     compiler.compile_using_nvrtc('', arch=arch)
Exemple #4
0
                float dsq = dx * dx + dy * dy + dz * dz;

                g += (dsq * wprod);
            }
        }

        result[0] = g;
    }

    }
    """

    # get the PTX
    from cupy.cuda.compiler import compile_using_nvrtc
    with open('cupy_ptx.txt', 'w') as fp:
        fp.write(compile_using_nvrtc(source_code))

    # compile and load CUDA kernel using CuPy
    brute_force_pairs_kernel = cp.RawKernel(source_code,
                                            'brute_force_pairs_kernel')

    d_x1 = cp.asarray(x1, dtype=cp.float32)
    d_y1 = cp.asarray(y1, dtype=cp.float32)
    d_z1 = cp.asarray(z1, dtype=cp.float32)
    d_w1 = cp.asarray(w1, dtype=cp.float32)

    d_x2 = cp.asarray(x2, dtype=cp.float32)
    d_y2 = cp.asarray(y2, dtype=cp.float32)
    d_z2 = cp.asarray(z2, dtype=cp.float32)
    d_w2 = cp.asarray(w2, dtype=cp.float32)