示例#1
0
文件: to_relay.py 项目: zheng-xq/tvm
def to_relay(graph, shape_dict, dtype_dict, params):
    """Convert an NNVM graph into the corresponding Relay expression.

    Parameters
    ----------
    graph : Graph
       The input graph.

    shape_dict : dict of str to shape
       The input shape.

    dtype_dict : dict of str to str/dtype
       The input shape.

    params : dict of str to array
        The parameters.

    Returns
    -------
    (expr, params) : Tuple[relay.Expr, dict of str to array]
        The corresponding Relay expression and parameters.
    """
    if isinstance(graph, Symbol):
        graph = graph_create(graph)

    param_shapes = dict((k, params[k].shape) for k in params)
    shape_dict = shape_dict.copy()
    shape_dict.update(param_shapes)
    graph = graph_attr.set_shape_inputs(graph, shape_dict)
    graph = graph_attr.set_dtype_inputs(graph, dtype_dict)
    graph = graph.apply(["InferShape", "InferType"])
    shape = graph.json_attr("shape")
    dtype = [graph_attr.TCODE_TO_DTYPE[di] for di in graph.json_attr("dtype")]
    heads = [x[0] for x in json.loads(graph.json())['heads']]

    gidx = graph.index
    relay_map = {}
    fn_params = []
    output_ids = []

    for nid, node in enumerate(gidx.nodes):
        children = []
        for i in node['inputs']:
            child = relay_map[i[0]]
            if isinstance(child, expr.TupleWrapper):
                children.append(child[i[1]])
            else:
                children.append(child)

        oshape = shape[gidx.entry_id(nid, 0)]
        odtype = dtype[gidx.entry_id(nid, 0)]
        attrs = node.get("attrs", {})
        node_name = node["name"]
        op_name = node["op"]

        if op_name == "null":
            v = var(node_name, shape=oshape, dtype=odtype)
            fn_params.append(v)
            relay_map[nid] = v
        else:
            if nid in heads:
                output_ids.append(nid)

            if op_name in NNVM_OP_2_RELAY_OP:
                str_attrs = StrAttrsDict(attrs)
                call = NNVM_OP_2_RELAY_OP[op_name](children, str_attrs, odtype)
                relay_map[nid] = call
            else:
                raise Exception(
                    "nnvm.to_relay: unsupported operator: {0}".format(op_name))

    outputs = [relay_map[nid] for nid in output_ids]
    if len(outputs) == 1:
        body = outputs[0]
    else:
        body = expr.Tuple(outputs)

    func = relay.Function(fn_params, body)
    return func, params
示例#2
0
def test_key_is_not_present():
    attrs = StrAttrsDict({"a": 1})
    assert not attrs.has_attr("b")