def get_input_info(graph_str: str, params: Dict[str, tvm.nd.NDArray]): """Return the 'shape' and 'dtype' dictionaries for the input tensors of a compiled module. .. note:: We can't simply get the input tensors from a TVM graph because weight tensors are treated equivalently. Therefore, to find the input tensors we look at the 'arg_nodes' in the graph (which are either weights or inputs) and check which ones don't appear in the params (where the weights are stored). These nodes are therefore inferred to be input tensors. .. note:: There exists a more recent API to retrieve the input information directly from the module. However, this isn't supported when using with RPC due to a lack of support for Array and Map datatypes. Therefore, this function exists only as a fallback when RPC is in use. If RPC isn't being used, please use the more recent API. Parameters ---------- graph_str : str JSON graph of the module serialized as a string. params : dict Parameter dictionary mapping name to value. Returns ------- shape_dict : dict Shape dictionary - {input_name: tuple}. dtype_dict : dict dtype dictionary - {input_name: dtype}. """ shape_dict = {} dtype_dict = {} params_dict = load_param_dict(params) param_names = [k for (k, v) in params_dict.items()] graph = json.loads(graph_str) for node_id in graph["arg_nodes"]: node = graph["nodes"][node_id] # If a node is not in the params, infer it to be an input node name = node["name"] if name not in param_names: shape_dict[name] = graph["attrs"]["shape"][1][node_id] dtype_dict[name] = graph["attrs"]["dltype"][1][node_id] return shape_dict, dtype_dict
def get_input_info(graph_str: str, params: Dict[str, tvm.nd.NDArray]): """Return the 'shape' and 'dtype' dictionaries for the input tensors of a compiled module. .. note:: We can't simply get the input tensors from a TVM graph because weight tensors are treated equivalently. Therefore, to find the input tensors we look at the 'arg_nodes' in the graph (which are either weights or inputs) and check which ones don't appear in the params (where the weights are stored). These nodes are therefore inferred to be input tensors. Parameters ---------- graph_str : str JSON graph of the module serialized as a string. params : dict Parameter dictionary mapping name to value. Returns ------- shape_dict : dict Shape dictionary - {input_name: tuple}. dtype_dict : dict dtype dictionary - {input_name: dtype}. """ shape_dict = {} dtype_dict = {} params_dict = load_param_dict(params) param_names = [k for (k, v) in params_dict.items()] graph = json.loads(graph_str) for node_id in graph["arg_nodes"]: node = graph["nodes"][node_id] # If a node is not in the params, infer it to be an input node name = node["name"] if name not in param_names: shape_dict[name] = graph["attrs"]["shape"][1][node_id] dtype_dict[name] = graph["attrs"]["dltype"][1][node_id] logger.debug("Collecting graph input shape and type:") logger.debug("Graph input shape: %s", shape_dict) logger.debug("Graph input type: %s", dtype_dict) return shape_dict, dtype_dict