Ejemplo n.º 1
0
 def _pad(*indices):
     not_zero = []  # A list of padding conditions that aren't trivial (zero padding)
     index_tuple = []  # The indices with which to access the padded tensor
     for i in range(dims):
         if equal_const_int(pad_before[i], 0) and equal_const_int(pad_after[i], 0):
             index_tuple.append(indices[i])
         else:
             index_tuple.append(indices[i] - pad_before[i])
             not_zero.append(indices[i] >= pad_before[i])
             not_zero.append(indices[i] < tensor.shape[i] + pad_before[i])
     if not_zero:
         not_zero = tvm.tir.all(*not_zero)
         return tvm.tir.if_then_else(not_zero, tensor(*index_tuple), tvm.tir.const(0, "uint8"))
     return tensor(*index_tuple)
Ejemplo n.º 2
0
    def _fold_buffer_dim(buf, scope, elem_block):
        ndim = len(buf.shape)
        x_size = 1
        base = 0
        for i in range(1, ndim + 1):
            if not utils.equal_const_int(buf.strides[ndim - i] - x_size, 0):
                raise RuntimeError("scope %s needs to have block=%d" %
                                   (scope, elem_block))
            x_size = x_size * buf.shape[ndim - i]
            if utils.equal_const_int(x_size - elem_block, 0):
                base = i + 1
                break
        if base == 0:
            raise RuntimeError("scope %s need to have block=%d, shape=%s" %
                               (scope, elem_block, buf.shape))
        shape = [elem_block]
        strides = [1]

        if base < ndim + 1 and not utils.equal_const_int(
                buf.strides[ndim - base], elem_block):
            shape.append(1)
            strides.append(elem_block)

        analyzer = tvm.arith.Analyzer()
        while base < ndim + 1:
            x_size = 1
            x_stride = buf.strides[ndim - base]
            next_base = base
            if not utils.equal_const_int(idxm(x_stride, elem_block), 0):
                raise RuntimeError(
                    "scope %s need to have block=%d, shape=%s, strides=%s" %
                    (scope, elem_block, buf.shape, buf.strides))
            for i in range(base, ndim + 1):
                k = ndim - i
                if not utils.equal_const_int(
                        x_size * x_stride - buf.strides[k], 0):
                    break
                x_size = x_size * buf.shape[k]
                next_base = i + 1
            shape.append(analyzer.simplify(x_size))
            strides.append(x_stride)
            assert next_base != base
            base = next_base

        strides = list(reversed(strides))
        shape = list(reversed(shape))
        return shape, strides
Ejemplo n.º 3
0
 def _check_compact(buf):
     ndim = len(buf.shape)
     size = tvm.tir.const(1, buf.shape[0].dtype)
     for i in reversed(range(ndim)):
         if not utils.equal_const_int(size - buf.strides[i], 0):
             raise RuntimeError(
                 "Cannot prove compact: shape=%s, strides=%s" %
                 (buf.shape, buf.strides))
         size = size * buf.shape[i]
Ejemplo n.º 4
0
    def _inject_copy(src, dst, pad_before, pad_after, pad_value):
        # FIXME: pad_value is ignored...
        env = get_env()
        _ = pad_value
        if dst.scope() == "global":
            # Store
            if pad_before or pad_after:
                raise RuntimeError("Do not support copy into DRAM with pad")
            if src.scope() == env.acc_scope:
                elem_width = env.OUT_WIDTH
                elem_bytes = env.OUT_ELEM_BYTES
                mem_type = env.dev.MEM_ID_OUT
                data_type = "int%d" % env.OUT_WIDTH
                task_qid = env.dev.QID_STORE_OUT
            else:
                raise RuntimeError("Do not support copy %s->dram" %
                                   (src.scope()))
            _check_compact(src)
            x_size, y_size, x_stride, offset = _get_2d_pattern(dst,
                                                               elem_width,
                                                               elem_bytes,
                                                               data_type,
                                                               src.scope(),
                                                               allow_fold=True)
            irb = tvm.tir.ir_builder.create()
            irb.scope_attr(env.dev.vta_axis, "coproc_scope",
                           env.dev.get_task_qid(task_qid))
            irb.emit(
                tvm.tir.call_extern(
                    "int32",
                    "VTAStoreBuffer2D",
                    env.dev.command_handle,
                    src.access_ptr("r", "int32"),
                    mem_type,
                    dst.data,
                    offset,
                    x_size,
                    y_size,
                    x_stride,
                ))
            return irb.get()
        elif src.scope() == "global":
            if dst.scope() == env.acc_scope:
                elem_width = env.ACC_WIDTH
                elem_bytes = env.ACC_ELEM_BYTES
                mem_type = env.dev.MEM_ID_ACC
                data_type = "int%d" % env.ACC_WIDTH
                task_qid = env.dev.QID_LOAD_OUT
            elif dst.scope() == env.inp_scope:
                elem_width = env.INP_WIDTH
                elem_bytes = env.INP_ELEM_BYTES
                mem_type = env.dev.MEM_ID_INP
                data_type = "int%d" % env.INP_WIDTH
                task_qid = env.dev.QID_LOAD_INP
            elif dst.scope() == env.wgt_scope:
                elem_width = env.WGT_WIDTH
                elem_bytes = env.WGT_ELEM_BYTES
                mem_type = env.dev.MEM_ID_WGT
                data_type = "int%d" % env.WGT_WIDTH
                task_qid = env.dev.QID_LOAD_WGT
            else:
                raise RuntimeError("Do not support copy dram->%s" %
                                   (dst.scope()))
            # collect pad statistics
            if pad_before:
                assert pad_after
                ndim = len(pad_before)
                if ndim <= 2 or ndim > 5:
                    raise ValueError(
                        "Limitation of 2D pad load forbid ndim=%d" % ndim)
                if ndim == 5:
                    # This case occurs when batch size N > 1
                    y_pad_before = pad_before[1]
                    x_pad_before = pad_before[2]
                    y_pad_after = pad_after[1]
                    x_pad_after = pad_after[2]
                    for dim in range(3, ndim):
                        if not utils.equal_const_int(pad_before[dim], 0):
                            raise ValueError(
                                "Do not support pad on the innermost block")
                        if not utils.equal_const_int(pad_after[dim], 0):
                            raise ValueError(
                                "Do not support pad on the innermost block")
                else:
                    y_pad_before = pad_before[0]
                    x_pad_before = pad_before[1]
                    y_pad_after = pad_after[0]
                    x_pad_after = pad_after[1]
                    for dim in range(2, ndim):
                        if not utils.equal_const_int(pad_before[dim], 0):
                            raise ValueError(
                                "Do not support pad on the innermost block")
                        if not utils.equal_const_int(pad_after[dim], 0):
                            raise ValueError(
                                "Do not support pad on the innermost block")
                allow_fold = False
            else:
                x_pad_before = 0
                y_pad_before = 0
                x_pad_after = 0
                y_pad_after = 0
                allow_fold = True

            _check_compact(dst)
            x_size, y_size, x_stride, offset = _get_2d_pattern(
                src,
                elem_width,
                elem_bytes,
                data_type,
                dst.scope(),
                allow_fold=allow_fold)

            if data_type != src.dtype:
                assert data_type == "int%d" % env.ACC_WIDTH and src.dtype == "int%d" % env.INP_WIDTH
                mem_type = env.dev.MEM_ID_ACC_8BIT

            irb = tvm.tir.ir_builder.create()
            irb.scope_attr(env.dev.vta_axis, "coproc_scope",
                           env.dev.get_task_qid(task_qid))

            irb.emit(
                tvm.tir.call_extern(
                    "int32",
                    "VTALoadBuffer2D",
                    env.dev.command_handle,
                    src.data,
                    offset,
                    x_size,
                    y_size,
                    x_stride,
                    x_pad_before,
                    y_pad_before,
                    x_pad_after,
                    y_pad_after,
                    dst.access_ptr("r", "int32"),
                    mem_type,
                ))
            return irb.get()

        else:
            raise RuntimeError("Do not support copy %s->%s" %
                               (src.scope(), dst.scope()))
Ejemplo n.º 5
0
    def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold):
        elem_block = elem_bytes * 8 // elem_width
        shape, strides = buf.shape, buf.strides
        if not utils.equal_const_int(idxm(buf.elem_offset, elem_block), 0):
            raise RuntimeError("scope %s need to have block=%d" %
                               (scope, elem_block))
        if allow_fold:
            shape, strides = _fold_buffer_dim(buf, scope, elem_block)
        else:
            shape = list(x for x in shape)
            strides = list(x for x in strides)

        def raise_error():
            """Internal function to raise error"""
            raise RuntimeError(
                ("Scope[%s]: cannot detect 2d pattern with elem_block=%d:" +
                 " shape=%s, strides=%s") %
                (scope, elem_block, buf.shape, buf.strides))

        ndim = len(shape)

        # Check if the inner-tensor is already flat
        flat = utils.equal_const_int(shape[-1], elem_block)

        if flat:
            if not utils.equal_const_int(strides[-1], 1):
                raise_error()

            if ndim == 1:
                x_size = 1
                x_stride = 1
                y_size = 1
                return x_size, y_size, x_stride, idxd(buf.elem_offset,
                                                      elem_block)
            if not utils.equal_const_int(strides[-2] - elem_block, 0):
                raise_error()

            if ndim == 2:
                x_size = shape[-2]
                x_stride = shape[-2]
                y_size = 1
                return x_size, y_size, x_stride, idxd(buf.elem_offset,
                                                      elem_block)
            if not utils.equal_const_int(idxm(strides[-3], elem_block), 0):
                raise_error()

            if ndim == 3:
                x_size = shape[-2]
                x_stride = idxd(strides[-3], elem_block)
                y_size = shape[-3]
                return x_size, y_size, x_stride, idxd(buf.elem_offset,
                                                      elem_block)

        else:
            if not utils.equal_const_int(strides[-1], 1):
                raise_error()
            if not utils.equal_const_int(strides[-2] - shape[-1], 0):
                raise_error()
            if not utils.equal_const_int(shape[-1] * shape[-2], elem_block):
                raise_error()

            if ndim == 2:
                x_size = 1
                x_stride = 1
                y_size = 1
                return x_size, y_size, x_stride, idxd(buf.elem_offset,
                                                      elem_block)
            if not utils.equal_const_int(strides[-3], elem_block):
                raise_error()

            if ndim == 3:
                x_size = shape[-3]
                x_stride = shape[-3]
                y_size = 1
                return x_size, y_size, x_stride, idxd(buf.elem_offset,
                                                      elem_block)
            if not utils.equal_const_int(idxm(strides[-4], elem_block), 0):
                raise_error()

            if ndim == 4:
                x_size = shape[-3]
                x_stride = idxd(strides[-4], elem_block)
                y_size = shape[-4]
                return x_size, y_size, x_stride, idxd(buf.elem_offset,
                                                      elem_block)

        raise_error()