def register_vta_tuning_tasks():
    from tvm.autotvm.task import TaskExtractEnv

    @tvm.te.tag_scope(tag=topi.tag.ELEMWISE)
    def my_clip(x, a_min, a_max):
        """Unlike topi's current clip, put min and max into two stages."""
        const_min = tvm.tir.const(a_min, x.dtype)
        const_max = tvm.tir.const(a_max, x.dtype)
        x = te.compute(x.shape, lambda *i: tvm.te.min(x(*i), const_max), name="clipA")
        x = te.compute(x.shape, lambda *i: tvm.te.max(x(*i), const_min), name="clipB")
        return x

    # init autotvm env to register VTA operator
    TaskExtractEnv()

    @autotvm.template("conv2d_packed.vta")
    def _topi_nn_conv2d(*args, **kwargs):
        assert not kwargs, "Do not support kwargs in template function call"
        A, W = args[:2]

        with tvm.target.vta():
            res = vta.top.conv2d_packed(*args, **kwargs)
            res = topi.right_shift(res, 8)
            res = my_clip(res, 0, 127)
            res = topi.cast(res, "int8")

        if tvm.target.Target.current().device_name == 'vta':
            s = vta.top.schedule_conv2d_packed([res])
        else:
            s = te.create_schedule([res.op])
        return s, [A, W, res]
Ejemplo n.º 2
0
def expr2graph(expr, target_ops, node_dict, node_list):
    """Convert relay expr to graph data structure
    and fetch workloads of target operators.

    Parameters
    ----------
    expr : tvm.relay.Expr.Function
        Input relay function expression.

    target_ops: List of str
        List of target relay base op name

    node_dict : dictionary from tvm.relay.Expr to int
        Dictionary to record node index

    node_list : list of dictionary
        List of nodes which contains all expr in the input relay function.
        Each node will be stored as a dictionary in the format of
        {"op": str, "node": tvm.relay.expr, "inputs": [int], "types": [tvm.relay.Type],
         "name": str, "workloads": [tuple], "topi_op": [function]}
    """
    env = TaskExtractEnv.get(allow_duplicate=True)
    topi_funcs = []
    for op_name in target_ops:
        if op_name not in OP2COMPUTE:
            raise RuntimeError("Not supported relay op in graph tuner: %s" %
                               op_name)
        topi_funcs += OP2COMPUTE[op_name]
    env.reset(topi_funcs)
    # pylint: disable=not-context-manager
    with env:
        _expr2graph_impl(expr, target_ops, node_dict, node_list)
        task_pos = 0
        for node_entry in node_list:
            if node_entry["op"] in target_ops:
                task_name, args = env.task_collection[task_pos]
                task = autotvm.task.create(task_name,
                                           args,
                                           target="llvm",
                                           target_host=None,
                                           template_key='direct')
                node_entry["workloads"] = [task.workload]
                node_entry["topi_op"] = [task_name]
                task_pos += 1
Ejemplo n.º 3
0
def expr2graph(expr, target_ops, node_dict, node_list):
    """Convert relay expr to graph data structure
    and fetch workloads of target operators.

    Parameters
    ----------
    expr : tvm.relay.Expr.Function
        Input relay function expression.

    target_ops: List of relay.op.Op
        List of target relay ops

    node_dict : dictionary from tvm.relay.Expr to int
        Dictionary to record node index

    node_list : list of dictionary
        List of nodes which contains all expr in the input relay function.
        Each node will be stored as a dictionary in the format of
        {"op": str, "node": tvm.relay.expr, "inputs": [int], "types": [tvm.relay.Type],
         "name": str, "workloads": [tuple], "topi_op": [function]}
    """
    # TODO(@kevinthesun, @icemelon9): Currently graph tuning pass relies on the fact
    #   that # autotvm tasks == # ops. But this won't be true after having relay op
    #   strategy. We need to find a solution to fix this.
    env = TaskExtractEnv.get(allow_duplicate=True)
    env.reset(target_ops)
    # pylint: disable=not-context-manager
    with env:
        _expr2graph_impl(expr, target_ops, node_dict, node_list)
        task_pos = 0
        for node_entry in node_list:
            if node_entry["op"] in target_ops:
                task_name, args = env.task_collection[task_pos]
                task = autotvm.task.create(task_name,
                                           args,
                                           target="llvm",
                                           target_host=None)
                node_entry["workloads"] = [task.workload]
                node_entry["topi_op"] = [task_name]
                task_pos += 1