def symbol2nx(graph, model_nodes, model_params, input_names: str = ''): if not input_names: input_names = ('data',) else: input_names = input_names.split(',') rnn_states = init_rnn_states(model_nodes) names_rnn_states = list(rnn_states.keys()) # as mxnet contain input layers as index of layer, for correct set up edges, we need provide index of layer with name of graph node index_node_keys = {} for i, node in enumerate(model_nodes): if node['name'] in model_params._arg_params and node['name'] not in input_names: node['value'] = np.array(model_params._arg_params[node['name']].asnumpy(), dtype=np.float32) elif node['name'] in model_params._aux_params and node['name'] not in input_names: node['value'] = np.array(model_params._aux_params[node['name']].asnumpy(), dtype=np.float32) elif node['name'] in names_rnn_states: node['value'] = np.zeros(rnn_states[node['name']]) node_name = graph.unique_id(node['name']) graph.add_node(node_name, **symbol_attrs(node)) graph.node[node_name].update(common_mxnet_fields(Node(graph, node_name))) index_node_keys[i] = node_name for i, attrs in enumerate(model_nodes): node = attrs edges = get_mxnet_node_edges(node, i, list(model_nodes), index_node_keys) if len(edges) > 0: graph.add_edges_from(edges) return graph
def symbol2nx(graph, model_nodes, model_params, input_names: str = ''): if not input_names: input_names = ('data', ) else: input_names = input_names.split(',') rnn_states = init_rnn_states(model_nodes) names_rnn_states = list(rnn_states.keys()) # as mxnet contain input layers as index of layer, for correct set up edges, we need provide index of layer with name of graph node index_node_keys = {} fw_name_map = {} for i, node in enumerate(model_nodes): if node['name'] in model_params._arg_params and node[ 'name'] not in input_names: node['value'] = np.array( model_params._arg_params[node['name']].asnumpy(), dtype=np.float32) elif node['name'] in model_params._aux_params and node[ 'name'] not in input_names: node['value'] = np.array( model_params._aux_params[node['name']].asnumpy(), dtype=np.float32) elif node['name'] in names_rnn_states: node['value'] = np.zeros(rnn_states[node['name']]) node_name = graph.unique_id(node['name']) graph.add_node(node_name, **symbol_attrs(node)) graph.node[node_name].update( common_mxnet_fields(Node(graph, node_name))) index_node_keys[i] = node_name fw_name_map[node_name] = node['name'] used_indices_set = set() for i, attrs in enumerate(model_nodes): node = attrs edges, used_indices = get_mxnet_node_edges(node, i, list(model_nodes), index_node_keys) if len(edges) > 0: graph.add_edges_from(edges) used_indices_set = used_indices_set.union(used_indices) output_ids = [ index_node_keys[node_id] for node_id in set(range(len(model_nodes))) - used_indices_set ] # Tensor names information corresponding to a node is stored on outgoing edges. # As output nodes do not have outgoing edges, fake outputs are required. In the following code # for each output Identity node is added, and tensor name for the output is kept # on (output, fake output) edge. After Result nodes adding transformation fake outputs # are deleted from graph. add_outputs_identity( graph, output_ids, lambda g, output_id, fake_node_id, fw_name: g.add_edges_from([ create_mxnet_edge(output_id, fake_node_id, 0, 0, fw_name[output_id] ) ]), {'fw_name': fw_name_map}) return graph