def test_conv(mode): net = ConvOpr(mode) data = mge.tensor(np.random.random((1, 3, 224, 224)).astype(np.float32)) traced_module, tm_result = get_traced_module(net, data) print(traced_module.flatten().graph) _test_convert_result(data, traced_module, tm_result)
def test_deconv_qint8(): net = ConvOpr("tflite_transpose") qat_net = quantize_qat(net) inp_dtype = dtype.qint8(16.0 / 128) data = mge.tensor(np.random.random((1, 3, 64, 64))) * 16 data = data.astype(inp_dtype) inp = mge.tensor(dtype.convert_from_qint8(data.numpy())) inp.qparams.scale = mge.tensor(dtype.get_scale(inp_dtype)) inp.qparams.dtype_meta = dtype._builtin_quant_dtypes["qint8"] traced_module, tm_result = get_traced_module(qat_net, inp) print(traced_module.flatten().graph) inp = inp.astype(inp_dtype) out_dtype = traced_module.graph.outputs[0].qparams scale = out_dtype.scale.numpy() _test_convert_result( inp, traced_module, tm_result, scale=scale, require_quantize=True, max_err=max_error, )
def test_conv2d(mode): net = ConvOpr(mode) mge_result = dump_mge_model(net, net.data, tmp_file) _test_convert_result(net.data, tmp_file, mge_result, max_error)
def test_conv2d(mode): net = ConvOpr(mode) tm_module, mge_result = get_traced_module(net, mge.tensor(net.data)) _test_convert_result(net.data, tm_module, mge_result, max_error)
from test.utils import ConvOpr, dump_mge_model import megengine as mge import numpy as np from megengine.core.tensor import dtype from megengine.quantization.quantize import quantize_qat from megengine.traced_module import trace_module if __name__ == "__main__": net = ConvOpr("normal") traced_module = trace_module(net, mge.tensor(net.data)) mge.save(traced_module, "float_model.tm") dump_mge_model(net, net.data, "float_model") qat_net = quantize_qat(net) inp_dtype = dtype.qint8(16.0 / 128) data = mge.tensor(np.random.random((1, 3, 224, 224))) * 16 data = data.astype(inp_dtype) inp = mge.tensor(dtype.convert_from_qint8(data.numpy())) inp.qparams.scale = mge.tensor(dtype.get_scale(inp_dtype)) inp.qparams.dtype_meta = dtype._builtin_quant_dtypes["qint8"] qat_module = trace_module(qat_net, inp) mge.save(qat_module, "qat_model.tm")