Ejemplo n.º 1
0
def get_target_source(best_config, dir_sid=None):
  # Note: Not thread safe due to multiple invokes of target codegen

  global_arg_props = get_global_arg_props()
  def get_kernel_metadata(config):
    inp_args, outp_args = [], []

    for buf in global_arg_props['_in']:
      if buf['name'].startswith('_'):
        # Just for Auto Shard
        assert(buf['dtype'] == 'int32' and buf['shape'] == [1])
        continue
      inp_args.append('%s:%s%s' % (buf['name'], buf['dtype'], buf['shape']))
    for buf in global_arg_props['_out']:
      outp_args.append('%s:%s%s' % (buf['name'], buf['dtype'], buf['shape']))

    device_code = os.environ.get('DEVICE_NAME', '')
    device_code = device_code if device_code else 'default'
    header_meta = '// GLOBALS: ' + ', '.join(inp_args) + ' -> ' + ', '.join(outp_args) + '\n// BACKEND: %s (%s)\n' % (backend, device_code)
    properties = "// CONFIG: %s\n// COMPUTE_V1: %s\n" % (config.strip() if isinstance(config, str) else '', os.environ['COMPUTE_V1'])
    return header_meta + properties

  def slices_to_code(kernel_slices):
    def tensor_display(encoded_name, prop):
      return f'{encoded_name}:{prop["dtype"]}{str(prop["shape"])}'

    kernel_slices.sort()
    code = ['']
    for i, (kernel_id, kernel_name, args, body) in enumerate(kernel_slices):
      num_outputs = len(global_arg_props['_out']) if i + 1 == len(kernel_slices) else 1
      display_inputs = ', '.join([tensor_display(x, prop) for _, x, prop in args[:-num_outputs]])
      display_outputs = ', '.join([tensor_display(x, prop) for _, x, prop in args[-num_outputs:]])
      kernel = backend_config.do_native_translation_v2((kernel_name, args[:-num_outputs], args[-num_outputs:], body), attrs=AntaresGlobal.attrs).strip()
      code.append(f'// LOCAL: {kernel_name} -- {display_inputs} -> {display_outputs}\n\n{kernel}\n')

    del kernel_slices
    code = '\n// ---------------------------------------------------------------------------\n'.join(code)
    return code

  def pack_device_source(kernel_slices):
    device_source = slices_to_code(kernel_slices)
    device_source = '%s\n%s' % (get_kernel_metadata(best_config), device_source)
    kernel_path = local_get_dir_file('my_kernel.cc', dir_sid=dir_sid)
    with open(kernel_path, 'w') as fp:
      fp.write(device_source)
    return device_source, kernel_path

  if getattr(AntaresGlobal, 'mode', None) == 'antares':
    json_config = json.loads(best_config)
    kernel_slices = backend_config.to_kernel_slices(AntaresGlobal.compute_graph, json_config if json_config is not None else {})
    return pack_device_source(kernel_slices)

  with open(local_get_dir_file('my_kernel.time', dir_sid=dir_sid), 'w') as fp:
    fp.write('%s' % time.time())
  default_tune_op = AntaresGlobal.default_tune_op
  assert isinstance(best_config, str), "Config value must be string type, got: %s" % best_config.__class__
  if best_config.startswith('['):
    # Ansor config
    from tvm import auto_scheduler
    [origin_cfg] = json.loads(best_config)
    origin_cfg_file = local_get_dir_file('my_kernel.cfg', dir_sid=dir_sid)
    with open(origin_cfg_file, 'w') as fp:
      fp.write(json.dumps(origin_cfg))
    origin_cfg = tvm.auto_scheduler.measure_record.load_records(origin_cfg_file)

    from tuner.Ansor.main import create_auto_task
    target = tvm.target.Target(tvm_target)
    auto_task = create_auto_task(target)

    for inp, res in origin_cfg:
      s, arg_bufs = auto_task.compute_dag.apply_steps_from_state(inp.state)
      break
    with open(local_get_dir_file('my_kernel.sched', dir_sid=dir_sid), 'w') as fp:
      fp.write(auto_task.compute_dag.print_python_code_from_state(inp.state))
  else:
    AntaresGlobal.attrs.auto_config.set_candidate(json.loads(best_config))
    with tvm.target.Target(tvm_target):
      s, arg_bufs = default_tune_op.get_template_op()

  if s is not None:
      lower_source = str(tvm.lower(s, arg_bufs, simple_mode=True))

      lower_file = local_get_dir_file('my_kernel.lower', dir_sid=dir_sid)
      with open(lower_file, 'w') as fp:
        fp.write(lower_source)

      # Compile Source Code
      def build_template():
        return tvm.build(s, arg_bufs, tvm_target, name='template_op')
      func = build_template()

  assert(len(func.imported_modules) == 1)
  kernel_slices = translate_code(func.imported_modules[0].get_source(), best_config)
  return pack_device_source(kernel_slices)
Ejemplo n.º 2
0
def get_target_source(best_config, dir_sid=None):
    # Note: Not thread-safe due to multiple ordered updates for config spaces

    with open(local_get_dir_file('my_kernel.time', dir_sid=dir_sid),
              'w') as fp:
        fp.write('%s' % time.time())
    default_tune_op = AntaresGlobal.default_tune_op
    assert isinstance(
        best_config, str
    ), "Config value must be string type, got: %s" % best_config.__class__
    if best_config.startswith('['):
        # Ansor config
        from tvm import auto_scheduler
        origin_cfg = json.loads(best_config)
        origin_cfg = {
            "i": [[
                '["main_compute.<locals>.auto_template"]',
                'cuda -keys=cuda,gpu -max_num_threads=%d -thread_warp_size=%d'
                % (device_properties().max_threads_per_block,
                   device_properties().warp_size)
            ], origin_cfg],
            "r": [[0], 0, 0, 0],
            "v":
            "v0.2",
        }
        origin_cfg_file = local_get_dir_file('my_kernel.cfg', dir_sid=dir_sid)
        with open(origin_cfg_file, 'w') as fp:
            fp.write(json.dumps(origin_cfg))
        origin_cfg = tvm.auto_scheduler.measure_record.load_records(
            origin_cfg_file)

        from tuner.Ansor.main import create_auto_task
        target = tvm.target.Target(tvm_target)
        auto_task = create_auto_task(target)

        for inp, res in origin_cfg:
            s, arg_bufs = auto_task.compute_dag.apply_steps_from_state(
                inp.state)
            break
        with open(local_get_dir_file('my_kernel.sched', dir_sid=dir_sid),
                  'w') as fp:
            fp.write(
                auto_task.compute_dag.print_python_code_from_state(inp.state))
    else:
        AntaresGlobal.attrs.auto_config.set_candidate(json.loads(best_config))
        with tvm.target.Target(tvm_target):
            s, arg_bufs = default_tune_op.get_template_op()

    if s is not None:
        lower_source = str(tvm.lower(s, arg_bufs, simple_mode=True))

        lower_file = local_get_dir_file('my_kernel.lower', dir_sid=dir_sid)
        with open(lower_file, 'w') as fp:
            fp.write(lower_source)

        # Verify Lower Code Code
        if len(('\n' + lower_source).split('\nprimfn(')) != 2:
            raise Exception('[Not Support Multi Unfuse-able kernels]\n\n' +
                            lower_source)

        max_threads_per_block = device_properties().max_threads_per_block
        max_shared_memory_per_block = device_properties(
        ).max_shared_memory_per_block
        assert max_threads_per_block > 0 and max_shared_memory_per_block >= 0, '[Error] Invalid device properties, maybe device is not detected correctly.'

        lower_lines = lower_source.split('\n')
        thread_extents, allocate_shared = [], []
        for ll in lower_lines:
            if ll.strip().startswith(
                    'attr [IterVar(') and ll.find(' "thread_extent" = ') >= 0:
                thread_name = ll.split('attr [IterVar(')[-1].split(':')[0]
                thread_val = int(
                    ll.split(' "thread_extent" = ')[-1].split(';')
                    [0].strip().split(' ')[0])
                thread_extents.append((thread_name, thread_val))
            elif ll.strip().startswith(
                    'allocate(') and ll.find('.shared, ') >= 0:
                last_arg_id = ll.rindex(', [')
                allocate_val = [
                    int(x)
                    for x in ll[last_arg_id + 3:ll.rindex(']')].split(', ')
                ]
                allocate_val = int(np.product(allocate_val))
                allocate_type = ll[ll.index(', ') + 2:last_arg_id]
                allocate_shared.append((allocate_type, allocate_val))

        reserved_axes = dict()
        for thread_name, thread_val in thread_extents:
            if thread_name in reserved_axes:
                assert reserved_axes[
                    thread_name] == thread_val, "Invalid code: Multiple hints for thread extent conflict with each other: %d v.s. %d" % (
                        reserved_axes[thread_name], thread_val)
            else:
                reserved_axes[thread_name] = thread_val

        num_threads = 1
        for thread_name in ['threadIdx.x', 'threadIdx.y', 'threadIdx.z']:
            num_threads *= reserved_axes.get(thread_name, 1)
        assert num_threads <= max_threads_per_block, "Invalid kernel code: using num_threads(%d) > max_threads_per_block(%d)" % (
            num_threads, max_threads_per_block)

        shared_memory_in_bytes = 0
        for allocate_type, allocate_size in allocate_shared:
            if allocate_type.startswith('custom['):
                type_name = allocate_type[7:].split(']')[0]
            else:
                type_name = allocate_type
            shared_memory_in_bytes += get_type_size(type_name) * allocate_size

        if shared_memory_in_bytes > max_shared_memory_per_block:
            raise Exception(
                "Invalid kernel code: using shared_memory_in_bytes %d > max_shared_memory_per_block %d"
                % (shared_memory_in_bytes, max_shared_memory_per_block))

        # Compile Source Code
        def build_template():
            return tvm.build(s, arg_bufs, tvm_target, name='template_op')

        func = build_template()

    assert (len(func.imported_modules) == 1)
    device_source = translate_code(func.imported_modules[0].get_source(),
                                   best_config)
    kernel_path = local_get_dir_file('my_kernel.cc', dir_sid=dir_sid)
    with open(kernel_path, 'w') as fp:
        fp.write(device_source)
    return device_source, kernel_path
Ejemplo n.º 3
0
def get_target_source(best_config, dir_sid=None):
    # Note: Not thread-safe due to multiple ordered updates for config spaces

    with open(local_get_dir_file('my_kernel.time', dir_sid=dir_sid),
              'w') as fp:
        fp.write('%s' % time.time())
    default_tune_op = AntaresGlobal.default_tune_op
    assert isinstance(
        best_config, str
    ), "Config value must be string type, got: %s" % best_config.__class__
    if best_config.startswith('['):
        # Ansor config
        from tvm import auto_scheduler
        origin_cfg = json.loads(best_config)
        origin_cfg = {
            "i": [[
                '["main_compute.<locals>.auto_template"]',
                'cuda -keys=cuda,gpu -max_num_threads=%d -thread_warp_size=%d'
                % (device_properties().max_threads_per_block,
                   device_properties().warp_size)
            ], origin_cfg],
            "r": [[0], 0, 0, 0],
            "v":
            "v0.2",
        }
        origin_cfg_file = local_get_dir_file('my_kernel.cfg', dir_sid=dir_sid)
        with open(origin_cfg_file, 'w') as fp:
            fp.write(json.dumps(origin_cfg))
        origin_cfg = tvm.auto_scheduler.measure_record.load_records(
            origin_cfg_file)

        from tuner.Ansor.main import create_auto_task
        target = tvm.target.Target(tvm_target)
        auto_task = create_auto_task(target)

        for inp, res in origin_cfg:
            s, arg_bufs = auto_task.compute_dag.apply_steps_from_state(
                inp.state)
            break
        with open(local_get_dir_file('my_kernel.sched', dir_sid=dir_sid),
                  'w') as fp:
            fp.write(
                auto_task.compute_dag.print_python_code_from_state(inp.state))
    else:
        AntaresGlobal.attrs.auto_config.set_candidate(json.loads(best_config))
        with tvm.target.Target(tvm_target):
            s, arg_bufs = default_tune_op.get_template_op()

    if s is not None:
        lower_source = str(tvm.lower(s, arg_bufs, simple_mode=True))

        lower_file = local_get_dir_file('my_kernel.lower', dir_sid=dir_sid)
        with open(lower_file, 'w') as fp:
            fp.write(lower_source)

        # Compile Source Code
        def build_template():
            return tvm.build(s, arg_bufs, tvm_target, name='template_op')

        func = build_template()

    assert (len(func.imported_modules) == 1)
    device_source = translate_code(func.imported_modules[0].get_source(),
                                   best_config)
    kernel_path = local_get_dir_file('my_kernel.cc', dir_sid=dir_sid)
    with open(kernel_path, 'w') as fp:
        fp.write(device_source)
    return device_source, kernel_path