Ejemplo n.º 1
0
def symbol2nx(graph, model_nodes, model_params, input_names: str = ''):
    if not input_names:
        input_names = ('data',)
    else:
        input_names = input_names.split(',')

    rnn_states = init_rnn_states(model_nodes)
    names_rnn_states = list(rnn_states.keys())

    # as mxnet contain input layers as index of layer, for correct set up edges, we need provide index of layer with name of  graph node
    index_node_keys = {}
    for i, node in enumerate(model_nodes):
        if node['name'] in model_params._arg_params and node['name'] not in input_names:
            node['value'] = np.array(model_params._arg_params[node['name']].asnumpy(), dtype=np.float32)
        elif node['name'] in model_params._aux_params and node['name'] not in input_names:
            node['value'] = np.array(model_params._aux_params[node['name']].asnumpy(), dtype=np.float32)
        elif node['name'] in names_rnn_states:
            node['value'] = np.zeros(rnn_states[node['name']])
        node_name = graph.unique_id(node['name'])
        graph.add_node(node_name, **symbol_attrs(node))
        graph.node[node_name].update(common_mxnet_fields(Node(graph, node_name)))
        index_node_keys[i] = node_name

    for i, attrs in enumerate(model_nodes):
        node = attrs
        edges = get_mxnet_node_edges(node, i, list(model_nodes), index_node_keys)
        if len(edges) > 0:
            graph.add_edges_from(edges)

    return graph
Ejemplo n.º 2
0
def symbol2nx(graph, model_nodes, model_params, input_names: str = ''):
    if not input_names:
        input_names = ('data', )
    else:
        input_names = input_names.split(',')

    rnn_states = init_rnn_states(model_nodes)
    names_rnn_states = list(rnn_states.keys())

    # as mxnet contain input layers as index of layer, for correct set up edges, we need provide index of layer with name of  graph node
    index_node_keys = {}
    fw_name_map = {}
    for i, node in enumerate(model_nodes):
        if node['name'] in model_params._arg_params and node[
                'name'] not in input_names:
            node['value'] = np.array(
                model_params._arg_params[node['name']].asnumpy(),
                dtype=np.float32)
        elif node['name'] in model_params._aux_params and node[
                'name'] not in input_names:
            node['value'] = np.array(
                model_params._aux_params[node['name']].asnumpy(),
                dtype=np.float32)
        elif node['name'] in names_rnn_states:
            node['value'] = np.zeros(rnn_states[node['name']])
        node_name = graph.unique_id(node['name'])
        graph.add_node(node_name, **symbol_attrs(node))
        graph.node[node_name].update(
            common_mxnet_fields(Node(graph, node_name)))
        index_node_keys[i] = node_name
        fw_name_map[node_name] = node['name']

    used_indices_set = set()
    for i, attrs in enumerate(model_nodes):
        node = attrs
        edges, used_indices = get_mxnet_node_edges(node, i, list(model_nodes),
                                                   index_node_keys)
        if len(edges) > 0:
            graph.add_edges_from(edges)
        used_indices_set = used_indices_set.union(used_indices)

    output_ids = [
        index_node_keys[node_id]
        for node_id in set(range(len(model_nodes))) - used_indices_set
    ]

    # Tensor names information corresponding to a node is stored on outgoing edges.
    # As output nodes do not have outgoing edges, fake outputs are required. In the following code
    # for each output Identity node is added, and tensor name for the output is kept
    # on (output, fake output) edge. After Result nodes adding transformation fake outputs
    # are deleted from graph.
    add_outputs_identity(
        graph, output_ids,
        lambda g, output_id, fake_node_id, fw_name: g.add_edges_from([
            create_mxnet_edge(output_id, fake_node_id, 0, 0, fw_name[output_id]
                              )
        ]), {'fw_name': fw_name_map})

    return graph