示例#1
0
def default_schedule(outs):
    """
    default schedule function.

    Args:
        outs (Union[tvm.tensor.Tensor, list[tvm.tensor.Tensor]]): outputs of compute.

    Returns:
        sch (schedule.Schedule): The created schedule.
    """
    if not isinstance(outs, tvm.tensor.Tensor) and not isinstance(outs, list):
        raise ValueError("outs should be list of _akg.tvm.tensor.Tensor or _akg.tvm.tensor.Tensor")
    device = 'cuda'
    ctx = tvm.context(device, 0)
    if not ctx.exist:
        raise SystemError("Skip because %s is not enabled" % device)
    outs_list = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
    with tvm.target.create(device):
        sch = tvm.create_schedule(outs_list[0].op)
        outputs_tensor = Queue()
        outputs_tensor.put(outs_list[0])
        op_list = []
        while not outputs_tensor.empty():
            out = outputs_tensor.get()
            if out.op not in op_list and isinstance(out.op, tvm.tensor.ComputeOp):
                op_list.append(out.op)
                for input_tensor in out.op.input_tensors:
                    outputs_tensor.put(input_tensor)
        for op in op_list:
            stage = sch[op.output(0)]
            bx, tx = stage.split(op.axis[0], factor=DEFAULT_GPU_THREAD)
            stage.bind(bx, tvm.thread_axis("blockIdx.x"))
            stage.bind(tx, tvm.thread_axis("threadIdx.x"))
    return sch
示例#2
0
def gpu_schedule_HSigmoid(outs):
    """
    gpu schedule HSigmoid
    Args:
        outs:

    Returns:

    """
    device = 'cuda'
    ctx = tvm.context(device, 0)
    if not ctx.exist:
        raise SystemError("Skip because %s is not enabled" % device)
    with tvm.target.create(device):
        sch = topi.cuda.schedule_elemwise(outs)
    return sch
示例#3
0
def gpu_schedule_ReLU6(outs):
    """
    gpu schedule ReLU6.

    Args:
        outs (tvm.tensor.Tensor): outputs of compute.

    Returns:
        sch (schedule.Schedule): The created schedule.
    """
    device = 'cuda'
    ctx = tvm.context(device, 0)
    if not ctx.exist:
        raise SystemError("Skip because %s is not enabled" % device)
    with tvm.target.create(device):
        sch = topi.cuda.schedule_elemwise(outs)
    return sch