def test_get_direct_ancestor(): data = relay.var("data") w0 = relay.var("w0") out1 = relay.nn.conv2d(data, w0) out2 = relay.add(out1, data * relay.expr.const(5.0)) out3 = out2 + relay.expr.const(2.5) w1 = relay.var("w1") out = relay.nn.conv2d(out3, w1) net = relay.Function(relay.ir_pass.free_vars(out), out) net = bind_inputs(net, { "data": (1, 16, 224, 224), "w0": (16, 16, 1, 1), "w1": (16, 16, 1, 1) }) target_ops = ["conv2d"] node_list = [] node_dict = {} expr2graph(net, target_ops, node_dict, node_list) visited_dict = {} input_names = ["data"] out = get_direct_ancestor(node_list, visited_dict, target_ops, 5, input_names) assert out == [ 2, 0 ], "Output mismatch: expecting [2, 0] but got %s." % str(out)
def test_expr2graph(): net, _ = resnet.get_workload(num_layers=50, batch_size=1) node_dict = {} node_list = [] target_ops = ["conv2d"] op_name_list = [] def _count_node(node): if not isinstance( node, relay.op.op.Op, ): return if isinstance(node, Call): op_name_list.append(node.op.name.split(".")[-1]) elif isinstance(node, TupleGetItem): op_name_list.append("TupleGetItem") elif isinstance(node, Tuple): op_name_list.append("Tuple") else: op_name_list.append("null") relay.ir_pass.post_order_visit(net, _count_node) expr2graph(net, target_ops, node_dict, node_list) for i, item in enumerate(zip(op_name_list, node_list)): op_name, node = item assert op_name == node["op"], "%dth Node operator mismatch: expecting %s but got %s" \ % (i, str(op_name), str(node["op"]))
def test_get_in_nodes(): data = relay.var("data") w0 = relay.var("w0") out1 = relay.nn.conv2d(data, w0) out2 = relay.add(out1, data) out3 = out2 + relay.expr.const(2.5) w1 = relay.var("w1") out = relay.nn.conv2d(out3, w1) net = relay.Function(relay.ir_pass.free_vars(out), out) net = bind_inputs(net, { "data": (1, 16, 224, 224), "w0": (16, 16, 1, 1), "w1": (16, 16, 1, 1) }) target_ops = ["conv2d"] input_names = ["data"] node_list = [] node_dict = {} expr2graph(net, target_ops, node_dict, node_list) out = get_in_nodes(node_list, target_ops, input_names) expected_out = {7: [3], 3: [2, 0], 2: [0]} diff_set = set(out) ^ set(expected_out) if len(diff_set) != 0: raise RuntimeError("Output mismatch: expecting %s but got %s." % (str(expected_out), str(out)))
def test_expr2graph(): mod, _ = synthetic.get_workload() node_dict = {} node_list = [] target_ops = [relay.op.get("nn.conv2d")] op_name_list = [] def _count_node(node): if isinstance(node, Call): op_name_list.append(node.op) elif isinstance(node, (Var, TupleGetItem, Tuple)): op_name_list.append(None) relay.analysis.post_order_visit(mod["main"], _count_node) expr2graph(mod["main"], target_ops, node_dict, node_list) assert len(node_list) == len(op_name_list) for i, item in enumerate(zip(op_name_list, node_list)): op_name, node = item assert op_name == node[ "op"], "%dth Node operator mismatch: expecting %s but got %s" % ( i, str(op_name), str(node["op"]), )
def test_has_multiple_inputs(): data = relay.var("data") out1 = data * relay.expr.const(3.0) w0 = relay.var("w0") out2 = relay.nn.conv2d(data, w0) out = relay.add(out1, out2) net = relay.Function(relay.ir_pass.free_vars(out), out) net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1)}) target_ops = ["conv2d"] node_list = [] node_dict = {} expr2graph(net, target_ops, node_dict, node_list) input_names = ["data"] verify_has_multiple_inputs(node_list, 2, input_names, False) verify_has_multiple_inputs(node_list, 4, input_names, False) verify_has_multiple_inputs(node_list, 5, input_names, True)