Ejemplo n.º 1
0
 def _impl(op, params, graph, **kwargs):
     deps, old_ths = kwargs['deps'], kwargs['old_ths']
     logger = logging.getLogger('log.mrt.calibrate')
     name, op_name = op.attr('name'), op.attr('op_name')
     childs, attr = sym_iter(op.get_children()), op.list_attr()
     if op_name == 'null':
         out = data if is_inputs(op, params) else params[name]
     elif childs is None:
         out = get_nd_op(op_name)(**attr)
     else:
         cinfos = [(c.attr('name'), get_entry_id(c)) for c in childs]
         nd_inputs = [out_cache[n[0]][n[1]] for n in cinfos]
         out = get_nd_op(op_name)(*nd_inputs, **attr)
         for n, _ in cinfos:
             assert n in deps
             deps[n].remove(name)
             if len(deps[n]) == 0:
                 del out_cache[n]
     out = [out] if len(op) == 1 else out
     out_cache[name] = [o.as_in_context(ctx) for o in out]
     opts = float(_get_opt(out[0], kwargs['lambd']))
     if old_ths and name in old_ths:
         th_dict[name] = max(old_ths[name], opts)
     else:
         th_dict[name] = opts
         p = logger.debug if opts < 30 else logger.warn
         p("collect symbol %-40s out_shape=%-20s th_dict: (%s)", name,
           [o.shape for o in out], th_dict[name])
Ejemplo n.º 2
0
    def _impl(op, params, graph, **kwargs):
        deps = kwargs['deps']
        name, op_name = op.attr('name'), op.attr('op_name')
        childs, attr = sym_iter(op.get_children()), op.list_attr()

        if op_name == 'null':
            start_time = None
            out = data if is_inputs(op, params) else params[name]
        elif childs is None:
            start_time= time.time()
            out = get_nd_op(op_name)(**attr)
            if gpu_flag:
                nd.waitall()
            end_time = time.time()
        else:
            cinfos = [(c.attr('name'), get_entry_id(c)) for c in childs]
            nd_inputs = [out_cache[n[0]][n[1]] for n in cinfos]
            start_time = time.time()
            out = get_nd_op(op_name)(*nd_inputs, **attr)
            if gpu_flag:
                nd.waitall()
            end_time = time.time()
            for n, _ in cinfos:
                assert n in deps
                deps[n].remove(name)
                if len(deps[n]) == 0:
                    del out_cache[n]
        if start_time is not None:
            if op_name not in times:
                times[op_name] = {}
            times[op_name][name] = end_time - start_time
        out = [out] if len(op) == 1 else out
        out_cache[name] = [o.as_in_context(ctx) for o in out]
Ejemplo n.º 3
0
    def _impl(op, params, graph, **kwargs):
        deps = kwargs['deps']
        name, op_name = op.attr('name'), op.attr('op_name')
        childs, attr = sutils.sym_iter(op.get_children()), op.list_attr()

        if op_name == 'null':
            out = data if sutils.is_inputs(op, params) else params[name]
        elif childs is None:
            out = sutils.get_nd_op(op_name)(**attr)
        else:
            cinfos = [(c.attr('name'), sutils.get_entry_id(c)) for c in childs]
            nd_inputs = [out_cache[n[0]][n[1]] for n in cinfos]
            out = sutils.get_nd_op(op_name)(*nd_inputs, **attr)
            for n, _ in cinfos:
                assert n in deps
                if name not in deps[n]:
                    # for op like: op = broadcast_mul(X, X)
                    # `cinfos` will have duplicate entries
                    # avoid removing more than once
                    continue
                deps[n].remove(name)
                if len(deps[n]) == 0:
                    del out_cache[n]
        if name == check_point:
            ans[check_point] = out
        out = [out] if len(op) == 1 else out
        out_cache[name] = [o.as_in_context(ctx) for o in out]
Ejemplo n.º 4
0
 def _sum_input(node, params, **kwargs):
     name = node.attr('name')
     nonlocal dim_sum, dim_per, dims
     if is_inputs(node, params):
         dims[name] = infer_shapes[name][0]
         dot = np.product(dims[name])
         dim_per[name] = dot
         dim_sum += dot
Ejemplo n.º 5
0
 def _change_node(op, params, graph, **kwargs):
     name = op.attr('name')
     if sutils.is_inputs(op, params):
         nonlocal first, last
         last = first + dim_per[name]
         op = mx.sym.slice(data_sum, name=tfm.N.n('slice'),
                 begin=(first,), end=(last,))
         op = mx.sym.reshape(op, name=tfm.N.n('reshape'),
                 shape=dims[name])
         first = last
     return op
Ejemplo n.º 6
0
    def _impl(op, params, graph):
        name, op_name = op.attr('name'), op.attr('op_name')
        _, oshp, _ = op.infer_shape()
        if is_params(op, params):
            if oshp is None:
                oshp = [params[name].shape]
                op = mx.sym.var(name, shape=oshp[0])
            assert params[name].shape == oshp[0], \
                    "Parameter %s's shape %s is inconsistent with \
                    params dict %s"                                    % (name, oshp[0], params[name].shape)
        elif is_inputs(op, params):
            if input_shape is None:
                assert oshp is not None, "It seems that graph doesn't set \
                        input_shape, please invoke attach_input_shape first."

            else:
                oshp = [input_shape]
                op = mx.sym.var(name, shape=oshp[0])
        infer_shapes[name] = oshp
        return op
Ejemplo n.º 7
0
 def input_names(self):
     """ List model input names.  """
     return [s.attr("name") for s in topo_sort(self.symbol) \
             if sutils.is_inputs(s, self.params)]
Ejemplo n.º 8
0
def compile_to_cvm(model,
                   model_name,
                   datadir="/data/std_out",
                   input_shape=None,
                   target="cuda"):
    """ Compile Mxnet model into CVM Accept-JSON&BIN-Format
    """
    logger = logging.getLogger("mrt.compile")
    symbol, params = model.symbol, model.params

    datadir = path.join(datadir, model_name)
    os.makedirs(datadir, exist_ok=True)

    # transform from mxnet symbol to cvm
    logger.info("Transform Mxnet symbol into CVM")
    nnvm_sym, _ = to_nnvm(symbol, params)
    dtype, nnvm_params = "int32", {}
    tvm_ctx = tvm.context(target, 0)
    for sym in topo_sort(symbol):
        if sutils.is_params(sym, params):
            key, value = sym.attr('name'), params[sym.attr('name')]
            flat = value.asnumpy()
            assert np.abs(flat).max() <= sutils.INT32_MAX, \
                "key: {}\nvalue: {}".format(key, value)
            assert (flat.astype(dtype).astype("float64") == flat).all(), \
                "key: {}\nvalue: {}".format(key, value)
            nnvm_params[key] = tvm.nd.array(flat.astype(dtype), tvm_ctx)

    # compile to JSON&Bytes format
    # graph = nnvm.graph.create(nnvm_sym)
    # open("/tmp/tmp.nnvm.json", "w").write(graph.json())
    logger.info("Compile into CVM graph")
    if input_shape is None:
        for sym in topo_sort(symbol):
            if sutils.is_inputs(sym, params):
                _, oshp, _ = sym.infer_shape()
                assert len(oshp) == 1
                input_shape = oshp[0]
    input_shapes = {'data': input_shape}
    with nnvm.compiler.build_config(opt_level=0):
        deploy_graph, _, nnvm_params = nnvm.compiler.build(nnvm_sym,
                                                           target=target,
                                                           shape=input_shapes,
                                                           params=nnvm_params,
                                                           dtype=dtype)

    # tvm parameters reduce
    logger.info("Parameters precision reduce")
    for sym in topo_sort(nnvm_sym):
        if sutils.is_params(sym, nnvm_params):
            name, attr = sym.attr('name'), sym.list_attr()
            precision = sutils.get_attr(attr, "precision")
            dtype = "int32" if precision > 8 else "int8"
            nnvm_params[name] = tvm.nd.array(
                params[name].asnumpy().astype(dtype), tvm_ctx)

    # dump
    logger.info("CVM Json&Params dump")
    with open(path.join(datadir, "symbol"), "w") as fout:
        fout.write(deploy_graph.json())
    param_bytes = nnvm.compiler.save_param_dict(nnvm_params)
    with open(path.join(datadir, "params"), "wb") as fout:
        fout.write(param_bytes)
    return deploy_graph, nnvm_params
Ejemplo n.º 9
0
 def _impl(op, params, graph):
     name, attr = op.attr('name'), op.list_attr()
     if is_inputs(op, params) and name in input_shapes:
         op = mx.sym.var(name, shape=input_shapes[name], attr=attr)
     return op
Ejemplo n.º 10
0
 def _name_replace(op, params, graph):
     name, attr = op.attr('name'), op.list_attr()
     if is_inputs(op, params):
         op = mx.sym.var("data", attr=attr)
     return op
Ejemplo n.º 11
0
 def _count(op, params, graph):
     nonlocal input_count
     input_count += is_inputs(op, params)