示例#1
0
def zero_pad2d(inputs, padding=0):
    """Zero padding for 2d tensor

    Args:
    -----------------------------
    inputs : tvm.te.tensor.Tensor
        shape [batch, channel, height, width]
    padding: (optional:0) int or tuple
        expected: (h_pad_up, h_pad_down, w_pad_up, w_pad_down)
    -----------------------------

    Returns:
    -----------------------------
    tvm.te.tensor.Tensor
        shape [batch, channel, padded_height, padded_width]
    -----------------------------
    """
    padding = (padding, padding, padding, padding) if isinstance(padding, (int, tvm.tir.IntImm)) else padding
    assert_print(isinstance(padding, tuple), "type(padding)={}".format(type(padding)))
    if len(padding) == 2:
        padding = (padding[0], padding[0], padding[1], padding[1])
    assert_print(len(padding) == 4)

    padding_zero = tvm.tir.expr.const(0, inputs.dtype)

    batch_size, in_channel, height, width = inputs.shape
    return tvm.te.compute(
        (batch_size, in_channel, height + padding[0] + padding[1], width + padding[2] + padding[3]),
        lambda b, c, h, w: tvm.te.if_then_else(
                            tvm.te.all(h >= padding[0], h < height + padding[0], w >= padding[2], w < width + padding[2]),
                            inputs[b, c, h - padding[0], w - padding[2]],
                            padding_zero
                            ),
        name='Padding', requires_grad=True
        )
示例#2
0
def cross_entropy(inputs, targets):
    '''
  https://pytorch.org/docs/stable/nn.html#torch.nn.CrossEntropyLoss
  loss(x,class) = −x[class]+log(\Sigma:j exp(x[j]))
  -x[class]: since targets is one-hot, we use "inner-dot" to compute x_class
  log(\Sigma:j exp(x[j])) may overflow when computing exp(x[j])
  >>>>Trick: maxval + log(\Sigma: j exp(x[j] - maxval))
  Finally, compute the average over batches
  '''
    assert_print(inputs.shape[0].value == targets.shape[0].value)
    assert_print(inputs.shape[1].value == targets.shape[1].value)
    N, C = inputs.shape
    c = tvm.te.reduce_axis([0, C], "c")
    k1 = tvm.te.reduce_axis([0, C], name="k1")
    # First compute the maximum for each batch
    max_val = tvm.te.compute([N],
                             lambda n: tvm.te.max(inputs[n, k1], axis=[k1]),
                             name="max_val")
    # Use the log_softmax trick to avoid overflow
    sum_val = tvm.te.compute(
        [N],
        lambda i: tvm.te.sum(tvm.tir.exp(inputs[i, c] - max_val[i]), axis=[c]),
        "sum_val")
    rrn = tvm.te.reduce_axis([0, N], "rrn")
    rrc = tvm.te.reduce_axis([0, C], "rrc")
    x_class = tvm.te.compute(
        [N],
        lambda i: tvm.te.sum(inputs[i, rrc] * targets[i, rrc], axis=[rrc]),
        name="x_class")
    return tvm.te.compute([1],
                          lambda i: tvm.te.sum(
                              (tvm.tir.log(sum_val[i + rrn]) + max_val[i + rrn]
                               - x_class[i + rrn]) / N,
                              axis=[rrn]),
                          name="cross_entropy")
示例#3
0
 def add_subspace(self, name, subspace, type_key, override=False):
     if name in self.subspaces and not override:
         raise RuntimeError("Same subspace name")
     assert_print(type_key in self.valid_type_keys)
     self.subspaces[name] = subspace
     self.types[type_key].append(name)
     self.dim += subspace.dim
示例#4
0
    def __init__(self,
                 in_channel,
                 out_channel,
                 kernel_size,
                 bias=False,
                 padding=0,
                 stride=1,
                 dilation=1,
                 groups=1):
        super(Conv2dLayer, self).__init__()
        if isinstance(kernel_size, (int, tvm.expr.IntImm)):
            kernel_size = (kernel_size, kernel_size)
        assert_print(isinstance(kernel_size, tuple) and len(kernel_size) == 2)
        self.weight = tvm.placeholder((out_channel, in_channel, *kernel_size),
                                      dtype="float32")
        self.params["weight"] = self.weight
        if bias:
            self.bias = tvm.placeholder((out_channel, ), dtype="float32")
            self.params["bias"] = self.bias
        else:
            self.bias = None

        def forward_func(inputs):
            return conv2d_nchw(inputs, self.weight, self.bias, stride, padding,
                               dilation, groups)

        self.forward_func = forward_func
示例#5
0
 def forward(self, batch_size, policy="random"):
     assert_print(policy in ["random", "best"])
     ret = dict()
     for name, walker in self.walkers.items():
         if policy == "random":
             ret_entities, ret_p_values = walker.random_batch(batch_size)
         elif policy == "best":
             ret_entities, ret_p_values = walker.best_batch(batch_size)
         ret[name] = (ret_entities, ret_p_values)
     return ret
示例#6
0
 def __init__(self, input_len, width, depth, output_len):
     super(Judger, self).__init__()
     assert_print(isinstance(width, int) and width > 0)
     assert_print(isinstance(depth, int) and depth > 1)
     self.net = nn.Sequential()
     self.net.add_module("input", nn.Linear(input_len, width))
     self.net.add_module("input_activate", nn.ReLU())
     for count in range(depth - 2):
         name = "hidden_{}".format(count)
         self.net.add_module(name, nn.Linear(width, width))
         self.net.add_module(name + "_activate", nn.ReLU())
     self.net.add_module("output", nn.Linear(width, output_len))
示例#7
0
def gemm(A, B, transposeA=False, transposeB=False):
    """Matrix multiplies matrix

    Args:
    -----------------------------
    A: tvm.te.tensor.Tensor
        shape [height, width]
    B: tvm.te.tensor.Tensor
        shape [width, length]
    transposeA: (optional:False) bool
    transposeB: (optional:False) bool
    -----------------------------

    Returns:
    -----------------------------
    tvm.te.tensor.Tensor
        shape [height, length]
    -----------------------------
    """
    if transposeA and transposeB:
        k = tvm.te.reduce_axis((0, B.shape[1]))
        assert_print(A.shape[0].value == B.shape[1].value)
        return tvm.te.compute((A.shape[1], B.shape[0]), lambda i, j: tvm.te.sum(A[k, i] * B[j, k], axis=k), requires_grad=True)
    elif transposeA and not transposeB:
        k = tvm.te.reduce_axis((0, B.shape[0]))
        assert_print(A.shape[0].value == B.shape[0].value)
        return tvm.te.compute((A.shape[1], B.shape[1]), lambda i, j: tvm.te.sum(A[k, i] * B[k, j], axis=k), requires_grad=True)
    elif not transposeA and transposeB:
        k = tvm.te.reduce_axis((0, B.shape[1]))
        assert_print(A.shape[1].value == B.shape[1].value)
        return tvm.te.compute((A.shape[0], B.shape[0]), lambda i, j: tvm.te.sum(A[i, k] * B[j, k], axis=k), requires_grad=True)
    else:
        k = tvm.te.reduce_axis((0, B.shape[0]))
        assert_print(A.shape[1].value == B.shape[0].value)
        return tvm.te.compute((A.shape[0], B.shape[1]), lambda i, j: tvm.te.sum(A[i, k] * B[k, j], axis=k), requires_grad=True)
示例#8
0
def lltm(input, targets, weight_for_classify, bias_for_classify,
         weight_for_gate, bias_for_gate, old_h, old_c):
    '''
  input: [batch_size, 28*28]
  new/old_h & new/old_c: [batch_size, state_size]
  weight_for_classify: [10, state_size]
  bias_for_classify: [10]
  result: [batch_size, 10]
  targets: [batch_size, 10] one-hot
  '''
    new_h, new_c = internel_lltm(input, weight_for_gate, bias_for_gate, old_h,
                                 old_c)
    assert_print(new_h.shape[1].value == weight_for_classify.shape[1].value)
    result = topi.nn.dense(new_h, weight_for_classify, bias_for_classify)
    loss = cross_entropy(result, targets)
    return loss, result, new_h, new_c
示例#9
0
def conv2d_nchw(inputs,
                weight,
                bias=None,
                stride=1,
                padding=0,
                dilation=1,
                groups=1):
    batch_size, in_channel, in_h, in_w = inputs.shape
    out_channel, channel_per_group, k_h, k_w = weight.shape
    assert_print((channel_per_group * groups).value == in_channel.value)
    out_channel_per_group = out_channel // groups
    assert_print((out_channel_per_group * groups).value == out_channel.value)

    stride = (stride, stride) if isinstance(stride,
                                            (int, tvm.tir.IntImm)) else stride
    padding = (padding,
               padding) if isinstance(padding,
                                      (int, tvm.tir.IntImm)) else padding
    dilation = (dilation,
                dilation) if isinstance(dilation,
                                        (int, tvm.tir.IntImm)) else dilation
    assert_print(isinstance(stride, tuple) and len(stride) == 2)
    assert_print(isinstance(padding, tuple) and len(padding) == 2)
    assert_print(isinstance(dilation, tuple) and len(dilation) == 2)

    out_h = (in_h + 2 * padding[0] - dilation[0] *
             (k_h - 1) - 1) // stride[0] + 1
    out_w = (in_w + 2 * padding[1] - dilation[1] *
             (k_w - 1) - 1) // stride[1] + 1
    rc = tvm.te.reduce_axis((0, channel_per_group), name="rc")
    rh = tvm.te.reduce_axis((0, k_h), name="rh")
    rw = tvm.te.reduce_axis((0, k_w), name="rw")

    padded = zero_pad2d(inputs, padding=padding)
    output = tvm.te.compute(
        (batch_size, out_channel, out_h, out_w),
        lambda b, c, h, w: tvm.te.sum(
            (padded[b, c // out_channel_per_group * channel_per_group + rc, h *
                    stride[0] + rh * dilation[0], w * stride[
                        1] + rw * dilation[1]] * weight[c, rc, rh, rw]),
            axis=[rc, rw, rh]),
        requires_grad=True)
    if bias is not None:
        output = tvm.te.compute(
            (batch_size, out_channel, out_h, out_w),
            lambda b, c, h, w: output[b, c, h, w] + bias[c])
    return output
示例#10
0
def zero_pad2d(inputs, padding=0):
  padding = (padding, padding, padding, padding) if isinstance(padding, (int, tvm.tir.IntImm)) else padding
  assert_print(isinstance(padding, tuple), "type(padding)={}".format(type(padding)))
  if len(padding) == 2:
      padding = (padding[0], padding[0], padding[1], padding[1])
  assert_print(len(padding) == 4)

  padding_zero = tvm.tir.expr.const(0, inputs.dtype)

  batch_size, in_channel, height, width = inputs.shape
  return tvm.te.compute(
      (batch_size, in_channel, height + padding[0] + padding[1], width + padding[2] + padding[3]),
      lambda b, c, h, w: tvm.te.if_then_else(
                          tvm.te.all(h >= padding[0], h < height + padding[0], w >= padding[2], w < width + padding[2]),
                          inputs[b, c, h - padding[0], w - padding[2]],
                          padding_zero
                          ),
      name='Padding', requires_grad=True
      )
示例#11
0
 def next_entity(self, pos, d):
     # d is tuple
     if len(d) == 1:
         next_pos = (pos + d[0]) % self.size
         return next_pos
     elif len(d) == 2:
         asc_pos, dec_pos = d[0], d[1]
         assert_print(0 <= asc_pos < self.dim)
         assert_print(0 <= dec_pos < self.dim)
         assert_print(asc_pos != dec_pos)
         current = self.static_entities[pos]
         ret = current.copy()
         left = current[asc_pos] * current[dec_pos]
         canout = False
         next_pos = -1
         while not canout:
             tmp = ret[asc_pos] + 1
             while tmp <= left:
                 if self.allow_non_divisible == 'continuous':
                     break
                 elif self.allow_non_divisible == 'power2' and is_power_of_x(
                         2, tmp):
                     break
                 elif left % tmp == 0:
                     break
                 tmp += 1
             tmp = min(tmp, left)
             ret[asc_pos] = tmp
             ret[dec_pos] = math.ceil(left / tmp)
             try:
                 next_pos = self.static_entities.index(ret)
                 canout = True
             except ValueError:
                 canout = False
         return next_pos
     else:
         raise RuntimeError(
             "Not support for direction more than two dims: {}".format(d))
示例#12
0
def cpu_schedule_simple(config, s, op, op_state):
    # always cache write here
    # if op.num_outputs > 1:
    #     raise RuntimeWarning("Too many outputs in one operation!")
    write_cache = s.cache_write(op.output(0), "global")
    # spatial split
    spatial_axes = s[op].op.axis
    splited_spatial_axes = []
    if "spatial" in config and len(config["spatial"]) > 0:
        # to align each axis
        assert_print(len(config["spatial"]) == len(spatial_axes), "align failed")
        for axis, nparts in zip(spatial_axes, config["spatial"]):
            nfactors = [1]
            count = len(nparts) - 1
            while count >= 0:
                nfactors.append(nparts[count] * nfactors[-1])
                count -= 1
            tmp_buffer = []
            num_factors = len(nfactors)
            for i in range(num_factors - 2):
                factor = nfactors[num_factors - 2 - i]
                part = nparts[i]
                if factor == 1:
                    tmp_buffer.append(axis)
                    axis = None
                elif part == 1:
                    tmp_buffer.append(None)
                else:
                    outer, axis = s[op].split(axis, factor=factor)
                    tmp_buffer.append(outer)
            tmp_buffer.append(axis)
            splited_spatial_axes.append(tmp_buffer)
    else:
        for axis in spatial_axes:
            splited_spatial_axes.append([axis])
    assert_print(len(splited_spatial_axes) > 0, "empty spatial axes")  # must be non-empty

    # always reorder and fuse here
    # this part actually suppose there is "spatial" in config
    # which is avoidable
    spatial_fuse_lsts = []
    spatial_fuse_extents = []
    reorder_lst = []
    fused_spatial_axes = []
    spatial_split_num_parts = len(splited_spatial_axes[0])
    for count in range(spatial_split_num_parts):
        tmp_buffer = [x[count] for x in splited_spatial_axes]
        tmp_extent = reduce(lambda a, b: a * b, [x[count] for x in config["spatial"]])
        spatial_fuse_lsts.append(tmp_buffer)
        spatial_fuse_extents.append(tmp_extent)
        reorder_lst.extend(tmp_buffer)
    reorder_lst_without_none = list(filter(lambda x: x is not None, reorder_lst))
    # print("reorder op", reorder_lst_without_none)
    s[op].reorder(*reorder_lst_without_none)
    for fuse_lst in spatial_fuse_lsts[:1]:
        tmp_buffer = list(filter(lambda x: x is not None, fuse_lst))
        # print("fuse op", tmp_buffer)
        fused = s[op].fuse(*tmp_buffer)
        fused_spatial_axes.append(fused)
    kernel_scope = fused_spatial_axes[0]
    if len(spatial_fuse_lsts) > 1:
        count = 0
        while count < len(config["spatial"]) and config["spatial"][count][1] == 1:
            count += 1
        if count == len(config["spatial"]):
            count -= 1
        next_pos_for_comptue_at = spatial_fuse_lsts[1][count]
    else:
        next_pos_for_comptue_at = kernel_scope

        # always parallel here
    s[op].parallel(kernel_scope)
    # vectorize
    if len(spatial_fuse_lsts) == 2:
        count = len(spatial_fuse_lsts[1]) - 1
        while count >= 1:
            if spatial_fuse_lsts[1][count] is not None and config["spatial"][1][count] > 1:
                # print("vectorize op", spatial_fuse_lsts[1][count])
                s[op].vectorize(spatial_fuse_lsts[1][count])
                break
            count -= 1
    elif len(spatial_fuse_lsts) > 2:
        count = len(spatial_fuse_lsts[-1]) - 1
        while count >= 0:
            if spatial_fuse_lsts[-1][count] is not None and config["spatial"][count][
                -1] > 1:
                # print("vectorize op", spatial_fuse_lsts[-1][count])
                s[op].vectorize(spatial_fuse_lsts[-1][count])
                break
            count -= 1
    # always compute at here
    # print("compute at", next_pos_for_comptue_at)
    s[write_cache].compute_at(s[op], next_pos_for_comptue_at)

    # spatial_split for write cache
    spatial_axes = s[write_cache].op.axis
    num_spatial_axes = len(spatial_axes)
    splited_spatial_axes = []
    if "spatial" in config and len(config["spatial"]) > 0:
        # to align each axis
        assert_print(len(config["spatial"]) == len(spatial_axes), "align failed")
        for axis, nparts in zip(spatial_axes, config["spatial"]):
            nfactors = [1]
            count = len(nparts) - 1
            while count >= 0:
                nfactors.append(nparts[count] * nfactors[-1])
                count -= 1
            tmp_buffer = []
            num_factors = len(nfactors)
            for i in range(num_factors - 2):
                factor = nfactors[num_factors - 2 - i]
                part = nparts[i]
                if factor == 1:
                    tmp_buffer.append(axis)
                    axis = None
                elif part == 1:
                    tmp_buffer.append(None)
                else:
                    outer, axis = s[write_cache].split(axis, factor=factor)
                    tmp_buffer.append(outer)
            tmp_buffer.append(axis)
            splited_spatial_axes.append(tmp_buffer)
    else:
        for axis in spatial_axes:
            splited_spatial_axes.append([axis])
    assert_print(len(splited_spatial_axes) > 0, "empty spatial axes")  # must be non-empty
    # reduce_split for write cache
    reduced_axes = s[write_cache].op.reduce_axis
    num_reduce_axes = len(reduced_axes)
    splited_reduced_axes = []
    if "reduce" in config and len(config["reduce"]) > 0:
        # to align each axis
        assert_print(len(config["reduce"]) == len(reduced_axes), "align reduce failed")
        for axis, nparts in zip(reduced_axes, config["reduce"]):
            nfactors = [1]
            count = len(nparts) - 1
            while count >= 0:
                nfactors.append(nparts[count] * nfactors[-1])
                count -= 1
            tmp_buffer = []
            num_factors = len(nfactors)
            for i in range(num_factors - 2):
                factor = nfactors[num_factors - 2 - i]
                part = nparts[i]
                if factor == 1:
                    tmp_buffer.append(axis)
                    axis = None
                elif part == 1:
                    tmp_buffer.append(None)
                else:
                    outer, axis = s[write_cache].split(axis, factor=factor)
                    tmp_buffer.append(outer)
            tmp_buffer.append(axis)
            splited_reduced_axes.append(tmp_buffer)
    else:
        for axis in reduced_axes:
            splited_reduced_axes.append([axis])
    # for easy align
    # reduce_split_num_parts = len(splited_reduced_axes[0])
    # assert reduce_split_num_parts == spatial_split_num_parts
    # reorder hybrid for spatial and reduce
    hybrid_axes = splited_spatial_axes + splited_reduced_axes
    hybrid_fuse_lsts = []
    hybrid_reorder_lst = []
    for count in range(spatial_split_num_parts):
        tmp_buffer = [x[count] for x in hybrid_axes]
        hybrid_fuse_lsts.append(tmp_buffer)
        hybrid_reorder_lst.extend(tmp_buffer)
    if len(hybrid_fuse_lsts) > 1:
        last_parts = hybrid_reorder_lst[-num_spatial_axes - num_reduce_axes:]
        hybrid_reorder_lst = hybrid_reorder_lst[:-num_spatial_axes - num_reduce_axes]
        tmp_buffer = last_parts[-num_reduce_axes:]
        tmp_buffer.extend(last_parts[:-num_reduce_axes])
        hybrid_reorder_lst.extend(tmp_buffer)
    hybrid_reorder_lst_without_none = list(
        filter(lambda x: x is not None, hybrid_reorder_lst))
    # print("reorder cache write", hybrid_reorder_lst_without_none)
    s[write_cache].reorder(*hybrid_reorder_lst_without_none)
    # fuse without reduce axes
    # assert len(hybrid_fuse_lsts) > 0
    # s[write_cache].fuse(*hybrid_fuse_lsts[0][:-num_reduce_axes])

    # unroll and vectorize without reduce axes
    if len(hybrid_fuse_lsts) > 1:
        rcount = num_spatial_axes - 1
        while rcount >= 0 and config["spatial"][rcount][-1] == 1:
            rcount -= 1
        if rcount >= 0:
            # print("vectorize cache write", hybrid_fuse_lsts[-1][rcount])
            s[write_cache].vectorize(hybrid_fuse_lsts[-1][rcount])
        for count in range(rcount):
            if config["spatial"][count][-1] > 1:
                # print("unroll cache write", hybrid_fuse_lsts[-1][count])
                s[write_cache].unroll(hybrid_fuse_lsts[-1][count])
    if len(hybrid_fuse_lsts) > 2:
        for count in range(num_spatial_axes):
            if config["spatial"][count][-2] > 1:
                # print("unroll cache write", hybrid_fuse_lsts[-2][count])
                s[write_cache].unroll(hybrid_fuse_lsts[-2][count])
示例#13
0
def cpu_schedule_split_reorder_fuse(config, s, op, op_state):
    # assert_print(op in s)

    loop_idx = []
    loop_lst = []

    # always cache write here
    # if op.num_outputs > 1:
    #     raise RuntimeWarning("Too many outputs in one operation!")
    write_cache = s.cache_write(op.output(0), "local")

    # spatial split
    spatial_axes = [axis for axis in s[op].op.axis]
    assert len(spatial_axes) > 0, "empty spatial axes"  # must be non-empty
    n = spatial_axes[0]
    kernel_scope, n = s[op].split(n, nparts=1)
    spatial_axes[0] = n

    splited_spatial_axes = []
    splited_spatial_extents = []
    if "spatial" in config and len(config["spatial"]) > 0:
        # to align each axis
        assert len(config["spatial"]) == len(spatial_axes), "align failed"
        for axis, nparts in zip(spatial_axes, config["spatial"]):
            tmp_buffer = []
            tmp_extents = []
            for count in range(len(nparts) - 1):
                outer, axis = s[op].split(axis, nparts=nparts[count])
                tmp_buffer.append(outer)
                tmp_extents.append(nparts[count])
            tmp_buffer.append(axis)
            tmp_extents.append(nparts[-1])
            splited_spatial_axes.append(tmp_buffer)
            splited_spatial_extents.append(tmp_extents)
    else:
        for axis in spatial_axes:
            splited_spatial_axes.append([axis])
            splited_spatial_extents.append([axis.dom.extent.value])

    # always reorder here
    reorder_lst = []
    reorder_parts = []
    reorder_part_extents = []
    for count in range(len(splited_spatial_axes[0])):
        tmp_buffer = [x[count] for x in splited_spatial_axes]
        tmp_extents = [x[count] for x in splited_spatial_extents]
        reorder_lst.extend(tmp_buffer)
        reorder_parts.append(tmp_buffer)
        reorder_part_extents.append(tmp_extents)
    s[op].reorder(*reorder_lst)

    # handle fuse request
    fused_parts = []
    fused_part_extents = []
    fused_part_idx = []
    if "fuse" in config and len(config["fuse"]) > 0:
        base_id = 0
        for part, extents in zip(reorder_parts, reorder_part_extents):
            tmp_part = []
            tmp_extents = []
            tmp_idx = []
            idx = 0
            beg = 0
            for end in config["fuse"][0]:
                if end - beg > 1:
                    fuse_lst = part[beg:end]
                    fused = s[op].fuse(*fuse_lst)
                    tmp_part.append(fused)
                    extent = reduce(lambda x, y: x * y, extents[beg:end], 1)
                    tmp_idx.extend([idx] * (end - beg))
                    idx += 1
                    tmp_extents.append(extent)
                elif end - beg == 1:
                    tmp_part.append(part[beg])
                    tmp_extents.append(extents[beg])
                    tmp_idx.append(idx)
                    idx += 1
                beg = end
            fused_parts.append(tmp_part)
            fused_part_extents.append(tmp_extents)
            fused_part_idx.append(tmp_idx)

            # for op state
            loop_lst.extend(tmp_part)
            loop_idx.extend([x + base_id for x in tmp_idx])
            base_id += len(tmp_part)
    else:
        fused_parts = reorder_parts
        fused_part_extents = reorder_part_extents
        fused_part_idx = [list(range(len(x))) for x in reorder_parts]

        # for op state
        loop_lst = reorder_lst
        loop_idx = list(range(len(reorder_lst)))

    # record op state
    op_state.loop_lst = loop_lst
    op_state.loop_idx = loop_idx

    # parallel
    fused = s[op].fuse(*fused_parts[0])
    s[op].parallel(fused)

    # compute at
    num_parts = len(fused_parts)
    if num_parts == 1:
        local_pos = fused
    elif num_parts == 2:
        local_pos = fused_parts[num_parts - 1][0]
    else:
        local_pos = fused_parts[num_parts - 2][-1]

    if "local_pos" in config and len(config["local_pos"]) > 0:
        local_at_part = config["local_pos"][0][0]
        local_at_idx = config["local_pos"][0][1]
        # index changed because of fusion
        cur_idx = fused_part_idx[local_at_part][local_at_idx]
        local_pos = fused_parts[local_at_part][cur_idx]

    # always compute at here
    s[write_cache].compute_at(s[op], local_pos)

    # reduce_split
    reduced_axes = s[write_cache].op.reduce_axis
    splited_reduced_axes = []
    if "reduce" in config and len(config["reduce"]) > 0:
        # to align each axis
        assert_print(len(config["reduce"]) == len(reduced_axes), "align reduce failed")
        for axis, nparts in zip(reduced_axes, config["reduce"]):
            tmp_buffer = []
            for count in range(len(nparts) - 1):
                outer, axis = s[write_cache].split(axis, nparts=nparts[count])
                tmp_buffer.append(outer)
            tmp_buffer.append(axis)
            splited_reduced_axes.append(tmp_buffer)
    else:
        for axis in reduced_axes:
            splited_reduced_axes.append([axis])

    # if has reduce axes
    if len(splited_reduced_axes) > 0:
        # always reorder here
        reduced_nonfuse_lsts = []
        reorder_lst = []
        length = len(splited_reduced_axes[0])
        # leave the last part
        for count in range(length - 1):
            tmp_buffer = [x[count] for x in splited_reduced_axes]
            reduced_nonfuse_lsts.append(tmp_buffer)
            reorder_lst.extend(tmp_buffer)
        # the last part
        last_part = [x[length - 1] for x in splited_reduced_axes]
        spatial_remainder = s[write_cache].op.axis
        # change the order of reduce axes and spatial axes
        if "reorder" in config and len(config["reorder"]) > 0:
            pos = config["reorder"][0][0]
            assert pos < len(spatial_remainder)
            tmp_buffer = []
            count = len(spatial_remainder) - 1
            while count > pos:
                tmp_buffer.append(spatial_remainder[count])
                count -= 1
            p = pos
            q = len(last_part) - 1
            while p >= 0 and q >= 0:
                tmp_buffer.append(spatial_remainder[p])
                tmp_buffer.append(last_part[q])
                p -= 1
                q -= 1
            while p >= 0:
                tmp_buffer.append(spatial_remainder[p])
                p -= 1
            while q >= 0:
                tmp_buffer.append(last_part[q])
                q -= 1
            tmp_buffer = list(reversed(tmp_buffer))
            reorder_lst.extend(tmp_buffer)
        else:
            reorder_lst.extend(last_part)
            reorder_lst.extend(spatial_remainder)
        s[write_cache].reorder(*reorder_lst)

    # unroll
    if "unroll" in config and len(config["unroll"]) > 0:
        step = config["unroll"][0][0]
        explicit = config["unroll"][0][1]
        s[op].pragma(kernel_scope, 'auto_unroll_max_step', step)
        s[op].pragma(kernel_scope, 'unroll_explicit', explicit)
示例#14
0
def gemm(A, B):
    k = tvm.te.reduce_axis((0, B.shape[0]))
    assert_print(A.shape[1].value == B.shape[0].value)
    return tvm.te.compute([A.shape[0], B.shape[1]],
                          lambda i, j: tvm.te.sum(A[i, k] * B[k, j], axis=k),
                          requires_grad=True)
示例#15
0
def cuda_schedule_split_fuse(config, s, op, op_state):
    # assert_print(op in s)

    # always cache write here
    # if op.num_outputs > 1:
    #     raise RuntimeWarning("Too many outputs in one operation!")
    write_cache = s.cache_write(op.output(0), "local")

    # always cache read here
    read_cache_share_lst = []
    read_cache_local_lst = []
    for t in op.input_tensors:
        share = s.cache_read(t, "shared", [write_cache])
        read_cache_share_lst.append(share)
        local = s.cache_read(share, "local", [write_cache])
        read_cache_local_lst.append(local)

    # spatial split
    spatial_axes = s[op].op.axis
    splited_spatial_axes = []
    if "spatial" in config and len(config["spatial"]) > 0:
        # to align each axis
        assert_print(
            len(config["spatial"]) == len(spatial_axes), "align failed")
        for axis, nparts in zip(spatial_axes, config["spatial"]):
            tmp_buffer = []
            for count in range(len(nparts) - 1):
                outer, axis = s[op].split(axis, nparts=nparts[count])
                tmp_buffer.append(outer)
            tmp_buffer.append(axis)
            splited_spatial_axes.append(tmp_buffer)
    else:
        for axis in spatial_axes:
            splited_spatial_axes.append([axis])
    assert_print(len(splited_spatial_axes) > 0,
                 "empty spatial axes")  # must be non-empty

    # always reorder and fuse here
    spatial_fuse_lsts = []
    spatial_fuse_extents = []
    reorder_lst = []
    fused_spatial_axes = []
    for count in range(len(splited_spatial_axes[0])):
        tmp_buffer = [x[count] for x in splited_spatial_axes]
        tmp_extent = reduce(lambda a, b: a * b,
                            [x[count] for x in config["spatial"]])
        spatial_fuse_lsts.append(tmp_buffer)
        spatial_fuse_extents.append(tmp_extent)
        reorder_lst.extend(tmp_buffer)
    s[op].reorder(*reorder_lst)
    for fuse_lst in spatial_fuse_lsts:
        fused = s[op].fuse(*fuse_lst)
        fused_spatial_axes.append(fused)
    kernel_scope = fused_spatial_axes[0]

    # always bind here
    length = len(fused_spatial_axes)
    thread_extents = 1
    assert_print(length > 1, "fused axes length <= 1")
    if 2 <= length <= 3:
        s[op].bind(fused_spatial_axes[0], te.thread_axis("blockIdx.x"))
        s[op].bind(fused_spatial_axes[1], te.thread_axis("threadIdx.x"))
        thread_pos = fused_spatial_axes[1]
        thread_extents = spatial_fuse_extents[1]
    else:
        s[op].bind(fused_spatial_axes[0], te.thread_axis("blockIdx.x"))
        s[op].bind(fused_spatial_axes[1], te.thread_axis("vthread"))
        s[op].bind(fused_spatial_axes[2], te.thread_axis("threadIdx.x"))
        thread_pos = fused_spatial_axes[2]
        thread_extents = spatial_fuse_extents[2]

    # always compute at here
    s[write_cache].compute_at(s[op], thread_pos)

    # reduce_split
    reduced_axes = s[write_cache].op.reduce_axis
    splited_reduced_axes = []
    if "reduce" in config and len(config["reduce"]) > 0:
        # to align each axis
        assert_print(
            len(config["reduce"]) == len(reduced_axes), "align reduce failed")
        for axis, nparts in zip(reduced_axes, config["reduce"]):
            tmp_buffer = []
            for count in range(len(nparts) - 1):
                outer, axis = s[write_cache].split(axis, nparts=nparts[count])
                tmp_buffer.append(outer)
            tmp_buffer.append(axis)
            splited_reduced_axes.append(tmp_buffer)
    else:
        for axis in reduced_axes:
            splited_reduced_axes.append([axis])
    share_pos = None
    local_pos = None
    # if has reduce axes
    if len(splited_reduced_axes) > 0:
        # always reorder here
        reduced_nonfuse_lsts = []
        reorder_lst = []
        length = len(splited_reduced_axes[0])

        for count in range(length):
            tmp_buffer = [x[count] for x in splited_reduced_axes]
            reduced_nonfuse_lsts.append(tmp_buffer)
            reorder_lst.extend(tmp_buffer)
        # change the order of reduce axes and spatial axes
        reorder_lst.extend(s[write_cache].op.axis)
        s[write_cache].reorder(*reorder_lst)

        if length == 1:
            share_pos = reduced_nonfuse_lsts[-1][0]
        else:
            share_pos = reduced_nonfuse_lsts[-2][0]
            local_pos = reduced_nonfuse_lsts[-1][-1]

    # always cache read here
    if share_pos is not None:
        for share in read_cache_share_lst:
            s[share].compute_at(s[write_cache], share_pos)
    else:
        for share in read_cache_share_lst:
            s[share].compute_inline()
    if local_pos is not None:
        for local in read_cache_local_lst:
            s[local].compute_at(s[write_cache], local_pos)
    else:
        for local in read_cache_local_lst:
            s[local].compute_inline()

    # always cooperative fetching
    if share_pos is not None:
        for share in read_cache_share_lst:
            fuse_lst = s[share].op.axis
            fused = s[share].fuse(*fuse_lst)
            outer, inner = s[share].split(fused, nparts=thread_extents)
            s[share].bind(outer, te.thread_axis("threadIdx.x"))

    # unroll
    if "unroll" in config and len(config["unroll"]) > 0:
        step = config["unroll"][0][0]
        explicit = config["unroll"][0][1]
        s[op].pragma(kernel_scope, 'auto_unroll_max_step', step)
        s[op].pragma(kernel_scope, 'unroll_explicit', explicit)
示例#16
0
def cuda_schedule_fuse_split(config, s, op, op_state):
    # assert_print(op in s)

    # always cache write here
    # if op.num_outputs > 1:
    #     raise RuntimeWarning("Too many outputs in one operation!")
    write_cache = s.cache_write(op.output(0), "local")

    # always cache read here
    read_cache_share_lst = []
    # read_cache_local_lst = []
    for t in op.input_tensors:
        share = s.cache_read(t, "shared", [write_cache])
        read_cache_share_lst.append(share)
        # local = s.cache_read(share, "local", [write_cache])
        # read_cache_local_lst.append(local)

    # spatial fuse
    spatial_axes = s[op].op.axis
    fused_spatial_axes = []
    if "fuse" in config and len(config["fuse"]) > 0:
        # fuse redundant axes
        beg = 0
        for end in config["fuse"][0]:
            fuse_lst = spatial_axes[beg:end]
            beg = end
            if len(fuse_lst) > 0:
                fused = s[op].fuse(*fuse_lst)
                fused_spatial_axes.append(fused)
    else:
        fused_spatial_axes = spatial_axes

    # spatial split
    split_factor_lst = []
    splited_spatial_axes = []
    if "spatial" in config and len(config["spatial"]) > 0:
        # to align each axis
        assert len(config["spatial"]) == len(spatial_axes), "align failed"
        # compute split factors
        if "fuse" in config and len(config["fuse"]) > 0:
            beg = 0
            for end in config["fuse"][0]:
                tmp_lst = [1] * len(config["spatial"][0])
                for i in range(beg, end):
                    for j in range(len(config["spatial"][i])):
                        tmp_lst[j] *= config["spatial"][i][j]
                if beg < end:
                    split_factor_lst.append(tmp_lst)
                beg = end
        else:
            split_factor_lst = config["spatial"]
        assert len(fused_spatial_axes) == len(split_factor_lst), "align failed"
        for axis, nparts in zip(fused_spatial_axes, split_factor_lst):
            tmp_buffer = []
            for count in range(len(nparts) - 1):
                outer, axis = s[op].split(axis, nparts=nparts[count])
                tmp_buffer.append(outer)
            tmp_buffer.append(axis)
            splited_spatial_axes.append(tmp_buffer)
    else:
        for axis in fused_spatial_axes:
            splited_spatial_axes.append([axis])
    assert len(
        splited_spatial_axes) > 0, "empty spatial axes"  # must be non-empty

    # always reorder here
    reorder_lst = []
    for count in range(len(splited_spatial_axes[0])):
        tmp_buffer = [x[count] for x in splited_spatial_axes]
        reorder_lst.extend(tmp_buffer)
    s[op].reorder(*reorder_lst)

    # fix kernel scope
    kernel_scope = reorder_lst[0]

    # always bind here
    # - prepare thread axis
    bx = te.thread_axis("blockIdx.x")
    by = te.thread_axis("blockIdx.y")
    bz = te.thread_axis("blockIdx.z")
    vx = te.thread_axis("vthread")
    vy = te.thread_axis("vthread")
    vz = te.thread_axis("vthread")
    tx = te.thread_axis("threadIdx.x")
    ty = te.thread_axis("threadIdx.y")
    tz = te.thread_axis("threadIdx.z")

    blocks = [bz, by, bx]
    threads = [tz, ty, tx]
    vthreads = [vz, vy, vx]

    block_extents = [-1, -1, -1]  # z, y, x
    virtual_extents = [-1, -1, -1]
    thread_extents = [-1, -1, -1]

    length = len(splited_spatial_axes)
    assert length >= 1
    # - bind
    count = min(length, len(blocks)) - 1
    while count >= 0:
        parts = len(splited_spatial_axes[count])
        assert parts > 0
        if parts == 1:
            s[op].bind(splited_spatial_axes[count][0], blocks[count])
            block_extents[count] = split_factor_lst[count][0]
        elif parts == 2:
            s[op].bind(splited_spatial_axes[count][0], blocks[count])
            block_extents[count] = split_factor_lst[count][0]
            s[op].bind(splited_spatial_axes[count][1], threads[count])
            thread_extents[count] = split_factor_lst[count][1]
        else:
            s[op].bind(splited_spatial_axes[count][0], blocks[count])
            block_extents[count] = split_factor_lst[count][0]
            s[op].bind(splited_spatial_axes[count][1], vthreads[count])
            virtual_extents[count] = split_factor_lst[count][1]
            s[op].bind(splited_spatial_axes[count][2], threads[count])
            thread_extents[count] = split_factor_lst[count][2]
        count -= 1
    # - compute at pos
    count = min(length, len(blocks)) - 1
    parts = len(splited_spatial_axes[count])
    thread_pos = splited_spatial_axes[count][min(parts - 1, 2)]

    # always compute at here
    s[write_cache].compute_at(s[op], thread_pos)

    # reduce_split
    reduced_axes = s[write_cache].op.reduce_axis
    splited_reduced_axes = []
    if "reduce" in config and len(config["reduce"]) > 0:
        # to align each axis
        assert_print(
            len(config["reduce"]) == len(reduced_axes), "align reduce failed")
        for axis, nparts in zip(reduced_axes, config["reduce"]):
            tmp_buffer = []
            for count in range(len(nparts) - 1):
                outer, axis = s[write_cache].split(axis, nparts=nparts[count])
                tmp_buffer.append(outer)
            tmp_buffer.append(axis)
            splited_reduced_axes.append(tmp_buffer)
    else:
        for axis in reduced_axes:
            splited_reduced_axes.append([axis])
    share_pos = None
    # local_pos = None
    # if has reduce axes
    if len(splited_reduced_axes) > 0:
        # always reorder here
        reduced_nonfuse_lsts = []
        reorder_lst = []
        length = len(splited_reduced_axes[0])
        # leave the last part
        for count in range(length - 1):
            tmp_buffer = [x[count] for x in splited_reduced_axes]
            reduced_nonfuse_lsts.append(tmp_buffer)
            reorder_lst.extend(tmp_buffer)
        # the last part
        last_part = [x[length - 1] for x in splited_reduced_axes]
        spatial_remainder = s[write_cache].op.axis
        # change the order of reduce axes and spatial axes
        if "reorder" in config and len(config["reorder"]) > 0:
            pos = config["reorder"][0][0]
            assert pos < len(spatial_remainder)
            tmp_buffer = []
            count = len(spatial_remainder) - 1
            while count > pos:
                tmp_buffer.append(spatial_remainder[count])
                count -= 1
            p = pos
            q = len(last_part) - 1
            while p >= 0 and q >= 0:
                tmp_buffer.append(spatial_remainder[p])
                tmp_buffer.append(last_part[q])
                p -= 1
                q -= 1
            while p >= 0:
                tmp_buffer.append(spatial_remainder[p])
                p -= 1
            while q >= 0:
                tmp_buffer.append(last_part[q])
                q -= 1
            tmp_buffer = list(reversed(tmp_buffer))
            reorder_lst.extend(tmp_buffer)
        else:
            reorder_lst.extend(last_part)
            reorder_lst.extend(spatial_remainder)
        s[write_cache].reorder(*reorder_lst)
        # decide where to compute at
        if length == 1:
            share_pos = last_part[-1]
        else:
            mid = math.ceil(length / 2.0) - 1
            share_pos = reduced_nonfuse_lsts[mid][-1]
            # local_pos = last_part[-1]

    # always cache read here
    if share_pos is not None:
        for share in read_cache_share_lst:
            s[share].compute_at(s[write_cache], share_pos)
    else:
        for share in read_cache_share_lst:
            s[share].compute_inline()
    # if local_pos is not None:
    #     for local in read_cache_local_lst:
    #         s[local].compute_at(s[write_cache], local_pos)
    # else:
    #     for local in read_cache_local_lst:
    #         s[local].compute_inline()

    # always cooperative fetching
    if share_pos is not None:
        for share in read_cache_share_lst:
            fuse_lst = s[share].op.axis
            fused = s[share].fuse(*fuse_lst)
            count = 2
            cur = 1
            limit = 1024
            while count >= 0:
                factor = thread_extents[count]
                if factor < 0:
                    defined = False
                    factor = 16
                else:
                    defined = True
                cur *= factor
                if not defined and cur > limit:
                    break
                fused, inner = s[share].split(fused, factor=factor)
                s[share].bind(inner, threads[count])
                count -= 1

    # unroll
    if "unroll" in config and len(config["unroll"]) > 0:
        step = config["unroll"][0][0]
        explicit = config["unroll"][0][1]
        s[op].pragma(kernel_scope, 'auto_unroll_max_step', step)
        s[op].pragma(kernel_scope, 'unroll_explicit', explicit)
示例#17
0
def cuda_schedule_split_reorder_fuse(config, s, op, op_state):
    # assert_print(op in s)

    loop_lst = []
    loop_idx = []

    # always cache write here
    # if op.num_outputs > 1:
    #     raise RuntimeWarning("Too many outputs in one operation!")
    write_cache = s.cache_write(op.output(0), "local")
    # always cache read here
    read_cache_share_lst = []
    # read_cache_local_lst = []
    for t in op.input_tensors:
        share = s.cache_read(t, "shared", [write_cache])
        read_cache_share_lst.append(share)
        # local = s.cache_read(share, "local", [write_cache])
        # read_cache_local_lst.append(local)

    # spatial split
    spatial_axes = [axis for axis in s[op].op.axis]
    assert len(spatial_axes) > 0, "empty spatial axes"  # must be non-empty
    n = spatial_axes[0]
    kernel_scope, n = s[op].split(n, nparts=1)
    spatial_axes[0] = n
    splited_spatial_axes = []
    splited_spatial_extents = []
    if "spatial" in config and len(config["spatial"]) > 0:
        # to align each axis
        assert len(config["spatial"]) == len(spatial_axes), "align failed"
        for axis, nparts in zip(spatial_axes, config["spatial"]):
            tmp_buffer = []
            tmp_extents = []
            for count in range(len(nparts) - 1):
                outer, axis = s[op].split(axis, nparts=nparts[count])
                tmp_buffer.append(outer)
                tmp_extents.append(nparts[count])
            tmp_buffer.append(axis)
            tmp_extents.append(nparts[-1])
            splited_spatial_axes.append(tmp_buffer)
            splited_spatial_extents.append(tmp_extents)
    else:
        for axis in spatial_axes:
            splited_spatial_axes.append([axis])
            splited_spatial_extents.append([axis.dom.extent.value])

    # always reorder here
    reorder_lst = []
    reorder_parts = []
    reorder_part_extents = []
    for count in range(len(splited_spatial_axes[0])):
        tmp_buffer = [x[count] for x in splited_spatial_axes]
        tmp_extents = [x[count] for x in splited_spatial_extents]
        reorder_lst.extend(tmp_buffer)
        reorder_parts.append(tmp_buffer)
        reorder_part_extents.append(tmp_extents)
    s[op].reorder(*reorder_lst)
    # handle fuse request
    fused_parts = []
    fused_part_extents = []
    fused_part_idx = []
    if "fuse" in config and len(config["fuse"]) > 0:
        base_id = 0
        for part, extents in zip(reorder_parts, reorder_part_extents):
            tmp_part = []
            tmp_extents = []
            tmp_idx = []
            idx = 0
            beg = 0
            for end in config["fuse"][0]:
                if end - beg > 1:
                    fuse_lst = part[beg:end]
                    fused = s[op].fuse(*fuse_lst)
                    tmp_part.append(fused)
                    extent = reduce(lambda x, y: x * y, extents[beg:end], 1)
                    tmp_idx.extend([idx] * (end - beg))
                    idx += 1
                    tmp_extents.append(extent)
                elif end - beg == 1:
                    tmp_part.append(part[beg])
                    tmp_extents.append(extents[beg])
                    tmp_idx.append(idx)
                    idx += 1
                beg = end
            fused_parts.append(tmp_part)
            fused_part_extents.append(tmp_extents)
            fused_part_idx.append(tmp_idx)

            loop_lst.extend(tmp_part)
            loop_idx.extend([x + base_id for x in tmp_idx])
            base_id += len(tmp_part)
    else:
        fused_parts = reorder_parts
        fused_part_extents = reorder_part_extents
        fused_part_idx = [list(range(len(x))) for x in reorder_parts]

        loop_lst = reorder_lst
        loop_idx = list(range(len(reorder_lst)))
    # record op state
    op_state.loop_lst = loop_lst
    op_state.loop_idx = loop_idx

    # always bind here
    # - prepare thread axis
    bx = te.thread_axis("blockIdx.x")
    by = te.thread_axis("blockIdx.y")
    bz = te.thread_axis("blockIdx.z")
    vx = te.thread_axis("vthread")
    vy = te.thread_axis("vthread")
    vz = te.thread_axis("vthread")
    tx = te.thread_axis("threadIdx.x")
    ty = te.thread_axis("threadIdx.y")
    tz = te.thread_axis("threadIdx.z")

    blocks = [bz, by, bx]
    threads = [tz, ty, tx]
    vthreads = [vz, vy, vx]

    block_extents = [-1, -1, -1]  # z, y, x
    virtual_extents = [-1, -1, -1]
    thread_extents = [-1, -1, -1]

    bind_option = [None, None, None]
    bind_candidate = [blocks, vthreads, threads]
    candiate_extents = [block_extents, virtual_extents, thread_extents]

    # - bind
    num_parts = len(fused_parts)
    if num_parts == 1:
        bind_option[0] = (fused_parts[0], fused_part_extents[0])
        local_pos = fused_parts[0][:len(bind_candidate[0])][-1]
    elif num_parts == 2:
        bind_option[0] = (fused_parts[0], fused_part_extents[0])
        bind_option[2] = (fused_parts[1], fused_part_extents[1])
        local_pos = fused_parts[1][:len(bind_candidate[2])][-1]
    else:
        bind_option[0] = (fused_parts[0], fused_part_extents[0])
        bind_option[1] = (fused_parts[1], fused_part_extents[1])
        bind_option[2] = (fused_parts[2], fused_part_extents[2])
        local_pos = fused_parts[2][:len(bind_candidate[2])][-1]
    for option, candidate, extents in zip(bind_option, bind_candidate,
                                          candiate_extents):
        if option is not None:
            for i, axis in enumerate(option[0][:len(candidate)]):
                s[op].bind(axis, candidate[i])
                extents[i] = option[1][i]
    # compute at
    if "local_pos" in config and len(config["local_pos"]) > 0:
        local_at_part = config["local_pos"][0][0]
        local_at_idx = config["local_pos"][0][1]
        # index changed because of fusion
        cur_idx = fused_part_idx[local_at_part][local_at_idx]
        local_pos = fused_parts[local_at_part][cur_idx]

    # always compute at here
    s[write_cache].compute_at(s[op], local_pos)

    # reduce_split
    reduced_axes = s[write_cache].op.reduce_axis
    splited_reduced_axes = []
    if "reduce" in config and len(config["reduce"]) > 0:
        # to align each axis
        assert_print(
            len(config["reduce"]) == len(reduced_axes), "align reduce failed")
        for axis, nparts in zip(reduced_axes, config["reduce"]):
            tmp_buffer = []
            for count in range(len(nparts) - 1):
                outer, axis = s[write_cache].split(axis, nparts=nparts[count])
                tmp_buffer.append(outer)
            tmp_buffer.append(axis)
            splited_reduced_axes.append(tmp_buffer)
    else:
        for axis in reduced_axes:
            splited_reduced_axes.append([axis])
    share_pos = None
    # local_pos = None
    # if has reduce axes
    if len(splited_reduced_axes) > 0:
        # always reorder here
        reduced_nonfuse_lsts = []
        reorder_lst = []
        length = len(splited_reduced_axes[0])
        # leave the last part
        for count in range(length - 1):
            tmp_buffer = [x[count] for x in splited_reduced_axes]
            reduced_nonfuse_lsts.append(tmp_buffer)
            reorder_lst.extend(tmp_buffer)
        # the last part
        last_part = [x[length - 1] for x in splited_reduced_axes]
        spatial_remainder = s[write_cache].op.axis
        # change the order of reduce axes and spatial axes
        if "reorder" in config and len(config["reorder"]) > 0:
            pos = config["reorder"][0][0]
            assert pos < len(spatial_remainder)
            tmp_buffer = []
            count = len(spatial_remainder) - 1
            while count > pos:
                tmp_buffer.append(spatial_remainder[count])
                count -= 1
            p = pos
            q = len(last_part) - 1
            while p >= 0 and q >= 0:
                tmp_buffer.append(spatial_remainder[p])
                tmp_buffer.append(last_part[q])
                p -= 1
                q -= 1
            while p >= 0:
                tmp_buffer.append(spatial_remainder[p])
                p -= 1
            while q >= 0:
                tmp_buffer.append(last_part[q])
                q -= 1
            tmp_buffer = list(reversed(tmp_buffer))
            reorder_lst.extend(tmp_buffer)
        else:
            reorder_lst.extend(last_part)
            reorder_lst.extend(spatial_remainder)
        s[write_cache].reorder(*reorder_lst)
        # decide where to compute at
        if "share_pos" in config and len(config["share_pos"]) > 0:
            share_at = config["share_pos"][0][0]
            share_idx = config["share_pos"][0][1]
            reduced_nonfuse_lsts.append(last_part)
            share_pos = reduced_nonfuse_lsts[share_at][share_idx]
        else:
            if length == 1:
                share_pos = last_part[-1]
            else:
                mid = math.ceil(length / 2.0) - 1
                share_pos = reduced_nonfuse_lsts[mid][-1]
                # local_pos = last_part[-1]

    # always cache read here
    if share_pos is not None:
        for share in read_cache_share_lst:
            s[share].compute_at(s[write_cache], share_pos)
    else:
        for share in read_cache_share_lst:
            s[share].compute_inline()
    # if local_pos is not None:
    #     for local in read_cache_local_lst:
    #         s[local].compute_at(s[write_cache], local_pos)
    # else:
    #     for local in read_cache_local_lst:
    #         s[local].compute_inline()

    # always cooperative fetching
    if share_pos is not None:
        for share in read_cache_share_lst:
            fuse_lst = s[share].op.axis
            fused = s[share].fuse(*fuse_lst)
            count = 2
            cur = 1
            limit = 1024
            while count >= 0:
                factor = thread_extents[count]
                if factor < 0:
                    defined = False
                    factor = 16
                else:
                    defined = True
                cur *= factor
                if not defined and cur > limit:
                    break
                fused, inner = s[share].split(fused, factor=factor)
                s[share].bind(inner, threads[count])
                count -= 1

    # unroll
    if "unroll" in config and len(config["unroll"]) > 0:
        step = config["unroll"][0][0]
        explicit = config["unroll"][0][1]
        s[op].pragma(kernel_scope, 'auto_unroll_max_step', step)
        s[op].pragma(kernel_scope, 'unroll_explicit', explicit)
示例#18
0
def cpu_schedule_split_fuse(config, s, op, op_state):
    # always cache write here
    # if op.num_outputs > 1:
    #     raise RuntimeWarning("Too many outputs in one operation!")
    write_cache = s.cache_write(op.output(0), "global")

    # spatial split
    spatial_axes = s[op].op.axis
    splited_spatial_axes = []
    if "spatial" in config and len(config["spatial"]) > 0:
        # to align each axis
        assert_print(len(config["spatial"]) == len(spatial_axes), "align failed")
        for axis, nparts in zip(spatial_axes, config["spatial"]):
            tmp_buffer = []
            for count in range(len(nparts) - 1):
                outer, axis = s[op].split(axis, nparts=nparts[count])
                tmp_buffer.append(outer)
            tmp_buffer.append(axis)
            splited_spatial_axes.append(tmp_buffer)
    else:
        for axis in spatial_axes:
            splited_spatial_axes.append([axis])
    assert_print(len(splited_spatial_axes) > 0, "empty spatial axes")  # must be non-empty

    # always reorder and fuse here
    spatial_fuse_lsts = []
    spatial_fuse_extents = []
    reorder_lst = []
    fused_spatial_axes = []
    for count in range(len(splited_spatial_axes[0])):
        tmp_buffer = [x[count] for x in splited_spatial_axes]
        tmp_extent = reduce(lambda a, b: a * b, [x[count] for x in config["spatial"]])
        spatial_fuse_lsts.append(tmp_buffer)
        spatial_fuse_extents.append(tmp_extent)
        reorder_lst.extend(tmp_buffer)
    s[op].reorder(*reorder_lst)
    for fuse_lst in spatial_fuse_lsts:
        fused = s[op].fuse(*fuse_lst)
        fused_spatial_axes.append(fused)
    kernel_scope = fused_spatial_axes[0]

    # always parallel here
    length = len(fused_spatial_axes)
    assert_print(length > 0, "empty spatial axes!")
    s[op].parallel(fused_spatial_axes[0])
    if length == 1:
        thread_pos = fused_spatial_axes[0]
    if 2 <= length <= 3:
        thread_pos = fused_spatial_axes[1]
    else:
        thread_pos = fused_spatial_axes[2]

    # always compute at here
    s[write_cache].compute_at(s[op], thread_pos)

    # reduce_split
    reduced_axes = s[write_cache].op.reduce_axis
    splited_reduced_axes = []
    if "reduce" in config and len(config["reduce"]) > 0:
        # to align each axis
        assert_print(len(config["reduce"]) == len(reduced_axes), "align reduce failed")
        for axis, nparts in zip(reduced_axes, config["reduce"]):
            tmp_buffer = []
            for count in range(len(nparts) - 1):
                outer, axis = s[write_cache].split(axis, nparts=nparts[count])
                tmp_buffer.append(outer)
            tmp_buffer.append(axis)
            splited_reduced_axes.append(tmp_buffer)
    else:
        for axis in reduced_axes:
            splited_reduced_axes.append([axis])

    # if has reduce axes
    if len(splited_reduced_axes) > 0:
        # always reorder here
        reduced_nonfuse_lsts = []
        reorder_lst = []
        length = len(splited_reduced_axes[0])

        for count in range(length):
            tmp_buffer = [x[count] for x in splited_reduced_axes]
            reduced_nonfuse_lsts.append(tmp_buffer)
            reorder_lst.extend(tmp_buffer)
        # change the order of reduce axes and spatial axes
        rlength = len(splited_reduced_axes)
        if rlength > 1:
            reorder_lst.extend(s[write_cache].op.axis)
        elif rlength == 1:  # in this case, have to interleave otherwise the reorder is of no use
            tmp_order = []
            p_spatial = len(s[write_cache].op.axis) - 1
            p_reduce = len(reorder_lst) - 1
            while p_spatial >= 0 and p_reduce >= 0:
                tmp_order.append(s[write_cache].op.axis[p_spatial])
                tmp_order.append(reorder_lst[p_reduce])
                p_spatial -= 1
                p_reduce -= 1
            while p_spatial >= 0:
                tmp_order.append(s[write_cache].op.axis[p_spatial])
                p_spatial -= 1
            while p_reduce >= 0:
                tmp_order.append(reorder_lst[p_reduce])
                p_reduce -= 1
            tmp_order = list(reversed(tmp_order))
            reorder_lst = tmp_order
        s[write_cache].reorder(*reorder_lst)

    # unroll
    if "unroll" in config and len(config["unroll"]) > 0:
        step = config["unroll"][0][0]
        s[op].pragma(kernel_scope, 'auto_unroll_max_step', step)

    # always vectorize here
    s[write_cache].vectorize(s[write_cache].op.axis[-1])