Beispiel #1
0
    def add_reshape_after_data_node(graph: Graph, data_node_name: str):
        """
        Adds reshape operation which changes shape of the tensor produced by TFSubgraphCall from 4D to real dimension
        of the tensor. The data_node_name node contains real dimensions of the tensor but they will be changed in the
        add_reshapes_for_tf_subgraph_calls function to a 4D because IE TF call layer supports output in 4D only.
        :param graph: graph to operate on.
        :param data_node_name: name of the data node to be reshaped to correct dimensions.
        :return: None
        """
        data_node = Node(graph, data_node_name)

        # if the data node was previously marked as output then we need to mark as output new reshaped data node
        is_out_node = False
        if len(data_node.out_nodes()) == 1 and data_node.out_node().has('op') and data_node.out_node().op == 'OpOutput':
            is_out_node = True
            graph.remove_node(data_node.out_node().id)

        # save old consumers nodes with edge attributes
        old_consumer_nodes_with_attrs = list()
        for index, out_op in enumerate(data_node.out_nodes()):
            edge_attrs = graph.get_edge_data(data_node_name, out_op.name)[0]
            old_consumer_nodes_with_attrs.append((out_op.name, edge_attrs))

        # remove old consumers from the data node
        for out_op in list(data_node.out_nodes()):
            graph.remove_edge(data_node_name, out_op.name)

        # reshape operation node
        reshape_node_name = graph.unique_id("Reshape_")
        graph.add_node(reshape_node_name, kind='op', precision="FP32", type='Reshape', name=reshape_node_name,
                       op='Reshape',
                       data_type=data_node['data_type'])
        update_ie_fields(graph.node[reshape_node_name])

        # reshape shape data node
        reshape_shape_data_node_name = graph.unique_id("Reshape_shape_")
        graph.add_node(reshape_shape_data_node_name, kind='data', precision="FP32", name=reshape_shape_data_node_name,
                       value=np.array(data_node['shape']), shape=[1])

        # reshaped data node
        reshaped_value = None
        if data_node['value'] is not None:
            reshaped_value = np.array(data_node['value'])
        reshaped_data_node_name = graph.unique_id("reshaped_data_")
        graph.add_node(reshaped_data_node_name, kind='data', precision="FP32", name=reshaped_data_node_name,
                       shape=np.array(data_node['shape']), value=reshaped_value, nchw_layout=True)

        if is_out_node:
            add_opoutput(graph, reshaped_data_node_name, 0, False)

        graph.add_edges_from([
            (data_node_name, reshape_node_name, {'in': 0}),
            (reshape_shape_data_node_name, reshape_node_name, {'in': 1}),
            (reshape_node_name, reshaped_data_node_name, {'out': 0}),
        ])

        for out_node_name, edge_attrs in old_consumer_nodes_with_attrs:
            graph.add_edges_from([
                (reshaped_data_node_name, out_node_name, edge_attrs)
            ])
Beispiel #2
0
    def infer(node: Node):
        # there are limitations coming from ONNX LSTM definition and normalization rules
        assert len(node.in_nodes()) >= 3  # X, W and R
        assert len(node.in_nodes()) <= 7
        assert len(node.out_nodes()) <= 3
        assert node.batch_dim <= 1
        assert node.sequence_dim <= 1
        assert node.batch_dim != node.sequence_dim

        assert node.direction in ['forward', 'reverse', 'bidirectional']

        if node.blobs_wrb:
            mark_input_bins(node, ['W', 'R', 'B'])
        else:
            mark_input_bins(node)
        input_shape = node.in_node(0).shape
        assert len(input_shape) == 3

        for port in [2, 3]:
            if port in node.in_nodes() and len(node.in_node(port).in_nodes()) > 0 and \
               'zero_shapes' in node.in_node(port).in_node():
                for i in node.in_node(port).in_node().zero_shapes:
                    if node.in_node(port).shape[i] != input_shape[i]:
                        node.in_node(port).value = np.repeat(
                            node.in_node(port).value, input_shape[i], axis=i)
                        node.in_node(port).shape[i] = input_shape[i]

        out_shape = np.array([
            input_shape[node.sequence_dim], input_shape[node.batch_dim],
            node.hidden_size
        ],
                             dtype=np.int64)
        assert not node.has_num_directions or node.sequence_dim == 0, \
            'If has_num_directions == True, then node.sequence_dim should be equal 0, but it is {}'.format(
                node.sequence_dim)
        num_directions = 2 if node.direction in ['bidirectional'] else 1
        num_layers = node.num_layers
        if node.has_num_directions:
            # insert extra dimension to output shape for num_directions
            out_shape = np.insert(out_shape, 1, np.int64(num_directions))
        node.out_node(0).shape = out_shape
        # extra outputs for hidden/cell states
        state_size = np.array([input_shape[1], node.hidden_size],
                              dtype=np.int64)
        if node.has_num_directions:
            state_size = np.insert(state_size, 0, num_directions * num_layers)
        for i in [1, 2]:
            if i not in node.out_nodes():
                data_node = Op._create_data_node(node.graph,
                                                 name=node.node +
                                                 '/ExtraOutput/' + str(i),
                                                 attrs={'executable': True})
                node.graph.add_edge(node.id, data_node.id, key=0, out=i)
                add_opoutput(node.graph, data_node.id, 0, False)
            else:
                data_node = node.out_node(i)
            data_node.shape = state_size.copy()
Beispiel #3
0
def update_body_graph(body_graph: Graph, subgraph_proto: dict,
                      body_parameter_names: list, body_results: list):
    """
    Updates the loop body graph with a sub-graph (for body or condition functions)
    :param body_graph: a loop body graph to be updated
    :param subgraph_proto: a sub-graph in a protobuf format to be added into the loop body graph
    :param body_parameter_names: a (unchanged) list of parameters in the loop body graph
    :param body_results: a list of Result nodes that is extended with a list from a sub-graph
    """
    # create a map from a node name in original model to a name in a loop body graph assuming
    # that names in the original model are unique
    # initially, the map contains names for parameters that are common for the body and condition graphs
    map_original_name = {}
    for idx, pb_node in enumerate(subgraph_proto['input_arg']):
        map_original_name[pb_node.name] = body_parameter_names[idx]

    # walk through all nodes (non-parameter and non-result nodes) and add into the loop body graph
    for pb_node in subgraph_proto['node_def']:
        # create an NX node
        id = body_graph.unique_id(pb_node.name)
        map_original_name[pb_node.name] = id
        body_graph.add_node(id, pb=pb_node, kind='op')
        if hasattr(body_graph, 'op_names_statistic') and hasattr(
                pb_node, 'op'):
            body_graph.op_names_statistic[pb_node.op] += 1

        # add incoming edges based on data_nodes_map
        for dst_port, inp in enumerate(pb_node.input):
            orig_src_id = inp.split(":")[0]

            # TODO: avoid this temporal workaround for TF 2.4 or higher RNN layers:
            #  skip control flow dependency
            if orig_src_id[0] == '^':
                continue

            src_id = map_original_name[orig_src_id]
            src_port = 0 if len(inp.split(":")) == 1 else int(
                inp.split(":")[-1])
            assert (body_graph.has_node(src_id))

            body_graph.add_edges_from(
                [create_tf_edge(src_id + ":" + str(src_port), id, dst_port)])

    # create Result nodes in the loop body graph
    for output in subgraph_proto['output_arg']:
        output_name = subgraph_proto['ret'][output.name]
        orig_src_id = output_name.split(":")[0]
        src_id = map_original_name[orig_src_id]
        src_port = 0 if len(output_name.split(":")) == 1 \
            else int(output_name.split(":")[-1])
        assert body_graph.has_node(
            src_id
        ), 'The body graph does not contain output with name "{}"'.format(
            src_id)
        body_results.append(
            Node(body_graph, add_opoutput(body_graph, src_id, src_port,
                                          False)))

    return True
Beispiel #4
0
def update_body_graph(body_graph: Graph, subgraph_proto: dict,
                      body_parameter_names: list, body_results: list):
    """
    Updates the loop body graph with a sub-graph (for body or condition functions)
    :param body_graph: a loop body graph to be updated
    :param subgraph_proto: a sub-graph in a protobuf format to be added into the loop body graph
    :param body_parameter_names: a (unchanged) list of parameters in the loop body graph
    :param body_results: a list of Result nodes that is extended with a list from a sub-graph
    """
    # create a map from a node name in original model to a name in a loop body graph assuming
    # that names in the original model are unique
    # initially, the map contains names for parameters that are common for the body and condition graphs
    map_original_name = {}
    for idx, pb_node in enumerate(subgraph_proto['input_arg']):
        map_original_name[pb_node.name] = body_parameter_names[idx]

    # walk through all nodes (non-parameter and non-result nodes) and add into the loop body graph
    for pb_node in subgraph_proto['node_def']:
        # create an NX node
        id = body_graph.unique_id(pb_node.name)
        map_original_name[pb_node.name] = id
        body_graph.add_node(id, pb=pb_node, kind='op')

        # add incoming edges based on data_nodes_map
        for dst_port, inp in enumerate(pb_node.input):
            orig_src_id = inp.split(":")[0]
            src_id = map_original_name[orig_src_id]
            src_port = 0 if len(inp.split(":")) == 1 else int(
                inp.split(":")[-1])
            assert (body_graph.has_node(src_id))
            edge_attrs = {
                'out': src_port,
                'in': dst_port,
                'name': src_id,
                'fw_tensor_debug_info': [(src_id, src_port)],
                'in_attrs': ['in', 'name'],
                'out_attrs': ['out', 'name'],
                'data_attrs': ['fw_tensor_debug_info']
            }
            body_graph.add_edge(src_id, id, **edge_attrs)

    # create Result nodes in the loop body graph
    for output in subgraph_proto['output_arg']:
        output_name = subgraph_proto['ret'][output.name]
        orig_src_id = output_name.split(":")[0]
        src_id = map_original_name[orig_src_id]
        src_port = 0 if len(output_name.split(":")) == 1\
            else int(output_name.split(":")[-1])
        assert body_graph.has_node(
            src_id
        ), 'The body graph does not contain output with name "{}"'.format(
            src_id)
        body_results.append(
            Node(body_graph, add_opoutput(body_graph, src_id, src_port,
                                          False)))
Beispiel #5
0
    def replace_pattern(self, graph: Graph, match: dict):
        if match['rnn_layer']['op'] == 'LSTM':
            return

        rnn_layer = match['rnn_layer']

        # Build TensorIterator body first
        body = Graph(name=rnn_layer.name + '/sub_graph')
        body.graph = graph.graph

        # 1. Input squeeze Reshape
        inputs = [
            Op._create_data_node(
                body, rnn_layer.name + '/inport/' + str(inp), {
                    'shape':
                    rnn_layer.in_node(inp).shape.copy(),
                    'value':
                    rnn_layer.in_node(inp).value.copy()
                    if rnn_layer.in_node(inp).value is not None
                    and inp in [1, 2] else None
                }) for inp in [0, 4, 1, 2]
        ]  # X, h_init, WR, B

        inputs[0].shape[rnn_layer.sequence_dim] = 1
        input_squeeze = Squeeze(
            body,
            dict(name=rnn_layer.name + '/input_squeeze', internal_layer_id=0))
        input_squeeze_dim = Const(
            body,
            dict(name=rnn_layer.name + '/input_squeeze_dim',
                 value=rnn_layer.sequence_dim)).create_node_with_data()
        inputs[0] = input_squeeze.create_node_with_data(
            [inputs[0], input_squeeze_dim],
            edge_attrs=[{
                'internal_port_id': 0
            }])

        # 2. Output unsqueeze Reshape
        outputs = [
            Op._create_data_node(
                body, rnn_layer.name + '/outport/' + str(out), {
                    'shape':
                    rnn_layer.out_node(out).shape.copy()
                    if out in rnn_layer.out_nodes() else None
                }) for out in [0]
        ]
        for out in outputs:
            add_opoutput(body, out.id, 0, False)

        outputs[0].shape = np.delete(outputs[0].shape.copy(),
                                     rnn_layer.sequence_dim)
        output_unsqueeze_dim = Const(
            body,
            dict(name=rnn_layer.name + '/output_unsqueeze_dim',
                 value=rnn_layer.sequence_dim)).create_node_with_data()
        output_unsqueeze = Unsqueeze(
            body,
            dict(name=rnn_layer.name + '/output_unsqueeze/',
                 internal_layer_id=2))

        additional_attrs = dict(activations=rnn_layer.activations,
                                activation_alpha=rnn_layer.activation_alpha,
                                activation_beta=rnn_layer.activation_beta,
                                clip=rnn_layer.clip)
        if rnn_layer.op == 'GRU':
            additional_attrs[
                'linear_before_reset'] = rnn_layer.linear_before_reset

        # 3. ***Cell
        rnn_cell_op = self.get_rnn_cell(rnn_layer['op'])(
            body,
            dict(hidden_size=rnn_layer.hidden_size,
                 name=rnn_layer.name + '/{}Cell'.format(rnn_layer.op),
                 **additional_attrs,
                 internal_layer_id=1))

        gru_cell = rnn_cell_op.create_node_with_data(inputs,
                                                     data_nodes=outputs,
                                                     edge_attrs=[{}, {
                                                         'internal_port_id':
                                                         1
                                                     }, {
                                                         'internal_port_id':
                                                         2
                                                     }, {
                                                         'bin':
                                                         'weights'
                                                     }, {
                                                         'bin':
                                                         'biases'
                                                     }])

        # internal ports for outputs of cell
        gru_cell.in_node().out_edge(0)['internal_port_id'] = 4  # h_state

        gru_cell = output_unsqueeze.create_node_with_data(
            [gru_cell, output_unsqueeze_dim])
        gru_cell.in_node().out_edge(0)['internal_port_id'] = 3
        add_opoutput(body, gru_cell.id, 0, False)

        # 4. TensorIterator layer creating
        assert rnn_layer.direction in ['forward', 'reverse']
        if rnn_layer.direction == 'forward':
            stride = 1
            start = None
            end = None
        else:
            assert rnn_layer.direction == 'reverse'
            stride = -1
            start = -1
            end = 0

        # stacked h_state
        output_port_map = [{
            'external_port_id': 3,
            'internal_layer_id': 2,
            'internal_port_id': 3,
            'axis': rnn_layer.sequence_dim,
            'stride': stride,
            'start': start,
            'end': end,
            'part_size': 1,
        }]

        # Adding last h_state to outputs
        if len(rnn_layer.out_nodes()) == 2:
            output_port_map.extend([{
                'external_port_id': 4,
                'internal_layer_id': 1,
                'internal_port_id': 4,
            }])

        ti_op = TensorIterator(
            graph,
            {
                'name':
                rnn_layer.name + '/TensorIterator',
                'body':
                body,
                'in_ports_count':
                4,
                'out_ports_count':
                len(rnn_layer.out_nodes()),
                'input_port_map': [
                    {
                        'external_port_id': 0,
                        'internal_layer_id': 0,
                        'internal_port_id': 0,
                        'axis': rnn_layer.sequence_dim,
                        'stride': stride,
                        'start': start,
                        'end': end,
                        'part_size': 1,
                    },
                    {
                        'external_port_id': 1,
                        'internal_layer_id': 1,
                        'internal_port_id': 1,
                    },
                ],
                'output_port_map':
                output_port_map,
                # only for h state
                'back_edges': [
                    {
                        'from_layer': 1,
                        'from_port': 4,
                        'to_layer': 1,
                        'to_port': 1,
                    },
                ]
            })

        assert sorted(rnn_layer.out_nodes().keys()) == list(range(len(rnn_layer.out_nodes()))), \
            "There are gaps in output ports of GRUSequence operation. Node {}".format(rnn_layer.id)

        outs = ti_op.create_node_with_data(
            [rnn_layer.in_node(i) for i in [0, 4]],  # X, h_init
            data_nodes=[
                rnn_layer.out_node(i)
                for i in range(len(rnn_layer.out_nodes()))
            ],
            edge_attrs=[{
                'external_port_id': 0
            }, {
                'external_port_id': 1
            }])

        if not isinstance(outs, list):
            outs = list([outs])

        graph.remove_node(rnn_layer.id)
        outs[0].in_edge(0)['external_port_id'] = 3
        for i, out in enumerate(outs[1:]):
            external_port_id = 4 + i
            out.in_edge()['external_port_id'] = external_port_id

        ti = outs[0].in_node()
        TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti)
        TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti)
        TensorIterator.normalize_internal_ids(ti)
Beispiel #6
0
    def find_and_replace_pattern(self, graph: Graph):
        # Iterate over all data nodes and find all with >= 1 consumers
        for input_data in list(graph.get_data_nodes()):
            # We don't use constant data nodes
            if input_data.value is not None:
                continue

            input_shape = np.array(input_data.shape)

            # Get all unique StridedSlice consumers
            out_nodes = [node for node in input_data.out_nodes() if node.op == 'StridedSlice' and node.in_node(0).name == input_data.name]
            sorted_out_nodes = sorted(out_nodes, key=lambda n: list(n.slices))
            out_nodes = unique_by(sorted_out_nodes, strided_slices_equality)
            if len(out_nodes) <= 1:
                continue

            valid_for_replacement = True

            for node in out_nodes:
                if len(node.slices) != len(out_nodes[0].slices):
                    valid_for_replacement = False

            # Detect dimension for splitting
            split_channel_dim = None
            for dim_id, s in enumerate(out_nodes[0].slices):
                l, r, stride = s.start, s.stop, s.step
                if l != 0 or r != input_shape[dim_id]:
                    if split_channel_dim is None:
                        split_channel_dim = dim_id
                    else:
                        valid_for_replacement = False

            if split_channel_dim is None:
                valid_for_replacement = False

            # split_dims contains tuples with split range and output data node
            split_dims = []
            for out_id, node in enumerate(out_nodes):
                # Check that StridedSlice op has stride eq 1 and splits only feature channel
                for id, s in enumerate(node.slices):
                    l, r, stride = s.start, s.stop, s.step
                    # We don't support StridedSlice with stride != 1
                    if stride != 1:
                        valid_for_replacement = False
                    if id == split_channel_dim:
                        split_dims.append((s.start, s.stop, node.out_node()))

            if not valid_for_replacement:
                continue

            # Check feature split intersection
            final_data_nodes_list = []
            sorted_split_dims = sorted(split_dims, key=lambda item: (item[0], item[1]))

            # check if we have similar StridedSlice operations with different outputs
            prev_sd = sorted_split_dims[0]
            to_remove = []
            for i in range(1, len(sorted_split_dims)):
                if sorted_split_dims[i][0] == prev_sd[0] and sorted_split_dims[i][1] == prev_sd[1] and sorted_split_dims[i][2].name != prev_sd[2].name:
                    cur_node = sorted_split_dims[i][2]
                    for out in cur_node.out_nodes():
                        attrs = deepcopy(graph.get_edge_data(cur_node.id, out.id)[0])
                        graph.remove_edge(cur_node.id, out.id)
                        graph.add_edge(prev_sd[2].id, out.id, **attrs)
                    to_remove.append(i)

            for ind in reversed(to_remove):
                sorted_split_dims.pop(ind)

            size_splits = []
            prev_r = 0
            for l, r, out in sorted_split_dims:
                # Split dims shouldn't intersect
                if l < prev_r:
                    valid_for_replacement = False
                prev_r = r

            if prev_r > input_shape[split_channel_dim]:
                valid_for_replacement = False

            if not valid_for_replacement:
                continue

            prev_r = 0
            for l, r, out in sorted_split_dims:
                # Save missing tensor part
                if l > prev_r:
                    shape = np.array(input_shape)
                    size_splits.append(l - prev_r)
                    shape[split_channel_dim] = l - prev_r
                    data_node = Op._create_data_node(graph, 'fake_data_'+out_nodes[0].name, {'shape': shape})
                    add_opoutput(graph, data_node.id, 0, False)
                    final_data_nodes_list.append(data_node)

                prev_r = r
                size_splits.append(r - l)
                final_data_nodes_list.append(out)

            if prev_r < input_shape[split_channel_dim]:
                # Add last part of tensor
                shape = input_shape.copy()
                shape[split_channel_dim] = input_shape[split_channel_dim] - prev_r
                size_splits.append(input_shape[split_channel_dim] - prev_r)
                data_node = Op._create_data_node(graph, 'fake_data_'+out_nodes[0].name, {'shape': shape})
                add_opoutput(graph, data_node.id, 0, False)
                final_data_nodes_list.append(data_node)

            for node in out_nodes:
                if not np.all([x == 0 for x in node.shrink_axis_mask]):
                    out_node = node.out_node()
                    if np.any(node['shrink_axis_mask']):
                        self.add_squeeze_for_shrink(graph, node)
                    if np.any(node['new_axis_mask']):
                        self.add_unsqueeze_for_new(graph, node)

                    for i in range(len(final_data_nodes_list)):
                        if final_data_nodes_list[i].name == out_node.name:
                            final_data_nodes_list[i] = node.out_node()
                            break

            # Insert Split layer and remove old StridedSlice layers
            # 1. Remove connections from input_data to StridedSlice ops
            out_data_nodes = []
            name_for_future_split = out_nodes[0].name
            for node in out_nodes:
                out_data_nodes.append(node.out_node())
                graph.remove_edge(input_data.id, node.id)
                graph.remove_edge(node.id, node.out_node().id)
                graph.remove_node(node.id)
                log.debug("Removed: {}".format(node.id))

            # 2. Create Split layer and reorder outputs
            name = name_for_future_split + "/Split"
            axis_const = Const(graph, {'value': int64_array(split_channel_dim),
                                       'name': name + '/Axis'}).create_node_with_data()
            size_splits_const = Const(graph, {'value': int64_array(size_splits),
                                              'name': name + '/Sizes'}).create_node_with_data()
            split = VariadicSplit(graph, dict(name=name, out_ports_count=len(size_splits)))

            split.create_node_with_data(inputs=[input_data, axis_const, size_splits_const],
                                        data_nodes=final_data_nodes_list)
Beispiel #7
0
    def replace_pattern(self, graph: Graph, match: dict):
        lstm = match['lstm']

        # Build TensorIterator body first
        body = Graph(name=lstm.name + '/sub_graph')
        body.graph = graph.graph

        # 1. Input squeeze Reshape
        inputs = [
            Op._create_data_node(
                body, lstm.name + '/inport/' + str(inp), {
                    'shape':
                    lstm.in_node(inp).shape.copy(),
                    'value':
                    lstm.in_node(inp).value.copy() if lstm.in_node(inp).value
                    is not None and inp in [1, 2] else None
                }) for inp in [0, 4, 5, 1, 2]
        ]  # X, WR, B, h_init, c_init

        inputs[0].shape[lstm.sequence_dim] = 1
        reshape_dim = inputs[0].shape.copy()
        reshape_dim[lstm.batch_dim] = -1
        reshape_dim = np.delete(reshape_dim, lstm.sequence_dim)
        input_squeeze = Reshape(
            body, dict(name=lstm.name + '/input_squeeze', internal_layer_id=0))
        squeeze_dim_data = Const(body, {
            'name': lstm.name + '/input_squeeze_dim',
            'value': reshape_dim
        }).create_node_with_data()
        inputs[0] = input_squeeze.create_node_with_data(
            [inputs[0], squeeze_dim_data],
            edge_attrs=[{
                'internal_port_id': 0
            }])

        # 2. Output unsqueeze Reshape
        outputs = [
            Op._create_data_node(
                body, lstm.name + '/outport/' + str(out), {
                    'shape':
                    lstm.out_node(out).shape.copy() if out in lstm.out_nodes()
                    else lstm.in_node(4).shape.copy()
                }) for out in [0, 1]
        ]
        for out in outputs:
            add_opoutput(body, out.id, 0, False)

        unsqueezed_output_shape = outputs[0].shape.copy()
        unsqueezed_output_shape[lstm.sequence_dim] = 1
        squeezed_output_shape = np.delete(unsqueezed_output_shape,
                                          lstm.sequence_dim)
        outputs[0].shape = squeezed_output_shape
        unsqueezed_output_shape[lstm.batch_dim] = -1
        output_unsqueeze = Reshape(
            body, dict(name=lstm.name + 'output_unsqueeze',
                       internal_layer_id=2))
        unsqueeze_dim_data = Const(
            body, {
                'name': lstm.name + '/output_unsqueeze_dim',
                'value': unsqueezed_output_shape
            }).create_node_with_data()

        # 3. LSTMCell
        lstm_cell_op = LSTMCell(
            body,
            dict(hidden_size=lstm.hidden_size,
                 activations=lstm.activations,
                 activation_alpha=lstm.activation_alpha,
                 activation_beta=lstm.activation_beta,
                 clip=lstm.clip,
                 input_forget=lstm.input_forget,
                 name=lstm.name + '/LSTMCell',
                 internal_layer_id=1))
        lstm_cell_node = lstm_cell_op.create_node_with_data(
            inputs,
            data_nodes=outputs,
            edge_attrs=[{}, {
                'internal_port_id': 1
            }, {
                'internal_port_id': 2
            }, {
                'bin': 'weights'
            }, {
                'bin': 'biases'
            }])
        lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 4
        lstm_cell_node[0].in_node().out_edge(1)['internal_port_id'] = 5
        lstm_cell_node[0] = output_unsqueeze.create_node_with_data(
            [lstm_cell_node[0], unsqueeze_dim_data])
        lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 3
        add_opoutput(body, lstm_cell_node[0].id, 0, False)

        # 4. TensorIterator layer creating
        assert lstm.direction in ['forward', 'reverse']
        if lstm.direction == 'forward':
            stride = 1
            start = None
            end = None
        else:
            assert lstm.direction == 'reverse'
            stride = -1
            start = -1
            end = 0

        output_port_map = [{
            'external_port_id': 3,
            'internal_layer_id': 2,
            'internal_port_id': 3,
            'axis': lstm.sequence_dim,
            'stride': stride,
            'start': start,
            'end': end,
            'part_size': 1,
        }]

        # Adding h_state, c_state to outputs
        if len(lstm.out_nodes()) == 3:
            output_port_map.extend([{
                'external_port_id': 4,
                'internal_layer_id': 1,
                'internal_port_id': 4,
            }, {
                'external_port_id': 5,
                'internal_layer_id': 1,
                'internal_port_id': 5,
            }])

        ti_op = TensorIterator(
            graph, {
                'name':
                lstm.name + '/TensorIterator',
                'body':
                body,
                'in_ports_count':
                3,
                'out_ports_count':
                len(lstm.out_nodes()),
                'input_port_map': [
                    {
                        'external_port_id': 0,
                        'internal_layer_id': 0,
                        'internal_port_id': 0,
                        'axis': lstm.sequence_dim,
                        'stride': stride,
                        'start': start,
                        'end': end,
                        'part_size': 1,
                    },
                    {
                        'external_port_id': 1,
                        'internal_layer_id': 1,
                        'internal_port_id': 1,
                    },
                    {
                        'external_port_id': 2,
                        'internal_layer_id': 1,
                        'internal_port_id': 2,
                    },
                ],
                'output_port_map':
                output_port_map,
                'back_edges': [
                    {
                        'from_layer': 1,
                        'from_port': 4,
                        'to_layer': 1,
                        'to_port': 1,
                    },
                    {
                        'from_layer': 1,
                        'from_port': 5,
                        'to_layer': 1,
                        'to_port': 2,
                    },
                ]
            })

        assert sorted(lstm.out_nodes().keys()) == list(range(len(lstm.out_nodes()))), \
            "There are gaps in output ports of LSTMSequence operation. Node {}".format(lstm.id)

        outs = ti_op.create_node_with_data(
            [lstm.in_node(i) for i in [0, 4, 5]],  # X, h_init, c_init
            data_nodes=[
                lstm.out_node(i) for i in range(len(lstm.out_nodes()))
            ],
            edge_attrs=[{
                'external_port_id': 0
            }, {
                'external_port_id': 1
            }, {
                'external_port_id': 2
            }])

        if not isinstance(outs, list):
            outs = list([outs])

        graph.remove_node(lstm.id)
        outs[0].in_edge(0)['external_port_id'] = 3
        for i, out in enumerate(outs[1:]):
            external_port_id = 4 + i
            out.in_edge()['external_port_id'] = external_port_id
Beispiel #8
0
def rnn_infer(node: Node, out_ports=None):
    """
    General infer function for RNN, GRU, LSTM layers.
    Assume that 0-port input of node is input data for recurrent layer and node have attrs:
    hidden_size,
    """
    if out_ports is None:
        out_ports = []

    # 1. Necessary checks (from ONNX specification)
    assert node.batch_dim <= 1
    assert node.sequence_dim <= 1
    assert node.batch_dim != node.sequence_dim
    assert node.direction in ['forward', 'reverse', 'bidirectional']

    if node.blobs_wrb:
        mark_input_bins(node, ['W', 'R', 'B'])
    else:
        mark_input_bins(node)

    # 2. Output shape calculations
    input_shape = node.in_node(0).shape
    assert len(input_shape) == 3

    # Reshape input nodes
    for port in [2, 3]:
        if port in node.in_nodes() and len(node.in_node(port).in_nodes()) > 0 and \
                'zero_shapes' in node.in_node(port).in_node():
            for i in node.in_node(port).in_node().zero_shapes:
                if node.in_node(port).shape[i] != input_shape[i]:
                    node.in_node(port).value = np.repeat(
                        node.in_node(port).value, input_shape[i], axis=i)
                    node.in_node(port).shape[i] = input_shape[i]

    out_shape = np.array([
        input_shape[node.sequence_dim], input_shape[node.batch_dim],
        node.hidden_size
    ],
                         dtype=np.int64)

    if node.batch_dim == 0:
        out_shape = np.array([
            input_shape[node.batch_dim], input_shape[node.sequence_dim],
            node.hidden_size
        ],
                             dtype=np.int64)

    num_directions = 2 if node.direction in ['bidirectional'] else 1
    if node.has_num_directions:
        if node.format == 'mxnet' and node.normalized is False:
            # In MXNet RNN layer return output with shape [seq_len, batch_size, hidden_size * num_directions]
            out_shape[-1] *= num_directions
        else:
            # ONNX-like, insert extra dimension to output shape for num_directions
            out_shape = np.insert(out_shape, 1, np.int64(num_directions))

    # 0 output is required creating it if doesn't exist
    if 0 not in node.out_nodes():
        data_node = Op._create_data_node(node.graph,
                                         name=node.node +
                                         '/ExtraOutput/{}'.format(0),
                                         attrs={'executable': True})
        if 0 not in node.out_ports():
            node.add_output_port(0)
        node.graph.add_edge(node.id, data_node.id, key=0, out=0)
        add_opoutput(node.graph, data_node.id, 0, False)
    node.out_port(0).data.set_shape(out_shape)

    # 3. Extra outputs for hidden/cell states shape calculations (optional)
    state_size = np.array([input_shape[node.batch_dim], node.hidden_size],
                          dtype=np.int64)
    if node.has_num_directions:
        state_size = np.insert(state_size, 0, num_directions)

    if node.multilayers:
        # For multilayer case state sizes from every layer will be concatenated by last axis
        num_layers = node.num_layers
        state_size[-1] *= num_layers

    for i in out_ports:
        # If node hasn't consumers for hidden/cells state -> create them
        if i not in node.out_nodes():
            data_node = Op._create_data_node(node.graph,
                                             name=node.node + '/ExtraOutput/' +
                                             str(i),
                                             attrs={'executable': True})
            if i not in node.out_ports():
                node.add_output_port(i)
            node.graph.add_edge(node.id, data_node.id, key=0, out=i)
            add_opoutput(node.graph, data_node.id, 0, False)
        else:
            data_node = node.out_node(i)
        data_node.shape = state_size.copy()
Beispiel #9
0
    def replace_pattern(graph, match: dict):
        # Here we will found all parts of TI: condition, inputs/outputs, back edges, body and create TensorIterator Op
        # and make all checks needed for TensorIterator work
        cond_data = match['condition'].out_node(
            0) if not match['condition'].out_port(0).disconnected() else None
        time_data = match['condition'].out_node(1) if len(
            match['condition'].out_nodes()) >= 1 else None
        name = match['condition'].name

        back_edges = []
        inputs = []
        outputs = []

        if cond_data is not None:
            for node in cond_data.out_nodes():
                if node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorBackEdge':
                    back_edges.append(node.id)
                elif node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorInput':
                    inputs.append(node.id)
                elif node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorOutput':
                    outputs.append(node.id)

        if time_data is not None:
            for node in time_data.out_nodes():
                if node['kind'] == 'op' and node['op'] == 'TensorIteratorInput':
                    inputs.append(node.id)
                elif node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorOutput':
                    outputs.append(node.id)
                else:
                    # something goes wrong here
                    assert False
        condition = match['condition']
        tensor_sequence_length = condition.in_node(0)

        nodes_to_remove = [
            n.id
            for n in (condition, cond_data, time_data, tensor_sequence_length)
            if n is not None
        ]
        graph.remove_nodes_from(nodes_to_remove)

        body_nodes, extra_inputs = get_body(graph, inputs, outputs)

        if cond_data is not None:
            body_nodes = list(set(body_nodes) - set([cond_data]))

        inputs += extra_inputs

        assert all([node in graph.nodes() for node in body_nodes])

        inputs = [Node(graph, node) for node in inputs]
        outputs = [Node(graph, node) for node in outputs]
        back_edges = [Node(graph, node) for node in back_edges]

        external_inputs = [{
            'external_data_id':
            node.in_node(1 if node.has_valid('axis') else 0),
            'internal_data_id':
            node.out_node(0),
            'axis':
            node.axis,
            'start':
            node.start,
            'end':
            node.end,
            'stride':
            node.stride,
            'part_size':
            node.part_size
        } for node in inputs]

        external_outputs = [{
            'external_data_id':
            node.out_node(0),
            'internal_data_id':
            node.in_node(1 if node.has_valid('axis') else 0),
            'axis':
            node.axis,
            'start':
            node.start,
            'end':
            node.end,
            'stride':
            node.stride,
            'part_size':
            node.part_size
        } for node in outputs]

        back_edges_data = [{
            'from_data_id': node.in_node(1),
            'to_data_id': node.out_node(0),
            'init_data_id': node.in_node(0),
        } for node in back_edges]

        body = Graph(name='body')
        body.graph = graph.graph
        body.add_nodes_from([(node, graph.node[node]) for node in body_nodes])
        body.add_edges_from([
            (u, v, k, d) for u, v, k, d in graph.edges(data=True, keys=True)
            if u in body_nodes and v in body_nodes
        ])

        graph.remove_nodes_from(body_nodes + [match['condition'].id] +
                                [inp.id for inp in inputs] +
                                [out.id for out in outputs])
        internal_id_count = 0
        real_back_edges = []
        for edge in back_edges_data:
            assert edge['from_data_id'].id in body.nodes()
            assert edge['to_data_id'].id in body.nodes()
            assert edge['init_data_id'].id in body.nodes()
            edge['from_data_id'] = Node(body, edge['from_data_id'].id)
            edge['to_data_id'] = Node(body, edge['to_data_id'].id)
            edge['init_data_id'] = Node(body, edge['init_data_id'].id)
            add_opoutput(body, edge['from_data_id'].id, 0, False)

            # Assign/reuse ids for the back-edge start; it comes from from_data_id
            assert len(edge['from_data_id'].in_nodes()) == 1
            # layer id
            if not edge['from_data_id'].in_node().has_valid(
                    'internal_layer_id'):
                edge['from_data_id'].in_node(
                )['internal_layer_id'] = internal_id_count
                internal_id_count += 1
            edge['from_layer'] = edge['from_data_id'].in_node(
            )['internal_layer_id']

            # port id
            if 'internal_port_id' not in edge['from_data_id'].in_edge():
                edge['from_data_id'].in_edge(
                )['internal_port_id'] = internal_id_count
                internal_id_count += 1
            edge['from_port'] = edge['from_data_id'].in_edge(
            )['internal_port_id']

            # Look at all consumers for a data that ends a back-edge
            # For each such consumer, there will be a separate back-edge (and input)
            current_real_back_edges = []
            for _, consumer, key, edge_attrs in body.out_edges(
                    edge['to_data_id'].id, data=True, keys=True):

                real_edge = {}
                real_edge.update(
                    edge)  # all real back_edges have the same back-edge start

                consumer = Node(body, consumer)

                if real_edge['to_data_id'].in_node().has_valid(
                        'internal_layer_id'):
                    assert False
                    real_edge['to_data_id'].out_node()['internal_layer_id'] = \
                        real_edge['to_data_id'].in_node().internal_layer_id
                elif not consumer.has_valid('internal_layer_id'):
                    consumer['internal_layer_id'] = internal_id_count
                    internal_id_count += 1
                real_edge['to_layer'] = consumer['internal_layer_id']

                assert 'internal_port_id' not in edge_attrs
                assert len(real_edge['init_data_id'].out_edges()) == 1
                assert not 'internal_port_id' in real_edge[
                    'init_data_id'].out_edge()
                edge_attrs['internal_port_id'] = internal_id_count
                internal_id_count += 1
                real_edge['to_port'] = edge_attrs['internal_port_id']
                real_edge['consumer'] = consumer
                real_edge['consumer_key'] = key

                real_edge['attrs'] = deepcopy(edge_attrs)
                current_real_back_edges.append(real_edge)

            # connect initial data node with each consumer providing actual edge attributes
            body.add_edges_from([
                (real_edge['init_data_id'].id, real_edge['consumer'].id,
                 real_edge['consumer_key'], real_edge['attrs'])
                for real_edge in current_real_back_edges
            ])

            body.remove_nodes_from(
                [edge['to_data_id'].id, edge['to_data_id'].in_node().id])
            real_back_edges += current_real_back_edges

        real_external_inputs = []

        for ext_inp in external_inputs:
            assert ext_inp['external_data_id'].id not in body.nodes()
            assert ext_inp['internal_data_id'].id in body.nodes()
            ext_inp['internal_data_id'] = Node(body,
                                               ext_inp['internal_data_id'].id)

            if ext_inp['axis'] is not None:
                # Insert squeezing resize at input port that has partitioning
                shape = ext_inp['internal_data_id'].shape.copy()
                assert not ext_inp['internal_data_id'].has_valid('value')
                new_input_data = Op._create_data_node(
                    body,
                    ext_inp['internal_data_id'].name + '/UnsqueezedInput',
                    dict(shape=shape_insert(shape, ext_inp['axis'], 1)))

                reshape_op = Squeeze(
                    body,
                    dict(name=ext_inp['internal_data_id'].name +
                         '/InputSqueeze'))
                reshape_dim_data = Const(
                    body, {
                        'name':
                        ext_inp['internal_data_id'].name + '/ReshapeDim',
                        'value': ext_inp['axis']
                    }).create_node_with_data()
                reshape_op.create_node_with_data(
                    [new_input_data, reshape_dim_data],
                    data_nodes=[ext_inp['internal_data_id']])
                ext_inp['internal_data_id'] = new_input_data

            ext_inp['internal_data_id']['is_input'] = True
            assert len(ext_inp['internal_data_id'].in_nodes()) == 0
            ext_inp['external_port_id'] = internal_id_count
            internal_id_count += 1
            for _, consumer, edge_attrs in body.out_edges(
                    ext_inp['internal_data_id'].id, data=True):
                real_ext_inp = {}
                real_ext_inp.update(ext_inp)
                consumer = Node(body, consumer)
                if not consumer.has_valid('internal_layer_id'):
                    consumer['internal_layer_id'] = internal_id_count
                    internal_id_count += 1
                if not 'internal_port_id' in edge_attrs:
                    edge_attrs['internal_port_id'] = internal_id_count
                    internal_id_count += 1
                real_ext_inp['internal_layer_id'] = consumer[
                    'internal_layer_id']
                real_ext_inp['internal_port_id'] = edge_attrs[
                    'internal_port_id']
                real_external_inputs.append(real_ext_inp)

        for ext_out in external_outputs:
            assert ext_out['external_data_id'].id not in body.nodes()
            assert ext_out['internal_data_id'].id in body.nodes()
            ext_out['internal_data_id'] = Node(body,
                                               ext_out['internal_data_id'].id)

            if ext_out['axis'] is not None:
                # Insert unsqueezing resize at output port that has partitioning
                reshape_op = Unsqueeze(
                    body,
                    dict(name=ext_out['internal_data_id'].name +
                         '/OutputUnsqueeze'))
                reshape_dim_data = Const(
                    body, {
                        'name':
                        ext_out['internal_data_id'].name + '/ReshapeDim',
                        'value': ext_out['axis']
                    }).create_node_with_data()
                ext_out['internal_data_id'] = reshape_op.create_node_with_data(
                    [ext_out['internal_data_id'], reshape_dim_data])

            # TODO: add here working with simple outputs

            if not any([
                    out_node.soft_get('op', None) == 'Result'
                    for out_node in ext_out['internal_data_id'].out_nodes()
            ]):
                add_opoutput(body, ext_out['internal_data_id'].id, 0, False)

            # assert len(ext_out['internal_data_id'].out_nodes()) == 0
            assert len(ext_out['internal_data_id'].in_nodes()) == 1
            if not 'internal_layer_id' in ext_out['internal_data_id'].in_node(
            ):
                ext_out['internal_data_id'].in_node(
                )['internal_layer_id'] = internal_id_count
                internal_id_count += 1
            if not 'internal_port_id' in ext_out['internal_data_id'].in_edge():
                ext_out['internal_data_id'].in_edge(
                )['internal_port_id'] = internal_id_count
                internal_id_count += 1
            ext_out['internal_layer_id'] = ext_out['internal_data_id'].in_node(
            )['internal_layer_id']
            ext_out['internal_port_id'] = ext_out['internal_data_id'].in_edge(
            )['internal_port_id']
            ext_out['external_port_id'] = internal_id_count
            internal_id_count += 1

        # create TensorIterator layer with pre-computed components
        ti_op = TensorIterator(
            graph, {
                'name':
                name + '/TensorIterator',
                'body':
                body,
                'in_ports_count':
                len(external_inputs),
                'out_ports_count':
                len(external_outputs),
                'input_port_map': [{
                    field: external_input[field]
                    for field in [
                        'external_port_id', 'internal_layer_id',
                        'internal_port_id', 'axis', 'stride', 'part_size',
                        'start', 'end'
                    ]
                } for external_input in real_external_inputs],
                'output_port_map': [{
                    field: external_output[field]
                    for field in [
                        'external_port_id', 'internal_layer_id',
                        'internal_port_id', 'axis', 'stride', 'part_size',
                        'start', 'end'
                    ]
                } for external_output in external_outputs],
                'back_edges': [{
                    field: edge[field]
                    for field in
                    ['from_layer', 'from_port', 'to_layer', 'to_port']
                } for edge in real_back_edges],
            })

        ti_outs = ti_op.create_node_with_data(
            inputs=[inp['external_data_id'] for inp in external_inputs],
            edge_attrs=[{
                'external_port_id': inp['external_port_id']
            } for inp in external_inputs],
            data_nodes=[out['external_data_id'] for out in external_outputs])

        if not isinstance(ti_outs, list):
            ti_outs = [ti_outs]

        for i, out in enumerate(ti_outs):
            out.in_edge(
            )['external_port_id'] = external_outputs[i]['external_port_id']

        ti = ti_outs[0].in_node()
        TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti)
        TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti)
        TensorIterator.normalize_internal_ids(ti)
Beispiel #10
0
    def extract(cls, loop_node):
        Loop.update_node_stat(loop_node, {})

        body_graph_proto = onnx_attr(loop_node, 'body', 'g', None)
        main_graph = loop_node.graph

        # create a Graph object for the body and take graph attributes from the main graph
        body_graph = Graph()
        main_graph_attrs_copy = copy.deepcopy(main_graph.graph)
        del main_graph_attrs_copy['tensor_mapping']
        body_graph.graph.update(main_graph_attrs_copy)
        loop_node['body'] = body_graph

        # maps a tensor name to a node produced it and the node port: str -> (node_id, node_port)
        data_nodes_map = {}
        body_graph.graph[
            'tensor_mapping'] = data_nodes_map  # save mapping for possible Loop inside the Loop

        body_parameters = add_initializers_and_inputs_to_graph(
            body_graph, body_graph_proto, data_nodes_map)

        external_edges = [
        ]  # (src_node, src_out_port), dest_body_parameter_node
        additional_params = {
        }  # (src_node, src_out_port) -> parameter_node (for manually added Parameters)
        # Go through all nodes in the original model order because data nodes are defined on-the-fly and order matters
        for pb_node in body_graph_proto.node:
            # create an NX node
            id = body_graph.unique_id(node_id(pb_node))
            body_graph.add_node(id, pb=pb_node, kind='op')

            # add incoming edges based on data_nodes_map
            for dst_port, inp in enumerate(pb_node.input):
                # should add edge inp --> id
                if inp not in data_nodes_map:
                    if inp == '':
                        # input is omitted; most likely it corresponds to an optional input for an operator
                        continue
                    elif inp in main_graph.graph['tensor_mapping']:
                        log.debug(
                            'The edge between outer and inner graphs detected: {} -> {}'
                            .format(inp, id))
                        if main_graph.graph['tensor_mapping'][
                                inp] not in additional_params:
                            # create new Parameter body node and connect the body node with the outer graph using it
                            param_id = str(inp)
                            body_graph.add_node(param_id,
                                                kind='op',
                                                op='Parameter',
                                                name=param_id,
                                                pb=None,
                                                shape=None)
                            parameter_node = Node(body_graph, param_id)
                            # need to manually update necessary attrs for the node because extractor will not be called
                            # for it because the node does not have .pb attribute
                            Parameter.update_node_stat(parameter_node, {})
                            external_edges.append(
                                (main_graph.graph['tensor_mapping'][inp],
                                 parameter_node))
                            src_id, src_port = param_id, 0
                            additional_params[main_graph.graph[
                                'tensor_mapping'][inp]] = parameter_node
                        else:
                            src_id, src_port = additional_params[
                                main_graph.graph['tensor_mapping'][inp]].id, 0
                    else:
                        raise Error(
                            'Reference to "{}" is not satisfied. A node refer not existing data tensor. ONNX '
                            'model is not consistent. Protobuf fragment: {}',
                            inp, pb_node)
                else:
                    src_id, src_port = data_nodes_map[inp]

                assert (body_graph.has_node(src_id))
                edge_attrs = {
                    'out': src_port,
                    'in': dst_port,
                    'name': inp,
                    'fw_tensor_debug_info': [(inp, inp)],
                    'in_attrs': ['in', 'name'],
                    'out_attrs': ['out', 'name'],
                    'data_attrs': ['fw_tensor_debug_info']
                }
                body_graph.add_edge(src_id, id, **edge_attrs)

            # add outgoing edges to data_nodes_map
            for src_port, out in enumerate(pb_node.output):
                if out in data_nodes_map:
                    log.debug("Detected reuse of blob {}.".format(out))
                data_nodes_map[out] = (id, src_port)

        body_results = []
        for output in body_graph_proto.output:
            tensor_name = str(output.name)
            node_name, output_port = data_nodes_map[tensor_name]
            assert body_graph.has_node(
                node_name
            ), 'The body graph does not contain output with name "{}"'.format(
                node_name)
            body_results.append(
                Node(body_graph,
                     add_opoutput(body_graph, node_name, output_port, False)))

        # add 'internal_layer_id' attribute which is a must have attribute for the loop body node
        for idx, body_node in enumerate(body_graph.get_op_nodes()):
            body_node['internal_layer_id'] = idx

        loop_carried_dependencies_count = len(body_graph_proto.input) - 2
        scan_outputs_count = len(
            body_graph_proto.output) - 1 - loop_carried_dependencies_count

        # Loop inputs:
        #   0 - trip count
        #   1 - execution condition
        #   2 .. - loop carried dependencies

        # Loop outputs:
        #   0 .. loop_carried_dependencies_count - 1 - loop carried dependencies
        #   loop_carried_dependencies_count .. - scan outputs

        # Body inputs:
        #   0 - iteration number
        #   1 - execution condition
        #   2 .. - loop carried dependencies

        # Body outputs:
        #   0 - execution condition
        #   1 .. loop_carried_dependencies_count - loop carried dependencies
        #   loop_carried_dependencies_count + 1 .. - scan outputs

        body_graph.stage = 'front'
        # some of the inputs/outputs may not be connected but the normalization transformation will take care of it
        # connection Loop body nodes with external input edges
        next_loop_input_port_idx = sorted(loop_node.in_edges().keys())[-1] + 1
        for (src_node, src_port), body_node in external_edges:
            main_graph.add_edge(
                src_node, loop_node.id, **{
                    'out': src_port,
                    'in': next_loop_input_port_idx,
                    'name': src_node,
                    'fw_tensor_debug_info': [(src_node, src_node)],
                    'in_attrs': ['in', 'name'],
                    'out_attrs': ['out', 'name'],
                    'data_attrs': ['fw_tensor_debug_info']
                })
            connect_body_input(loop_node, next_loop_input_port_idx, body_node)
            next_loop_input_port_idx += 1

        # mark current iteration input Parameter node
        Loop.mark_current_iteration_parameter_node(loop_node,
                                                   body_parameters[0])

        # connect initial value for "execution condition" input of the loop
        connect_body_input(loop_node, 1, body_parameters[1])
        # add back edge with "execution condition"
        Loop.add_back_edge(loop_node, body_parameters[1], body_results[0])
        # mark "execution condition" Result node
        Loop.mark_execution_condition_result_node(loop_node, body_results[0])

        # connect initial value for "loop carried" dependencies variables
        for idx in range(loop_carried_dependencies_count):
            connect_body_input(loop_node, idx + 2, body_parameters[idx + 2])
        # add back edge for "loop carried" dependencies variables
        for idx in range(loop_carried_dependencies_count):
            Loop.add_back_edge(loop_node, body_parameters[idx + 2],
                               body_results[idx + 1])
        # connect final value for "loop carried" dependencies variables
        for idx in range(loop_carried_dependencies_count):
            connect_body_output(loop_node, idx, body_results[idx + 1])

        # connect "scan outputs" and mark axis for concatenation
        for idx in range(loop_carried_dependencies_count,
                         loop_carried_dependencies_count + scan_outputs_count):
            connect_body_output(loop_node, idx, body_results[idx + 1], axis=0)

        # run function to parse body nodes attributes similar to the main graph
        extract_node_attrs(
            body_graph, lambda node: onnx_op_extractor(
                node, check_for_duplicates(onnx_op_extractors)))
        return cls.enabled
Beispiel #11
0
    def extract(cls, loop_node):
        Loop.update_node_stat(loop_node, {})

        body_graph_proto = onnx_attr(loop_node, 'body', 'g', None)
        main_graph = loop_node.graph

        # create a Graph object for the body and take graph attributes from the main graph
        body_graph = Graph()
        main_graph_attrs_copy = {}
        for attr_key, attr_value in main_graph.graph.items():
            if attr_key not in ['tensor_mapping', 'parent_node']:
                main_graph_attrs_copy[attr_key] = copy.deepcopy(attr_value)
        body_graph.graph.update(main_graph_attrs_copy)
        loop_node['body'] = body_graph
        # save parent node for nested loops to know which node contains body (and which graph is on upper level)
        body_graph.graph['parent_node'] = loop_node

        # maps a tensor name to a node produced it and the node port: str -> (node_id, node_port)
        data_nodes_map = {}
        body_graph.graph['tensor_mapping'] = data_nodes_map  # save mapping for possible Loop inside the Loop

        body_parameters = add_initializers_and_inputs_to_graph(body_graph, body_graph_proto, data_nodes_map)

        external_edges = []  # (src_node, src_out_port), dest_body_parameter_node
        # save additional edges information for graph on each level, the first one is the deepest
        additional_params = []  # (src_node, src_out_port) -> parameter_node (for manually added Parameters)
        # Go through all nodes in the original model order because data nodes are defined on-the-fly and order matters
        for pb_node in body_graph_proto.node:
            # create an NX node
            id = body_graph.unique_id(node_id(pb_node))
            body_graph.add_node(id, pb=pb_node, kind='op')
            if hasattr(body_graph, 'op_names_statistic') and hasattr(pb_node, 'op_type'):
                body_graph.op_names_statistic[pb_node.op_type] += 1

            # add incoming edges based on data_nodes_map
            for dst_port, inp in enumerate(pb_node.input):
                # should add edge src_internal_id --> dst_id
                if inp not in data_nodes_map:
                    if inp == '':
                        # input is omitted; most likely it corresponds to an optional input for an operator
                        continue
                    else:
                        is_finished = create_cross_body_edge(body_graph, external_edges, additional_params,
                                                             inp, id, dst_port)
                        if not is_finished:
                            raise Error(
                                'Reference to "{}" is not satisfied. A node refer not existing data tensor. ONNX '
                                'model is not consistent. Protobuf fragment: {}', inp, pb_node)
                else:
                    src_id, src_port = data_nodes_map[inp]
                    create_edge_with_attrs(body_graph, inp, src_id, src_port, id, dst_port)

            # add outgoing edges to data_nodes_map
            for src_port, out in enumerate(pb_node.output):
                if out in data_nodes_map:
                    log.debug("Detected reuse of blob {}.".format(out))
                data_nodes_map[out] = (id, src_port)

        body_results = []
        for output in body_graph_proto.output:
            tensor_name = str(output.name)
            node_name, output_port = data_nodes_map[tensor_name]
            assert body_graph.has_node(node_name), 'The body graph does not contain output with name "{}"'.format(
                node_name)
            body_results.append(Node(body_graph, add_opoutput(body_graph, node_name, output_port, False)))

        # add 'internal_layer_id' attribute which is a must have attribute for the loop body node
        for idx, body_node in enumerate(body_graph.get_op_nodes()):
            body_node['internal_layer_id'] = idx

        loop_carried_dependencies_count = len(body_graph_proto.input) - 2
        scan_outputs_count = len(body_graph_proto.output) - 1 - loop_carried_dependencies_count

        # Loop inputs:
        #   0 - trip count
        #   1 - execution condition
        #   2 .. - loop carried dependencies

        # Loop outputs:
        #   0 .. loop_carried_dependencies_count - 1 - loop carried dependencies
        #   loop_carried_dependencies_count .. - scan outputs

        # Body inputs:
        #   0 - iteration number
        #   1 - execution condition
        #   2 .. - loop carried dependencies

        # Body outputs:
        #   0 - execution condition
        #   1 .. loop_carried_dependencies_count - loop carried dependencies
        #   loop_carried_dependencies_count + 1 .. - scan outputs

        # some of the inputs/outputs may not be connected but the normalization transformation will take care of it
        # connection Loop body nodes with external input edges
        next_loop_input_port_idx = sorted(loop_node.in_edges().keys())[-1] + 1
        cur_graph = body_graph
        for external_edges_subg in external_edges:
            if 'parent_node' not in cur_graph.graph:
                continue
            cur_loop_node = cur_graph.graph['parent_node']
            parent_graph = cur_loop_node.graph
            for (src_node, src_port), body_node, tensor_name in external_edges_subg:
                create_edge_with_attrs(parent_graph, tensor_name, src_node, src_port,
                                       cur_loop_node.id, next_loop_input_port_idx)

                Loop.connect_body_input(cur_loop_node, next_loop_input_port_idx, body_node)
                next_loop_input_port_idx += 1
            cur_graph = parent_graph

        # mark current iteration input Parameter node
        Loop.mark_current_iteration_parameter_node(loop_node, body_parameters[0])

        # connect initial value for "execution condition" input of the loop
        Loop.connect_body_input(loop_node, 1, body_parameters[1])
        # add back edge with "execution condition"
        Loop.add_back_edge(loop_node, body_parameters[1], body_results[0])
        # mark "execution condition" Result node
        Loop.mark_execution_condition_result_node(loop_node, body_results[0])

        # connect initial value for "loop carried" dependencies variables
        for idx in range(loop_carried_dependencies_count):
            Loop.connect_body_input(loop_node, idx + 2, body_parameters[idx + 2])
        # add back edge for "loop carried" dependencies variables
        for idx in range(loop_carried_dependencies_count):
            Loop.add_back_edge(loop_node, body_parameters[idx + 2], body_results[idx + 1])
        # connect final value for "loop carried" dependencies variables
        for idx in range(loop_carried_dependencies_count):
            Loop.connect_body_output(loop_node, idx, body_results[idx + 1])

        # connect "scan outputs" and mark axis for concatenation
        for idx in range(loop_carried_dependencies_count, loop_carried_dependencies_count + scan_outputs_count):
            Loop.connect_body_output(loop_node, idx, body_results[idx + 1], axis=0)

        # run function to parse body nodes attributes similar to the main graph
        extract_node_attrs(body_graph, lambda node: onnx_op_extractor(node, check_for_duplicates(onnx_op_extractors)))
        return cls.enabled