Ejemplo n.º 1
0
class SingleDeviceKernel():
  def __init__(self, kernel_prog, ):
    self.func_name = func_name or kernel_name
    self.name = kernel_name + ".cu"
    with open(path.join(kernel_dir, self.name), 'r') as cu_f:
      self.kernel_source = cu_f.read().encode()
    self.prog = Program(self.kernel_source, self.name.encode())
    ptx = self.prog.compile([self.get_compute_arch_arg(device_id)])
    self.module = Module()
    self.module.load(ptx.encode())

  def prep_args(self, kwargs):
    args = []
    for k, v in kwargs.items():
      try:
        args.append(v.data_ptr())
      except:
        args.append(v)
    return args

  def linear_launch(num_threads, *args):
    kernel_func = self.module.get_function(self.func_name)
    kernel_func.linear_launch(
      num_threads,
      args = self.prep_args(args),
      stream=Stream(
        ptr = torch.cuda.current_stream().cuda_stream
      )
    )   
Ejemplo n.º 2
0
    def __call__(self, input):
        if not self.jit or not isinstance(input, torch.cuda.FloatTensor):
            norm = input.norm(2, input.dim() - 1)
            return torch.cat([norm, norm.new(norm.size()).zero_()], input.dim() - 1)

        out = input.new(input.size())
        input = input.contiguous()

        if not iscomplex(input):
            raise TypeError('The input and outputs should be complex')

        if (self.modulus_cache[input.get_device()] is None):
            kernel = """
            extern "C"
            __global__ void abs_complex_value(const float * x, float2 * z, int n)
            {
                int i = blockIdx.x * blockDim.x + threadIdx.x;
            if (i >= n)
                return;
            z[i] = make_float2(normf(2, x + 2*i), 0);

            }
            """
            print('modulus.cu')
            prog = Program(kernel, 'modulus.cu')
            ptx = prog.compile([('-arch='+get_compute_arch(input))])
            module = Module()
            module.load(ptx.encode())
            self.modulus_cache[input.get_device()] = module
        fabs = self.modulus_cache[input.get_device()].get_function('abs_complex_value')
        fabs(grid=(self.GET_BLOCKS(int(out.nelement())//2), 1, 1),
             block=(self.CUDA_NUM_THREADS, 1, 1),
             args=[input.data_ptr(), out.data_ptr(), out.numel() // 2],
             stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
        return out
    def __call__(self, input):
        if not self.jit or not isinstance(input, torch.cuda.FloatTensor):
            norm = input.norm(2, input.dim() - 1)
            return torch.cat([norm, norm.new(norm.size()).zero_()], input.dim() - 1)

        out = input.new(input.size())
        input = input.contiguous()

        if not iscomplex(input):
            raise TypeError('The input and outputs should be complex')

        if (self.modulus_cache[input.get_device()] is None):
            kernel = b"""
            extern "C"
            __global__ void abs_complex_value(const float * x, float2 * z, int n)
            {
                int i = blockIdx.x * blockDim.x + threadIdx.x;
            if (i >= n)
                return;
            z[i] = make_float2(normf(2, x + 2*i), 0);

            }
            """
            print('modulus.cu')
            prog = Program(kernel, b'modulus.cu')
            ptx = prog.compile(['-arch='+get_compute_arch(input)])
            module = Module()
            module.load(bytes(ptx.encode()))
            self.modulus_cache[input.get_device()] = module
        fabs = self.modulus_cache[input.get_device()].get_function('abs_complex_value')
        fabs(grid=(self.GET_BLOCKS(int(out.nelement())//2), 1, 1),
             block=(self.CUDA_NUM_THREADS, 1, 1),
             args=[input.data_ptr(), out.data_ptr(), out.numel() // 2],
             stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
        return out
Ejemplo n.º 4
0
 def __init__(self, kernel_prog, ):
   self.func_name = func_name or kernel_name
   self.name = kernel_name + ".cu"
   with open(path.join(kernel_dir, self.name), 'r') as cu_f:
     self.kernel_source = cu_f.read().encode()
   self.prog = Program(self.kernel_source, self.name.encode())
   ptx = self.prog.compile([self.get_compute_arch_arg(device_id)])
   self.module = Module()
   self.module.load(ptx.encode())
Ejemplo n.º 5
0
    def __call__(self, input, k):
        out = input.new(input.size(0), input.size(1),
                        input.size(2) // k,
                        input.size(3) // k, 2)

        if not self.jit or isinstance(input,
                                      (torch.FloatTensor, torch.DoubleTensor)):
            y = input.view(input.size(0), input.size(1),
                           input.size(2) // out.size(2), out.size(2),
                           input.size(3) // out.size(3), out.size(3), 2)

            out = y.mean(4).squeeze(4).mean(2).squeeze(2)
            return out

        if not iscomplex(input):
            raise (TypeError('The input and outputs should be complex'))

        input = input.contiguous()

        if (self.periodize_cache[(input.size(), out.size(),
                                  input.get_device())] is None):
            kernel = '''
            #define NW ${W} / ${k}
            #define NH ${H} / ${k}
            extern "C"
            __global__ void periodize(const ${Dtype}2 *input, ${Dtype}2 *output)
            {
              int tx = blockIdx.x * blockDim.x + threadIdx.x;
              int ty = blockIdx.y * blockDim.y + threadIdx.y;
              int tz = blockIdx.z * blockDim.z + threadIdx.z;
              if(tx >= NW || ty >= NH || tz >= ${B})
                return;
              input += tz * ${H} * ${W} + ty * ${W} + tx;
              ${Dtype}2 res = make_${Dtype}2(0.f, 0.f);
              for (int j=0; j<${k}; ++j)
                for (int i=0; i<${k}; ++i)
                {
                  const ${Dtype}2 &c = input[j * NH * ${W} + i * NW];
                  res.x += c.x;
                  res.y += c.y;
                }
              res.x /= ${k} * ${k};
              res.y /= ${k} * ${k};
              output[tz * NH * NW + ty * NW + tx] = res;
            }
            '''
            B = input.nelement() // (2 * input.size(-2) * input.size(-3))
            W = input.size(-2)
            H = input.size(-3)
            k = input.size(-2) // out.size(-2)
            kernel = Template(kernel).substitute(B=B,
                                                 H=H,
                                                 W=W,
                                                 k=k,
                                                 Dtype=getDtype(input))
            name = str(input.get_device()) + '-' + str(B) + '-' + str(
                k) + '-' + str(H) + '-' + str(W) + '-periodize.cu'
            print(name)
            prog = Program(kernel, name.encode())
            ptx = prog.compile(['-arch=' + get_compute_arch(input)])
            module = Module()
            module.load(bytes(ptx.encode()))
            self.periodize_cache[(input.size(), out.size(),
                                  input.get_device())] = module
        grid = (self.GET_BLOCKS(out.size(-3), self.block[0]),
                self.GET_BLOCKS(out.size(-2), self.block[1]),
                self.GET_BLOCKS(
                    out.nelement() // (2 * out.size(-2) * out.size(-3)),
                    self.block[2]))
        periodize = self.periodize_cache[(
            input.size(), out.size(),
            input.get_device())].get_function('periodize')
        periodize(grid=grid,
                  block=self.block,
                  args=[input.data_ptr(), out.data_ptr()],
                  stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
        return out
Ejemplo n.º 6
0
 def compile_and_prep_kernel(self, device_id):
   ptx = self.prog.compile([self.get_compute_arch_arg(device_id)])
   module = Module()
   module.load(ptx.encode())
   return module
    def __call__(self, input, k):
        out = input.new(input.size(0), input.size(1), input.size(2) // k, input.size(3) // k, 2)

        if not self.jit or isinstance(input, (torch.FloatTensor, torch.DoubleTensor)):
            y = input.view(input.size(0), input.size(1),
                           input.size(2)//out.size(2), out.size(2),
                           input.size(3)//out.size(3), out.size(3),
                           2)
            out = y.mean(4).squeeze(4).mean(2).squeeze(2)
            return out

        if not iscomplex(input):
            raise (TypeError('The input and outputs should be complex'))

        input = input.contiguous()

        if (self.periodize_cache[(input.size(), out.size(), input.get_device())] is None):
            kernel = '''
            #define NW ${W} / ${k}
            #define NH ${H} / ${k}
            extern "C"
            __global__ void periodize(const ${Dtype}2 *input, ${Dtype}2 *output)
            {
              int tx = blockIdx.x * blockDim.x + threadIdx.x;
              int ty = blockIdx.y * blockDim.y + threadIdx.y;
              int tz = blockIdx.z * blockDim.z + threadIdx.z;
              if(tx >= NW || ty >= NH || tz >= ${B})
                return;
              input += tz * ${H} * ${W} + ty * ${W} + tx;
              ${Dtype}2 res = make_${Dtype}2(0.f, 0.f);
              for (int j=0; j<${k}; ++j)
                for (int i=0; i<${k}; ++i)
                {
                  const ${Dtype}2 &c = input[j * NH * ${W} + i * NW];
                  res.x += c.x;
                  res.y += c.y;
                }
              res.x /= ${k} * ${k};
              res.y /= ${k} * ${k};
              output[tz * NH * NW + ty * NW + tx] = res;
            }
            '''
            B = input.nelement() // (2*input.size(-2) * input.size(-3))
            W = input.size(-2)
            H = input.size(-3)
            k = input.size(-2) // out.size(-2)
            kernel = Template(kernel).substitute(B=B, H=H, W=W, k=k, Dtype=getDtype(input))
            name = str(input.get_device())+'-'+str(B)+'-'+str(k)+'-'+str(H)+'-'+str(W)+'-periodize.cu'
            print(name)
            prog = Program(kernel, name.encode())
            ptx = prog.compile(['-arch='+get_compute_arch(input)])
            module = Module()
            module.load(bytes(ptx.encode()))
            self.periodize_cache[(input.size(), out.size(), input.get_device())] = module
        grid = (self.GET_BLOCKS(out.size(-3), self.block[0]),
                self.GET_BLOCKS(out.size(-2), self.block[1]),
                self.GET_BLOCKS(out.nelement() // (2*out.size(-2) * out.size(-3)), self.block[2]))
        periodize = self.periodize_cache[(input.size(), out.size(), input.get_device())].get_function('periodize')
        periodize(grid=grid, block=self.block, args=[input.data_ptr(), out.data_ptr()],
                  stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
        return out