def test_expr2graph(): mod, _ = synthetic.get_workload() node_dict = {} node_list = [] target_ops = [relay.op.get("nn.conv2d")] op_name_list = [] def _count_node(node): if isinstance(node, Call): op_name_list.append(node.op) elif isinstance(node, (Var, TupleGetItem, Tuple)): op_name_list.append(None) relay.analysis.post_order_visit(mod["main"], _count_node) expr2graph(mod["main"], target_ops, node_dict, node_list) assert len(node_list) == len(op_name_list) for i, item in enumerate(zip(op_name_list, node_list)): op_name, node = item assert op_name == node[ "op"], "%dth Node operator mismatch: expecting %s but got %s" % ( i, str(op_name), str(node["op"]), )
def test_extract_resnet(): mod, _params = get_workload() items = relay.analysis.extract_fused_functions(mod) assert len(items) == 6
def test_change_batch_synthetic(): net, params = synthetic.get_workload() new_net = transform.ChangeBatch({net["main"].params[0]: 0}, batch_size=123)(net) assert new_net["main"].checked_type.ret_type.shape[0] == 123