def replace_pattern(graph: nx.MultiDiGraph, match: dict):
     reshape_node = match['reshape']
     in_node = reshape_node.in_node()
     out_node = reshape_node.out_node()
     if not np.array_equal(in_node.shape, out_node.shape):
         return False
     remove_op_node_with_data_node(graph, reshape_node)
Beispiel #2
0
    def replace_pattern(self, graph: Graph, match: dict):
        ti = match['ti']
        direct_reverse = match['direct_reverse']
        inverse_reverse = match['inverse_reverse']

        assert direct_reverse.seq_axis == inverse_reverse.seq_axis
        assert direct_reverse.batch_axis is None and inverse_reverse.batch_axis is None or \
               direct_reverse.batch_axis == inverse_reverse.batch_axis

        # Modify stride in TI
        for port_map in [ti.input_port_map, ti.output_port_map]:
            for port in port_map:
                if 'axis' in port and port[
                        'axis'] is not None and 'external_port_id' in port:
                    assert port['axis'] == direct_reverse.seq_axis, \
                        'axis == {} != {} == direct_reverse.seq_dim'.format(port['axis'], direct_reverse.seq_axis)
                    if 'stride' not in port or port['stride'] is None:
                        port['stride'] = 1
                    assert port['stride'] in [-1, 1]
                    port['stride'] = -port['stride']
                    if port['stride'] == -1:
                        port['start'] = -1
                        port['end'] = 0
                    elif port['stride'] == 1:
                        port['start'] = None
                        port['end'] = None

        # Remove reverses
        remove_op_node_with_data_node(graph, direct_reverse)
        remove_op_node_with_data_node(graph, inverse_reverse)
 def find_and_replace_pattern(self, graph: Graph):
     for node in graph.pseudo_topological_sort():
         if node.kind == 'data' or node.op != 'Switch':
             continue
         switch_op_node = node
         pred_id_data_node = switch_op_node.in_node(1)
         graph.remove_edge(pred_id_data_node.id, switch_op_node.id)
         remove_op_node_with_data_node(graph, switch_op_node)
Beispiel #4
0
 def replace_pattern(graph: Graph, match: dict):
     """
     Removes output SoftMax layer
     :param graph: graph to operate on
     :param match: dictionary with matched nodes
     """
     if len(match['softmax_data'].out_nodes()) == 1:
         remove_op_node_with_data_node(graph, match['softmax_node'])
Beispiel #5
0
 def find_and_replace_pattern(self, graph: Graph):
     for n in pseudo_topological_sort(graph):
         if graph.node[n]['kind'] == 'data' or graph.node[n]['op'] != 'Switch':
             continue
         switch_op_node = Node(graph, n)
         pred_id_data_node = switch_op_node.in_node(1)
         graph.remove_edge(pred_id_data_node.id, switch_op_node.id)
         remove_op_node_with_data_node(graph, switch_op_node)
Beispiel #6
0
 def replace_pattern(self, graph: Graph, match: dict):
     if len(graph.in_edges(match['merge'].id)) <= 1:
         remove_op_node_with_data_node(
             graph, match['merge'],
             list(match['merge'].in_nodes().values())[0])
         log.info(
             "Useles Merge op and data nodes was deleted op='{}'".format(
                 match['merge'].id))
Beispiel #7
0
 def replace_pattern(graph: nx.MultiDiGraph, match: dict):
     node = match['reshape_1']
     if (node.has_valid('type') and node.type == 'Reshape'
             and len(node.out_nodes()) == 1
             and node.out_node().has_valid('kind')
             and node.out_node().kind == 'data'
             and len(node.out_node().out_nodes()) == 1):
         remove_op_node_with_data_node(graph, node)
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['pad']
        for port, input_node in node.in_nodes().items():
            if port != 0:
                graph.remove_edge(input_node.id, node.id)

        # remove Pad operation if all pads are equal to 0
        if np.all(node.pads == 0):
            remove_op_node_with_data_node(graph, node)
 def replace_pattern(graph: Graph, match: dict):
     """
     Removes output SoftMax layer
     :param graph: graph to operate on
     :param match: dictionary with matched nodes
     """
     if len(match['softmax_data'].out_nodes()) == 1:
         remove_op_node_with_data_node(graph, match['softmax_node'])
     else:
         log.error("SoftMax is not last layer, so can't be removed", extra={'is_warning': True})
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='Pad'):
            all_pads_zeros = True
            for in_port_ind in range(1, 3):
                input_node = node.in_port(in_port_ind).get_source().node
                value = input_node.soft_get('value', None)
                all_pads_zeros &= input_node.soft_get('type') == 'Const' and value is not None and np.all(value == 0)

            if all_pads_zeros:
                remove_op_node_with_data_node(graph, node)
    def replace_pattern(graph: Graph, match: dict):
        node_ss = match['strided_slice']
        # slices = [elem for elem in node_ss.slices if elem is not None]
        # node_ss.slices = np.array(slices)

        if node_ss.out_port(0).data.get_value() is not None:
            # StridedSlices(SS) in shape-calculating sub-graphs that should not be deleted that easily
            # Example:
            # In RetinaNetFilteredDetectionsReplacement we have SS that slices first batch
            # We delete such SS for batch 1, but it should be performed while reshaping the model
            return

        output_data_node = node_ss.out_node(0)
        input_data_node = node_ss.in_node(0)

        out_shape = output_data_node.shape

        if not np.all(node_ss.shrink_axis_mask == 0):
            out_shape = list(out_shape)
            for i in range(len(node_ss.shrink_axis_mask)):
                if node_ss.shrink_axis_mask[i] == 1:
                    out_shape.insert(i, 1)
            out_shape = int64_array(out_shape)

        if not np.all(node_ss.new_axis_mask == 0):
            out_shape = list(out_shape)
            for i in reversed(range(len(node_ss.new_axis_mask))):
                if node_ss.new_axis_mask[i] == 1:
                    out_shape.pop(i)
            out_shape = int64_array(out_shape)

        if np.array_equal(input_data_node.shape, out_shape) and \
                all(elem.step == 1 for elem in match['strided_slice'].slices):
            if not np.all(node_ss.shrink_axis_mask == 0):
                ConvertGroupedStridedSlice.add_squeeze_for_shrink(
                    graph, node_ss)
            if not np.all(node_ss.new_axis_mask == 0):
                ConvertGroupedStridedSlice.add_unsqueeze_for_new(
                    graph, node_ss)

            log.info("Useless StridedSlice op '{}' has been detected".format(
                match['strided_slice'].id))
            # remove inputs to Strided Slice so it has just one input with data so we can use 'remove_op_node' function
            graph.remove_edge(match['strided_slice'].in_node(1).id,
                              match['strided_slice'].id)
            graph.remove_edge(match['strided_slice'].in_node(2).id,
                              match['strided_slice'].id)
            if len(match['strided_slice'].in_nodes()) > 3:
                graph.remove_edge(match['strided_slice'].in_node(3).id,
                                  match['strided_slice'].id)

            remove_op_node_with_data_node(graph, match['strided_slice'])
Beispiel #12
0
def fuse_sequence_of_reshapes(graph: Graph):
    for node in list(graph.nodes()):
        if not graph.has_node(node):
            # data node can be already removed
            continue
        node = Node(graph, node)
        if (node.has_valid('type') and node.type == 'Reshape'
                and len(node.out_nodes()) == 1
                and node.out_node().has_valid('kind')
                and node.out_node().kind == 'data'
                and len(node.out_node().out_nodes()) == 1):

            log.debug('First phase for Reshape: {}'.format(node.name))

            next_op = node.out_node().out_node()
            log.debug('second node: {}'.format(next_op.graph.node[next_op.id]))
            if next_op.has_valid('type') and next_op.type == 'Reshape':
                # Detected Reshape1 --> data --> Reshape2 pattern without side edges
                # Remove Reshape1
                log.debug('Second phase for Reshape: {}'.format(node.name))
                remove_op_node_with_data_node(graph, node)

    reshape_nodes = graph.get_op_nodes(op='Reshape')
    for reshape_node in reshape_nodes:
        in_ports = [
            port for port in reshape_node.in_ports().values()
            if not port.disconnected()
        ]
        assert len(in_ports) in [
            1, 2
        ], "`Reshape` node must have 2 inputs or 1 input with `dim`"
        if len(in_ports) == 2:
            previous_dim_op = reshape_node.in_port(1).get_source().node.op
            if previous_dim_op != 'Const':
                continue
            dim = reshape_node.in_port(1).get_connection().data.get_value()
        else:
            assert reshape_node.has_valid(
                'dim'), "`Reshape` node with 1 input must have `dim` attribute"
            dim = reshape_node.dim

        in_shape = reshape_node.in_port(0).get_connection().data.get_shape()

        if np.array_equal(dim, in_shape) and len(reshape_node.out_nodes()):
            log.debug("Useless reshape with dim {} was deleted: {}".format(
                str(dim), reshape_node.name))
            reshape_node.out_port(0).get_connection().set_source(
                reshape_node.in_port(0).get_source())
Beispiel #13
0
    def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
        output_data_node = match['strided_slice'].out_node(0)
        input_data_node = match['strided_slice'].in_node(0)
        if np.array_equal(input_data_node.shape, output_data_node.shape) and \
                all(elem.step == 1 for elem in match['strided_slice'].slices):
            log.info("Useless StridedSlice op '{}' has been detected".format(
                match['strided_slice'].id))
            # remove inputs to Strided Slice so it has just one input with data so we can use 'remove_op_node' function
            graph.remove_edge(match['strided_slice'].in_node(1).id,
                              match['strided_slice'].id)
            graph.remove_edge(match['strided_slice'].in_node(2).id,
                              match['strided_slice'].id)
            graph.remove_edge(match['strided_slice'].in_node(3).id,
                              match['strided_slice'].id)

            remove_op_node_with_data_node(graph, match['strided_slice'])
    def replace_pattern(graph: nx.MultiDiGraph, match: dict):
        """
        Need to find the pattern: Parent (any type) -> SoftMAx -> OpOutput

        It is needed to remove output SoftMAx layer

        Parameters
        ----------
        graph : nx.MultiDiGraph
           Graph with loaded model.
        match : dict
           Patterns which were found in graph structure.
        """
        softmax = match['softmax_node']
        child = softmax.out_node()
        if not child.has_and_set('is_output'):
            return
        remove_op_node_with_data_node(graph, softmax)
Beispiel #15
0
    def replace_pattern(graph: Graph, match: dict):
        node_ss = match['strided_slice']
        output_data_node = node_ss.out_node(0)
        input_data_node = node_ss.in_node(0)

        out_shape = output_data_node.shape

        if not np.all(node_ss.shrink_axis_mask == 0):
            out_shape = list(out_shape)
            for i in range(len(node_ss.shrink_axis_mask)):
                if node_ss.shrink_axis_mask[i] == 1:
                    out_shape.insert(i, 1)
            out_shape = int64_array(out_shape)

        if not np.all(node_ss.new_axis_mask == 0):
            out_shape = list(out_shape)
            for i in reversed(range(len(node_ss.new_axis_mask))):
                if node_ss.new_axis_mask[i] == 1:
                    out_shape.pop(i)
            out_shape = int64_array(out_shape)

        if np.array_equal(input_data_node.shape, out_shape) and \
                all(elem.step == 1 for elem in match['strided_slice'].slices):
            if not np.all(node_ss.shrink_axis_mask == 0):
                ConvertGroupedStridedSlice.add_squeeze_for_shrink(
                    graph, node_ss)
            if not np.all(node_ss.new_axis_mask == 0):
                ConvertGroupedStridedSlice.add_unsqueeze_for_new(
                    graph, node_ss)

            log.info("Useless StridedSlice op '{}' has been detected".format(
                match['strided_slice'].id))
            # remove inputs to Strided Slice so it has just one input with data so we can use 'remove_op_node' function
            graph.remove_edge(match['strided_slice'].in_node(1).id,
                              match['strided_slice'].id)
            graph.remove_edge(match['strided_slice'].in_node(2).id,
                              match['strided_slice'].id)
            if len(match['strided_slice'].in_nodes()) > 3:
                graph.remove_edge(match['strided_slice'].in_node(3).id,
                                  match['strided_slice'].id)

            remove_op_node_with_data_node(graph, match['strided_slice'])
Beispiel #16
0
def fuse_sequence_of_reshapes(graph: nx.MultiDiGraph):
    for node in list(graph.nodes()):
        node = Node(graph, node)
        if not graph.has_node(node.id):
            # data node can be already removed
            continue
        if (node.has_valid('type') and node.type == 'Reshape'
                and len(node.out_nodes()) == 1
                and node.out_node().has_valid('kind')
                and node.out_node().kind == 'data'
                and len(node.out_node().out_nodes()) == 1):

            log.debug('First phase for Reshape: {}'.format(node.name))

            next_op = node.out_node().out_node()
            log.debug('second node: {}'.format(next_op.graph.node[next_op.id]))
            if next_op.has_valid('type') and next_op.type == 'Reshape':
                # Detected Reshape1 --> data --> Reshape2 pattern without side edges
                # Remove Reshape1
                log.debug('Second phase for Reshape: {}'.format(node.name))
                remove_op_node_with_data_node(graph, node)
 def find_and_replace_pattern(self, graph: Graph):
     intervals = {}
     for node in graph.get_op_nodes(type='FakeQuantize', keep_in_IR=False):
         prev_node = node.in_node().in_node()
         prev_node_id = prev_node.id
         prev_node_out_shape = prev_node.out_node()['shape']
         C = prev_node_out_shape[1]
         assert node.in_node(1).value.size == 1
         assert node.in_node(2).value.size == 1
         # Input and output ranges should match if we want to remove FakeQuantize from model
         assert_msg = "FakeQuantize cannot be removed because input and output intervals do not match"
         assert node.in_node(1).value == node.in_node(3).value, assert_msg
         assert node.in_node(2).value == node.in_node(4).value, assert_msg
         min = ', '.join([str(node.in_node(1).value.flatten()[0])] * C)
         max = ', '.join([str(node.in_node(2).value.flatten()[0])] * C)
         intervals[prev_node_id] = {'min': min, 'max': max}
         remove_op_node_with_data_node(graph, node)
     if intervals:
         if 'statistics' not in graph.graph:
             graph.graph['statistics'] = intervals
         else:
             graph.graph['statistics'].update(intervals)
    def replace_pattern(self, graph: Graph, match: dict):
        ti = match['ti']
        direct_reverse = match['direct_reverse']
        inverse_reverse = match['inverse_reverse']

        assert direct_reverse.seq_axis == inverse_reverse.seq_axis
        assert direct_reverse.batch_axis is None and inverse_reverse.batch_axis is None or \
               direct_reverse.batch_axis == inverse_reverse.batch_axis

        if not self.is_fusable_reverse_sequence(direct_reverse) or \
                not self.is_fusable_reverse_sequence(inverse_reverse):
            # we can not merge ReverseSequence without equal sequences
            return

        # Modify stride in TI
        for port_map in [ti.input_port_map, ti.output_port_map]:
            for port in port_map:
                if 'axis' in port and port[
                        'axis'] is not None and 'external_port_id' in port:
                    assert port['axis'] == direct_reverse.seq_axis, \
                        'axis == {} != {} == direct_reverse.seq_dim'.format(port['axis'], direct_reverse.seq_axis)
                    if 'stride' not in port or port['stride'] is None:
                        port['stride'] = 1
                    assert port['stride'] in [-1, 1]
                    port['stride'] = -port['stride']
                    if port['stride'] == -1:
                        port['start'] = -1
                        port['end'] = 0
                    elif port['stride'] == 1:
                        port['start'] = 0
                        port['end'] = -1

        # disconnect subgraph for seq length calculation
        direct_reverse.in_port(1).disconnect()
        inverse_reverse.in_port(1).disconnect()
        # Remove reverses
        remove_op_node_with_data_node(graph, direct_reverse)
        remove_op_node_with_data_node(graph, inverse_reverse)
    def find_and_replace_pattern(self, graph: Graph):
        reshape_nodes = graph.get_op_nodes(type='Reshape')
        for node in reshape_nodes:
            if not graph.has_node(node.id):
                # the Reshape node has been removed in the previous iteration
                continue
            if len(node.out_port(0).get_destinations()) == 1:
                log.debug('First phase for Reshape: {}'.format(
                    node.soft_get('name')))

                next_op = get_next_operation(node)[0]
                log.debug('second node: id={}, type={}'.format(
                    next_op.soft_get('id'), next_op.soft_get('type')))
                if next_op.has_valid('type') and next_op.type == 'Reshape':
                    dim_value = next_op.in_port(1).data.get_value()
                    if dim_value is None or 0 in dim_value or -1 in dim_value:
                        # we do not fuse reshape sequences with special symbols: 0, -1
                        continue

                    # Detected Reshape1 --> data --> Reshape2 pattern without side edges. Remove Reshape1
                    log.debug('Second phase for Reshape: {}'.format(
                        node.soft_get('name')))
                    remove_op_node_with_data_node(graph, node)
Beispiel #20
0
    def replace_pattern(self, graph: Graph, match: dict):

        quantize = match['quantize']

        # Check for total number of ReLU consumers -- if something else consume its output it cannot be fused
        if len(match['relu'].out_node().out_nodes()) > 1:
            log.debug('ReluQuantizeFuse: cannot fuse because ReLU have multiple consumers')
            return

        # If the fusion is applicable, direct modifications to quantize 1-st and 2-nd inputs
        # are performed. So the data nodes at those inputs shouldn't have more than 1 consumer
        # maximum 2 consumers to the same quantize op (consumed by 1st and 2nd ports).
        # TODO: relax this limitation and duplicate data nodes accordingly to modify the input range freely

        # Provisional limitation that related to binary quantization
        # TODO: Relax it beyond binarization case
        if len(quantize.in_node(1).out_nodes()) != 2 or \
                        len(quantize.in_node(2).out_nodes()) != 2 or \
                        quantize.in_node(1).id != quantize.in_node(2).id or \
                        quantize.levels != 2:
            log.debug('ReluQuantizeFuse: cannot fuse because Quantize op has '
                      'unexpected number of consumers for ports 1 and 2')
            return

        threshold = quantize.in_node(1)

        # As we restricted to binarization case only, so we need to detect from
        # which side of 0 Quantize threshold resides:
        #   if the threshold > 0, it remains the same;
        #   if the threshold == 0, it also remains the same;
        #   if the threshold < 0, it should be modified to -infinity that means that all inputs map to output_high

        modification_mask = threshold.value < 0
        threshold.value[modification_mask] = float('-inf')

        # Remove ReLU as it no longer needed
        remove_op_node_with_data_node(graph, match['relu'])
    def replace_pattern(self, graph: Graph, match: dict):

        # This transformation works if and only if a body of TI
        # matches the following topology (Reshape -> LSTMCell -> Reshape)
        nodes = [('input_unsqueezed'), ('squeeze', dict(op='Reshape')),
                 ('input_squeezed'), ('input_hidden'), ('input_cell'),
                 ('weights'), ('biases'), ('lstm', dict(op='LSTMCell')),
                 ('output_hidden'), ('output_cell'),
                 ('unsqueeze', dict(op='Reshape')), ('output_unsqueezed'),
                 ('const_w', dict(op='Const')), ('const_b', dict(op='Const')),
                 ('op_output', dict(op='OpOutput')),
                 ('op_output_1', dict(op='OpOutput')),
                 ('op_output_2', dict(op='OpOutput'))]
        edges = [
            ('input_unsqueezed', 'squeeze'),
            ('squeeze', 'input_squeezed'),
            ('input_squeezed', 'lstm', {
                'in': 0
            }),
            ('input_hidden', 'lstm', {
                'in': 1
            }),
            ('input_cell', 'lstm', {
                'in': 2
            }),
            ('weights', 'lstm', {
                'in': 3
            }),
            ('biases', 'lstm', {
                'in': 4
            }),
            ('const_w', 'weights'),
            ('const_b', 'biases'),
            ('lstm', 'output_hidden', {
                'out': 0
            }),
            ('lstm', 'output_cell', {
                'out': 1
            }),
            ('output_hidden', 'unsqueeze'),
            ('unsqueeze', 'output_unsqueezed'),
            ('output_unsqueezed', 'op_output'),
            ('output_hidden', 'op_output_1'),
            ('output_cell', 'op_output_2'),
        ]
        ti = match['ti']
        isomorphisms = find_isomorphisms(ti.body, nodes, edges)
        if len(list(isomorphisms)) != 1:
            return
        isomorphism = isomorphisms[0]

        direct_permute = match['direct_permute']
        inverse_permute = match['inverse_permute']

        permute_order = [1, 0, 2]

        # Check both perumute orders exactly match expected one - [1, 0, 2]
        if not direct_permute.has_valid('order') or not np.array_equal(
                direct_permute.order, permute_order):
            return
        if not inverse_permute.has_valid('order') or not np.array_equal(
                inverse_permute.order, permute_order):
            return

        def find_ports(port_map: list, attrs: dict):
            """ Find all ports in a given port map with specified attributes """
            result = []
            for i, port in enumerate(port_map):
                if dict_includes(port, attrs):
                    result.append(i)
            return result

        # Check TI has only single partitioned input/output port; all partitioned ports have defined axis
        data_input_port = find_ports(ti.input_port_map,
                                     {'axis': lambda attr: attr in [0, 1]})
        data_output_port = find_ports(ti.output_port_map,
                                      {'axis': lambda attr: attr in [0, 1]})
        assert len(data_input_port) == 1
        assert len(data_output_port) == 1
        data_input_port = data_input_port[0]
        data_output_port = data_output_port[0]
        # Verify that they are really connected to Permute layers (guarantied by port numbers of TI, see the pattern)
        assert ti.in_edge(0)['external_port_id'] == ti.input_port_map[
            data_input_port]['external_port_id']
        assert ti.out_edge(0)['external_port_id'] == ti.output_port_map[
            data_output_port]['external_port_id']

        # Verify that the TI body have required Reshapes connected to the found ports
        squeeze = isomorphism['squeeze']
        unsqueeze = isomorphism['unsqueeze']
        assert squeeze['internal_layer_id'] == ti.input_port_map[
            data_input_port]['internal_layer_id']
        assert squeeze.in_edge(0)['internal_port_id'] == ti.input_port_map[
            data_input_port]['internal_port_id']
        assert unsqueeze['internal_layer_id'] == ti.output_port_map[
            data_output_port]['internal_layer_id']
        assert unsqueeze.out_edge(0)['internal_port_id'] == ti.output_port_map[
            data_output_port]['internal_port_id']
        assert len(squeeze.in_node().shape) == 3
        assert len(squeeze.out_node().shape) == 2
        assert len(unsqueeze.in_node().shape) == 2
        assert len(unsqueeze.out_node().shape) == 3

        # Remove permutes
        remove_op_node_with_data_node(graph, direct_permute)
        remove_op_node_with_data_node(graph, inverse_permute)
        match['output'].shape = match['output'].shape[permute_order]

        # swap 0/1 axis for partitioned ports
        ti.input_port_map[data_input_port][
            'axis'] = 1 - ti.input_port_map[data_input_port]['axis']
        ti.output_port_map[data_output_port][
            'axis'] = 1 - ti.output_port_map[data_output_port]['axis']

        # smap 0-th and 1-th shape entries for reshapes inside body
        squeeze.in_node().shape = squeeze.in_node().shape[[1, 0, 2]]
        unsqueeze.out_node().shape = unsqueeze.out_node().shape[[1, 0, 2]]
        unsqueeze.dim = unsqueeze.dim[[1, 0, 2]]
Beispiel #22
0
 def replace_pattern(self, graph: Graph, match: dict):
     remove_op_node_with_data_node(graph, match['op'])
Beispiel #23
0
    def replace_pattern(self, graph: Graph, match: dict):

        # This transformation works if and only if a body of TI
        # matches the following topology (Squeeze -> LSTMCell -> Unsqueeze)
        nodes = [
            ('squeeze_dim', dict(kind='op', op='Const')),
            ('squeeze_dim_data', dict(kind='data')),
            ('unsqueeze_dim', dict(kind='op', op='Const')),
            ('unsqueeze_dim_data', dict(kind='data')),
            ('input_unsqueezed', dict(kind='data')),
            ('squeeze', dict(kind='op', op='Squeeze')),
            ('input_squeezed', dict(kind='data')),
            ('input_hidden', dict(kind='data')),
            ('input_cell', dict(kind='data')),
            ('weights', dict(kind='data')),
            ('biases', dict(kind='data')),
            ('lstm', dict(kind='op', op='LSTMCell')),
            ('output_hidden', dict(kind='data')),
            ('output_cell', dict(kind='data')),
            ('unsqueeze', dict(kind='op', op='Unsqueeze')),
            ('output_unsqueezed', dict(kind='data')),
            ('const_w', dict(kind='op', op='Const')),
            ('const_b', dict(kind='op', op='Const')),
            ('op_output', dict(kind='op', op='Result')),
            ('op_output_1', dict(kind='op', op='Result')),
            ('op_output_2', dict(kind='op', op='Result')),
            ('input_unsqueezed_i', dict(kind='op', op='Parameter')),
            ('input_hidden_i', dict(kind='op', op='Parameter')),
            ('input_cell_i', dict(kind='op', op='Parameter')),
        ]
        edges = [
            ('input_unsqueezed', 'squeeze', {
                'in': 0
            }),
            ('squeeze', 'input_squeezed'),
            ('squeeze_dim', 'squeeze_dim_data'),
            ('squeeze_dim_data', 'squeeze', {
                'in': 1
            }),
            ('input_squeezed', 'lstm', {
                'in': 0
            }),
            ('input_hidden', 'lstm', {
                'in': 1
            }),
            ('input_cell', 'lstm', {
                'in': 2
            }),
            ('weights', 'lstm', {
                'in': 3
            }),
            ('biases', 'lstm', {
                'in': 4
            }),
            ('const_w', 'weights'),
            ('const_b', 'biases'),
            ('lstm', 'output_hidden', {
                'out': 0
            }),
            ('lstm', 'output_cell', {
                'out': 1
            }),
            ('output_hidden', 'unsqueeze'),
            ('unsqueeze', 'output_unsqueezed'),
            ('unsqueeze_dim', 'unsqueeze_dim_data'),
            ('unsqueeze_dim_data', 'unsqueeze', {
                'in': 1
            }),
            ('output_unsqueezed', 'op_output'),
            ('output_hidden', 'op_output_1'),
            ('output_cell', 'op_output_2'),
            ('input_unsqueezed_i', 'input_unsqueezed'),
            ('input_hidden_i', 'input_hidden'),
            ('input_cell_i', 'input_cell'),
        ]
        ti = match['ti']
        isomorphisms = find_isomorphisms(ti.body, nodes, edges)
        if len(list(isomorphisms)) != 1:
            return
        isomorphism = isomorphisms[0]

        direct_permute = match['direct_permute']
        inverse_permute = match['inverse_permute']

        permute_order = [1, 0, 2]

        # Check both perumute orders exactly match expected one - [1, 0, 2]
        direct_order = direct_permute.in_port(1).data.get_value()
        if direct_order is None or not np.array_equal(direct_order,
                                                      permute_order):
            return
        inverse_order = inverse_permute.in_port(1).data.get_value()
        if inverse_order is None or not np.array_equal(inverse_order,
                                                       permute_order):
            return

        # Check non-ShapeOf output out of direct Transpose is exactly one
        direct_permute_dsts = direct_permute.out_port(0).get_destinations()
        if len([
                dst for dst in direct_permute_dsts
                if dst.node.soft_get('type') != 'ShapeOf'
        ]) != 1:
            return
        for shape_of_dst in [
                dst for dst in direct_permute_dsts
                if dst.node.soft_get('type') == 'ShapeOf'
        ]:
            name = shape_of_dst.node.soft_get(
                'name', shape_of_dst.node.id) + '/FusedToTITranspose'
            gather = create_op_with_const_inputs(
                graph,
                op=Gather,
                op_attrs={'name': name},
                port_value_dict={
                    1: int64_array(permute_order),
                    2: int64_array(0)
                })
            shape_of_dst.node.out_port(0).get_connection().insert_node(gather)

        def find_ports(port_map: list, attrs: dict):
            """ Find all ports in a given port map with specified attributes """
            result = []
            for i, port in enumerate(port_map):
                if dict_includes(port, attrs):
                    result.append(i)
            return result

        # Check TI has only single partitioned input/output port; all partitioned ports have defined axis
        data_input_port = find_ports(ti.input_port_map,
                                     {'axis': lambda attr: attr in [0, 1]})
        data_output_port = find_ports(ti.output_port_map,
                                      {'axis': lambda attr: attr in [0, 1]})
        assert len(data_input_port) == 1
        assert len(data_output_port) == 1
        data_input_port = data_input_port[0]
        data_output_port = data_output_port[0]
        # Verify that they are really connected to Transpose layers (guarantied by port numbers of TI, see the pattern)
        assert ti.in_edge(0)['external_port_id'] == ti.input_port_map[
            data_input_port]['external_port_id']
        assert ti.out_edge(0)['external_port_id'] == ti.output_port_map[
            data_output_port]['external_port_id']

        # Verify that the TI body have required Reshapes connected to the found ports
        squeeze = isomorphism['squeeze']
        unsqueeze = isomorphism['unsqueeze']

        assert len(squeeze.in_node().shape) == 3
        assert len(squeeze.out_node().shape) == 2
        assert len(unsqueeze.in_node().shape) == 2
        assert len(unsqueeze.out_node().shape) == 3

        # Remove permutes
        remove_op_node_with_data_node(graph, direct_permute)
        remove_op_node_with_data_node(graph, inverse_permute)
        match['output'].shape = match['output'].shape[permute_order]

        # swap 0/1 axis for partitioned ports
        ti.input_port_map[data_input_port][
            'axis'] = 1 - ti.input_port_map[data_input_port]['axis']
        ti.output_port_map[data_output_port][
            'axis'] = 1 - ti.output_port_map[data_output_port]['axis']

        isomorphism['input_unsqueezed_i'].shape = isomorphism[
            'input_unsqueezed_i'].shape[[1, 0, 2]]
        isomorphism['input_unsqueezed_i'].infer(
            isomorphism['input_unsqueezed_i'])
        isomorphism['squeeze_dim'].value = ti.input_port_map[data_input_port][
            'axis']
        isomorphism['squeeze_dim'].infer(isomorphism['squeeze_dim'])
        isomorphism['squeeze']['need_shape_inference'] = True

        isomorphism['unsqueeze_dim'].value = ti.output_port_map[
            data_output_port]['axis']
        isomorphism['unsqueeze_dim'].infer(isomorphism['unsqueeze_dim'])
        isomorphism['unsqueeze'].infer(isomorphism['unsqueeze'])