def sym_config_infos(symbol, params, cfg_dict={}, logger=logging): """ Customized graph-level topo pass definition. Interface for MRT main2 configuration Create customized samplers and optimizors. Use it just before calibration. """ names = set() def _collect_names(symbol, params, **kwargs): names.add(symbol.attr("name")) topo_visit_transformer(symbol, params, _collect_names) noncfgs = set() keys = list(cfg_dict.keys()) for name in keys: if name == _RES_NAME: continue if name not in names: del cfg_dict[name] noncfgs.add(name) if noncfgs: logger.warn( "Symbols (names: %s) not found in graph." + \ "Please double check config file (.ini)." % list(noncfgs)) if _RES_NAME in cfg_dict: cfg_info = cfg_dict.pop(_RES_NAME) keys = cfg_dict.keys() for name in [n for n in names if n not in keys]: cfg_dict[name] = cfg_info def _sym_config_infos(sym, params, **kwargs): name = sym.attr("name") cfg_info = cfg_dict.get(name, {}) gn_info = cfg_info.get("gn_info", DEFAULT_GN_INFO) quant_type = cfg_info.get("quant_type", DEFAULT_QUANT_TYPE) get_quantizer(quant_type) opt_info = cfg_info.get("opt_type", make_key_opt(DEFAULT_OPT_INFO)) get_optimizor(opt_info) cfg_dict[name] = cfg_info if cfg_info else \ {"gn_info": gn_info, "quant_type": quant_type, "opt_info": opt_info} topo_visit_transformer(symbol, params, _sym_config_infos) return cfg_dict
def sym_separate_bias(symbol, params): """ Separate bias attribute as an independent symbol in rewrite stage. """ def _separate_bias(op, **kwargs): name, op_name = op.attr('name'), op.attr('op_name') attr, childs = op.list_attr(), sutils.sym_iter(op.get_children()) if childs and len(childs) < 3 or op_name not in \ [Convolution.op_name, FullyConnected.op_name]: return op attr['no_bias'] = True op = sutils.get_mxnet_op(op_name)(childs[0], childs[1], **attr, name=N.n(name)) bn = childs[2].attr('name') if op_name == Convolution.op_name: if 'layout' in attr: assert attr['layout'] == 'NCHW' B = mx.sym.expand_dims(childs[2], axis=0, name=N.n('expand_dims')) B = mx.sym.expand_dims(B, axis=-1, name=N.n('expand_dims')) B = mx.sym.expand_dims(B, axis=-1, name=N.n(bn)) else: B = mx.sym.expand_dims(childs[2], axis=0, name=N.n(bn)) op = mx.sym.broadcast_add(op, B, name=name) return op return topo_visit_transformer(symbol, params, _separate_bias)
def sym_separate_pad(symbol, params): """ Separate pad attribute as an independent symbol in rewrite stage. """ def _separate_pad(op, **kwargs): name, op_name = op.attr('name'), op.attr('op_name') attr, childs = op.list_attr(), sutils.sym_iter(op.get_children()) if op_name not in [Convolution.op_name]: return op if 'layout' in attr: assert attr['layout'] == 'NCHW' PH, PW = sutils.get_attr(attr, 'pad', (0, 0)) if 'pad' in attr: del attr['pad'] if PH == 0 and PW == 0: return sutils.get_mxnet_op(op_name)(*childs, **attr, name=name) childs[0] = mx.sym.pad(childs[0], pad_width=(0, 0, 0, 0, PH, PH, PW, PW), mode='constant', constant_value=0, name=N.n('pad')) op = sutils.get_mxnet_op(op_name)(*childs, **attr, name=name) return op return topo_visit_transformer(symbol, params, _separate_pad)
def sym_slice_channel(symbol, params, cfg_dict={}): """ Customized graph-level topo pass definition. Interface for granularity control. While layer-wise feature is by default, MRT support channel-wise features specified in cfg_dict. """ infer_shapes = infer_shape(symbol, params) def _slice_channel(op, **kwargs): name, op_name = op.attr("name"), op.attr("op_name") gn_info = cfg_dict[name].get("gn_info", DEFAULT_GN_INFO) gn_type = gn_info["gn_type"] if gn_type == CHANNEL_WISE_TYPE: op = apply_pass("slice_channel", cfg_dict=cfg_dict, infer_shapes=infer_shapes)(op, **kwargs) return op sym, params = topo_visit_transformer(symbol, params, _slice_channel) sym, params = fuse_constant(sym, params) return sym, params
def prepare_for_compile(symbol, params): infer_shapes = infer_shape(symbol, params) return topo_visit_transformer( symbol, params, apply_pass("prepare_for_compile", infer_shapes=infer_shapes))
def quantize(symbol, params, features, precs, buffers, cfg_dict, op_input_precs, restore_names, shift_bits, softmax_lambd): """ Customized graph-level topo pass definition. Interface for MRT GEN Quantization. Parameters ---------- symbol : mxnet.symbol the grouped output symbol represent the graph to be quantized. params : dict symbol name maps to mxnet.NDArray, represent graph parameters features : dict symbol name maps to mrt.V2.Feature precs : dict symbol name maps to precision dict buffers : dict symbol name maps to mrt.V2.Buffer cfg_dict : dict symbol name maps to configuration dict op_input_precs : dict symbol name maps to input precision restore_names : set set of symbol names representing symbols to be restored shift_bits : int hyperparameter for quantize precision control softmax_lambd : float hyperparameter for feature optimization """ infer_shapes = infer_shape(symbol, params) def restore(op, **kwargs): features, precs, buffers = \ kwargs['features'], kwargs['precs'], kwargs['buffers'] name, op_name = op.attr('name'), op.attr('op_name') childs, attr = sutils.sym_iter(op.get_children()), op.list_attr() childs = [] if childs is None else childs new_childs = [] for c in childs: cname = c.attr('name') sc = buffers[c.attr('name')].get() \ if cname in buffers else 1 new_childs.append(c if sc == 1 else c / sc) out = sutils.get_mxnet_op(op_name)(*new_childs, **attr, name=name) ft = features[name] assert ft.name == FT_TYPE_EXP absmax = features[name].get() precs[name][OUT_KEY] = get_bit_exp(absmax) buffers[name] = get_buffer_exp(1) return out def _quant(op, **kwargs): op = apply_pass("quantize", infer_shapes=kwargs['infer_shapes'], features=kwargs['features'], cfg_dict=kwargs['cfg_dict'], )(op, **kwargs) if op.attr('name') not in restore_names \ else restore(op, **kwargs) if is_var(op, kwargs['params']): return op name = op.attr('name') features, buffers = kwargs['features'], kwargs['buffers'] precs = kwargs['precs'] ft = features[name] absmax = ft.get_threshold() name, op_name = op.attr('name'), op.attr('op_name') buf = buffers[name] assert buf.name == BUF_TYPE_EXP scale = buf.get() tight_prec = get_bit_exp(absmax * scale) if precs[name][OUT_KEY] > tight_prec: op = mx.sym.Custom(op, precision=tight_prec, name=N.n('clip'), op_type='cvm_clip') clip_name = op.attr('name') infer_shapes[clip_name] = infer_shapes[name] features[clip_name] = ft precs[clip_name] = {OUT_KEY: tight_prec} if name in precs and name in precs[name]: oprec = precs[name][name] del precs[name][name] precs[clip_name][clip_name] = oprec buffers[clip_name] = buf cfg_dict[clip_name] = cfg_dict[name] return op sym, params = topo_visit_transformer(symbol, params, _quant, infer_shapes=infer_shapes, features=features, precs=precs, buffers=buffers, cfg_dict=cfg_dict, op_input_precs=op_input_precs, shift_bits=shift_bits, softmax_lambd=softmax_lambd) def quantize_output(op, **kwargs): name = op.attr('name') features = kwargs['features'] precs, buffers = kwargs['precs'], kwargs['buffers'] # Requantize output symbol if name in precs and name in precs[name]: oprec = precs[name][name] ft = features[name] assert ft.name == FT_TYPE_EXP oscale = scale_exp(ft.get(), oprec) quant = get_quantizer_exp() op, oprec, oscale = quant.quantize(op, oprec, oscale=oscale, oname=name, **kwargs) oname = op.attr('name') features[oname] = features[name] precs[oname] = oprec buffers[oname] = get_buffer_exp(oscale) return op return topo_visit_transformer(sym, params, quantize_output, features=features, precs=precs, buffers=buffers, cfg_dict=cfg_dict, shift_bits=shift_bits, softmax_lambd=softmax_lambd)
def rewrite(symbol, params): infer_shapes = infer_shape(symbol, params) return topo_visit_transformer( symbol, params, apply_pass("rewrite", infer_shapes=infer_shapes))
def sym_calibrate(symbol, params, data, cfg_dict, **kwargs): """ Customized graph-level topo pass definition. Interface for MRT GEN Calibration. """ logger = logging.getLogger('log.mrt') _, deps = sutils.topo_sort(symbol, logger=logger, with_deps=True) features, out_cache = {}, {} ctx = kwargs.get('ctx', mx.cpu()) logger.info("calibrate model outputs") nparams = convert_params_dtype(params, src_dtypes="float64", dest_dtype="float32") def _impl(op, params, graph, **kwargs): deps = kwargs['deps'] logger = logging.getLogger('log.mrt.calibrate') name, op_name = op.attr('name'), op.attr('op_name') childs, attr = sutils.sym_iter(op.get_children()), op.list_attr() quant_type, opt_info = \ cfg_dict[name]["quant_type"], cfg_dict[name]["opt_info"] gn_info = cfg_dict[name]["gn_info"] quantizer, optimizor = \ get_quantizer(quant_type), get_optimizor(opt_info) if op_name == 'null': out = data if is_inputs(op, params) else params[name] elif childs is None: out = get_nd_op(op_name)(**attr) else: cinfos = [(c.attr('name'), get_entry_id(c)) for c in childs] nd_inputs = [out_cache[n[0]][n[1]] for n in cinfos] out = get_nd_op(op_name)(*nd_inputs, **attr) for n, _ in cinfos: assert n in deps if name not in deps[n]: # for op like: op = broadcast_mul(X, X) # `cinfos` will have duplicate entries # avoid removing more than once continue deps[n].remove(name) if len(deps[n]) == 0: del out_cache[n] out = [out] if len(op) == 1 else out out_cache[name] = [o.as_in_context(ctx) for o in out] raw_ft = quantizer.sample(out[0], **gn_info) hist_ft = features[name] if name in features else None features[name] = optimizor.get_opt(raw_ft, out[0], hist_ft=hist_ft, logger=logger, name=name) topo_visit_transformer(symbol, nparams, _impl, logger=logger, deps=deps, data=data, **kwargs) out_cache.clear() return features