Exemplo n.º 1
0
    def unsqueeze_num_directions(graph: Graph, match: dict):
        """ Assuming considered LSTM/GRU/RNN node should has num_directions in output shape and add Unsqueeze
            to match it.
        """

        rnn_layer = match['rnn_layer']
        rnn_layer_name = rnn_layer.soft_get('name', rnn_layer.id)
        # num_directions is at 1st position in output shape, and in 0st position in hidden and cell states
        # please refer to docs in this transform

        direction_dim = [1, 0, 0]  # index of dimension with direction index
        for i in rnn_layer.out_nodes():
            old_data_node = rnn_layer.out_node(i)
            old_shape = old_data_node.shape.copy()
            new_shape = shape_delete(old_shape, direction_dim[i])

            data = Op._create_data_node(graph, name=rnn_layer.name + '/Out/{}/'.format(i), attrs={'shape': new_shape})
            graph.remove_edge(rnn_layer.id, old_data_node.id)
            graph.add_edge(rnn_layer.id, data.id, key=0, out=i)

            unsqueeze = Unsqueeze(graph, dict())

            unsqueeze_dim_data = Const(graph, {'name': rnn_layer.name + '/UnsqueezeNumDirections/{}/Dim'.format(i),
                                               'value': int64_array([direction_dim[i]])}).create_node_with_data()

            unsqueeze.create_node_with_data([data, unsqueeze_dim_data],
                                            dict(name=rnn_layer_name + '/UnsqueezeNumDirections/{}'.format(i)),
                                            data_nodes=[old_data_node])
Exemplo n.º 2
0
 def replace_output_edges(graph: Graph, output_edges_match: dict):
     """
     Replacing existing input/output edges with a new ones to a new sub-graph.
     :param graph: networkX graph to operate on.
     :param output_edges_match: match of output edges between old and new sub-graph.
     :return: None
     """
     for old_name_port, new_name_port in output_edges_match.items():
         old_node_name, old_out_port = __class__.extract_port(old_name_port)
         new_node_name, new_out_port = __class__.extract_port(new_name_port)
         for src, dst, edge_attrs in graph.out_edges(old_node_name, data=True):
             if edge_attrs['out'] == old_out_port:
                 new_edge_attrs = edge_attrs.copy()
                 new_edge_attrs['out'] = new_out_port
                 # Add control_flow ports, as we do not copy control flow ports to new node
                 if 'control_flow_edge' in new_edge_attrs and new_edge_attrs['control_flow_edge'] is True:
                     in_port_id = 'control_flow_{}'.format(new_edge_attrs['in'])
                     out_port_id = 'control_flow_{}'.format(new_edge_attrs['out'])
                     in_node, out_node = Node(graph, dst), Node(graph, new_node_name)
                     # if not out_node.has_port('out', out_port_id, control_flow=True):
                     out_node.add_output_port(out_port_id, control_flow=True, skip_if_exist=True)
                     # if not in_node.has_port('in', in_port_id, control_flow=True):
                     in_node.add_input_port(in_port_id, control_flow=True, skip_if_exist=True)
                 graph.add_edge(new_node_name, dst, **new_edge_attrs)
                 log.debug("Created edge from {} to {} with attrs: {}".format(new_node_name, dst, new_edge_attrs))
Exemplo n.º 3
0
    def find_and_replace_pattern(self, graph: Graph):
        mp = {}
        used = {}
        for node in graph.get_op_nodes(type='Concat'):
            in_nodes = tuple(
                [node.in_node(idx).id for idx in range(len(node.in_nodes()))])
            out_node = (node.id, node.out_node().id)
            if in_nodes in mp:
                log.warning("Something is weird! {} and {}".format(
                    node.id, mp[in_nodes]))
            else:
                mp.update({in_nodes: out_node})
                used.update({node.id: {x: False for x in in_nodes}})

        for key in mp.keys():
            replacers = []
            for i in range(len(key)):
                for j in range(i + 1, len(key)):
                    arr = tuple(key[i:j + 1])
                    if arr in mp.keys() and arr != key:
                        replacers.append((len(arr), arr))

            replacers.sort(reverse=True)

            concat_id = mp[key][0]
            for ln, arr in replacers:
                # Check that we can do it!!!
                we_can = True
                for x in arr:
                    if used[concat_id][x]:
                        we_can = False
                        break

                if not we_can:
                    continue

                for x in arr:
                    used[concat_id][x] = True

                edge_attrs = graph.get_edge_data(arr[0], concat_id)[0]
                for in_node in arr:
                    graph.remove_edge(in_node, concat_id)

                new_input = mp[arr][1]
                out_port = len(Node(graph, new_input).out_nodes()) + 1
                edge_attrs['out'] = out_port
                graph.add_edge(new_input, concat_id, **edge_attrs)

                # Renumber 'in' attrs
                concat_node = Node(graph, concat_id)
                ln = len(concat_node.in_nodes())
                ports = [x for x in concat_node.in_nodes().keys()]
                ports.sort()

                p_id = 0
                for p in ports:
                    in_node = concat_node.in_nodes()[p]
                    graph[in_node.id][concat_id][0]['in'] = p_id
                    p_id += 1
Exemplo n.º 4
0
    def add_output_reshape(graph: Graph, match: dict):
        """
        Since MXNet Y output shape is [batch_size, seq_len, hidden_size * num_directions] we need to add reshape
        from above common format [batch_size, num_directions, seq_len, hidden_size] to MXNet format.
        """
        lstm = match['rnn_layer']
        input = match['input']
        if not lstm.has_num_directions:
            return
        old_data_node = lstm.out_node(0)
        num_directions = 2 if lstm.direction in ['bidirectional'] else 1
        mxnet_shape = lstm.out_node(0).shape.copy()

        if lstm.batch_dim == 0:
            mo_shape = shape_array([
                input.shape[lstm.batch_dim], input.shape[lstm.sequence_dim],
                lstm.hidden_size
            ])
        else:
            mo_shape = shape_array([
                input.shape[lstm.sequence_dim], input.shape[lstm.batch_dim],
                lstm.hidden_size
            ])

        if lstm.has_num_directions:
            mo_shape = shape_insert(mo_shape, 1, np.int64(num_directions))

        lstm_name = lstm.soft_get('name', lstm.id)

        new_data = Op._create_data_node(graph,
                                        name=lstm_name +
                                        '/Data/Reshape_mxnet/',
                                        attrs={'shape': mo_shape})
        graph.remove_edge(lstm.id, old_data_node.id)
        graph.add_edge(lstm.id, new_data.id, key=0, out=0)

        # Add Transpose
        permute_order = Const(
            graph, {
                'name': lstm_name + '/Transpose_mxnet_order',
                'value': int64_array([0, 2, 1, 3])
            }).create_node_with_data()
        permute_data = Transpose(graph, {
            'name': lstm_name + '/Transpose_mxnet/'
        }).create_node_with_data([new_data, permute_order])

        # Add Reshape
        reshape = Reshape(graph, {'name': lstm_name + '/Reshape_mxnet/'})
        reshape_dim_data = Const(
            graph, {
                'name': lstm_name + '/Reshape_mxnet_dim',
                'value': int64_array(unmask_shape(mxnet_shape))
            }).create_node_with_data()

        reshape.create_node_with_data([permute_data, reshape_dim_data],
                                      dict(),
                                      data_nodes=[old_data_node])
Exemplo n.º 5
0
def batch_norm_fuse_action(graph: Graph, match: dict):
    """
    Multiply convolution kernel by batch normalization coefficient and remove mul op.
    """
    if match['norm'].value is None or match['kernel'].value is None:
        # cannot fuse non-const normalization coefficients
        return
    if len(graph.out_edges(match['conv_output'].node)) > 1 or len(graph.out_edges(match['kernel'].node)) > 1:
        # we cannot modify original kernel or convolution, if they are used multiple times
        # TODO make a copy of conv and kernel instead of this check
        return
    match['kernel'].value = match['kernel'].value * match['norm'].value
    graph.remove_edge(match['conv_output'].node, match['mul'].node)
    graph.remove_edge(match['mul'].node, match['mul_output'].node)
    # graph.remove_node(match['mul'].node)  # if we remove a node, next iteration over isomorphisms gives an error
    graph.add_edge(match['conv'].node, match['mul_output'].node, out=0)
Exemplo n.º 6
0
 def replace_input_edges(graph: Graph, input_edges_match: dict):
     """
     Replacing existing input/output edges with a new ones to a new sub-graph.
     :param graph: networkX graph to operate on.
     :param input_edges_match: match of input edges between old and new sub-graph.
     :return: None
     """
     for old_name_port, new_name_port in input_edges_match.items():
         old_node_name, old_in_port = __class__.extract_port(old_name_port)
         new_node_name, new_in_port = __class__.extract_port(new_name_port)
         old_node = Node(graph, old_node_name)
         src_node_name = old_node.get_sorted_inputs()[old_in_port][0]
         edge_attrs = graph[src_node_name][old_node_name][0].copy()
         edge_attrs['in'] = new_in_port
         graph.add_edge(src_node_name, new_node_name, **edge_attrs)
         log.debug("Created edge from {} to {} with attrs: {}".format(src_node_name, new_node_name, edge_attrs))
Exemplo n.º 7
0
def add_edge_caffe(graph: Graph, bottom: str, dst_layer: str, blob_producers: dict, dst_port: int):
    """
    Creates an edge and adds it to the graph.
    """
    src_layer = blob_producers[bottom][0]
    src_port = blob_producers[bottom][1]
    edge_attrs = {
        'out': src_port,
        'in': dst_port,
        'name': bottom,
        # debug anchor for a framework name and tensor name
        'fw_tensor_debug_info': [(blob_producers[bottom][2], bottom)],
        'in_attrs': ['in', 'name'],
        'out_attrs': ['out', 'name'],
        'data_attrs': ['fw_tensor_debug_info']
    }
    graph.add_edge(src_layer, dst_layer, **edge_attrs)
Exemplo n.º 8
0
def pad_op_transform(graph: Graph, match: dict):
    op = match['op']
    pad_op = match['pad_op']
    input_data = pad_op.in_node(0)

    if pad_op.mode != 'constant':
        log.info(
            'The pad node "{}" with pad mode "{}" cannot be fused.'.format(
                pad_op.soft_get('name'), pad_op.mode))
        return

    if op.type == 'Pooling' and op.pool_method == 'max':
        return

    if pad_op.mode == 'constant':
        fill_value = pad_op.in_port(3).data.get_value()
        if fill_value is None or fill_value != 0.0:
            log.info(
                'The pad node "{}" with non-zero fill value cannot be fused.'.
                format(pad_op.soft_get('name')))
            return

    input_tensor_dims = len(match['pad_output'].shape)
    for in_port in [1, 2]:
        pads = pad_op.in_port(in_port).data.get_value()
        if pads[get_features_dim(op.graph.graph['layout'], input_tensor_dims)] != 0 or \
                pads[get_batch_dim(op.graph.graph['layout'], input_tensor_dims)] != 0:
            log.info(
                'The pad node "{}" with padding over feature/batch dimension cannot be fused.'
                .format(pad_op.soft_get('name')))
            return

    op.pad += np.concatenate([
        pad_op.in_port(1).data.get_value().reshape([-1, 1]),
        pad_op.in_port(2).data.get_value().reshape([-1, 1])
    ],
                             axis=1)
    op.pad_spatial_shape = op.pad[op.spatial_dims]
    op['auto_pad'] = None
    if op.type == 'Pooling':
        op['exclude_pad'] = False
    assert (graph[match['pad_output'].node][match['op'].node][0]['in'] == 0)
    edge_attrs = graph.get_edge_data(match['pad_output'].id, match['op'].id)[0]
    graph.remove_edge(match['pad_output'].id, match['op'].id)
    graph.add_edge(input_data.id, match['op'].id, **{'in': 0, **edge_attrs})
Exemplo n.º 9
0
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['op']
        node.op = 'Conv2D'

        if node.bias_term:
            num_inputs = len(node.in_nodes()) - 2
            w_node = node.in_node(len(node.in_nodes()) - 2)
            b_node = node.in_node(len(node.in_nodes()) - 1)
        else:
            num_inputs = len(node.in_nodes()) - 1
            w_node = node.in_node(len(node.in_nodes()) - 1)

        for i in range(1, num_inputs):
            in_i = node.in_node(i)
            out_i = node.out_node(i)
            conv_id = graph.unique_id(node.id + '__')
            graph.add_node(conv_id, **copy.deepcopy(node.get_attrs()))
            new_conv = Node(graph, conv_id)
            new_conv.name = conv_id

            graph.remove_edge(in_i.id, node.id)
            graph.remove_edge(node.id, out_i.id)
            graph.add_edges_from([
                (w_node.id, conv_id, {
                    'in': 1,
                    'bin': 'weights'
                }),
            ])

            if node.bias_term:
                graph.add_edges_from([
                    (b_node.id, conv_id, {
                        'in': 2,
                        'bin': 'biases'
                    }),
                ])

            graph.add_edges_from([
                (in_i.id, conv_id, {
                    'in': 0
                }),
            ])
            graph.add_edge(conv_id, out_i.id, **{'out': 0})
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        if not node.has_port('in', 2) or node.in_port(2).disconnected() or not node.has_and_set('shape_input'):
            return

        if node.has_valid('layout') and not node.layout.startswith('NC') and graph.graph['layout'] == 'NCHW':
            input_shape_rank = len(node.in_port(0).data.get_shape())
            permutation = PermuteAttrs.get_nhwc_to_nchw_permutation(input_shape_rank)

            data_node = node.in_node(2)

            name = node.soft_get('name', node.id) + '/ShapeGather'
            const = Const(graph, {'value': permutation.perm, 'name': name + '/Const',
                                  'need_shape_inference': True}).create_node_with_data()
            axis_const = Const(graph, {'value': int64_array(0), 'name': name + '/Axis'}).create_node_with_data()
            gather = Gather(graph, {'name': name,
                                    'need_shape_inference': True}).create_node_with_data([data_node, const, axis_const])
            attrs = graph.get_edge_data(data_node.id, node.id, key=0).copy()

            graph.add_edge(gather.id, node.id, **attrs)
            graph.remove_edge(data_node.id, node.id)
Exemplo n.º 11
0
def build_graph_with_edge_attrs(nodes_attrs: dict,
                                edges: list,
                                update_attributes: dict = None,
                                cli: Namespace = Namespace(static_shape=False,
                                                           data_type='FP32')):
    """
    Build the Graph with specific nodes and edges.
    :param nodes_attrs: dictionary where key is the node name and the value is the dictionary with node attributes.
    :param edges: list of pairs with start and end node names of the edge.
    :param update_attributes: optional dictionary which specifies nodes names and their attributes to be updated. The
    key is a node name to update attribute and the value is a dictionary with attribute name and its value.
    :param cli: Namespace with cli keys to associate with the graph
    :return: generated graph.
    """
    graph = Graph()
    for node_1, node_2, attr in edges:
        if node_1 not in graph.nodes():
            graph.add_node(node_1, **deepcopy(nodes_attrs[node_1]))
        if node_2 not in graph.nodes():
            graph.add_node(node_2, **deepcopy(nodes_attrs[node_2]))
        graph.add_edge(node_1, node_2, **attr)
    if update_attributes is not None:
        for node_name, new_attrs in update_attributes.items():
            assert (node_name in graph.nodes())
            for attr, value in new_attrs.items():
                graph.node[node_name][attr] = value

    for node in graph.get_op_nodes():
        # Add in_ports attribute
        in_edges = node.in_edges()
        for attr in in_edges.values():
            node.add_input_port(idx=attr['in'])

        # Add out_ports attribute
        out_edges = node.out_edges()
        for attr in out_edges.values():
            node.add_output_port(idx=attr['out'])

    graph.graph['cmd_params'] = cli
    return graph
Exemplo n.º 12
0
    def replace_op(self, graph: Graph, node: Node):
        ss_node = create_op_with_const_inputs(graph, Split, {1: int64_array(1)}, {'name': 'Split_eltwise_' + node.name,
                                                                                  'num_splits': node['num_inputs']})

        inp = node.get_inputs()
        in_node = inp[0][0]
        edge_attrs = inp[0][1]
        graph.add_edge(in_node, ss_node.id, **edge_attrs)
        if ss_node.num_splits == 2:
            if node['operation'] == 'mul':
                eltwise_node = Mul(graph, attrs={'name': 'Eltwise_' + node.name}).create_node()
            elif node['operation'] == 'sum':
                eltwise_node = Add(graph, attrs={'name': 'Eltwise_' + node.name}).create_node()
            else:
                raise Error('Error on replacing Kaldi eltwise: unknown type ' + node['operation'])
        elif ss_node.num_splits > 2:
            eltwise_node = EltwiseN(graph, attrs={'name': 'Eltwise_' + node.name,
                                                  'operation': node['operation']}).create_node()
        else:
            raise Error('Error on replacing Kaldi eltwise')
        for i in range(ss_node.num_splits):
            ss_node.out_port(i).get_connection().set_destination(eltwise_node.in_port(i))
        return [eltwise_node.id]
Exemplo n.º 13
0
    def find_and_replace_pattern(self, graph: Graph):
        # Iterate over all data nodes and find all with >= 1 consumers
        for input_data in list(graph.get_data_nodes()):
            # We don't use constant data nodes
            if input_data.value is not None:
                continue

            if input_data.shape is None:
                continue
            input_shape = shape_array(input_data.shape)

            # Get all unique StridedSlice consumers
            out_nodes = [node for node in input_data.out_nodes() if node.op == 'StridedSlice' and
                         node.in_node(0).id == input_data.id]

            if len(out_nodes) <= 1:
                continue

            valid_for_replacement = True
            for n in out_nodes:
                if any(not isinstance(s, slice) for s in n.slices):
                    # this is a slice with dynamic dimension. Such operation is not valid for replacement
                    valid_for_replacement = False
            if not valid_for_replacement:
                continue

            sorted_out_nodes = sorted(out_nodes, key=lambda n: list(n.slices))
            out_nodes = unique_by(sorted_out_nodes, strided_slices_equality)

            for node in out_nodes:
                if len(node.slices) != len(out_nodes[0].slices):
                    valid_for_replacement = False

            # Detect dimension for splitting
            split_channel_dim = None
            for dim_id, s in enumerate(out_nodes[0].slices):
                l, r, stride = s.start, s.stop, s.step
                # if both l and r are None then the dimension is not sliced
                if (l != 0 or r != input_shape[dim_id]) and (l is not None or r is not None):
                    if split_channel_dim is None:
                        split_channel_dim = dim_id
                    else:
                        valid_for_replacement = False

            if split_channel_dim is None:
                valid_for_replacement = False

            # split_dims contains tuples with split range and output data node
            split_dims = []
            for out_id, node in enumerate(out_nodes):
                # Check that StridedSlice op has stride eq 1 and splits only feature channel
                for id, s in enumerate(node.slices):
                    l, r, stride = s.start, s.stop, s.step
                    # We don't support StridedSlice with stride != 1
                    if stride != 1:
                        valid_for_replacement = False
                    if id == split_channel_dim:
                        split_dims.append((s.start, s.stop, node.out_node()))

            if not valid_for_replacement:
                continue

            # Check feature split intersection
            final_data_nodes_list = []
            sorted_split_dims = sorted(split_dims, key=lambda item: (item[0], item[1]))

            # check if we have similar StridedSlice operations with different outputs
            prev_sd = sorted_split_dims[0]
            to_remove = []
            for i in range(1, len(sorted_split_dims)):
                if sorted_split_dims[i][0] == prev_sd[0] and sorted_split_dims[i][1] == prev_sd[1] and sorted_split_dims[i][2].name != prev_sd[2].name:
                    cur_node = sorted_split_dims[i][2]
                    for out in cur_node.out_nodes():
                        attrs = deepcopy(graph.get_edge_data(cur_node.id, out.id)[0])
                        graph.remove_edge(cur_node.id, out.id)
                        graph.add_edge(prev_sd[2].id, out.id, **attrs)
                    to_remove.append(i)

            for ind in reversed(to_remove):
                sorted_split_dims.pop(ind)

            size_splits = []
            prev_r = 0
            for l, r, out in sorted_split_dims:
                # Split dims shouldn't intersect
                if l < prev_r:
                    valid_for_replacement = False
                prev_r = r

            if prev_r > input_shape[split_channel_dim]:
                valid_for_replacement = False

            if not valid_for_replacement:
                continue

            prev_r = 0
            for l, r, out in sorted_split_dims:
                # Save missing tensor part
                if l > prev_r:
                    shape = mo_array(input_shape)
                    size_splits.append(l - prev_r)
                    shape[split_channel_dim] = l - prev_r
                    data_node = Op._create_data_node(graph, 'fake_data_'+out_nodes[0].name, {'shape': shape})
                    add_opoutput(graph, data_node.id, 0, False, keep_output_port=True)
                    final_data_nodes_list.append(data_node)

                prev_r = r
                size_splits.append(r - l)
                final_data_nodes_list.append(out)

            if prev_r < input_shape[split_channel_dim]:
                # Add last part of tensor
                shape = input_shape.copy()
                shape[split_channel_dim] = input_shape[split_channel_dim] - prev_r
                size_splits.append(input_shape[split_channel_dim] - prev_r)
                data_node = Op._create_data_node(graph, 'fake_data_'+out_nodes[0].name, {'shape': shape})
                add_opoutput(graph, data_node.id, 0, False, keep_output_port=True)
                final_data_nodes_list.append(data_node)

            for node in out_nodes:
                if not np.all([x == 0 for x in node.shrink_axis_mask]):
                    out_node = node.out_node()
                    if np.any(node['shrink_axis_mask']):
                        self.add_squeeze_for_shrink(graph, node)
                    if np.any(node['new_axis_mask']):
                        self.add_unsqueeze_for_new(graph, node)

                    for i in range(len(final_data_nodes_list)):
                        if final_data_nodes_list[i].name == out_node.name:
                            final_data_nodes_list[i] = node.out_node()
                            break

            # Insert Split layer and remove old StridedSlice layers
            # 1. Remove connections from input_data to StridedSlice ops
            out_data_nodes = []
            name_for_future_split = out_nodes[0].name
            for node in out_nodes:
                out_data_nodes.append(node.out_node())
                graph.remove_edge(input_data.id, node.id)
                graph.remove_edge(node.id, node.out_node().id)
                graph.remove_node(node.id)
                log.debug("Removed: {}".format(node.id))

            # 2. Create Split layer and reorder outputs
            name = name_for_future_split + "/Split"
            axis_const = Const(graph, {'value': int64_array(split_channel_dim),
                                       'name': name + '/Axis'}).create_node_with_data()
            size_splits_const = Const(graph, {'value': int64_array(size_splits),
                                              'name': name + '/Sizes'}).create_node_with_data()
            split = VariadicSplit(graph, dict(name=name, out_ports_count=len(size_splits)))

            split.create_node_with_data(inputs=[input_data, axis_const, size_splits_const],
                                        data_nodes=final_data_nodes_list)
Exemplo n.º 14
0
def build_graph_with_attrs(nodes_with_attrs: list,
                           edges_with_attrs: list,
                           new_nodes_with_attrs: list = [],
                           new_edges_with_attrs: list = [],
                           update_edge_attrs: dict = None,
                           update_nodes_attributes: list = None,
                           nodes_with_edges_only: bool = False,
                           add_nodes_from_edges: bool = False):
    """
    Build the Graph with specific nodes and edges. Also update of edge and node parameters is supported.
    :param nodes_with_attrs: list of tuples ('node_name', {node_attrs})
    :param edges_with_attrs: list of tuples like (start node, end node, (optional) {attrs of the edge}).
    :param new_nodes_with_attrs: analogically nodes_with_attrs
    :param new_edges_with_attrs: analogically new_edges
    :param update_edge_attrs: optional dictionary like {('from_node', 'to_node', key): {edge_attrs}}.
    :param update_nodes_attributes: optional list of tuples which specifies nodes names and their attributes to be
    updated. The first element is a node name to update attribute and the second element is a dictionary with attribute
    name and its value.
    :param nodes_with_edges_only: add nodes which has at least one incoming or outcoming edge.
    :param add_nodes_from_edges: whether nodes that is not listed in all_nodes but are in all_edges is allowed.
    :return: generated graph.
    """
    if not_all_new([node[0] for node in nodes_with_attrs],
                   [node[0] for node in new_nodes_with_attrs]):
        raise Error(
            'Some nodes from new_nodes_with_attrs are already in nodes.'
            ' Please, add to new_nodes_with_attrs only NEW nodes.')

    if not_all_new([(edge[0], edge[1]) for edge in edges_with_attrs],
                   [(edge[0], edge[1]) for edge in new_edges_with_attrs]):
        raise Error(
            'Some edges from new_edges_with_attrs are already in edges.'
            ' Please, add to new_edges_with_attrs only NEW edges.')

    # Check that all nodes from list of edges are in nodes
    all_nodes = nodes_with_attrs + new_nodes_with_attrs
    all_edges = edges_with_attrs + new_edges_with_attrs
    all_nodes_names = [node[0] for node in all_nodes]
    if not add_nodes_from_edges and not all_edges_in_nodes(
            nodes=all_nodes_names, edges=all_edges):
        raise Error(
            "Some nodes from list of edges is not in nodes. Please, add all necessary nodes."
        )

    graph = Graph()

    # Create dict for nodes with attrs
    nodes_attrs = {}
    for node_name, attrs in all_nodes:
        nodes_attrs[node_name] = attrs
        if 'name' not in attrs:
            attrs['name'] = node_name

    if nodes_with_edges_only:
        # filter nodes to keep only ones with edges connected
        filtered_nodes = {}
        for edge in all_edges:
            node_1, node_2 = edge[0], edge[1]
            filtered_nodes[node_1] = nodes_attrs[node_1]
            filtered_nodes[node_2] = nodes_attrs[node_2]
        nodes_attrs = filtered_nodes

    # Create all nodes
    for node, attrs in nodes_attrs.items():
        graph.add_node(node, **deepcopy(attrs))

    # Connect nodes with edges (also unpack edge params)
    for edge in all_edges:
        node_1, node_2 = edge[0], edge[1]
        edge_attrs = edge[2] if len(edge) == 3 else {}
        graph.add_edge(node_1, node_2, **edge_attrs)

    # Update attributes of edges
    if update_edge_attrs:
        # it will work in 2.x networkx only
        for edge, attr in update_edge_attrs.items():
            for k, v in attr.items():
                nx.set_edge_attributes(G=graph, name=k, values={edge: v})

    # Update attributes of nodes
    if update_nodes_attributes is not None:
        for node_name, new_attrs in update_nodes_attributes:
            assert (node_name in graph.nodes())
            for attr, value in new_attrs.items():
                graph.node[node_name][attr] = value

    for node_id in graph.nodes():
        node = Node(graph, node_id)
        check_and_update_ports(node, [
            graph.get_edge_data(edge[0], node_id)[0]
            for edge in graph.in_edges(node_id)
        ], True)
        check_and_update_ports(node, [
            graph.get_edge_data(node_id, edge[1])[0]
            for edge in graph.out_edges(node_id)
        ], False)

    for node in graph.get_op_nodes():
        # Add in_ports attribute
        in_edges = node.in_edges()
        for i in range(len(in_edges)):
            node.add_input_port(idx=i)

        # Add out_ports attribute
        out_edges = node.out_edges()
        for i in range(len(out_edges)):
            node.add_output_port(idx=i)
    return graph
Exemplo n.º 15
0
def build_graph(nodes_attrs: dict,
                edges: list,
                update_attributes: dict = None,
                nodes_with_edges_only: bool = False,
                cli: Namespace = None):
    """
    Build the Graph with specific nodes and edges.
    :param nodes_attrs: dictionary where key is the node name and the value is the dictionary with node attributes.
    :param edges: list of pairs with start and end node names of the edge.
    :param update_attributes: optional dictionary which specifies nodes names and their attributes to be updated. The
    key is a node name to update attribute and the value is a dictionary with attribute name and its value.
    :param nodes_with_edges_only: add nodes which has at least one incoming or outcoming edge.
    :param cli: Namespace with cli keys to associate with the graph
    :return: generated graph.
    """
    # no mutable values must be set as default function argument
    cli = Namespace(static_shape=False,
                    data_type='FP32') if cli is None else cli
    graph = Graph()

    for node_name, attrs in nodes_attrs.items():
        if 'name' not in attrs:
            attrs['name'] = node_name

    if nodes_with_edges_only:
        # filter nodes to keep only ones with edges connected
        filtered_nodes = {}
        for item in edges:
            if len(
                    item
            ) == 2:  # TODO: is there any better way in python to do that?
                node1, node2 = item
            else:
                node1, node2, _ = item
            filtered_nodes[node1] = nodes_attrs[node1]
            filtered_nodes[node2] = nodes_attrs[node2]
        nodes_attrs = filtered_nodes

    # create all nodes first
    for node, attrs in nodes_attrs.items():
        assert node not in graph.nodes()
        graph.add_node(node, **deepcopy(attrs))

    # connect nodes with edges
    for item in edges:
        if len(item
               ) == 2:  # TODO: is there any better way in python to do that?
            node_1, node_2 = item
            edge_attrs = {}
        else:
            node_1, node_2, edge_attrs = item

        common_attrs = {
            'in': len(graph.in_edges(node_2)),
            'out': len(graph.out_edges(node_1)),
            'name': nodes_attrs[node_1]['name']
        }
        common_attrs.update(edge_attrs)
        graph.add_edge(node_1, node_2, **common_attrs)

    if update_attributes is not None:
        for node_name, new_attrs in update_attributes.items():
            assert (node_name in graph.nodes(
            )), 'Node with name "{}" is not in the graph'.format(node_name)
            for attr, value in new_attrs.items():
                graph.node[node_name][attr] = value

    for node in graph.get_op_nodes():
        # Add in_ports attribute
        in_edges = node.in_edges(control_flow=True)
        for attr in in_edges.values():
            control_flow = True if 'control_flow_edge' in attr and attr[
                'control_flow_edge'] is True else False
            node.add_input_port(idx=attr['in'], control_flow=control_flow)

        # Add out_ports attribute
        out_edges = node.out_edges(control_flow=True)
        for attr in out_edges.values():
            control_flow = True if 'control_flow_edge' in attr and attr[
                'control_flow_edge'] is True else False
            node.add_output_port(idx=attr['out'], control_flow=control_flow)

    graph.graph['cmd_params'] = cli
    return graph
Exemplo n.º 16
0
def merge_nodes(graph: Graph,
                nodes_to_merge_names: list,
                inputs_desc: list = None,
                outputs_desc: list = None):
    """
    Merges nodes specified in the set 'nodes_to_merge_names' into one mega-node, creating new edges between mega-node
    and inputs/outputs nodes of the mega-node. The added edges contain name of input/output nodes which will be used for
    generation of placeholders and will be saved to the IR xml so IE plug-in know how to map input/output data for the
    layer. Also the function adds protobufs of the nodes of the sub-graph and 'Const' ops consumed by nodes in the
    sub-graph to the node's attribute 'pbs'.
    :param graph: the graph object to operate on.
    :param nodes_to_merge_names: list of nodes names that should be merged into a single node.
    :param inputs_desc: optional list describing input nodes order.
    :param outputs_desc: optional list describing output nodes order.
    """
    if not is_connected_component(graph, nodes_to_merge_names):
        log.warning(
            "The following nodes do not form connected sub-graph: {}".format(
                nodes_to_merge_names))
        # graph.dump_graph_for_graphviz(nodes_to_dump=nodes_to_merge_names)

    new_node_name = graph.unique_id("TFSubgraphCall_")
    log.info("Create new node with name '{}' for nodes '{}'".format(
        new_node_name, ', '.join(nodes_to_merge_names)))
    graph.add_node(new_node_name)
    new_node_attrs = graph.node[new_node_name]

    new_node_attrs['name'] = new_node_name
    set_tf_custom_call_node_attrs(new_node_attrs)
    new_node = Node(graph, new_node_name)

    added_input_tensors_names = set(
    )  # set of tensors that are were added as input to the sub-graph
    added_new_node_output_tensors = dict(
    )  # key - tensor name, value - out port

    for node_name in nodes_to_merge_names:
        node = Node(graph, node_name)
        add_node_pb_if_not_yet_added(node, new_node)
        # TODO: any improvements?
        for in_node_name, edge_attrs in Node(graph, node_name).get_inputs():
            in_node = Node(graph, in_node_name)

            # internal edges between nodes of the sub-graph
            if in_node_name in nodes_to_merge_names:
                add_node_pb_if_not_yet_added(in_node, new_node)
                continue

            # edge outside of sub-graph into sub-graph
            if in_node_name not in nodes_to_merge_names:
                # we cannot use the 'in_node_name' as a protobuf operation name here
                # because the 'in_node_name' could be a sub-graph matched before.
                input_tensor_name = node.pb.input[edge_attrs['in']]
                if input_tensor_name not in added_input_tensors_names:
                    if not new_node.has_port('in', edge_attrs['in']):
                        new_node.add_input_port(edge_attrs['in'])
                    graph.add_edge(
                        in_node_name, new_node_name,
                        **merge_edge_props(
                            {
                                'in':
                                find_input_port(new_node, inputs_desc,
                                                node_name, edge_attrs['in']),
                                'out':
                                edge_attrs['out'],
                                'internal_input_node_name':
                                input_tensor_name,
                                'original_dst_node_name':
                                node_name,
                                'original_dst_port':
                                edge_attrs['in'],
                                'in_attrs': [
                                    'in', 'internal_input_node_name',
                                    'original_dst_node_name',
                                    'original_dst_port', 'placeholder_name'
                                ],
                                'out_attrs': ['out']
                            }, edge_attrs))
                    log.debug(
                        "Creating edge from outside of sub-graph to inside sub-graph: {} -> {}"
                        .format(in_node_name, new_node_name))
                    added_input_tensors_names.add(input_tensor_name)

        # edge from inside sub-graph to outside sub-graph
        for out_node_name, edge_attrs in Node(graph, node_name).get_outputs():
            if out_node_name not in nodes_to_merge_names:
                log.debug(
                    "Creating edge from inside of sub-graph to outside sub-graph: {} -> {}"
                    .format(new_node_name, out_node_name))
                out_name = internal_output_name_for_node(
                    node_name, edge_attrs['out'])
                if out_name not in added_new_node_output_tensors.keys():
                    added_new_node_output_tensors[out_name] = find_output_port(
                        new_node, outputs_desc, node_name, edge_attrs['out'])
                if not new_node.has_port(
                        'out', added_new_node_output_tensors[out_name]):
                    new_node.add_output_port(
                        added_new_node_output_tensors[out_name])
                graph.add_edge(
                    new_node_name, out_node_name,
                    **merge_edge_props(
                        {
                            'in': edge_attrs['in'],
                            'out': added_new_node_output_tensors[out_name],
                            'internal_output_node_name': out_name,
                            'in_attrs': ['in', 'internal_input_node_name'],
                            'out_attrs': ['out', 'internal_output_node_name']
                        }, edge_attrs))
        new_node['output_tensors_names'] = [
            val for val in
            {v: k
             for k, v in added_new_node_output_tensors.items()}.values()
        ]

    # add nodes using the same order as in initial GraphDef so we can dump them to IR in "correct" order
    new_node['nodes_order'] = [
        node for node in graph.graph['initial_nodes_order']
        if node in new_node['pbs'].keys()
    ]

    for n in nodes_to_merge_names:
        if graph.has_node(
                n):  # check if not deleted by another (similar) pattern
            graph.remove_node(n)
    return Node(graph, new_node_name)
Exemplo n.º 17
0
    def replace_pattern(self, graph: Graph, match: dict):
        log.debug('================== ConditionFind ===============')
        # init_1
        init_1 = match['init_1_data'].value
        assert init_1 is not None
        init_1 = int(init_1)

        # init_2
        init_2 = match['init_2_data'].value
        assert init_2 is not None
        init_2 = int(init_2)

        # step_1
        assert match['add_1_y_data'].value is not None
        step_1 = int(match['add_1_y_data'].value)

        # step_2
        assert match['add_2_y_data'].value is not None
        step_2 = int(match['add_2_y_data'].value)

        dynamic_seq_len = self.check_dynamic_seq_len(graph, match)

        # Create condition node and delete all useless nodes from condition pattern
        loop_condition = match['loop_cond_data']
        iterator_data = self.looking_for_iteration_counter(graph, match)

        condition_attrs = dict(time=dict(init=init_2, step=step_2),
                               iter=dict(init=init_1, step=step_1),
                               name=match['loop_cond'].name +
                               '/TensorIteratorCondition_')
        condition = TensorIteratorCondition(graph, attrs=condition_attrs)
        condition_data = condition.create_node_with_data(
            inputs=[match['Strided_slice_data'], match['minimum_data']],
            data_nodes=[loop_condition, iterator_data])

        safe_nodes = [
            'loop_cond_data', 'Identity_1_data', 'Identity_2_data',
            'Strided_slice', 'Strided_slice_data', 'minimum', 'minimum_data'
        ]

        identity_ops = [n.op for n in iterator_data.out_nodes()]
        if 'GreaterEqual' in identity_ops:
            greater_equal_id = [
                n.id for n in iterator_data.out_nodes()
                if n.op == 'GreaterEqual'
            ][0]

            if dynamic_seq_len:
                # Add BackEdge for time iterator node
                backedge = TensorIteratorBackEdge(
                    graph, dict(name='/TimeIterator/TensorIteratorBackEdge_'))
                backedge_data = backedge.create_node_with_data(inputs=[
                    match['init_2_data'], match['add_2_data'],
                    condition_data[0]
                ], )

                graph.remove_edge(match['add_2'].in_node(0).id,
                                  match['add_2'].id)
                graph.add_edge(backedge_data.id, match['add_2'].id,
                               **{'in': 0})

                graph.remove_edge(iterator_data.id, greater_equal_id)
                graph.add_edge(backedge_data.id, greater_equal_id, **{'in': 0})

                # nodes for time iterator
                safe_nodes += [
                    'init_2_data', 'init_2', 'Identity_2_data', 'add_2_data',
                    'add_2', 'add_2_y', 'add_2_y_data'
                ]

                # Manually reshape all iterator nodes (for time) from 0D to 1D
                iterator_data_nodes = [
                    backedge_data, match['add_2_data'], match['add_2_y_data'],
                    match['add_2_y'], match['init_2_data'], match['init_2']
                ]
                make_nodes_1D(iterator_data_nodes)
            else:
                # Delete Selects from this cycle to make it not dynamic:
                greater_equal_idxs = [
                    n.id for n in iterator_data.out_nodes()
                    if n.op == 'GreaterEqual'
                ]
                delete_selects_from(graph, greater_equal_idxs)

        # Delete useless nodes
        nodes_for_remove = []
        for node in match.keys():
            if node not in safe_nodes:
                nodes_for_remove.append(match[node].id)
        graph.remove_nodes_from(nodes_for_remove)
Exemplo n.º 18
0
    def replace_pattern(graph: Graph, match: dict):
        time_len = match['concatenated_hidden_states'].shape[0]
        r"""
        Working with concatenated_cell_states_data part first, because IE TensorIterator primitive doesn't have
        concatenated cell states output and if we can not collapse it, then we does not support this type of BlockLSTM

        We simplify the sub-graph below by taking another output of BlockLSTM:
        concatenated cell states over the whole time sequence -> last cell state

        BlockLSTM
           || out 1 (concatenated cell states coming out of BlockLSTM)
           \/  in 1
        ConcatV2
           || (concatenation with initial state or another unused data)
           \/
        Reshape
           ||
           \/
         Gather (taking the last cell state from previous BlockLSTM, if Gather indexes == time_len)
        """
        # check that there are no other consumers of concatenated_cell_states_data data flow
        valid_output_names = [
            'concat_1', 'concat_1_data', 'reshape_1', 'reshape_1_data',
            'gather_1', 'gather_1_data'
        ]
        valid_output_node_ids = [match[name].id for name in valid_output_names]
        node_names_to_check_outputs = [
            'concatenated_cell_states_data', 'concat_1_data', 'reshape_1_data'
        ]
        for name in node_names_to_check_outputs:
            for node in match[name].out_nodes():
                if node.id not in valid_output_node_ids:
                    raise Error(
                        "BlockLSTM node {} has output which contains concatenated cell states over the whole "
                        "time sequence. It is not replaceable by another output and is not supported "
                        "originally".format(match['BlockLSTM'].id))

        # check that we really take the last cell state data by Gather
        gather_indexes = match['gather_1'].in_node(1).value
        if len(gather_indexes) == 1:
            gather_index = gather_indexes[0]
        else:
            raise Error(
                "BlockLSTM node {} has output which contains concatenated cell states over the whole "
                "time sequence. It is not replaceable by another output and is not supported "
                "originally".format(match['BlockLSTM'].id))
        if gather_index != time_len:
            raise Error(
                "BlockLSTM node {} has output which contains concatenated cell states over the whole "
                "time sequence. It is not replaceable by another output and is not supported "
                "originally".format(match['BlockLSTM'].id))
        """
        We passed #1 and #2 stages from class description. It means that we can translate the rest of the pattern 
        to LSTMSequence even without following optimizations
        """

        node = match['BlockLSTM']
        weights_node = node.in_node(1)
        biases_node = node.in_node(2)
        shift_const = node.forget_bias

        # Assign temporary shape for them for easier manipulation
        # TF stores weights in IO order
        input_size = node.in_node(0).shape[-1]
        hidden_size = node.in_node(3).shape[-1]
        weights = weights_node.value
        biases = biases_node.value
        assert weights.shape[0] == input_size + hidden_size, \
            "weights.shape={} input_size={} hidden_size={}".format(weights.shape, input_size, hidden_size)
        assert weights.shape[1] == biases.shape[0] == 4 * hidden_size, \
            "weights.shape={} biases.shape={} hidden_size={}".format(weights.shape, biases.shape, hidden_size)

        weights = weights.reshape([
            weights.shape[0],
            4,  # gates
            hidden_size
        ])

        biases = biases.reshape([
            4,  # gates
            hidden_size
        ])

        # Reorder gates icfo --> fico for both weights and biases
        gate_reorder = [2, 0, 1, 3]
        weights = np.take(weights, gate_reorder, axis=1)
        biases = np.take(biases, gate_reorder, axis=0)

        # shift_const.value should be added to the first 1/4th part of the biases (f-gate: 0)
        # Note: in case of moving this code up before gate reordering, the addition
        # should be applied at different place
        biases[0] += shift_const

        # Return to the original shapes
        weights = weights.reshape([weights.shape[0], -1])
        biases = biases.flatten()

        # TF stores weights in IO, but IE requires it in OI: transpose
        weights = weights.transpose()

        weights_node.value = weights
        weights_node.shape = int64_array(weights.shape)
        biases_node.value = biases
        biases_node.shape = int64_array(biases.shape)

        attrs = dict(
            graph.get_edge_data(match['gather_1'].id,
                                match['gather_1_data'].id)[0])
        attrs.update({'out': 2})
        graph.remove_edge(match['BlockLSTM'].id,
                          match['concatenated_cell_states_data'].id)
        graph.remove_edge(match['gather_1'].id, match['gather_1_data'].id)

        match['BlockLSTM'].add_output_port(attrs['out'])
        graph.add_edge(match['BlockLSTM'].id, match['gather_1_data'].id,
                       **attrs)
        """
        #3 Renumbering h_init_state, c_init_state input ports to match RNNSequence ports order.
        """
        h_init_port = 4
        c_init_port = 5
        # c_init_state
        if 4 in node.in_nodes():
            assert c_init_port not in node.in_nodes()
            cell_state_edge = graph.get_edge_data(node.in_node(4).id, node.id)
            cell_state_edge[0]['in'] = c_init_port

        #h_init_state
        if 3 in node.in_nodes():
            assert h_init_port not in node.in_nodes()
            hidden_state_edge = graph.get_edge_data(
                node.in_node(3).id, node.id)
            hidden_state_edge[0]['in'] = h_init_port

        new_attrs = {
            'sequence_dim': 0,
            'batch_dim': 1,
            'direction': 'forward',
            'hidden_size': match['concatenated_hidden_states'].shape[-1],
            'format': 'tf',
        }

        LSTM.update_node_stat(match['BlockLSTM'], new_attrs)
        """
        Optional #4 optimization from class description following
        """
        data_to_mul = [
            n for n in match['mul'].in_nodes().values()
            if n.id != match['concatenated_hidden_states'].id
        ]
        if len(data_to_mul) != 1:
            return  # unexpected type of mul
        data_to_mul = data_to_mul[0]
        if not data_to_mul.has_valid('value'):
            return  # unexpected type of mul
        data_to_mul_value = data_to_mul.value
        if not np.all(data_to_mul_value == 1):
            return  # unexpected type of mul

        # remove useless mul
        attrs = dict(
            graph.get_edge_data(match['BlockLSTM'].id,
                                match['concatenated_hidden_states'].id)[0])
        graph.remove_edge(match['BlockLSTM'].id,
                          match['concatenated_hidden_states'].id)
        graph.remove_edge(match['mul'].id, match['mul_data'].id)
        graph.add_edge(match['BlockLSTM'].id, match['mul_data'].id, **attrs)

        # find true usages of concatenated hidden states data (not last hidden state)
        valid_output_names = [
            'mul_data', 'concat_0', 'concat_0_data', 'reshape_0',
            'reshape_0_data', 'gather_0', 'gather_0_data'
        ]
        valid_output_node_ids = [match[name].id for name in valid_output_names]
        node_names_to_check_outputs = [
            'mul_data', 'concat_0_data', 'reshape_0_data'
        ]

        list_of_concatenated_hidden_states_children_node_ids = []
        for name in node_names_to_check_outputs:
            for node in match[name].out_nodes():
                if node.id not in valid_output_node_ids:
                    list_of_concatenated_hidden_states_children_node_ids.append(
                        node.id)

        if len(list_of_concatenated_hidden_states_children_node_ids) != 1:
            return  # not supported placement of pattern
        conacenated_child_node_id = list_of_concatenated_hidden_states_children_node_ids[
            0]
        if conacenated_child_node_id != match[
                'after_mul_op_to_the_rest_of_model'].id:
            return  # not supported placement of pattern

        gather_indexes = match['gather_0'].in_node(1).value
        if len(gather_indexes) == 1:
            gather_index = gather_indexes[0]
        else:
            return  # we have to translate this type of BlockLSTM to LSTMSequence to TensorIterator as is
        if gather_index != time_len:
            return  # we have to translate this type of BlockLSTM to LSTMSequence to TensorIterator as is

        attrs = dict(
            graph.get_edge_data(match['gather_0'].id,
                                match['gather_0_data'].id)[0])
        attrs.update({'out': 1})
        graph.remove_edge(match['mul_data'].id, match['concat_0'].id)
        graph.remove_edge(match['gather_0'].id, match['gather_0_data'].id)

        graph.add_edge(match['BlockLSTM'].id, match['gather_0_data'].id,
                       **attrs)
Exemplo n.º 19
0
def protobuf2nx(graph: Graph, pb):
    """
    Convert proto message with ONNX model to equivalent NX representation. All nodes and edges are restored here as
    ONNX model has op/data representation, that means that nodes are connected via tensor names. Name of tensors are
    defined on demand in nodes, so we have a code similar to Caffe here.

    :param graph: the Graph object to load the graph into
    :param pb: the ONNX file protobuf message
    :return: None
    """
    # maps a tensor name to a node produced it and the node port: str -> (node_id, node_port)
    data_nodes_map = {}

    graph_pb = pb.graph
    add_initializers_and_inputs_to_graph(graph, graph_pb, data_nodes_map)

    output_ids = []
    for outp in graph_pb.output:
        name = str(outp.name)
        if graph.has_node(name):
            log.error(
                'Name {} of output node already exists in graph. Ignoring this output. If the output is required,'
                ' please rename it.'.format(name),
                extra={'is_warning': True})
            continue
        else:
            # add fake node on output
            graph.add_node(name, kind='op', op='FakeOutput', pb=outp)
            output_ids.append(name)

    # Go through all nodes in the original model order (because data nodes are defined on-the-fly and order is
    # important)
    for node in graph_pb.node:
        # create an NX node
        fw_name = node_id(node)
        id = graph.unique_id(fw_name)
        graph.add_node(id, pb=node, kind='op')
        if hasattr(graph, 'op_names_statistic') and hasattr(node, 'op_type'):
            graph.op_names_statistic[node.op_type] += 1

        # add incoming edges based on data_nodes_map
        for dst_port, inp in enumerate(node.input):
            # should add edge inp --> id
            if inp not in data_nodes_map:
                if inp == '':
                    # input is omitted; most likely it corresponds to an optional input for an operator
                    continue
                else:
                    raise Error(
                        'Reference to {} is not satisfied. A node refer not existing data tensor. ONNX model is not '
                        'consistent. Protobuf fragment: {}', inp, node)
            src_id, src_port = data_nodes_map[inp]

            assert (graph.has_node(src_id))
            edge_attrs = {
                'out': src_port,
                'in': dst_port,
                'name': inp,
                'fw_tensor_debug_info': [(src_id, inp)],
                'in_attrs': ['in', 'name'],
                'out_attrs': ['out', 'name'],
                'data_attrs': ['fw_tensor_debug_info']
            }
            graph.add_edge(src_id, id, **edge_attrs)

        # add outgoing edges to data_nodes_map
        for src_port, out in enumerate(node.output):
            if out in output_ids:
                edge_attrs = {
                    'out': src_port,
                    'in': 0,
                    'name': out,
                    'fw_tensor_debug_info': [(fw_name, out)],
                    'in_attrs': ['in', 'name'],
                    'out_attrs': ['out', 'name'],
                    'data_attrs': ['fw_tensor_debug_info']
                }
                graph.add_edge(id, out, **edge_attrs)
            if out in data_nodes_map:
                log.debug("Detected reuse of blob {}.".format(out))
            data_nodes_map[out] = (id, src_port)

    graph.graph[
        'tensor_mapping'] = data_nodes_map  # save main graph tensor names mapping for Loop op parsing
Exemplo n.º 20
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in list(graph.nodes()):
            node = Node(graph, node)
            node_name = node.soft_get('name', node.id)
            # Check that node layout mismatch with graph layout
            # For example: NHWC and NCHW or NCDHW and NDHWC
            if node.kind == 'op' and node.has_valid(
                    'layout') and node.layout != indices_mapping[len(
                        node.layout)][graph.graph['layout']]:
                input = node.in_node()
                output = node.out_node()

                # Calculate permutation for further Transpose operations
                if graph.graph['layout'] == 'NCHW':
                    # if Node has NCHW and graph has NHWC layout
                    permutation = PermuteAttrs.get_nhwc_to_nchw_permutation(
                        len(node.layout))
                else:
                    # if Node has NHWC and graph has NCHW layout
                    permutation = PermuteAttrs.get_nchw_to_nhwc_permutation(
                        len(node.layout))

                # Schematic representation of transformation below
                #
                #                                           \            NCHW                              NCHW
                #            NHWC                        --  \            |  permutation       permutation  |
                #   data-->Convolution(example)-->data   --  /            |      |       NCHW      |        |
                #                                           /   data->Transpose->data->Convolution->data->Transpose->data

                # 1. Insert input Transpose
                #    This Transpose will permute input from original input layout to operation layout
                edge_attrs = graph.get_edge_data(input.id, node.id)[0]
                graph.remove_edge(input.id, node.id)

                input_permute_name = node_name + '/input_transpose'
                input_order_const = Const(
                    graph, {
                        'name': input_permute_name + '/order',
                        'value': permutation.perm
                    }).create_node_with_data()
                input_permute_op = Transpose(graph,
                                             {'name': input_permute_name})
                input_permute_data_node = input_permute_op.create_node_with_data(
                    [input, input_order_const])

                graph.add_edge(input_permute_data_node.id, node.id,
                               **edge_attrs)

                # 2. Insert output Transpose
                #    This Transpose will permute output from operation layout to original input layout
                edge_attrs = graph.get_edge_data(node.id, output.id)[0]
                graph.remove_edge(node.id, output.id)

                input_data_node = Op.create_data_node(
                    graph, node, {'shape': output.shape[permutation.perm]},
                    edge_attrs)

                output_permute_name = node_name + '/output_transpose'
                output_order_const = Const(
                    graph, {
                        'name': output_permute_name + '/order',
                        'value': permutation.inv
                    }).create_node_with_data()
                output_permute_op = Transpose(graph, {
                    'name': output_permute_name
                }).create_node_with_data([input_data_node, output_order_const],
                                         data_nodes=output)

                # 3. Add permutations for Node
                #    Here we use permutation mechanism where data nodes takes permutation attribute.
                #    And then we call permute_attrs method that permutes node attributes according to permutations on
                #    data nodes.
                node.in_node()['permutation'] = permutation
                node.out_node()['permutation'] = permutation
                node.permute_attrs.permute_attrs(node)

                node.in_node()['permutation'] = None
                node.out_node()['permutation'] = None