def test_get_direct_ancestor(): data = relay.var("data") w0 = relay.var("w0") out1 = relay.nn.conv2d(data, w0) out2 = relay.add(out1, data * relay.expr.const(5.0)) out3 = out2 + relay.expr.const(2.5) w1 = relay.var("w1") out = relay.nn.conv2d(out3, w1) net = relay.Function(relay.analysis.free_vars(out), out) net = bind_inputs(net, { "data": (1, 16, 224, 224), "w0": (16, 16, 1, 1), "w1": (16, 16, 1, 1) }) target_ops = [relay.op.get("nn.conv2d")] node_list = [] node_dict = {} expr2graph(net, target_ops, node_dict, node_list) visited_dict = {} input_names = ["data"] out = get_direct_ancestor(node_list, visited_dict, target_ops, 5, input_names) assert out == [0], "Output mismatch: expecting [0] but got %s." % str(out) # non-regression test out = relay.add(relay.log(data), relay.sqrt(data)) net = relay.Function(relay.analysis.free_vars(out), out) net = bind_inputs(net, {"data": (1, 16, 224, 224)}) node_list = [] node_dict = {} expr2graph(net, target_ops, node_dict, node_list) out = get_direct_ancestor(node_list, visited_dict, target_ops, 3, input_names) assert out == [0], "Output mismatch: expecting [0] but got %s." % str(out)
def test_get_in_nodes(): data = relay.var("data") w0 = relay.var("w0") out1 = relay.nn.conv2d(data, w0) out2 = relay.add(out1, data) out3 = out2 + relay.expr.const(2.5) w1 = relay.var("w1") out = relay.nn.conv2d(out3, w1) net = relay.Function(relay.ir_pass.free_vars(out), out) net = bind_inputs(net, { "data": (1, 16, 224, 224), "w0": (16, 16, 1, 1), "w1": (16, 16, 1, 1) }) target_ops = ["conv2d"] input_names = ["data"] node_list = [] node_dict = {} expr2graph(net, target_ops, node_dict, node_list) out = get_in_nodes(node_list, target_ops, input_names) expected_out = {7: [3], 3: [2, 0], 2: [0]} diff_set = set(out) ^ set(expected_out) if len(diff_set) != 0: raise RuntimeError("Output mismatch: expecting %s but got %s." % (str(expected_out), str(out)))
def test_has_multiple_inputs(): data = relay.var("data") out1 = data * relay.expr.const(3.0) w0 = relay.var("w0") out2 = relay.nn.conv2d(data, w0) out = relay.add(out1, out2) net = relay.Function(relay.ir_pass.free_vars(out), out) net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1)}) target_ops = ["conv2d"] node_list = [] node_dict = {} expr2graph(net, target_ops, node_dict, node_list) input_names = ["data"] verify_has_multiple_inputs(node_list, 2, input_names, False) verify_has_multiple_inputs(node_list, 4, input_names, False) verify_has_multiple_inputs(node_list, 5, input_names, True)