Ejemplo n.º 1
0
def sym_simulate(symbol, params, inputs_ext, outputs_ext, precs, th_dict):
    logger = logging.getLogger('log.simulate')

    infer_shapes = spass.sym_infer_shape(symbol, params, inputs_ext)
    scales = {}
    for k, v in outputs_ext.items():
        if 'threshold' in v:
            logger.debug("Update thresholds of output %s", k)
            th_dict[k] = v['threshold']
        if 'fixed' in v and v['fixed']:
            scales[k] = 1
    _update_scale_and_precs(symbol, params, inputs_ext, th_dict, precs, scales)

    ssym, sparams = topo_visit(symbol,
                               params,
                               inputs_ext,
                               get_op=get_mxnet_op,
                               logger=logger,
                               callback=_simulate_layer,
                               scales=scales,
                               precs=precs)
    _, sparams = topo_visit(ssym,
                            sparams,
                            inputs_ext,
                            get_op=get_mxnet_op,
                            logger=logger,
                            callback=_simulate_parameters,
                            scales=scales)
    sparams = examine_parameters(ssym, sparams, inputs_ext)
    return ssym, sparams, scales
Ejemplo n.º 2
0
def split_model(symbol, params, inputs_ext, keys):
    infer_shapes = spass.sym_infer_shape(symbol, params, inputs_ext)
    bases = [s for s in topo_sort(symbol) if s.attr('name') in keys]
    base = mx.sym.Group(bases)
    base_params = {k: params[k] for k in base.list_inputs() if k in params}
    base_inputs_ext = inputs_ext

    graph = {}
    inputs = {k: v for k, v in inputs_ext.items()}
    for sym in topo_sort(symbol):
        name, op_name = sym.attr('name'), sym.attr('op_name')
        childs, attr = sym_iter(sym.get_children()), sym.list_attr()
        node = sym
        if childs is not None:
            childs = [graph[c.attr('name')] for c in childs]
            node = get_mxnet_op(op_name)(*childs, **attr, name=name)
        if name in keys:
            node = mx.sym.var(name)
            inputs[name] = {'shape': infer_shapes[name]}
        graph[name] = node
    nodes = [graph[sym.attr('name')] for sym in symbol]
    top = nodes[0] if len(nodes) == 1 else mx.sym.Group(nodes)
    top_params = {k: params[k] for k in top.list_inputs() if k in params}
    top_inputs_ext = {k: v for k, v in inputs.items() if k not in inputs_ext}

    return base, base_params, base_inputs_ext, top, top_params, top_inputs_ext
Ejemplo n.º 3
0
    def _set_prerequisites(self):
        self.sym, self.prm = check_graph(self.sym, self.prm)

        for sym in topo_sort(self.sym):
            name, op_name = sym.attr('name'), sym.attr('op_name')
            self.precs[name] = {}

        for k in self.ins_ext:
            self.precs[k][out_key] = PREC(8, L0)

        self.shpes = spass.sym_infer_shape(self.sym, self.prm, self.ins_ext)
Ejemplo n.º 4
0
def sym_annotate(symbol,
                 params,
                 inputs_ext,
                 outputs_ext,
                 th_dict,
                 in_bit=8,
                 out_bit=8):
    logger = logging.getLogger('log.infer.precision')
    precs = {}
    topo_visit(symbol,
               params,
               inputs_ext,
               get_op=get_mxnet_op,
               logger=logger,
               callback=_infer_fixed_precs,
               precs=precs)
    _update_input_precs(precs, in_bit, inputs_ext)
    infer_shapes = spass.sym_infer_shape(symbol, params, inputs_ext)
    topo_visit(symbol,
               params,
               inputs_ext,
               get_op=get_mxnet_op,
               logger=logger,
               callback=_infer_dynamic_precs,
               infer_shapes=infer_shapes,
               precs=precs,
               fix_param=False)
    topo_visit(symbol,
               params,
               inputs_ext,
               get_op=get_mxnet_op,
               logger=logger,
               callback=_infer_parameter_precs,
               precs=precs,
               outputs_ext=outputs_ext)
    topo_visit(symbol,
               params,
               inputs_ext,
               get_op=get_mxnet_op,
               logger=logger,
               callback=_infer_dynamic_precs,
               infer_shapes=infer_shapes,
               precs=precs,
               fix_param=True)

    for sym in symbol:
        precs[sym.attr('name')][target_key] = out_bit
    symbol, params = topo_visit(symbol,
                                params,
                                inputs_ext,
                                get_op=get_mxnet_op,
                                logger=logger,
                                callback=_sym_annotate,
                                precs=precs,
                                th_dict=th_dict)

    for sym in topo_sort(symbol):
        name, op_name = sym.attr('name'), sym.attr('op_name')
        childs = sym_iter(sym.get_children())
        childs = childs if childs else []
        logger.debug("%-20s name=%-40s out_prec=%s in_precs=%s", op_name, name,
                     precs[name][out_key],
                     [precs[c.attr('name')][name] for c in childs])
    return symbol, params, precs
Ejemplo n.º 5
0
    data = gluon.utils.split_and_load(data,
                                      ctx_list=ctx,
                                      batch_axis=0,
                                      even_split=False)
    res = [net1.forward(d) for d in data]
    res = nd.concatenate(res)
    acc_top1.update(label, res)
    _, top1 = acc_top1.get()
    acc_top5.update(label, res)
    _, top5 = acc_top5.get()
    return "top1={:6.2%} top5={:6.2%}".format(top1, top5)


if True:
    sym, params = mx.sym.load(sym_file), nd.load(param_file)
    infer_shapes = (spass.sym_infer_shape(sym, params, inputs_ext))
    sym, params = spass.sym_quant_prepare(sym, params, inputs_ext)

    if True:
        mrt = _mrt.MRT(sym, params, inputs_ext)
        mrt.set_data('data', data)
        mrt.calibrate(ctx=calib_ctx)
        mrt.set_output_prec(8)
        qsym, qparams, inputs_ext = mrt.quantize()
    else:
        inputs_ext['data']['data'] = data
        th_dict = calib.sym_calibrate(sym, params, inputs_ext, ctx=calib_ctx)
        qsym, qparams, precs, _ = calib.sym_simulate(sym, params, inputs_ext,
                                                     th_dict)
        qsym, qparams = calib.sym_realize(qsym, qparams, inputs_ext, precs,
                                          "cvm")