Exemple #1
0
def test_sym_pass(iter_num=10):
    inputs_ext = { 'data': {
            'shape': (batch_size, 1, 28, 28),
    } }
    inputs = [mx.sym.var(n) for n in inputs_ext]

    data_iter = iter(val_loader)
    def data_iter_func():
        return next(data_iter)
    data, _ = data_iter_func()

    net1 = utils.load_model(*load_fname(version), inputs, ctx=ctx)
    def graph_func(data):
        return net1.forward(data.as_in_context(ctx))

    sym_file, param_file = load_fname(version)
    sym, params = mx.sym.load(sym_file), nd.load(param_file)
    sym, params = spass.sym_quant_prepare(sym, params, inputs_ext)
    if True:
        mrt = _mrt.MRT(sym, params, inputs_ext)
        mrt.set_data('data', data)
        mrt.calibrate(ctx=ctx)
        mrt.set_output_prec(8)
        qsym, qparams, inputs_ext = mrt.quantize()
    else:
        inputs_ext['data']['data'] = data
        th_dict = calib.sym_calibrate(sym, params, inputs_ext, ctx=ctx)
        qsym, qparams, precs, _ = calib.sym_simulate(sym, params, inputs_ext, th_dict)
        qsym, qparams = calib.sym_realize(qsym, qparams, inputs_ext, precs, "cvm")
    dump_sym, dump_params, dump_ext = load_fname(version, "sym.quantize", True)
    sim.save_ext(dump_ext, inputs_ext)
    nd.save(dump_params, qparams)
    open(dump_sym, "w").write(qsym.tojson())

    dump_sym, dump_params, dump_ext = load_fname(version, "sym.quantize", True)
    (inputs_ext,) = sim.load_ext(dump_ext)
    inputs = [mx.sym.var(n) for n in inputs_ext]
    net2 = utils.load_model(dump_sym, dump_params, inputs, ctx=ctx)
    def cvm_quantize(data):
        data = sim.load_real_data(data, 'data', inputs_ext)
        return net2.forward(data.as_in_context(ctx))

    utils.multi_eval_accuracy(graph_func, data_iter_func,
            cvm_quantize,
            iter_num=iter_num)
Exemple #2
0
def test_sym_pass(batch_size=10, iter_num=10, quantize=True):

    logger = logging.getLogger("log.test.sym.pass")

    calib_ctx = mx.gpu(2)
    ctx = [mx.gpu(int(i)) for i in "1,2,3,4".split(',') if i.strip()]
    input_size = 299
    version = "v3"
    h, w = input_size, input_size
    inputs_ext = {
        'data': {
            'shape': (batch_size, 3, h, w),
        }
    }
    inputs = [mx.sym.var(name) for name in inputs_ext]

    logger.info("load dataset, symbol and parameters")
    data_iter = ds.load_imagenet_rec(batch_size, input_size)

    def data_iter_func():
        data = data_iter.next()
        return data.data[0], data.label[0]

    net1 = utils.load_model(*load_fname(version), inputs, ctx=ctx)
    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)
    acc_top1.reset()
    acc_top5.reset()

    def inception_v3(data, label):
        data = gluon.utils.split_and_load(data,
                                          ctx_list=ctx,
                                          batch_axis=0,
                                          even_split=False)
        res = [net1.forward(d) for d in data]
        res = nd.concatenate(res)
        acc_top1.update(label, res)
        _, top1 = acc_top1.get()
        acc_top5.update(label, res)
        _, top5 = acc_top5.get()
        return "top1={:6.2%} top5={:6.2%}".format(top1, top5)

    if quantize:
        sym_file, param_file = load_fname(version)
        sym, params = mx.sym.load(sym_file), nd.load(param_file)
        sym, params = spass.sym_quant_prepare(sym, params, inputs_ext)
        data, _ = data_iter_func()
        if True:
            dump_sym, dump_params, dump_ext = load_fname(version, "mrt", True)
            mrt = _mrt.MRT(sym, params, inputs_ext)
            mrt.set_data('data', data)
            mrt.calibrate(ctx=calib_ctx)
            mrt.set_output_prec(8)
            qsym, qparams, inputs_ext = mrt.quantize()
        else:
            dump_sym, dump_params, dump_ext = load_fname(
                version, "sym.quantize", True)
            inputs_ext['data']['data'] = data
            th_dict = calib.sym_calibrate(sym,
                                          params,
                                          inputs_ext,
                                          ctx=calib_ctx)
            qsym, qparams, precs, _ = calib.sym_simulate(
                sym, params, inputs_ext, th_dict)
            qsym, qparams = calib.sym_realize(qsym, qparams, inputs_ext, precs)
        sim.save_ext(dump_ext, inputs_ext)
        nd.save(dump_params, qparams)
        open(dump_sym, "w").write(qsym.tojson())

    dump_sym, dump_params, dump_ext = load_fname(version, "mrt", True)
    (inputs_ext, ) = sim.load_ext(dump_ext)
    net2 = utils.load_model(dump_sym, dump_params, inputs, ctx=ctx)
    qacc_top1 = mx.metric.Accuracy()
    qacc_top5 = mx.metric.TopKAccuracy(5)
    qacc_top1.reset()
    qacc_top5.reset()

    def cvm_quantize(data, label):
        data = sim.load_real_data(data, 'data', inputs_ext)
        data = gluon.utils.split_and_load(data,
                                          ctx_list=ctx,
                                          batch_axis=0,
                                          even_split=False)
        res = [net2.forward(d) for d in data]
        res = nd.concatenate(res)
        qacc_top1.update(label, res)
        _, top1 = qacc_top1.get()
        qacc_top5.update(label, res)
        _, top5 = qacc_top5.get()
        return "top1={:6.2%} top5={:6.2%}".format(top1, top5)

    utils.multi_validate(inception_v3,
                         data_iter_func,
                         cvm_quantize,
                         iter_num=iter_num,
                         logger=logger)
Exemple #3
0
def test_mx_quantize(batch_size=10, iter_num=10):
    logger = logging.getLogger("log.test.mx.quantize")

    ctx = [mx.gpu(int(i)) for i in "1,3".split(',') if i.strip()]
    inputs_ext = { 'data': {
        'shape': (batch_size, 3, 224, 224),
    }}
    inputs = [mx.sym.var(n) for n in inputs_ext]

    data_iter = ds.load_imagenet_rec(batch_size)
    def data_iter_func():
        data = data_iter.next()
        return data.data[0], data.label[0]
    data, _ = data_iter_func()

    net1 = utils.load_model(*load_fname(version), inputs, ctx=ctx)
    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)
    acc_top1.reset()
    acc_top5.reset()
    def mobilenet(data, label):
        data = gluon.utils.split_and_load(data, ctx_list=ctx, batch_axis=0, even_split=False)
        res = [net1.forward(d) for d in data]
        res = nd.concatenate(res)
        acc_top1.update(label, res)
        _, top1 = acc_top1.get()
        acc_top5.update(label, res)
        _, top5 = acc_top5.get()
        return "top1={:6.2%} top5={:6.2%}".format(top1, top5)

    calib_ctx = mx.gpu(1)
    sym_fname, param_fname = load_fname(version)
    sym, params = mx.sym.load(sym_fname), nd.load(param_fname)
    sym, params = spass.sym_quant_prepare(sym, params, inputs_ext)
    if True:
        if True:
            mrt = _mrt.MRT(sym, params, inputs_ext)
            mrt.set_data('data', data)
            mrt.calibrate()
            # [ 0.0008745864 0.03330660510427334 ] 0.6670066884888368 0.7753906
            # mrt.set_threshold("mobilenet0_dense0_weight", 0.67)
            # # [ -0.0036011334 0.054821780899052534 ] 1.100036751338784 1.4626989
            # mrt.set_threshold("mobilenet0_conv24_batchnorm24_fwd_weight", 1.1)
            # # [ 0.013243316 1.7543557133786065 ] 70.18747185088569 94.66275
            # mrt.set_threshold("mobilenet0_conv23_batchnorm23_fwd_weight", 35.10)
            # # [ -0.0016149869 0.05713169649243355 ] 1.1442489167675376 1.7122083
            # mrt.set_threshold("mobilenet0_conv20_batchnorm20_fwd_weight", 1.144)
            # # [ -0.0015804865 0.04523811489343643 ] 0.9063427844084799 1.0745146
            # mrt.set_threshold("mobilenet0_conv16_batchnorm16_fwd_weight", 0.90)
            # # [ 0.4315614 2.447332109723772 ] 49.37820360490254 63.959927
            # mrt.set_threshold("mobilenet0_conv2_batchnorm2_fwd", 49.37)
            # # [ 0.9770754 1.3392452512468611 ] 27.761980422905516 40.729546
            # mrt.set_threshold("mobilenet0_relu2_fwd", 27.76)
            # [ 1.0975745 1.0489919010632773 ] 22.077412493692915 23.784576
            # mrt.set_threshold("mobilenet0_relu4_fwd", 22.08)
            # # [ 0.9885562 2.360489403014386 ] 48.19834426651407 69.22121
            # mrt.set_threshold("mobilenet0_conv5_batchnorm5_fwd", 48.2)
            # # [ 0.7895588 1.0544661745870065 ] 21.878882319617176 30.95745
            # mrt.set_threshold("mobilenet0_relu17_fwd", 21.88)
            # # [ 0.8717863 1.0887600296120434 ] 22.646986888608513 28.265652
            # mrt.set_threshold("mobilenet0_relu19_fwd", 22.65)
            # # [ 0.35124516 0.6501711574631898 ] 13.354668314135012 20.770807
            # mrt.set_threshold("mobilenet0_relu20_fwd", 13.35)
            # # [ 0.9378179 1.110470714216975 ] 23.147232155910086 27.886068
            # mrt.set_threshold("mobilenet0_relu21_fwd", 23.15)
            # # [ 0.36263302 0.6352599878026505 ] 13.067832775738754 17.18809
            # mrt.set_threshold("mobilenet0_relu22_fwd", 13.07)
            # # [ 0.19875833 0.49999100821358816 ] 10.198578498193196 16.625143
            # mrt.set_threshold("mobilenet0_relu24_fwd", 10.2)
            # # [ 0.32357717 1.6308352606637138 ] 65.55698759215218 75.84912
            # mrt.set_threshold("mobilenet0_conv25_batchnorm25_fwd", 32.94)
            # # [ 0.36793178 1.512995992388044 ] 30.62785163096019 49.464615
            # mrt.set_threshold("mobilenet0_relu26_fwd", 30.63)
            # # [ 18.028658 38.61970520019531 ] 790.4227619171143 805.51886
            # mrt.set_threshold("sum0", 790.423)
            mrt.set_output_prec(8)
            qsym, qparams, inputs_ext = mrt.quantize()
        else:
            inputs_ext['data']['data'] = data
            th_dict = calib.sym_calibrate(sym, params, inputs_ext, ctx=calib_ctx)
            qsym, qparams, precs, _ = calib.sym_simulate(sym, params, inputs_ext, th_dict)
            qsym, qparams = calib.sym_realize(qsym, qparams, inputs_ext, precs)
        dump_sym, dump_params, dump_ext = load_fname(version, "sym.quantize", True)
        sim.save_ext(dump_ext, inputs_ext)
        nd.save(dump_params, qparams)
        open(dump_sym, "w").write(qsym.tojson())

        dump_sym, dump_params = load_fname(version, "nnvm.compile")
        nnvm_sym, nnvm_params = spass.mxnet_to_nnvm(qsym, qparams, inputs_ext)
        spass.cvm_build(nnvm_sym, nnvm_params, inputs_ext, dump_sym, dump_params)

    dump_sym, dump_params, dump_ext = load_fname(version, "sym.quantize", True)
    (inputs_ext,) = sim.load_ext(dump_ext)
    net2 = utils.load_model(dump_sym, dump_params, inputs, ctx=ctx)
    qacc_top1 = mx.metric.Accuracy()
    qacc_top5 = mx.metric.TopKAccuracy(5)
    qacc_top1.reset()
    qacc_top5.reset()
    def cvm_quantize(data, label):
        data = sim.load_real_data(data, 'data', inputs_ext)
        data = gluon.utils.split_and_load(data, ctx_list=ctx, batch_axis=0, even_split=False)
        res = [net2.forward(d) for d in data]
        res = nd.concatenate(res)
        qacc_top1.update(label, res)
        _, top1 = qacc_top1.get()
        qacc_top5.update(label, res)
        _, top5 = qacc_top5.get()
        return "top1={:6.2%} top5={:6.2%}".format(top1, top5)

    utils.multi_validate(mobilenet, data_iter_func,
            cvm_quantize,
            iter_num=iter_num, logger=logger)
    _, top1 = acc_top1.get()
    acc_top5.update(label, res)
    _, top5 = acc_top5.get()
    return "top1={:6.2%} top5={:6.2%}".format(top1, top5)


if True:
    sym, params = mx.sym.load(sym_file), nd.load(param_file)
    sym, params = spass.sym_quant_prepare(sym, params, inputs_ext)

    import os
    open(os.path.expanduser('~/tvm-cvm/data/test_ryt2.json'),
         'w').write(sym.tojson())
    exit()

    qsym, qparams, precs, _ = calib.sym_simulate(sym, params, inputs_ext, data,
                                                 calib_ctx)
    qsym, qparams = calib.sym_realize(qsym, qparams, inputs_ext, precs, "tvm")
    dump_sym, dump_params, dump_ext = load_fname(version, "sym.quantize", True)
    sim.save_ext(dump_ext, inputs_ext)
    nd.save(dump_params, qparams)
    open(dump_sym, "w").write(qsym.tojson())

dump_sym, dump_params, dump_ext = load_fname(version, "sym.quantize", True)
sym, params = mx.sym.load(dump_sym), nd.load(dump_params)
(inputs_ext, ) = sim.load_ext(dump_ext)
inputs = [mx.sym.var(n) for n in inputs_ext]
net2 = utils.load_model(dump_sym, dump_params, inputs, ctx=ctx)
qacc_top1 = mx.metric.Accuracy()
qacc_top5 = mx.metric.TopKAccuracy(5)
qacc_top1.reset()
qacc_top5.reset()
Exemple #5
0
    return "top1={:6.2%} top5={:6.2%}".format(top1, top5)


if True:
    sym, params = mx.sym.load(sym_file), nd.load(param_file)
    sym, params = spass.sym_quant_prepare(sym, params, inputs_ext)
    if True:
        mrt = _mrt.MRT(sym, params, inputs_ext)
        mrt.set_data('data', data)
        mrt.calibrate(ctx=calib_ctx)
        mrt.set_output_prec(8)
        qsym, qparams, inputs_ext = mrt.quantize()
    else:
        inputs_ext['data']['data'] = data
        th_dict = calib.sym_calibrate(sym, params, inputs_ext, ctx=calib_ctx)
        qsym, qparams, precs, _ = calib.sym_simulate(sym, params, inputs_ext,
                                                     th_dict)
        qsym, qparams = calib.sym_realize(qsym, qparams, inputs_ext, precs,
                                          "cvm")
    dump_sym, dump_params, dump_ext = load_fname(version, "sym.quantize", True)
    sim.save_ext(dump_ext, inputs_ext)
    nd.save(dump_params, qparams)
    open(dump_sym, "w").write(qsym.tojson())

dump_sym, dump_params, dump_ext = load_fname(version, "sym.quantize", True)
(inputs_ext, ) = sim.load_ext(dump_ext)
inputs = [mx.sym.var(n) for n in inputs_ext]
net2 = utils.load_model(dump_sym, dump_params, inputs, ctx=ctx)
qacc_top1 = mx.metric.Accuracy()
qacc_top5 = mx.metric.TopKAccuracy(5)
qacc_top1.reset()
qacc_top5.reset()
def test_sym_pass(batch_size=10, iter_num=10):
    logger = logging.getLogger("log.test.sym.pass")

    base_ctx = mx.gpu(1)
    ctx = mx.gpu(2)
    input_size = 416
    h, w = input_size, input_size
    inputs_ext = {
        'data': {
            'shape': (batch_size, 3, h, w),
        }
    }

    val_data = dataset.load_voc(batch_size, input_size)
    val_data_iter = iter(val_data)

    def data_iter_func():
        data, label = next(val_data_iter)
        return data, label

    sym_file, param_file = load_fname("_darknet53_voc")
    sym, params = mx.sym.load(sym_file), nd.load(param_file)
    sym, params = spass.sym_quant_prepare(sym, params, inputs_ext)
    if False:
        th_dict = {}
        for i in range(16):
            data, _ = data_iter_func()
            for k, v in inputs_ext.items():
                v['data'] = data
            th_dict = calib.sym_calibrate(sym,
                                          params,
                                          inputs_ext,
                                          old_ths=th_dict,
                                          ctx=ctx)
        _, _, dump_ext = load_fname("_darknet53_voc", "dict", True)
        sim.save_ext(dump_ext, th_dict)

    _, _, dump_ext = load_fname("_darknet53_voc", "dict", True)
    (th_dict, ) = sim.load_ext(dump_ext)
    inputs = [mx.sym.var(name) for name in inputs_ext]
    net1 = mx.gluon.nn.SymbolBlock(sym, inputs)
    utils.load_parameters(net1, params, ctx=ctx)
    metric = dataset.load_voc_metric()
    metric.reset()

    def yolov3(data, label):
        def net(data):
            out = net1(data.as_in_context(ctx))
            print([o[0][0][:] for o in out])
            return out

        acc = validate_data(net, data, label, metric)
        return "{:6.2%}".format(acc)

    keys = [
        'yolov30_yolooutputv30_conv0_fwd',
        'yolov30_yolooutputv31_conv0_fwd',
        'yolov30_yolooutputv32_conv0_fwd',
    ]
    base, base_params, base_inputs_ext, top, top_params, top_inputs_ext \
            = split_model(sym, params, inputs_ext, keys, logger)
    dump_sym, dump_params = load_fname("_darknet53_voc", "base")
    open(dump_sym, "w").write(base.tojson())
    dump_sym, dump_params, dump_ext = load_fname("_darknet53_voc", "top", True)
    open(dump_sym, "w").write(top.tojson())
    nd.save(dump_params, top_params)
    sim.save_ext(dump_ext, top_inputs_ext)

    base_inputs = [mx.sym.var(n) for n in base_inputs_ext]
    base_graph = mx.gluon.nn.SymbolBlock(base, base_inputs)
    utils.load_parameters(base_graph, base_params, ctx=base_ctx)

    top_inputs = [mx.sym.var(n) for n in top_inputs_ext]
    top_graph = mx.gluon.nn.SymbolBlock(top, top_inputs)
    utils.load_parameters(top_graph, top_params, ctx=ctx)

    # quantize base graph
    if False:
        qbase, qbase_params, qbase_prec, base_oscales = calib.sym_simulate(
            base, base_params, base_inputs_ext, th_dict)
        qbase, qbase_params = calib.sym_realize(qbase, qbase_params,
                                                base_inputs_ext, qbase_prec)
        dump_sym, dump_params, dump_ext = load_fname("_darknet53_voc",
                                                     "base.quantize", True)
        open(dump_sym, "w").write(qbase.tojson())
        sim.save_ext(dump_ext, base_inputs_ext, base_oscales)
        nd.save(dump_params, qbase_params)

    if False:
        qb_sym, qb_params, qb_ext = load_fname("_darknet53_voc",
                                               "base.quantize", True)
        net2_inputs_ext, base_oscales = sim.load_ext(qb_ext)
        net2_inputs = [mx.sym.var(n) for n in net2_inputs_ext]
        net2 = utils.load_model(qb_sym, qb_params, net2_inputs, ctx=ctx)
        base_metric = dataset.load_voc_metric()
        base_metric.reset()

        def base_quantize(data, label):
            def net(data):
                data = sim.load_real_data(data, 'data', net2_inputs_ext)
                tmp = list(net2(data.as_in_context(ctx)))
                tmp = [t / base_oscales[i] for i, t in enumerate(tmp)]
                return top_graph(*tmp)

            acc = validate_data(net, data, label, base_metric)
            return "{:6.2%}".format(acc)

    # quantize top graph
    if False:
        in_bit, out_bit = 8, 30
        outputs_ext = {
            'yolov30_yolooutputv30_expand_dims0': {
                'threshold': 1,
                'type': 'score'
            },
            'yolov30_yolooutputv31_expand_dims0': {
                'threshold': 1,
                'type': 'score'
            },
            'yolov30_yolooutputv32_expand_dims0': {
                'threshold': 1,
                'type': 'score'
            },
            'yolov30_yolooutputv30_tile0': {
                'threshold': 416,
                'type': 'bbox'
            },
            'yolov30_yolooutputv31_tile0': {
                'threshold': 416,
                'type': 'bbox'
            },
            'yolov30_yolooutputv32_tile0': {
                'threshold': 416,
                'type': 'bbox'
            },
            'yolov30_yolooutputv30_broadcast_add1': {
                'fixed': True,
                'type': 'ids'
            },
            'yolov30_yolooutputv31_broadcast_add1': {
                'fixed': True,
                'type': 'ids'
            },
            'yolov30_yolooutputv32_broadcast_add1': {
                'fixed': True,
                'type': 'ids'
            },
        }
        qsym, qparams, type_ext = anno.mixed_precision(top,
                                                       top_params,
                                                       top_inputs_ext,
                                                       th_dict,
                                                       in_bit=in_bit,
                                                       out_bit=out_bit,
                                                       out_ext=outputs_ext,
                                                       runtime="cvm")
        out_scales = [type_ext['ids'], type_ext['score'], type_ext['bbox']]

        dump_sym, dump_params, dump_ext = load_fname("_darknet53_voc",
                                                     "top.quantize", True)
        open(dump_sym, "w").write(qsym.tojson())
        sim.save_ext(dump_ext, top_inputs_ext, out_scales)
        nd.save(dump_params, qparams)

    if True:
        sym_file, param_file, ext_file = load_fname("_darknet53_voc",
                                                    "top.quantize", True)
        net3_inputs_ext, net3_scales = sim.load_ext(ext_file)
        top_sym = base_graph(mx.sym.Group(base_inputs))
        top_names = [c.attr('name') for c in top_sym]
        net3_inputs = [mx.sym.var(n) for n in net3_inputs_ext]
        net3 = utils.load_model(sym_file, param_file, net3_inputs, ctx=ctx)
        top_qmetric = dataset.load_voc_metric()
        top_qmetric.reset()

        def top_quantize(data, label):
            def net(data):
                tmp = base_graph(data.as_in_context(base_ctx))
                tmp = [t.as_in_context(ctx) for t in tmp]
                tmp = [
                    sim.load_real_data(tmp[i], n, net3_inputs_ext)
                    for i, n in enumerate(top_names)
                ]
                out = net3(*tmp)
                out = [(t / net3_scales[i]) for i, t in enumerate(out)]
                print([o[0][0][:] for o in out])
                return out

            acc = validate_data(net, data, label, top_qmetric)
            return "{:6.2%}".format(acc)

    # merge quantize model
    if False:
        qb_sym, qb_params, qb_ext = load_fname("_darknet53_voc",
                                               "base.quantize", True)
        qbase, qbase_params = mx.sym.load(qb_sym), nd.load(qb_params)
        qbase_inputs_ext, _ = sim.load_ext(qb_ext)
        qt_sym, qt_params, qt_ext = load_fname("_darknet53_voc",
                                               "top.quantize", True)
        qtop, qtop_params = mx.sym.load(qt_sym), nd.load(qt_params)
        _, out_scales = sim.load_ext(qt_ext)
        maps = dict(
            zip([c.attr('name') for c in qbase],
                [c.attr('name') for c in base]))
        qsym, qparams = merge_model(qbase, qbase_params, qbase_inputs_ext,
                                    qtop, qtop_params, maps)
        sym_file, param_file, ext_file = load_fname("_darknet53_voc",
                                                    "all.quantize", True)
        open(sym_file, "w").write(qsym.tojson())
        nd.save(param_file, qparams)
        sim.save_ext(ext_file, qbase_inputs_ext, out_scales)

    if False:
        sym_file, param_file, ext_file = load_fname("_darknet53_voc",
                                                    "all.quantize", True)
        net4_inputs_ext, net4_scales = sim.load_ext(ext_file)
        net4_inputs = [mx.sym.var(n) for n in net4_inputs_ext]
        net4 = utils.load_model(sym_file, param_file, net4_inputs, ctx=ctx)
        all_qmetric = dataset.load_voc_metric()
        all_qmetric.reset()

        def all_quantize(data, label):
            def net(data):
                data = sim.load_real_data(data, 'data', net4_inputs_ext)
                out = net4(data.as_in_context(ctx))
                out = [(t / net4_scales[i]) for i, t in enumerate(out)]
                return out

            acc = validate_data(net, data, label, all_qmetric)
            return "{:6.2%}".format(acc)

    if False:
        sym_file, param_file, ext_file = load_fname("_darknet53_voc",
                                                    "all.quantize", True)
        net4_inputs_ext, net4_scales = sim.load_ext(ext_file)
        datadir = "/data/voc/data/"
        for i in range(50):
            countdir = datadir + "/" + str(i)
            os.makedirs(countdir, exist_ok=True)
            data, label = data_iter_func()
            data = sim.load_real_data(data, 'data', net4_inputs_ext)
            np.save(countdir + "/data.npy", data.asnumpy().astype('int8'))
            np.save(countdir + "/label.npy", label.asnumpy())

        # data = sim.load_real_data(data, 'data', net4_inputs_ext)
        # np.save("/tmp/yolo/data", data.asnumpy().astype('int8'))
        # out = net4(data.as_in_context(ctx))
        # for i, o in enumerate(out):
        #    np.save("/tmp/yolo/result"+str(i), o.asnumpy().astype('int32'))
        exit()

    utils.multi_validate(
        yolov3,
        data_iter_func,
        top_quantize,
        # base_quantize, # top_quantize, all_quantize,
        iter_num=iter_num,
        logger=logger)
def test_sym_pass(batch_size=10, iter_num=10):
    logger = logging.getLogger("log.test.sym.pass")

    version = ""
    sym_fname, param_fname = load_fname(version)
    sym, params = mx.sym.load(sym_fname), nd.load(param_fname)
    params = {k.split(':')[1]: v for k, v in params.items()}

    calib_ctx = mx.gpu(2)
    ctx = [mx.gpu(int(i)) for i in "1,2,3,4,5,6,7".split(',') if i.strip()]
    inputs_ext = {
        'data': {
            'shape': (batch_size, 3, 224, 224),
        }
    }
    inputs = [mx.sym.var(name) for name in inputs_ext]

    logger.info("load dataset, symbol and parameters")

    order = sutils.topo_sort(sym)
    for op_head in order:
        if op_head.attr('name') == 'classifier':
            break
    sym = op_head
    net = mx.gluon.nn.SymbolBlock(sym, inputs)
    load_parameters(net, params, ctx=ctx)

    data_iter = ds.load_imagenet_rec(batch_size)

    def data_iter_func():
        data = data_iter.next()
        return data.data[0], data.label[0]

    for i in range(10):
        if i == 3:
            break
        data, _ = data_iter_func()
    data_iter.reset()

    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)
    acc_top1.reset()
    acc_top5.reset()

    def resnet(data, label):
        data = gluon.utils.split_and_load(data,
                                          ctx_list=ctx,
                                          batch_axis=0,
                                          even_split=False)
        res = [net.forward(d) for d in data]
        res = nd.concatenate(res)
        acc_top1.update(label, res)
        _, top1 = acc_top1.get()
        acc_top5.update(label, res)
        _, top5 = acc_top5.get()
        return "top1={:6.2%} top5={:6.2%}".format(top1, top5)

    sym, params = spass.sym_quant_prepare(sym, params, inputs_ext)
    qsym, qparams, precs, _ = calib.sym_simulate(sym, params, inputs_ext, data,
                                                 calib_ctx)
    qsym, qparams = calib.sym_realize(qsym, qparams, inputs_ext, precs, "cvm")
    dump_sym, dump_params, dump_ext = load_fname(version, "sym.quantize", True)
    sim.save_ext(dump_ext, inputs_ext)
    nd.save(dump_params, qparams)
    open(dump_sym, "w").write(qsym.tojson())

    dump_sym, dump_params, dump_ext = load_fname(version, "sym.quantize", True)
    (inputs_ext, ) = sim.load_ext(dump_ext)
    net3 = utils.load_model(dump_sym, dump_params, inputs, ctx=ctx)
    qacc_top1 = mx.metric.Accuracy()
    qacc_top5 = mx.metric.TopKAccuracy(5)
    qacc_top1.reset()
    qacc_top5.reset()

    def cvm_quantize(data, label):
        data = sim.load_real_data(data, 'data', inputs_ext)
        data = gluon.utils.split_and_load(data,
                                          ctx_list=ctx,
                                          batch_axis=0,
                                          even_split=False)
        res = [net3.forward(d) for d in data]
        res = nd.concatenate(res)
        qacc_top1.update(label, res)
        _, top1 = qacc_top1.get()
        qacc_top5.update(label, res)
        _, top5 = qacc_top5.get()
        return "top1={:6.2%} top5={:6.2%}".format(top1, top5)

    utils.multi_validate(resnet,
                         data_iter_func,
                         cvm_quantize,
                         iter_num=iter_num,
                         logger=logger)