Esempio n. 1
0
    def _fold_buffer_dim(buf, scope, elem_block):
        ndim = len(buf.shape)
        x_size = 1
        base = 0
        for i in range(1, ndim + 1):
            if not util.equal_const_int(buf.strides[ndim - i] - x_size, 0):
                raise RuntimeError("scope %s needs to have block=%d" %
                                   (scope, elem_block))
            x_size = x_size * buf.shape[ndim - i]
            if util.equal_const_int(x_size - elem_block, 0):
                base = i + 1
                break
        if base == 0:
            raise RuntimeError("scope %s need to have block=%d, shape=%s" %
                               (scope, elem_block, buf.shape))
        shape = [elem_block]
        strides = [1]

        if base < ndim + 1 and not util.equal_const_int(
                buf.strides[ndim - base], elem_block):
            shape.append(1)
            strides.append(elem_block)

        analyzer = tvm.arith.Analyzer()
        while base < ndim + 1:
            x_size = 1
            x_stride = buf.strides[ndim - base]
            next_base = base
            if not util.equal_const_int(idxm(x_stride, elem_block), 0):
                raise RuntimeError(
                    "scope %s need to have block=%d, shape=%s, strides=%s" %
                    (scope, elem_block, buf.shape, buf.strides))
            for i in range(base, ndim + 1):
                k = ndim - i
                if not util.equal_const_int(x_size * x_stride - buf.strides[k],
                                            0):
                    break
                x_size = x_size * buf.shape[k]
                next_base = i + 1
            shape.append(analyzer.simplify(x_size))
            strides.append(x_stride)
            assert next_base != base
            base = next_base

        strides = list(reversed(strides))
        shape = list(reversed(shape))
        return shape, strides
Esempio n. 2
0
 def _check_compact(buf):
     ndim = len(buf.shape)
     size = tvm.tir.const(1, buf.shape[0].dtype)
     for i in reversed(range(ndim)):
         if not util.equal_const_int(size - buf.strides[i], 0):
             raise RuntimeError(
                 "Cannot prove compact: shape=%s, strides=%s" % (buf.shape, buf.strides))
         size = size * buf.shape[i]
Esempio n. 3
0
 def _check_compact(buf):
     ndim = len(buf.shape)
     size = tvm.const(1, buf.shape[0].dtype)
     for i in reversed(range(ndim)):
         if not util.equal_const_int(size - buf.strides[i], 0):
             raise RuntimeError(
                 "Cannot prove compact: shape=%s, strides=%s" % (buf.shape, buf.strides))
         size = size * buf.shape[i]
Esempio n. 4
0
    def _fold_buffer_dim(buf, scope, elem_block):
        ndim = len(buf.shape)
        x_size = 1
        base = 0
        for i in range(1, ndim + 1):
            if not util.equal_const_int(buf.strides[ndim - i] - x_size, 0):
                raise RuntimeError("scope %s needs to have block=%d" % (scope, elem_block))
            x_size = x_size * buf.shape[ndim - i]
            if util.equal_const_int(x_size - elem_block, 0):
                base = i + 1
                break
        if base == 0:
            raise RuntimeError("scope %s need to have block=%d, shape=%s" % (
                scope, elem_block, buf.shape))
        shape = [elem_block]
        strides = [1]

        if base < ndim + 1 and not util.equal_const_int(buf.strides[ndim - base], elem_block):
            shape.append(1)
            strides.append(elem_block)

        while base < ndim + 1:
            x_size = 1
            x_stride = buf.strides[ndim - base]
            next_base = base
            if not util.equal_const_int(x_stride % elem_block, 0):
                raise RuntimeError(
                    "scope %s need to have block=%d, shape=%s, strides=%s" % (
                        scope, elem_block, buf.shape, buf.strides))
            for i in range(base, ndim + 1):
                k = ndim - i
                if not util.equal_const_int(x_size * x_stride - buf.strides[k], 0):
                    break
                x_size = x_size * buf.shape[k]
                next_base = i + 1
            shape.append(tvm.ir_pass.Simplify(x_size))
            strides.append(x_stride)
            assert next_base != base
            base = next_base

        strides = list(reversed(strides))
        shape = list(reversed(shape))
        return shape, strides
Esempio n. 5
0
    def _inject_copy(src, dst, pad_before, pad_after, pad_value):
        # FIXME: pad_value is ignored...
        _ = pad_value
        if dst.scope == "global":
            # Store
            if pad_before or pad_after:
                raise RuntimeError("Do not support copy into DRAM with pad")
            if src.scope == env.acc_scope:
                elem_width = env.OUT_WIDTH
                elem_bytes = env.OUT_ELEM_BYTES
                mem_type = env.dev.MEM_ID_OUT
                data_type = "int%d" % env.OUT_WIDTH
                task_qid = env.dev.QID_STORE_OUT
            else:
                raise RuntimeError("Do not support copy %s->dram" %
                                   (src.scope))
            _check_compact(src)
            x_size, y_size, x_stride, offset = _get_2d_pattern(dst,
                                                               elem_width,
                                                               elem_bytes,
                                                               data_type,
                                                               src.scope,
                                                               allow_fold=True)
            irb = tvm.tir.ir_builder.create()
            irb.scope_attr(env.dev.vta_axis, "coproc_scope",
                           env.dev.get_task_qid(task_qid))
            irb.emit(
                tvm.tir.call_extern("int32", "VTAStoreBuffer2D",
                                    env.dev.command_handle,
                                    src.access_ptr("r", "int32"), mem_type,
                                    dst.data, offset, x_size, y_size,
                                    x_stride))
            return irb.get()
        elif src.scope == "global":
            if dst.scope == env.acc_scope:
                elem_width = env.ACC_WIDTH
                elem_bytes = env.ACC_ELEM_BYTES
                mem_type = env.dev.MEM_ID_ACC
                data_type = "int%d" % env.ACC_WIDTH
                task_qid = env.dev.QID_LOAD_OUT
            elif dst.scope == env.inp_scope:
                elem_width = env.INP_WIDTH
                elem_bytes = env.INP_ELEM_BYTES
                mem_type = env.dev.MEM_ID_INP
                data_type = "int%d" % env.INP_WIDTH
                task_qid = env.dev.QID_LOAD_INP
            elif dst.scope == env.wgt_scope:
                elem_width = env.WGT_WIDTH
                elem_bytes = env.WGT_ELEM_BYTES
                mem_type = env.dev.MEM_ID_WGT
                data_type = "int%d" % env.WGT_WIDTH
                task_qid = env.dev.QID_LOAD_WGT
            else:
                raise RuntimeError("Do not support copy dram->%s" %
                                   (dst.scope))
            # collect pad statistics
            if pad_before:
                assert pad_after
                ndim = len(pad_before)
                if ndim <= 2 or ndim > 5:
                    raise ValueError(
                        "Limitation of 2D pad load forbid ndim=%d" % ndim)
                if ndim == 5:
                    # This case occurs when batch size N > 1
                    y_pad_before = pad_before[1]
                    x_pad_before = pad_before[2]
                    y_pad_after = pad_after[1]
                    x_pad_after = pad_after[2]
                    for dim in range(3, ndim):
                        if not util.equal_const_int(pad_before[dim], 0):
                            raise ValueError(
                                "Do not support pad on the innermost block")
                        if not util.equal_const_int(pad_after[dim], 0):
                            raise ValueError(
                                "Do not support pad on the innermost block")
                else:
                    y_pad_before = pad_before[0]
                    x_pad_before = pad_before[1]
                    y_pad_after = pad_after[0]
                    x_pad_after = pad_after[1]
                    for dim in range(2, ndim):
                        if not util.equal_const_int(pad_before[dim], 0):
                            raise ValueError(
                                "Do not support pad on the innermost block")
                        if not util.equal_const_int(pad_after[dim], 0):
                            raise ValueError(
                                "Do not support pad on the innermost block")
                allow_fold = False
            else:
                x_pad_before = 0
                y_pad_before = 0
                x_pad_after = 0
                y_pad_after = 0
                allow_fold = True

            _check_compact(dst)
            x_size, y_size, x_stride, offset = _get_2d_pattern(
                src,
                elem_width,
                elem_bytes,
                data_type,
                dst.scope,
                allow_fold=allow_fold)

            irb = tvm.tir.ir_builder.create()
            irb.scope_attr(env.dev.vta_axis, "coproc_scope",
                           env.dev.get_task_qid(task_qid))

            irb.emit(
                tvm.tir.call_extern("int32", "VTALoadBuffer2D",
                                    env.dev.command_handle, src.data, offset,
                                    x_size, y_size, x_stride, x_pad_before,
                                    y_pad_before, x_pad_after, y_pad_after,
                                    dst.access_ptr("r", "int32"), mem_type))
            return irb.get()

        else:
            raise RuntimeError("Do not support copy %s->%s" %
                               (src.scope, dst.scope))
Esempio n. 6
0
    def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold):
        elem_block = elem_bytes * 8 // elem_width
        if buf.dtype != dtype:
            raise RuntimeError("Expect buffer type to be %s instead of %s" %
                               (dtype, buf.dtype))
        shape, strides = buf.shape, buf.strides
        if not util.equal_const_int(idxm(buf.elem_offset, elem_block), 0):
            raise RuntimeError("scope %s need to have block=%d" %
                               (scope, elem_block))
        if allow_fold:
            shape, strides = _fold_buffer_dim(buf, scope, elem_block)
        else:
            shape = list(x for x in shape)
            strides = list(x for x in strides)

        def raise_error():
            """Internal function to raise error """
            raise RuntimeError(
                ("Scope[%s]: cannot detect 2d pattern with elem_block=%d:" +
                 " shape=%s, strides=%s") %
                (scope, elem_block, buf.shape, buf.strides))

        ndim = len(shape)

        # Check if the inner-tensor is already flat
        flat = util.equal_const_int(shape[-1], elem_block)

        if flat:
            if not util.equal_const_int(strides[-1], 1):
                raise_error()

            if ndim == 1:
                x_size = 1
                x_stride = 1
                y_size = 1
                return x_size, y_size, x_stride, idxd(buf.elem_offset,
                                                      elem_block)
            if not util.equal_const_int(strides[-2] - elem_block, 0):
                raise_error()

            if ndim == 2:
                x_size = shape[-2]
                x_stride = shape[-2]
                y_size = 1
                return x_size, y_size, x_stride, idxd(buf.elem_offset,
                                                      elem_block)
            if not util.equal_const_int(idxm(strides[-3], elem_block), 0):
                raise_error()

            if ndim == 3:
                x_size = shape[-2]
                x_stride = idxd(strides[-3], elem_block)
                y_size = shape[-3]
                return x_size, y_size, x_stride, idxd(buf.elem_offset,
                                                      elem_block)

        else:
            if not util.equal_const_int(strides[-1], 1):
                raise_error()
            if not util.equal_const_int(strides[-2] - shape[-1], 0):
                raise_error()
            if not util.equal_const_int(shape[-1] * shape[-2], elem_block):
                raise_error()

            if ndim == 2:
                x_size = 1
                x_stride = 1
                y_size = 1
                return x_size, y_size, x_stride, idxd(buf.elem_offset,
                                                      elem_block)
            if not util.equal_const_int(strides[-3], elem_block):
                raise_error()

            if ndim == 3:
                x_size = shape[-3]
                x_stride = shape[-3]
                y_size = 1
                return x_size, y_size, x_stride, idxd(buf.elem_offset,
                                                      elem_block)
            if not util.equal_const_int(idxm(strides[-4], elem_block), 0):
                raise_error()

            if ndim == 4:
                x_size = shape[-3]
                x_stride = idxd(strides[-4], elem_block)
                y_size = shape[-4]
                return x_size, y_size, x_stride, idxd(buf.elem_offset,
                                                      elem_block)

        raise_error()
Esempio n. 7
0
    def _inject_copy(src, dst, pad_before, pad_after, pad_value):
        # FIXME: pad_value is ignored...
        _ = pad_value
        if dst.scope == "global":
            # Store
            if pad_before or pad_after:
                raise RuntimeError("Do not support copy into DRAM with pad")
            if src.scope == env.acc_scope:
                elem_width = env.OUT_WIDTH
                elem_bytes = env.OUT_ELEM_BYTES
                mem_type = env.dev.MEM_ID_OUT
                data_type = "int%d" % env.OUT_WIDTH
                task_qid = env.dev.QID_STORE_OUT
            else:
                raise RuntimeError("Do not support copy %s->dram" % (src.scope))
            _check_compact(src)
            x_size, y_size, x_stride, offset = _get_2d_pattern(
                dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True)
            irb = tvm.ir_builder.create()
            irb.scope_attr(env.dev.vta_axis, "coproc_scope",
                           env.dev.get_task_qid(task_qid))
            irb.emit(tvm.call_extern(
                "int32", "VTAStoreBuffer2D",
                env.dev.command_handle,
                src.access_ptr("r", "int32"),
                mem_type, dst.data, offset, x_size, y_size, x_stride))
            return irb.get()
        elif src.scope == "global":
            if dst.scope == env.acc_scope:
                elem_width = env.ACC_WIDTH
                elem_bytes = env.ACC_ELEM_BYTES
                mem_type = env.dev.MEM_ID_ACC
                data_type = "int%d" % env.ACC_WIDTH
                task_qid = env.dev.QID_LOAD_OUT
            elif dst.scope == env.inp_scope:
                elem_width = env.INP_WIDTH
                elem_bytes = env.INP_ELEM_BYTES
                mem_type = env.dev.MEM_ID_INP
                data_type = "int%d" % env.INP_WIDTH
                task_qid = env.dev.QID_LOAD_INP
            elif dst.scope == env.wgt_scope:
                elem_width = env.WGT_WIDTH
                elem_bytes = env.WGT_ELEM_BYTES
                mem_type = env.dev.MEM_ID_WGT
                data_type = "int%d" % env.WGT_WIDTH
                task_qid = env.dev.QID_LOAD_WGT
            else:
                raise RuntimeError("Do not support copy dram->%s" % (dst.scope))
            # collect pad statistics
            if pad_before:
                assert pad_after
                ndim = len(pad_before)
                if ndim <= 2 or ndim > 4:
                    raise ValueError("Limitation of 2D pad load forbid ndim=%d" % ndim)
                if ndim > 2:
                    if not util.equal_const_int(pad_before[ndim - 1], 0):
                        raise ValueError("Do not support pad on the innermost block")
                    if not util.equal_const_int(pad_after[ndim - 1], 0):
                        raise ValueError("Do not support pad on the innermost block")
                if ndim > 3:
                    if not util.equal_const_int(pad_before[ndim - 2], 0):
                        raise ValueError("Do not support pad on the innermost block")
                    if not util.equal_const_int(pad_after[ndim - 2], 0):
                        raise ValueError("Do not support pad on the innermost block")
                y_pad_before = pad_before[0]
                x_pad_before = pad_before[1]
                y_pad_after = pad_after[0]
                x_pad_after = pad_after[1]
                allow_fold = False
            else:
                x_pad_before = 0
                y_pad_before = 0
                x_pad_after = 0
                y_pad_after = 0
                allow_fold = True

            _check_compact(dst)
            x_size, y_size, x_stride, offset = _get_2d_pattern(
                src, elem_width, elem_bytes, data_type,
                dst.scope, allow_fold=allow_fold)

            irb = tvm.ir_builder.create()
            irb.scope_attr(env.dev.vta_axis, "coproc_scope",
                           env.dev.get_task_qid(task_qid))

            irb.emit(tvm.call_extern(
                "int32", "VTALoadBuffer2D",
                env.dev.command_handle,
                src.data, offset, x_size, y_size, x_stride,
                x_pad_before, y_pad_before,
                x_pad_after, y_pad_after,
                dst.access_ptr("r", "int32"), mem_type))
            return irb.get()

        else:
            raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope))
Esempio n. 8
0
    def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold):
        elem_block = elem_bytes * 8 // elem_width
        if buf.dtype != dtype:
            raise RuntimeError("Expect buffer type to be %s instead of %s" %
                               (dtype, buf.dtype))
        shape, strides = buf.shape, buf.strides
        if not util.equal_const_int(buf.elem_offset % elem_block, 0):
            raise RuntimeError("scope %s need to have block=%d" % (scope, elem_block))
        if allow_fold:
            shape, strides = _fold_buffer_dim(buf, scope, elem_block)
        else:
            shape = list(x for x in shape)
            strides = list(x for x in strides)

        def raise_error():
            """Internal function to raise error """
            raise RuntimeError(
                ("Scope[%s]: cannot detect 2d pattern with elem_block=%d:" +
                 " shape=%s, strides=%s") % (scope, elem_block, buf.shape, buf.strides))

        ndim = len(shape)

        # Check if the inner-tensor is already flat
        flat = util.equal_const_int(shape[-1], elem_block)

        if flat:
            if not util.equal_const_int(strides[-1], 1):
                raise_error()

            if ndim == 1:
                x_size = 1
                x_stride = 1
                y_size = 1
                return x_size, y_size, x_stride, buf.elem_offset / elem_block
            if not util.equal_const_int(strides[-2] - elem_block, 0):
                raise_error()

            if ndim == 2:
                x_size = shape[-2]
                x_stride = shape[-2]
                y_size = 1
                return x_size, y_size, x_stride, buf.elem_offset / elem_block
            if not util.equal_const_int(strides[-3] % elem_block, 0):
                raise_error()

            if ndim == 3:
                x_size = shape[-2]
                x_stride = strides[-3] / elem_block
                y_size = shape[-3]
                return x_size, y_size, x_stride, buf.elem_offset / elem_block

        else:
            if not util.equal_const_int(strides[-1], 1):
                raise_error()
            if not util.equal_const_int(strides[-2] - shape[-1], 0):
                raise_error()
            if not util.equal_const_int(shape[-1] * shape[-2], elem_block):
                raise_error()

            if ndim == 2:
                x_size = 1
                x_stride = 1
                y_size = 1
                return x_size, y_size, x_stride, buf.elem_offset / elem_block
            if not util.equal_const_int(strides[-3], elem_block):
                raise_error()

            if ndim == 3:
                x_size = shape[-3]
                x_stride = shape[-3]
                y_size = 1
                return x_size, y_size, x_stride, buf.elem_offset / elem_block
            if not util.equal_const_int(strides[-4] % elem_block, 0):
                raise_error()

            if ndim == 4:
                x_size = shape[-3]
                x_stride = strides[-4] / elem_block
                y_size = shape[-4]
                return x_size, y_size, x_stride, buf.elem_offset / elem_block

        raise_error()