예제 #1
0
def test_expr2graph():
    mod, _ = synthetic.get_workload()
    node_dict = {}
    node_list = []
    target_ops = [relay.op.get("nn.conv2d")]
    op_name_list = []

    def _count_node(node):
        if isinstance(node, Call):
            op_name_list.append(node.op)
        elif isinstance(node, (Var, TupleGetItem, Tuple)):
            op_name_list.append(None)

    relay.analysis.post_order_visit(mod["main"], _count_node)

    expr2graph(mod["main"], target_ops, node_dict, node_list)
    assert len(node_list) == len(op_name_list)
    for i, item in enumerate(zip(op_name_list, node_list)):
        op_name, node = item
        assert op_name == node[
            "op"], "%dth Node operator mismatch: expecting %s but got %s" % (
                i,
                str(op_name),
                str(node["op"]),
            )
예제 #2
0
def test_extract_resnet():
    mod, _params = get_workload()
    items = relay.analysis.extract_fused_functions(mod)
    assert len(items) == 6
예제 #3
0
def test_change_batch_synthetic():
    net, params = synthetic.get_workload()
    new_net = transform.ChangeBatch({net["main"].params[0]: 0},
                                    batch_size=123)(net)
    assert new_net["main"].checked_type.ret_type.shape[0] == 123