コード例 #1
0
def expr2graph(expr, target_ops, node_dict, node_list):
    """Convert relay expr to graph data structure
    and fetch workloads of target operators.

    Parameters
    ----------
    expr : tvm.relay.Expr.Function
        Input relay function expression.

    target_ops: List of str
        List of target relay base op name

    node_dict : dictionary from tvm.relay.Expr to int
        Dictionary to record node index

    node_list : list of dictionary
        List of nodes which contains all expr in the input relay function.
        Each node will be stored as a dictionary in the format of
        {"op": str, "node": tvm.relay.expr, "inputs": [int], "types": [tvm.relay.Type],
         "name": str, "workloads": [tuple], "topi_op": [function]}
    """
    env = TaskExtractEnv.get(allow_duplicate=True)
    topi_funcs = []
    for op_name in target_ops:
        if op_name not in OP2COMPUTE:
            raise RuntimeError("Not supported relay op in graph tuner: %s" %
                               op_name)
        topi_funcs += OP2COMPUTE[op_name]
    env.reset(topi_funcs)
    with env:
        _expr2graph_impl(expr, target_ops, node_dict, node_list)
        task_pos = 0
        for node_entry in node_list:
            if node_entry["op"] in target_ops:
                task_name, args = env.task_collection[task_pos]
                task = autotvm.task.create(task_name,
                                           args,
                                           target="llvm",
                                           target_host=None,
                                           template_key='direct')
                node_entry["workloads"] = [task.workload]
                node_entry["topi_op"] = [task_name]
                task_pos += 1
コード例 #2
0
def expr2graph(expr, target_ops, node_dict, node_list):
    """Convert relay expr to graph data structure
    and fetch workloads of target operators.

    Parameters
    ----------
    expr : tvm.relay.Expr.Function
        Input relay function expression.

    target_ops: List of relay.op.Op
        List of target relay ops

    node_dict : dictionary from tvm.relay.Expr to int
        Dictionary to record node index

    node_list : list of dictionary
        List of nodes which contains all expr in the input relay function.
        Each node will be stored as a dictionary in the format of
        {"op": str, "node": tvm.relay.expr, "inputs": [int], "types": [tvm.relay.Type],
         "name": str, "workloads": [tuple], "topi_op": [function]}
    """
    # TODO(@kevinthesun, @icemelon9): Currently graph tuning pass relies on the fact
    #   that # autotvm tasks == # ops. But this won't be true after having relay op
    #   strategy. We need to find a solution to fix this.
    env = TaskExtractEnv.get(allow_duplicate=True)
    env.reset(target_ops)
    # pylint: disable=not-context-manager
    with env:
        _expr2graph_impl(expr, target_ops, node_dict, node_list)
        task_pos = 0
        for node_entry in node_list:
            if node_entry["op"] in target_ops:
                task_name, args = env.task_collection[task_pos]
                task = autotvm.task.create(task_name,
                                           args,
                                           target="llvm",
                                           target_host=None)
                node_entry["workloads"] = [task.workload]
                node_entry["topi_op"] = [task_name]
                task_pos += 1