Exemple #1
0
def test_preprocess():
    batch_size = 2
    module = Main()
    data = mge.tensor(np.random.randint(0, 256, size=(batch_size, 3, 48, 160)),
                      dtype=np.float32)
    traced_module = trace_module(module, {"data": data})
    obj = pickle.dumps(traced_module)
    traced_module = pickle.loads(obj)
    module = Net(traced_module)
    module.eval()
    quad = mge.tensor(np.random.normal(size=(batch_size, 4, 2)),
                      dtype=np.float32)
    expect = module(data, quad)
    traced_module = trace_module(module, data, quad)
    actual = traced_module(data, quad)
    for i, j in zip(expect, actual):
        np.testing.assert_array_equal(i, j)
    func = trace(traced_module, capture_as_const=True)
    actual = func(data, quad)
    for i, j in zip(expect, actual):
        np.testing.assert_array_equal(i, j)
    model = io.BytesIO()
    func.dump(model, arg_names=("data", "quad"))
    model.seek(0)
    infer_cg = cgtools.GraphInference(model)
    actual = list(
        infer_cg.run(inp_dict={
            "data": data.numpy(),
            "quad": quad.numpy()
        }).values())[0]
    np.testing.assert_allclose(expect, actual)
Exemple #2
0
def test_preprocess():
    module = Main()
    data = F.ones((1, 14, 8, 8), dtype=np.uint8)
    traced_module = trace_module(module, data)
    obj = pickle.dumps(traced_module)
    traced_module = pickle.loads(obj)
    module = Net(traced_module)
    module.eval()
    idx = F.zeros((1, ), dtype=np.int32)
    roi = F.ones((1, 2, 2), dtype=np.float32)
    y = module(data, idx, roi)
    traced_module = trace_module(module, data, idx, roi)
    np.testing.assert_array_equal(traced_module(data, idx, roi), y)
    func = trace(traced_module, capture_as_const=True)
    np.testing.assert_array_equal(func(data, idx, roi), y)
    model = io.BytesIO()
    func.dump(model, arg_names=("data", "idx", "roi"))
    model.seek(0)
    infer_cg = cgtools.GraphInference(model)
    np.testing.assert_allclose(
        list(
            infer_cg.run(inp_dict={
                "data": data.numpy(),
                "idx": idx.numpy(),
                "roi": roi.numpy()
            }).values())[0],
        y,
        atol=1e-6,
    )
Exemple #3
0
def test_float_func_conv():
    class FConvOpr(M.Module):
        def __init__(self):
            super().__init__()
            self.conv = F.conv2d

        def forward(
                self,
                inp,
                weight,
                bias=None,
                stride=(1, 1),
                padding=(0, 0),
                dilation=(1, 1),
                groups=1,
        ):
            x = F.conv2d(inp, weight, bias, stride, padding, dilation, groups)
            return x

    net = FConvOpr()
    data = mge.tensor(np.random.random((1, 16, 32, 32))).astype("float32")
    weight = mge.tensor(np.random.random((32, 16, 2, 2))).astype("float32")
    traced_module = trace_module(net, data, weight)
    tm_result = traced_module(data, weight)
    _test_convert_result([data, weight],
                         traced_module,
                         tm_result,
                         max_err=1e-4)
def test_dump_and_load():
    module = MyModule()
    x = Tensor(np.ones((1, 8, 14, 14)))
    expect = module(x)
    traced_module = trace_module(module, x)
    np.testing.assert_array_equal(expect, traced_module(x))
    obj = pickle.dumps(traced_module)
    new_tm = pickle.loads(obj)
    _check_id(new_tm)
    _check_expr_users(new_tm)
    traced_module.graph._reset_ids()
    old_nodes = traced_module.graph.nodes().as_list()
    new_nodes = new_tm.graph.nodes().as_list()
    old_exprs = traced_module.graph.exprs().as_list()
    new_exprs = new_tm.graph.exprs().as_list()
    assert len(old_nodes) == len(new_nodes)
    for i, j in zip(old_nodes, new_nodes):
        assert i._name == j._name
        assert i._qualname == j._qualname
        assert i._id == j._id
    assert len(old_exprs) == len(new_exprs)
    for i, j in zip(old_exprs, new_exprs):
        assert i._id == j._id

    np.testing.assert_array_equal(expect, traced_module(x))
Exemple #5
0
    def _check_module(build_func: Callable):
        net = build_func()
        net.eval()
        buffer = io.BytesIO()
        mge.save(net.state_dict(), buffer)
        buffer.seek(0)

        inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
        traced_net = trace_module(build_func(), inp)
        traced_net.load_state_dict(mge.load(buffer))

        _check_param(net, traced_net)

        buffer.seek(0)
        traced_net = trace_module(build_func(), inp).flatten()
        traced_net.load_state_dict(mge.load(buffer))

        _check_param(net, traced_net)
Exemple #6
0
def test_id_and_name():
    def _check_id(traced_module):
        _total_ids = traced_module.graph._total_ids
        node_ids = [n._id for n in traced_module.graph.nodes().as_list()]
        assert len(set(node_ids)) == len(node_ids)
        assert max(node_ids) + 1 == _total_ids[0]

        expr_ids = [n._id for n in traced_module.graph.exprs().as_list()]
        assert len(set(expr_ids)) == len(expr_ids)
        assert max(expr_ids) + 1 == _total_ids[1]

    def _check_name(flatened_module):
        node_names = [n._name for n in flatened_module.graph.nodes().as_list()]
        assert len(set(node_names)) == len(node_names)

    traced_module, x, expect = _init_module()
    _check_id(traced_module)

    flattened_module = traced_module.flatten()
    _check_id(flattened_module)
    _check_name(flattened_module)

    # pickle check
    obj = pickle.dumps(traced_module)
    traced_module = pickle.loads(obj)
    Node._set_next_id(159)
    Expr._set_next_id(1024)

    graph = traced_module.graph
    for expr in graph.get_function_by_type(F.relu).as_list():
        relu_out = expr.outputs[0]
        cur_graph = expr.top_graph
        with cur_graph.insert_exprs():
            neg_out = F.neg(relu_out)
        cur_graph.replace_node({relu_out: neg_out})
        cur_graph.compile()
    _check_id(traced_module)

    flattened_module = traced_module.flatten()
    _check_id(flattened_module)
    _check_name(flattened_module)

    # check trace TracedModule
    obj = pickle.dumps(traced_module)
    traced_module = pickle.loads(obj)
    module = NewModule(traced_module)
    traced_module = trace_module(module, x)
    _check_id(traced_module)

    flattened_module = traced_module.flatten()
    _check_id(flattened_module)
    _check_name(flattened_module)
def test_tensor_method_loader():
    class MyModule3(Module):
        def forward(self, x):
            return x + 1

    m = MyModule3()
    x = Tensor(np.ones((20)))
    traced_module = trace_module(m, x)
    orig_loader_dict = S.TENSORMETHOD_LOADER
    S.TENSORMETHOD_LOADER = {}

    @register_tensor_method_loader("__add__")
    def add_loader(expr):
        args = list(expr.args)
        if not isinstance(args[1], TensorNode):
            args[1] = Tensor(args[1])
            node = Constant(args[1], "const").outputs[0]

            astype_expr = CallMethod(node, "astype")
            oup = TensorNode(
                astype_expr,
                shape=node.shape,
                dtype=node.dtype,
                qparams=node.qparams,
            )
            astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
            astype_expr.return_val = (oup, )

            add_expr = CallMethod(oup, "__add__")
            add_expr.set_args_kwargs(oup, oup)
            oup1 = TensorNode(
                add_expr,
                shape=oup.shape,
                dtype=oup.dtype,
                qparams=node.qparams,
            )
            add_expr.return_val = oup1
            args[1] = oup1
            expr.set_args_kwargs(*args)

    obj = pickle.dumps(traced_module)
    new_module = pickle.loads(obj)
    _check_expr_users(new_module)
    _check_id(new_module)
    result = new_module(x)
    gt = m(x)
    assert (isinstance(new_module.graph._exprs[0], Constant)
            and len(new_module.graph._exprs) == 4)
    np.testing.assert_equal(result.numpy(), (x + 2).numpy())
    S.TENSORMETHOD_LOADER = orig_loader_dict
Exemple #8
0
def test_backward_fold_scale(conv_cls, bn_cls):
    module = MyModule(conv_cls, bn_cls)
    module.eval()
    inp = mge.Tensor(np.random.random((1, 3, 32, 32)))
    desired = module(inp)
    traced_net = tm.trace_module(module, inp)

    traced_net = traced_net.flatten()
    optimized_net = tm.optimize(traced_net, "BackwardFoldScale")

    actual = optimized_net(inp)
    np.testing.assert_allclose(desired=desired, actual=actual, atol=1e-4)
    # fuse all mul to conv
    mul_list = optimized_net.graph.get_method_by_type("__mul__").as_list()
    assert len(mul_list) == 0
Exemple #9
0
def test_jit_trace():
    module = MyModule()
    module.eval()
    x = F.ones((1, 8, 14, 14))
    expect = module(x)
    traced_module = trace_module(module, x)
    func = trace(traced_module, capture_as_const=True)
    np.testing.assert_array_equal(func(x), expect)
    model = io.BytesIO()
    func.dump(model)
    model.seek(0)
    infer_cg = cgtools.GraphInference(model)
    np.testing.assert_allclose(list(infer_cg.run(x.numpy()).values())[0],
                               expect,
                               atol=1e-6)
Exemple #10
0
 def _check_qualname(net):
     inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
     net.eval()
     traced_net = trace_module(net, inp)
     base_qualname = traced_net.graph.qualname
     for node in traced_net.graph.nodes():
         qualname = node.qualname
         qualname = qualname[len(base_qualname) + 1:]
         if qualname.endswith("]"):
             qualname = qualname.rsplit(".", 1)[0]
         if qualname.startswith("["):
             qualname = ""
         traced_attr = get_subattr(traced_net, qualname)
         orig_attr = get_subattr(net, qualname)
         assert traced_attr is not None
         assert orig_attr is not None
Exemple #11
0
def test_training_converge(test_traced_module):
    net = XORNet()
    if test_traced_module:
        inp = Tensor(np.random.random((14, 2)))
        net = trace_module(net, inp)
    opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
    gm = ad.GradManager().attach(net.parameters())

    @trace(symbolic=False)
    def train(data, label):
        with gm:
            pred = net(data)
            loss = F.nn.cross_entropy(pred, label)
            gm.backward(loss)
            optim.clip_grad_norm(net.parameters(), max_norm=0.2, ord=2.0)
        return loss

    def infer(data):
        return net(data)

    train_dataset = minibatch_generator()
    losses = []

    for data, label in itertools.islice(train_dataset, 2000):
        data = Tensor(data, dtype=np.float32)
        label = Tensor(label, dtype=np.int32)
        opt.clear_grad()
        loss = train(data, label)
        optim.clip_grad_value(net.parameters(), lower=-0.1, upper=0.1)
        opt.step()
        losses.append(loss.numpy())
    assert (np.mean(losses[-100:]) <
            0.1), "Final training Loss must be low enough, get {}".format(
                np.mean(losses[-100:]))

    ngrid = 10
    x = np.linspace(-1.0, 1.0, ngrid)
    xx, yy = np.meshgrid(x, x)
    xx = xx.reshape((ngrid * ngrid, 1))
    yy = yy.reshape((ngrid * ngrid, 1))
    data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32))
    pred = infer(data)
    precision = calculate_precision(data.numpy(), pred.numpy())
    assert precision == 1.0, "Test precision must be high enough, get {}".format(
        precision)
Exemple #12
0
def test_fuse_bn(conv_cls, bn_cls):
    module = MyModule(conv_cls, bn_cls)
    module.eval()
    inp = mge.Tensor(np.random.random((1, 3, 32, 32)))
    desired = module(inp)
    traced_net = tm.trace_module(module, inp)

    traced_net = traced_net.flatten()
    optimized_net = tm.optimize(traced_net, "FuseConvBn")

    actual = optimized_net(inp)
    np.testing.assert_allclose(desired=desired, actual=actual, atol=1e-4)
    # fuse all mul to conv
    bn_list = optimized_net.graph.get_function_by_type(F.batch_norm).as_list()
    assert len(bn_list) == 0

    bn_list = optimized_net.graph.get_module_by_type(M.BatchNorm2d).as_list()
    assert len(bn_list) == 0
Exemple #13
0
def test_trace_module_2():
    class Model(M.Module):
        def __init__(self):
            super().__init__()

        def forward(self, x):
            out = x.shape
            out = apply(builtin.Elemwise(mode="ADD"), out, Tensor(1))
            return out

    traced_model = trace_module(Model(), Tensor(([1,])))

    assert isinstance(traced_model.graph._exprs[0], Apply) and isinstance(
        traced_model.graph._exprs[0].opdef, builtin.GetVarShape
    )
    assert isinstance(traced_model.graph._exprs[1], Constant)
    assert isinstance(traced_model.graph._exprs[2], Apply) and isinstance(
        traced_model.graph._exprs[2].opdef, builtin.Elemwise
    )
    assert int(traced_model(Tensor([1, 2]))[0]) == 3
Exemple #14
0
def test_functional_loader():
    class MyModule2(Module):
        def forward(self, x, y):
            return F.conv2d(x, y)

    m = MyModule2()
    x = Tensor(np.random.random((1, 3, 32, 32)))
    y = Tensor(np.random.random((3, 3, 3, 3)))
    traced_module = trace_module(m, x, y)
    orig_loader_dict = S.FUNCTIONAL_LOADER
    S.FUNCTIONAL_LOADER = {}

    @register_functional_loader(("megengine.functional.nn", "conv2d"))
    def conv2df_loader(expr):
        # expr.func = ("megengine.functional.nn","conv2d")
        kwargs = expr.kwargs
        orig_weight = expr.named_args["weight"]

        astype_expr = CallMethod(orig_weight, "astype")
        oup = TensorNode(
            astype_expr,
            shape=orig_weight.shape,
            dtype=orig_weight.dtype,
            qparams=orig_weight.qparams,
        )

        astype_expr.set_args_kwargs(orig_weight, expr.named_args["inp"].dtype)
        astype_expr.return_val = (oup, )

        expr.set_arg("weight", oup)

    obj = pickle.dumps(traced_module)
    new_module = pickle.loads(obj)
    _check_expr_users(new_module)
    _check_id(new_module)
    result = new_module(x, y)
    gt = m(x, y)
    assert (isinstance(new_module.graph._exprs[0], CallMethod)
            and len(new_module.graph._exprs) == 2)
    np.testing.assert_equal(result.numpy(), gt.numpy())
    S.FUNCTIONAL_LOADER = orig_loader_dict
Exemple #15
0
def test_extra_block():
    class PostProcess(M.Module):
        def forward(self, x):
            return x * 2

    class Net(M.Module):
        def __init__(self, traced_module):
            super().__init__()
            self.post_process = PostProcess()
            self.traced_module = traced_module

        def forward(self, x):
            x = self.traced_module(x)
            x = self.post_process(x)
            return x

    traced_module, x, expect = _init_block()
    module = Net(traced_module)
    np.testing.assert_allclose(2 * expect, module(x), atol=1e-6)
    traced_module = trace_module(module, x)
    np.testing.assert_allclose(2 * expect, traced_module(x), atol=1e-6)
Exemple #16
0
def test_opdef_loader():
    class MyModule1(Module):
        def forward(self, x, y):
            op = Elemwise("ADD")
            return apply(op, x, y)[0]

    m = MyModule1()
    x = Tensor(np.ones((20)))
    y = Tensor(np.ones((20)))
    traced_module = trace_module(m, x, y)
    orig_loader_dict = S.OPDEF_LOADER
    S.OPDEF_LOADER = {}

    @register_opdef_loader(Elemwise)
    def add_opdef_loader(expr):
        if expr.opdef_state["mode"] == "ADD":
            expr.opdef_state["mode"] = "MUL"
            node = expr.inputs[1]
            astype_expr = CallMethod(node, "astype")
            oup = TensorNode(
                astype_expr,
                shape=node.shape,
                dtype=expr.inputs[0].dtype,
                qparams=node.qparams,
            )
            astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
            astype_expr.return_val = (oup, )
            expr.inputs[1] = oup

    obj = pickle.dumps(traced_module)
    new_module = pickle.loads(obj)
    _check_id(new_module)
    _check_expr_users(new_module)
    _check_name(new_module.flatten())
    assert (isinstance(new_module.graph._exprs[0], CallMethod)
            and new_module.graph._exprs[1].opdef.mode == "MUL"
            and len(new_module.graph._exprs) == 2)
    result = new_module(x, y)
    np.testing.assert_equal(result.numpy(), x.numpy())
    S.OPDEF_LOADER = orig_loader_dict
Exemple #17
0
def test_module_loader():
    class MyModule4(Module):
        def __init__(self):
            super().__init__()
            self.conv = M.Conv2d(3, 3, 3)

        def forward(self, x):
            return self.conv(x)

    m = MyModule4()
    x = Tensor(np.random.random((1, 3, 32, 32)))
    traced_module = trace_module(m, x)
    orig_loader_dict = S.MODULE_LOADER
    S.MODULE_LOADER = {}

    @register_module_loader(("megengine.module.conv", "Conv2d"))
    def conv2dm_loader(expr):
        module = expr.inputs[0].owner
        args = list(expr.args)
        orig_inp = args[1]
        astype_expr = CallMethod(orig_inp, "astype")
        oup = TensorNode(
            astype_expr,
            shape=orig_inp.shape,
            dtype=orig_inp.dtype,
            qparams=orig_inp.qparams,
        )
        astype_expr.set_args_kwargs(orig_inp, module.weight.dtype)
        astype_expr.return_val = (oup, )
        args[1] = oup
        expr.set_args_kwargs(*args)

    obj = pickle.dumps(traced_module)
    new_module = pickle.loads(obj)
    result = new_module(x)
    gt = m(x)
    assert (isinstance(new_module.graph._exprs[1], CallMethod)
            and len(new_module.graph._exprs) == 3)
    np.testing.assert_equal(result.numpy(), gt.numpy())
    S.MODULE_LOADER = orig_loader_dict
Exemple #18
0
    def _construct_tm(self):
        if self.tm is not None:
            return self.tm
        all_opr = []
        inputs = []
        dtypes = []
        shapes = []
        datas = []
        outputs = []
        params = []
        for opr in self.ir_graph.all_oprs:
            op_cls = PARAMEXTRACT.get(type(opr), None)
            if op_cls is not None:
                params.append(op_cls(opr).extract())
            else:
                params.append({})
            all_opr.append(type(opr))
            inputs.append([i.name for i in opr.inp_tensors])
            dtypes.append([i.dtype for i in opr.inp_tensors])
            shapes.append([i.shape for i in opr.inp_tensors])
            datas.append([i.np_data for i in opr.inp_tensors])
            assert (len(opr.out_tensors) == 1
                    ), "MegEngine Cannot supports multiple outputs of one Opr"
            outputs.append([o.name for o in opr.out_tensors])

        module = ONNXModule(
            all_opr,
            inputs,
            dtypes,
            shapes,
            datas,
            outputs,
            params,
            self.quantizer,
            self.map_ir_tensor_2_mge_tensor,
            self.graph_outputs,
        )

        self.tm = tm.trace_module(module, *self.inp_data)
        return self.tm
Exemple #19
0
    def _check_qat_module(qat_net: QATModule):
        inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
        traced_net = trace_module(qat_net, inp)

        for name, qat_module in qat_net.named_modules():
            if not isinstance(qat_module, QATModule):
                continue
            traced_qat_module = get_subattr(traced_net, name)
            weight_qparams, act_qparams = get_qparams(qat_module)
            traced_weight_qparams, traced_act_qparams = get_qparams(
                traced_qat_module)
            if weight_qparams:
                check_qparams(weight_qparams, traced_weight_qparams)
            if act_qparams:
                check_qparams(act_qparams, traced_act_qparams)
        flatten_traced_net = traced_net.flatten()
        conv0_node = flatten_traced_net.graph.get_node_by_name(
            "MyModule_block0_conv0").as_unique()
        conv0_out_node = flatten_traced_net.graph.get_node_by_name(
            "MyModule_block0_conv0_out").as_unique()
        assert isinstance(conv0_node.owner, TracedModule)
        assert conv0_out_node.expr.inputs[0] is conv0_node
Exemple #20
0
def test_shared_module():
    class MyModule(M.Module):
        def __init__(self):
            super().__init__()
            self.a = M.Elemwise("ADD")
            self.b = self.a

        def forward(self, x, y):
            z = self.a(x, y)
            z = self.b(z, y)
            return z

    x = Tensor(1)
    y = Tensor(2)
    m = MyModule()
    tm = trace_module(m, x, y)
    obj = pickle.dumps(tm)
    load_tm = pickle.loads(obj)
    _check_expr_users(load_tm)
    _check_name(load_tm.flatten())
    _check_id(load_tm)
    assert load_tm.a is load_tm.b
Exemple #21
0
def get_traced_module(net, *x):
    traced_module = trace_module(net, *x)
    expect = traced_module(*x)
    return traced_module, expect
Exemple #22
0
def test_trace_module():
    enable_expr_checker()
    x = Tensor(1)
    m1 = MyModule1()
    tm1 = trace_module(m1, x)

    m2 = MyModule2()
    tm2 = trace_module(m2, x)
    inp = Tensor(2)
    gt = m1(inp)
    output = tm1(inp)
    for a, b in zip(output, gt):
        np.testing.assert_equal(a.numpy(), b.numpy())

    gt1 = m2(inp)
    output1 = tm2(inp)

    for a, b in zip(output1, gt1):
        np.testing.assert_equal(a.numpy(), b.numpy())

    a, b = Tensor(1), Tensor(2)
    m3 = MyModule3()
    gt = m3(a, b)
    tm3 = trace_module(m3, a, b)
    out = tm3(a, b)
    np.testing.assert_equal(out.numpy(), gt.numpy())
    assert isinstance(tm3.modules.__dict__["0"], M.Elemwise)
    assert isinstance(tm3.modules.__dict__["2"], TracedModule)
    assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise)
    assert isinstance(tm3.modules.__dict__["3"], M.Elemwise)

    m4 = MyModule4()
    tm4 = trace_module(m4, a, b)
    np.testing.assert_equal(tm4(a, b).numpy(), 3)
    np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
    np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)

    tm4 = trace_module(m4, a, y=b)
    np.testing.assert_equal(tm4(a, b).numpy(), 3)
    np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
    np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)

    tm4 = trace_module(m4, x=a, y=b)
    np.testing.assert_equal(tm4(a, b).numpy(), 3)
    np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
    np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)

    tm5 = trace_module(tm4, a, b)
    np.testing.assert_equal(tm5(a, b).numpy(), 3)
    np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
    np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)

    tm5 = trace_module(tm4, a, y=b)
    np.testing.assert_equal(tm5(a, b).numpy(), 3)
    np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
    np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)

    tm5 = trace_module(tm4, x=a, y=b)
    np.testing.assert_equal(tm5(a, b).numpy(), 3)
    np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
    np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)

    assert len(tm4.graph._exprs) == 1
    assert isinstance(tm4.graph._exprs[0], CallFunction)

    class MyModule5(Module):
        def __init__(self):
            super().__init__()
            self.m1 = tm4

        def forward(self, x, y):
            return self.m1(x, y)

    tm6 = trace_module(MyModule5(), a, b)
    assert tm6.m1.argspec is None
    assert tm6.m1._is_top is False
Exemple #23
0
def _init_cls(cls):
    module = cls()
    x = F.ones((1, 3, 3, 3))
    y = module(x)
    traced_module = trace_module(module, x)
    return traced_module, x, y
Exemple #24
0
from test.utils import ConvOpr, dump_mge_model

import megengine as mge
import numpy as np
from megengine.core.tensor import dtype
from megengine.quantization.quantize import quantize_qat
from megengine.traced_module import trace_module

if __name__ == "__main__":
    net = ConvOpr("normal")
    traced_module = trace_module(net, mge.tensor(net.data))
    mge.save(traced_module, "float_model.tm")
    dump_mge_model(net, net.data, "float_model")

    qat_net = quantize_qat(net)
    inp_dtype = dtype.qint8(16.0 / 128)
    data = mge.tensor(np.random.random((1, 3, 224, 224))) * 16
    data = data.astype(inp_dtype)
    inp = mge.tensor(dtype.convert_from_qint8(data.numpy()))
    inp.qparams.scale = mge.tensor(dtype.get_scale(inp_dtype))
    inp.qparams.dtype_meta = dtype._builtin_quant_dtypes["qint8"]

    qat_module = trace_module(qat_net, inp)
    mge.save(qat_module, "qat_model.tm")