def _impl(op, params, graph, **kwargs): deps, old_ths = kwargs['deps'], kwargs['old_ths'] logger = logging.getLogger('log.mrt.calibrate') name, op_name = op.attr('name'), op.attr('op_name') childs, attr = sym_iter(op.get_children()), op.list_attr() if op_name == 'null': out = data if is_inputs(op, params) else params[name] elif childs is None: out = get_nd_op(op_name)(**attr) else: cinfos = [(c.attr('name'), get_entry_id(c)) for c in childs] nd_inputs = [out_cache[n[0]][n[1]] for n in cinfos] out = get_nd_op(op_name)(*nd_inputs, **attr) for n, _ in cinfos: assert n in deps deps[n].remove(name) if len(deps[n]) == 0: del out_cache[n] out = [out] if len(op) == 1 else out out_cache[name] = [o.as_in_context(ctx) for o in out] opts = float(_get_opt(out[0], kwargs['lambd'])) if old_ths and name in old_ths: th_dict[name] = max(old_ths[name], opts) else: th_dict[name] = opts p = logger.debug if opts < 30 else logger.warn p("collect symbol %-40s out_shape=%-20s th_dict: (%s)", name, [o.shape for o in out], th_dict[name])
def _impl(op, params, graph, **kwargs): deps = kwargs['deps'] name, op_name = op.attr('name'), op.attr('op_name') childs, attr = sym_iter(op.get_children()), op.list_attr() if op_name == 'null': start_time = None out = data if is_inputs(op, params) else params[name] elif childs is None: start_time= time.time() out = get_nd_op(op_name)(**attr) if gpu_flag: nd.waitall() end_time = time.time() else: cinfos = [(c.attr('name'), get_entry_id(c)) for c in childs] nd_inputs = [out_cache[n[0]][n[1]] for n in cinfos] start_time = time.time() out = get_nd_op(op_name)(*nd_inputs, **attr) if gpu_flag: nd.waitall() end_time = time.time() for n, _ in cinfos: assert n in deps deps[n].remove(name) if len(deps[n]) == 0: del out_cache[n] if start_time is not None: if op_name not in times: times[op_name] = {} times[op_name][name] = end_time - start_time out = [out] if len(op) == 1 else out out_cache[name] = [o.as_in_context(ctx) for o in out]
def _impl(op, params, graph, **kwargs): deps = kwargs['deps'] name, op_name = op.attr('name'), op.attr('op_name') childs, attr = sutils.sym_iter(op.get_children()), op.list_attr() if op_name == 'null': out = data if sutils.is_inputs(op, params) else params[name] elif childs is None: out = sutils.get_nd_op(op_name)(**attr) else: cinfos = [(c.attr('name'), sutils.get_entry_id(c)) for c in childs] nd_inputs = [out_cache[n[0]][n[1]] for n in cinfos] out = sutils.get_nd_op(op_name)(*nd_inputs, **attr) for n, _ in cinfos: assert n in deps if name not in deps[n]: # for op like: op = broadcast_mul(X, X) # `cinfos` will have duplicate entries # avoid removing more than once continue deps[n].remove(name) if len(deps[n]) == 0: del out_cache[n] if name == check_point: ans[check_point] = out out = [out] if len(op) == 1 else out out_cache[name] = [o.as_in_context(ctx) for o in out]
def _sum_input(node, params, **kwargs): name = node.attr('name') nonlocal dim_sum, dim_per, dims if is_inputs(node, params): dims[name] = infer_shapes[name][0] dot = np.product(dims[name]) dim_per[name] = dot dim_sum += dot
def _change_node(op, params, graph, **kwargs): name = op.attr('name') if sutils.is_inputs(op, params): nonlocal first, last last = first + dim_per[name] op = mx.sym.slice(data_sum, name=tfm.N.n('slice'), begin=(first,), end=(last,)) op = mx.sym.reshape(op, name=tfm.N.n('reshape'), shape=dims[name]) first = last return op
def _impl(op, params, graph): name, op_name = op.attr('name'), op.attr('op_name') _, oshp, _ = op.infer_shape() if is_params(op, params): if oshp is None: oshp = [params[name].shape] op = mx.sym.var(name, shape=oshp[0]) assert params[name].shape == oshp[0], \ "Parameter %s's shape %s is inconsistent with \ params dict %s" % (name, oshp[0], params[name].shape) elif is_inputs(op, params): if input_shape is None: assert oshp is not None, "It seems that graph doesn't set \ input_shape, please invoke attach_input_shape first." else: oshp = [input_shape] op = mx.sym.var(name, shape=oshp[0]) infer_shapes[name] = oshp return op
def input_names(self): """ List model input names. """ return [s.attr("name") for s in topo_sort(self.symbol) \ if sutils.is_inputs(s, self.params)]
def compile_to_cvm(model, model_name, datadir="/data/std_out", input_shape=None, target="cuda"): """ Compile Mxnet model into CVM Accept-JSON&BIN-Format """ logger = logging.getLogger("mrt.compile") symbol, params = model.symbol, model.params datadir = path.join(datadir, model_name) os.makedirs(datadir, exist_ok=True) # transform from mxnet symbol to cvm logger.info("Transform Mxnet symbol into CVM") nnvm_sym, _ = to_nnvm(symbol, params) dtype, nnvm_params = "int32", {} tvm_ctx = tvm.context(target, 0) for sym in topo_sort(symbol): if sutils.is_params(sym, params): key, value = sym.attr('name'), params[sym.attr('name')] flat = value.asnumpy() assert np.abs(flat).max() <= sutils.INT32_MAX, \ "key: {}\nvalue: {}".format(key, value) assert (flat.astype(dtype).astype("float64") == flat).all(), \ "key: {}\nvalue: {}".format(key, value) nnvm_params[key] = tvm.nd.array(flat.astype(dtype), tvm_ctx) # compile to JSON&Bytes format # graph = nnvm.graph.create(nnvm_sym) # open("/tmp/tmp.nnvm.json", "w").write(graph.json()) logger.info("Compile into CVM graph") if input_shape is None: for sym in topo_sort(symbol): if sutils.is_inputs(sym, params): _, oshp, _ = sym.infer_shape() assert len(oshp) == 1 input_shape = oshp[0] input_shapes = {'data': input_shape} with nnvm.compiler.build_config(opt_level=0): deploy_graph, _, nnvm_params = nnvm.compiler.build(nnvm_sym, target=target, shape=input_shapes, params=nnvm_params, dtype=dtype) # tvm parameters reduce logger.info("Parameters precision reduce") for sym in topo_sort(nnvm_sym): if sutils.is_params(sym, nnvm_params): name, attr = sym.attr('name'), sym.list_attr() precision = sutils.get_attr(attr, "precision") dtype = "int32" if precision > 8 else "int8" nnvm_params[name] = tvm.nd.array( params[name].asnumpy().astype(dtype), tvm_ctx) # dump logger.info("CVM Json&Params dump") with open(path.join(datadir, "symbol"), "w") as fout: fout.write(deploy_graph.json()) param_bytes = nnvm.compiler.save_param_dict(nnvm_params) with open(path.join(datadir, "params"), "wb") as fout: fout.write(param_bytes) return deploy_graph, nnvm_params
def _impl(op, params, graph): name, attr = op.attr('name'), op.list_attr() if is_inputs(op, params) and name in input_shapes: op = mx.sym.var(name, shape=input_shapes[name], attr=attr) return op
def _name_replace(op, params, graph): name, attr = op.attr('name'), op.list_attr() if is_inputs(op, params): op = mx.sym.var("data", attr=attr) return op
def _count(op, params, graph): nonlocal input_count input_count += is_inputs(op, params)