Exemplo n.º 1
0
def test_get_direct_ancestor():
    data = relay.var("data")
    w0 = relay.var("w0")
    out1 = relay.nn.conv2d(data, w0)
    out2 = relay.add(out1, data * relay.expr.const(5.0))
    out3 = out2 + relay.expr.const(2.5)
    w1 = relay.var("w1")
    out = relay.nn.conv2d(out3, w1)
    net = relay.Function(relay.ir_pass.free_vars(out), out)
    net = bind_inputs(net, {
        "data": (1, 16, 224, 224),
        "w0": (16, 16, 1, 1),
        "w1": (16, 16, 1, 1)
    })
    target_ops = ["conv2d"]
    node_list = []
    node_dict = {}
    expr2graph(net, target_ops, node_dict, node_list)
    visited_dict = {}
    input_names = ["data"]
    out = get_direct_ancestor(node_list, visited_dict, target_ops, 5,
                              input_names)
    assert out == [
        2, 0
    ], "Output mismatch: expecting [2, 0] but got %s." % str(out)
Exemplo n.º 2
0
def test_expr2graph():
    net, _ = resnet.get_workload(num_layers=50, batch_size=1)
    node_dict = {}
    node_list = []
    target_ops = ["conv2d"]
    op_name_list = []

    def _count_node(node):
        if not isinstance(
                node,
                relay.op.op.Op,
        ):
            return
        if isinstance(node, Call):
            op_name_list.append(node.op.name.split(".")[-1])
        elif isinstance(node, TupleGetItem):
            op_name_list.append("TupleGetItem")
        elif isinstance(node, Tuple):
            op_name_list.append("Tuple")
        else:
            op_name_list.append("null")

    relay.ir_pass.post_order_visit(net, _count_node)

    expr2graph(net, target_ops, node_dict, node_list)
    for i, item in enumerate(zip(op_name_list, node_list)):
        op_name, node = item
        assert op_name == node["op"], "%dth Node operator mismatch: expecting %s but got %s" \
                                      % (i, str(op_name), str(node["op"]))
Exemplo n.º 3
0
def test_get_in_nodes():
    data = relay.var("data")
    w0 = relay.var("w0")
    out1 = relay.nn.conv2d(data, w0)
    out2 = relay.add(out1, data)
    out3 = out2 + relay.expr.const(2.5)
    w1 = relay.var("w1")
    out = relay.nn.conv2d(out3, w1)
    net = relay.Function(relay.ir_pass.free_vars(out), out)
    net = bind_inputs(net, {
        "data": (1, 16, 224, 224),
        "w0": (16, 16, 1, 1),
        "w1": (16, 16, 1, 1)
    })
    target_ops = ["conv2d"]
    input_names = ["data"]
    node_list = []
    node_dict = {}
    expr2graph(net, target_ops, node_dict, node_list)
    out = get_in_nodes(node_list, target_ops, input_names)
    expected_out = {7: [3], 3: [2, 0], 2: [0]}
    diff_set = set(out) ^ set(expected_out)
    if len(diff_set) != 0:
        raise RuntimeError("Output mismatch: expecting %s but got %s." %
                           (str(expected_out), str(out)))
Exemplo n.º 4
0
def test_expr2graph():
    mod, _ = synthetic.get_workload()
    node_dict = {}
    node_list = []
    target_ops = [relay.op.get("nn.conv2d")]
    op_name_list = []

    def _count_node(node):
        if isinstance(node, Call):
            op_name_list.append(node.op)
        elif isinstance(node, (Var, TupleGetItem, Tuple)):
            op_name_list.append(None)

    relay.analysis.post_order_visit(mod["main"], _count_node)

    expr2graph(mod["main"], target_ops, node_dict, node_list)
    assert len(node_list) == len(op_name_list)
    for i, item in enumerate(zip(op_name_list, node_list)):
        op_name, node = item
        assert op_name == node[
            "op"], "%dth Node operator mismatch: expecting %s but got %s" % (
                i,
                str(op_name),
                str(node["op"]),
            )
Exemplo n.º 5
0
def test_has_multiple_inputs():
    data = relay.var("data")
    out1 = data * relay.expr.const(3.0)
    w0 = relay.var("w0")
    out2 = relay.nn.conv2d(data, w0)
    out = relay.add(out1, out2)
    net = relay.Function(relay.ir_pass.free_vars(out), out)
    net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1)})
    target_ops = ["conv2d"]
    node_list = []
    node_dict = {}
    expr2graph(net, target_ops, node_dict, node_list)
    input_names = ["data"]
    verify_has_multiple_inputs(node_list, 2, input_names, False)
    verify_has_multiple_inputs(node_list, 4, input_names, False)
    verify_has_multiple_inputs(node_list, 5, input_names, True)