Example #1
0
def test_get_direct_ancestor():
    data = relay.var("data")
    w0 = relay.var("w0")
    out1 = relay.nn.conv2d(data, w0)
    out2 = relay.add(out1, data * relay.expr.const(5.0))
    out3 = out2 + relay.expr.const(2.5)
    w1 = relay.var("w1")
    out = relay.nn.conv2d(out3, w1)
    net = relay.Function(relay.analysis.free_vars(out), out)
    net = bind_inputs(net, {
        "data": (1, 16, 224, 224),
        "w0": (16, 16, 1, 1),
        "w1": (16, 16, 1, 1)
    })
    target_ops = [relay.op.get("nn.conv2d")]
    node_list = []
    node_dict = {}
    expr2graph(net, target_ops, node_dict, node_list)
    visited_dict = {}
    input_names = ["data"]
    out = get_direct_ancestor(node_list, visited_dict, target_ops, 5,
                              input_names)
    assert out == [0], "Output mismatch: expecting [0] but got %s." % str(out)

    # non-regression test
    out = relay.add(relay.log(data), relay.sqrt(data))
    net = relay.Function(relay.analysis.free_vars(out), out)
    net = bind_inputs(net, {"data": (1, 16, 224, 224)})
    node_list = []
    node_dict = {}
    expr2graph(net, target_ops, node_dict, node_list)
    out = get_direct_ancestor(node_list, visited_dict, target_ops, 3,
                              input_names)
    assert out == [0], "Output mismatch: expecting [0] but got %s." % str(out)
Example #2
0
def test_get_in_nodes():
    data = relay.var("data")
    w0 = relay.var("w0")
    out1 = relay.nn.conv2d(data, w0)
    out2 = relay.add(out1, data)
    out3 = out2 + relay.expr.const(2.5)
    w1 = relay.var("w1")
    out = relay.nn.conv2d(out3, w1)
    net = relay.Function(relay.ir_pass.free_vars(out), out)
    net = bind_inputs(net, {
        "data": (1, 16, 224, 224),
        "w0": (16, 16, 1, 1),
        "w1": (16, 16, 1, 1)
    })
    target_ops = ["conv2d"]
    input_names = ["data"]
    node_list = []
    node_dict = {}
    expr2graph(net, target_ops, node_dict, node_list)
    out = get_in_nodes(node_list, target_ops, input_names)
    expected_out = {7: [3], 3: [2, 0], 2: [0]}
    diff_set = set(out) ^ set(expected_out)
    if len(diff_set) != 0:
        raise RuntimeError("Output mismatch: expecting %s but got %s." %
                           (str(expected_out), str(out)))
Example #3
0
def test_has_multiple_inputs():
    data = relay.var("data")
    out1 = data * relay.expr.const(3.0)
    w0 = relay.var("w0")
    out2 = relay.nn.conv2d(data, w0)
    out = relay.add(out1, out2)
    net = relay.Function(relay.ir_pass.free_vars(out), out)
    net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1)})
    target_ops = ["conv2d"]
    node_list = []
    node_dict = {}
    expr2graph(net, target_ops, node_dict, node_list)
    input_names = ["data"]
    verify_has_multiple_inputs(node_list, 2, input_names, False)
    verify_has_multiple_inputs(node_list, 4, input_names, False)
    verify_has_multiple_inputs(node_list, 5, input_names, True)