예제 #1
0
def _cvt_to_def_str(obj):
  # bool
  if isinstance(obj, bool):
    return str(int(obj))
  # tensorflow type
  if fw.has_tensorflow():
    if isinstance(obj, fw.tensorflow.DType):
      return {fw.tensorflow.int8: 'char',
              fw.tensorflow.int16: 'short',
              fw.tensorflow.int32: 'int',
              fw.tensorflow.int64: 'long',
              fw.tensorflow.float16: 'half',
              fw.tensorflow.float32: 'float',
              fw.tensorflow.float64: 'double'}[obj]
  # torch type
  elif fw.has_torch():
    if isinstance(obj, fw.torch.dtype):
      return {fw.torch.int8: 'char',
              fw.torch.int16: 'short',
              fw.torch.int32: 'int',
              fw.torch.int64: 'long',
              fw.torch.float16: 'half',
              fw.torch.float32: 'float',
              fw.torch.float64: 'double'}[obj]
  else:
    assert False
  # default
  return str(obj)
예제 #2
0
 def apply(cls, *args, **kwargs):
     if fw.has_tensorflow():
         return cls.apply_tensorflow(*args, **kwargs)
     elif fw.has_torch():
         return cls.apply_torch(*args, **kwargs)
     else:
         assert False
예제 #3
0
파일: utils.py 프로젝트: ghosthamlet/triton
def shape(A):
    if fw.has_tensorflow():
        return A.shape.as_list()
    elif fw.has_torch():
        return A.shape
    else:
        assert False
예제 #4
0
파일: utils.py 프로젝트: ghosthamlet/triton
 def _get_key(key):
     if fw.has_tensorflow():
         if isinstance(key, fw.tensorflow.Tensor):
             key = id(key.op)
     if fw.has_torch():
         if isinstance(key, fw.torch.Tensor):
             key = id(key)
     return key
예제 #5
0
파일: utils.py 프로젝트: ghosthamlet/triton
def empty(shape, dtype):
    if fw.has_tensorflow():
        shape = [fw.tensorflow.constant(x) for x in shape]
        shape = fw.tensorflow.stack(shape)
        return tf_empty_proxy(shape, dtype)
        #return fw.tf_extra_ops.alloc_empty(args, T = dtype)
    elif fw.has_torch():
        return fw.torch.empty(shape, dtype=dtype, device='cuda:0')
예제 #6
0
def _make_framework_op(src, outputs, tmp, options):
  src, name = _make_framework_src(src, outputs, tmp, options)
  cache_path = _make_cache_path(src)
  cpp, so = _write_bindings(src, cache_path)
  _build(cpp, cache_path)
  if fw.has_tensorflow():
    return fw.tensorflow.load_op_library(so).__dict__[name]
  elif fw.has_torch():
    fw.torch.ops.load_library(so)
    return getattr(fw.torch.ops.triton, name)
  else:
    assert False
예제 #7
0
def _write_bindings(src, root):
  if fw.has_tensorflow():
    name = 'tensorflow'
  elif fw.has_torch():
    name = 'torch'
  else:
    assert False
  cpp = os.path.join(root, '{name}.cpp'.format(name=name))
  suffix = sysconfig.get_config_var('EXT_SUFFIX')
  so = os.path.join(root, '{name}{suffix}'.format(name=name, suffix=suffix))
  recompile = False
  # recompile if .so does not exist
  if not os.path.exists(cpp) or not os.path.exists(so):
    recompile = True
  # recompile if cpp was modified after .so
  elif max(cpp, so, key=os.path.getctime) == cpp:
    recompile = True
  # write cpp file
  if recompile:
    with open(cpp, 'w+') as handle:
      handle.writelines(src)
  # return path of cpp file
  return (cpp, so)
예제 #8
0
def _build(src, path):
  ccdir = os.path.join(libtriton.__file__, os.path.pardir)
  ccdir = os.path.realpath(ccdir)
  # include directories
  triton_include_dirs = [os.path.join(ccdir, 'include')]
  include_dirs = triton_include_dirs 
  # library directories
  triton_library_dirs = [ccdir]
  library_dirs = triton_library_dirs
  # libraries
  libraries = ['triton']
  # add framework
  extra_compile_args = []
  if fw.has_tensorflow():
    library_dirs += [fw.tensorflow.sysconfig.get_lib()]
    include_dirs += [fw.tensorflow.sysconfig.get_include()]
    include_dirs += ['/usr/local/cuda/include/']
    libraries += [fw.tensorflow.sysconfig.get_link_flags()[1].replace('-l', '')]
    abi = fw.tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in fw.tensorflow.__dict__ else 0
    extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
    name = 'tensorflow'
  elif fw.has_torch():
    prefix = os.path.dirname(fw.torch.__file__)
    library_dirs += [os.path.join(prefix, 'lib')]
    include_dirs += ['/usr/local/cuda/include/',
                     os.path.join(prefix, 'lib', 'include'),
                     os.path.join(prefix, 'lib', 'include', 'torch', 'csrc', 'api', 'include'),
                     os.path.join(prefix, 'include'),
                     os.path.join(prefix, 'include', 'torch', 'csrc', 'api', 'include')]
    libraries += ['torch']
    abi = fw.torch._C._GLIBCXX_USE_CXX11_ABI
    extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
    name = 'torch'
  else:
    assert False
  # extra arguments
  extra_link_args = []
  # dependences
  depends = [os.path.realpath(libtriton.__file__)]
  # create extension module
  ext = setuptools.Extension(
      name = name,
      language = 'c++',
      sources = [src],
      include_dirs = include_dirs,
      extra_compile_args = extra_compile_args + ['-g0'],
      extra_link_args = extra_link_args,
      library_dirs = library_dirs,
      libraries = libraries,
      depends = depends
  )
  # build extension module
  args = ['build_ext']
  tmp = tempfile.mkdtemp()
  args.append('--build-temp=' + tmp)
  args.append('--build-lib=' + path)
  args.append('-q')
  args = dict(
      name = name,
      ext_modules = [ext],
      script_args = args,
  ) 
  with quiet():
    setuptools.setup(**args)
  shutil.rmtree(tmp)
예제 #9
0
  def __call__(self, *args, **kwargs):

    ########################
    # keyword arguments
    ########################
    num_warps = kwargs['num_warps'] if 'num_warps' in kwargs else [2, 4, 8]
    defines = kwargs['defines']     if 'defines'   in kwargs else dict()
    bench = kwargs['bench']         if 'bench'     in kwargs else 0
    if 'grid' not in kwargs:
      raise RuntimeError('Must provide grid for kernel launch')
    grid = kwargs['grid']


    #########################
    # cache
    ########################

    # create a new framework op when defines are different
    key = '-'.join(['{key}-{val}'.format(key=key, val=val) for key, val in defines.items()])
    if key not in self.fw_id.keys():
      # code generation options
      macros = []
      for k, v in defines.items():
        cvt = lambda x: _cvt_to_def_str(x)
        if(isinstance(v, list)):
          values = list(map(cvt, v))
        else:
          values = [cvt(v)]
        macros.append((k, values))
      opt = libtriton.options_space()
      opt.defines = macros
      opt.num_warps = [2, 4, 8]
      # create unique id for this op
      op_id = libtriton.make_op_id()
      self.fw_id[key] = op_id
      # register function
      libtriton.register_fn(op_id, self.src, opt)
      for name, value in self.cst.items():
        libtriton.register_cst(op_id, name, value)
      if self.fw_op is None:
        self.fw_op = _make_framework_op(self.src, self.outputs, self.tmp, opt)

    ########################
    # initialize
    ########################
    op_id = self.fw_id[key]
    libtriton.register_grid(op_id, grid)
    bench_id = libtriton.make_scalar_id() if bench > 0 else -1

    #########################
    # call framework function
    #########################
    if fw.has_tensorflow():
      empty = [x for x in args if isinstance(x, triton.utils.tf_empty_proxy)]
      if len(empty) != len(self.outputs):
        raise ValueError('Number of empty arguments does not much number of outputs provided')
      # operands
      operands = [x.shape if isinstance(x, triton.utils.tf_empty_proxy) else x for x in args]
      # output data types
      kwargs = {'id': op_id, 'bench': bench, 'bench_id': bench_id}
      for i, x in enumerate(args):
        if isinstance(x, triton.utils.tf_empty_proxy):
          kwargs['T' + str(i)] = x.dtype
      # launch
      ret = self.fw_op(*operands, **kwargs)
      ret = [ret] if isinstance(ret, fw.tensorflow.Tensor) else ret
      op_def = ret[0].op.op_def
       # fill empty tensors with corresponding values
      for j, y in enumerate(op_def.output_arg):
        found = False
        for i, x in enumerate(op_def.input_arg):
          if y.name + '_shape' == x.name:
            args[i].tensor = ret[j]
            found = True
        assert found
      # store timing information
      if bench > 0:
        for y in ret: 
          bench_registry[y] = triton.utils.id_dict.lazy_entry(bench_id)
    
    ############################
    # call torch function
    ############################
    elif fw.has_torch():
      args = [x if isinstance(x, fw.torch.Tensor) else x for x in args]
      ret = self.fw_op(op_id, bench, bench_id, *args)
      if bench > 0:
        bench_registry[ret] = libtriton.retrieve_scalar(bench_id)

    else:
      assert False