Beispiel #1
0
    def find_and_replace_pattern(self, graph: Graph):
        for gather in graph.get_op_nodes(type='Gather'):
            indices = gather.in_port(1).get_source().node
            indices_value = gather.in_port(1).data.get_value()
            if indices.op == 'Const' and indices_value is not None and indices_value.ndim == 0:
                log.debug(
                    'The Gather node {} has constant 0D input with indices'.
                    format(gather.id))

                new_indices = Const(graph, {
                    'value': np.array([indices_value.item()])
                }).create_node()

                # the input shape is changed so need to disconnect port first
                gather.in_port(1).disconnect()
                gather.in_port(1).connect(new_indices.out_port(0))

                # the output of Gather is changed so need to run shape inference for it and override the existing shape
                gather['override_output_shape'] = True
                gather['need_shape_inference'] = True

                # insert Squeeze to remove the dimension 'axis' which become equal to 1 after change of the Gather
                # indices constant
                squeeze = Squeeze(graph, {
                    'name': gather.id + '/Squeeze'
                }).create_node()
                squeeze_axis = Const(
                    graph, {
                        'name': squeeze.id + '/axis',
                        'value': int64_array([gather.axis])
                    }).create_node()

                gather.out_port(0).get_connection().insert_node(squeeze)
                squeeze.in_port(1).connect(squeeze_axis.out_port(0))
Beispiel #2
0
 def test_squeeze_squeeze_dims(self, input_value, input_shape, squeeze_dims,
                               ref_value, ref_shape):
     graph = build_graph(
         nodes_attributes, [('data', 'squeeze'),
                            ('squeeze_dims', 'squeeze_dims_data'),
                            ('squeeze_dims_data', 'squeeze'),
                            ('squeeze', 'data_out')], {
                                'data': {
                                    'shape': input_shape,
                                    'value': input_value
                                },
                                'squeeze_dims': {
                                    'value': squeeze_dims,
                                    'shape': squeeze_dims.shape
                                },
                                'squeeze_dims_data': {
                                    'value': squeeze_dims,
                                    'shape': squeeze_dims.shape
                                },
                            })
     node = Node(graph, 'squeeze')
     if ref_shape is None:  # the test should fail
         with self.assertRaises(Error):
             Squeeze.infer(node)
     else:
         Squeeze.infer(node)
         if ref_value is not None:
             self.assertTrue(
                 strict_compare_tensors(
                     node.out_port(0).data.get_value(), ref_value))
         self.assertTrue(
             strict_compare_tensors(
                 node.out_port(0).data.get_shape(), ref_shape))
Beispiel #3
0
    def extract(node):
        attrs = get_mxnet_layer_attrs(node.symbol_dict)

        Squeeze.update_node_stat(node, {
            'squeeze_dims': attrs.int("axis", None),
            'keep_at_least_1d': True
        })
        return __class__.enabled
Beispiel #4
0
    def extract(node):
        axis = np.array(onnx_attr(node, 'axes', 'ints', default=[]),
                        dtype=np.int64)

        attrs = {'squeeze_dims': axis if len(axis) != 0 else None}

        # update the attributes of the node
        Squeeze.update_node_stat(node, attrs)
        return __class__.enabled
Beispiel #5
0
    def replace_op(self, graph: Graph, node: Node):
        squeeze_op = Squeeze(graph, dict())
        squeeze_op.attrs['old_infer'] = squeeze_op.attrs['infer']
        squeeze_op.attrs['infer'] = __class__.do_infer

        squeeze_node = squeeze_op.create_node([],
                                              dict(name=node.name +
                                                   '/Squeeze'))
        node.insert_node_after(squeeze_node)
        return []
Beispiel #6
0
    def add_squeeze_for_shrink(graph: Graph, ss_node: Node):
        # add Squeeze for shrink_axis_mask
        log.info(
            "StridedSlice op with shrink mask '{}' has been detected".format(
                ss_node.id))

        if len(ss_node.in_nodes()) != 4 or len(ss_node.out_nodes()) != 1:
            return

        shape_out = ss_node.out_node().shape
        dim = np.array(range(len(ss_node['shrink_axis_mask'])))[np.array(
            ss_node['shrink_axis_mask'], dtype=bool)]
        ss_shape = []
        i = 0
        k = 0

        # Don't permute reshape if channels were squeezed
        dont_permute = graph.graph['layout'] == 'NCHW'
        if graph.graph['layout'] == 'NHWC' and ss_node['shrink_axis_mask'][
                -1] == 1:
            dont_permute = True

        while k < len(shape_out):
            if i >= len(ss_node['shrink_axis_mask']
                        ) or not ss_node['shrink_axis_mask'][i]:
                ss_shape.append(shape_out[k])
                k = k + 1
            else:
                ss_node['shrink_axis_mask'][i] = 0
                ss_shape.append(1)
            i = i + 1

        while i < len(ss_node['shrink_axis_mask']):
            ss_node['shrink_axis_mask'][i] = 0
            ss_shape.append(1)
            i = i + 1

        ss_node.out_port(0).data.set_shape(ss_shape)

        # insert Squeeze
        squeeze_node = Squeeze(
            graph,
            dict(name=ss_node.name + '/Squeeze_shrink',
                 nchw_layout=dont_permute,
                 correct_data_layout=dont_permute)).create_node()
        ss_node.out_port(0).get_connection().insert_node(squeeze_node)
        squeeze_node.out_port(0).data.set_shape(shape_out)

        dims_node = Const(graph, {
            'name': squeeze_node.id + '/Indices',
            'value': int64_array(dim)
        }).create_node()
        dims_node.out_port(0).connect(squeeze_node.in_port(1))
Beispiel #7
0
    def replace_op(self, graph: Graph, node: Node):
        for out_port in node.out_ports().values():
            squeeze_node = Squeeze(graph, dict(name=node.name +
                                               '/Squeeze_')).create_node([])
            dims_node = Const(graph, {
                'value': np.array(node.axis),
                'name': node.name + '/Squeeze_axis'
            }).create_node()

            out_port.get_connection().insert_node(squeeze_node)
            dims_node.out_port(0).connect(squeeze_node.in_port(1))
        # do not replace any output edge
        return []
Beispiel #8
0
 def replace_sub_graph(self, graph: Graph, match: dict):
     """
     In ONNX ArgMax operation has keepdims attribute that indicates
     whether to stay a dimension along which maximum is computed or not.
     In case of keepdims=0 this dimension should be removed but ArgMax operation in IR format
     is not designed to cover this case. So we should additionally add Squeeze operation 
     right after ArgMax for this case.
     """
     argmax_node = match['argmax']
     axis = argmax_node.axis
     squeeze_node = Squeeze(graph, {'squeeze_dims': [axis]}).create_node()
     argmax_node.out_port(0).get_connection().set_source(
         squeeze_node.out_port(0))
     squeeze_node.in_port(0).connect(argmax_node.out_port(0))
Beispiel #9
0
    def replace_pattern(graph: Graph, match: dict):
        """
        Workarounds not supported type of Tile in Inference Engine (Tiles are supported for 2-D or 4-D tensors):
        Searches for Tiles with 3D shapes and covers it with Reshapes.

        Example: Tile (axis=1, tiles=16):
            in_shape: [1,1,101]
            out_shape: [1,16,101]

        Old behaviour:
            Tile -> [1,16,101]
        New behaviour:
            Reshape [1,1,101,1] -> Tile -> [1,16,101,1] -> Reshape [1,16,101]
        """
        node = match['tile']
        name = node.soft_get('name', node.id)

        out_shape = node.out_port(0).data.get_shape()
        assert out_shape is not None, 'Output shape is undefined for {} in back phase'.format(name)
        if out_shape.size != 3:
            return

        inp_shape = node.in_port(0).data.get_shape()
        assert inp_shape is not None, 'Input shape is undefined for {} in back phase'.format(name)

        unsqueeze_dim = Const(graph, {'name': name + '/3D_Tile_Unsqueeze_dim', 'value': int64_array([3])}).create_node()
        unsqueeze = Unsqueeze(graph, {'name': name + '/3D_Tile_Unsqueeze', 'override_output_shape': True}).create_node()
        unsqueeze_dim.out_port(0).connect(unsqueeze.in_port(1))

        const = Const(graph, {'name': name + '/additional_axis', 'value': int64_array([1])}).create_node()
        new_tiles = new_shape_node_from_shape_nodes([node.in_port(1).get_source().node, const])

        node.in_port(1).get_connection().set_source(new_tiles.out_port(0))

        squeeze_dim = Const(graph, {'name': name + '/3D_Tile_Squeeze_dim', 'value': int64_array([3])}).create_node()
        squeeze = Squeeze(graph, {'name': name + '/3D_Tile_Squeeze', 'override_output_shape': True}).create_node()
        squeeze_dim.out_port(0).connect(squeeze.in_port(1))

        source = node.in_port(0).get_source()
        node.in_port(0).get_connection().set_source(unsqueeze.out_port(0))
        unsqueeze.in_port(0).connect(source)

        node.out_port(0).get_connection().set_source(squeeze.out_port(0))
        node.out_port(0).connect(squeeze.in_port(0))

        node['override_output_shape'] = True
        new_tiles['override_output_shape'] = True
        node['need_shape_inference'] = True
Beispiel #10
0
    def replace_op(self, graph: nx.MultiDiGraph, node: Node):
        for ind in range(len(node.out_nodes())):
            squeeze_node = Squeeze(graph, dict(squeeze_dims=[node.axis], name=node.name + '/Squeeze_')).create_node([])
            insert_node_after(node, squeeze_node, ind)

        # do not replace any output edge
        return []
Beispiel #11
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['argmax']

        connected_ports = [
            port for port in node.in_ports().values()
            if not port.disconnected()
        ]
        squeeze_node = Squeeze(graph, dict()).create_node([],
                                                          dict(name=node.name +
                                                               '/Squeeze'))
        if len(connected_ports) == 2:
            node.in_port(1).get_source().connect(squeeze_node.in_port(1))
        else:
            axis_node = Const(graph, {'value': node.axis}).create_node()
            node.in_port(1).connect(axis_node.out_port(0))
        node.out_port(0).get_connection().set_source(squeeze_node.out_port(0))
        node.out_port(0).connect(squeeze_node.in_port(0))
        return []
Beispiel #12
0
 def find_and_replace_pattern(self, graph: Graph):
     for node in graph.get_op_nodes(squeeze_axis=True):
         name = node.soft_get('name', node.id)
         for out_port in node.out_ports().values():
             if node.has_valid('axis'):
                 squeeze_node = create_op_with_const_inputs(
                     graph, Squeeze, {1: np.array(node.axis)},
                     {'name': name + '/Squeeze_'})
                 out_port.get_connection().insert_node(squeeze_node)
             elif node.is_in_port_connected(1):
                 squeeze_node = Squeeze(graph, {
                     'name': name + '/Squeeze_'
                 }).create_node()
                 out_port.get_connection().insert_node(squeeze_node)
                 node.in_port(1).get_connection().add_destination(
                     squeeze_node.in_port(1))
             else:
                 raise Error(
                     'Unknown axis to squeeze for node {}'.format(name))
Beispiel #13
0
    def test_squeeze_empty_squeeze_dims(self):
        graph = build_graph(
            nodes_attributes, [('data', 'squeeze'),
                               ('squeeze_dims', 'squeeze_dims_data'),
                               ('squeeze_dims_data', 'squeeze'),
                               ('squeeze', 'data_out')], {
                                   'data': {
                                       'shape': np.array([1, 2, 1, 4])
                                   },
                                   'squeeze_dims': {
                                       'value': np.array([]),
                                       'shape': np.array([1])
                                   },
                                   'squeeze_dims_data': {
                                       'value': np.array([]),
                                       'shape': np.array([1])
                                   },
                               })
        node = Node(graph, 'squeeze')
        Squeeze.infer(node)

        self.assertTrue(np.all(node.out_port(0).data.get_shape() == [2, 4]))
Beispiel #14
0
    def replace_op(self, graph: Graph, node: Node):
        if node.module.inverse:
            axes = Const(
                graph, {
                    'value': int64_array(range(2, node.module.num_axes - 1))
                }).create_node()
            dft_node = IDFT(graph, dict(name=node.name,
                                        in_ports_count=2)).create_node(
                                            [node.in_node(0), axes])

            # Slice a real part
            begin_id = Const(graph, {
                'value': int64_array([0, 0])
            }).create_node()
            end_id = Const(graph, {'value': int64_array([0, 1])}).create_node()
            real = StridedSlice(
                graph,
                dict(name=node.name + '/real',
                     begin_mask=[0, 0],
                     end_mask=[0, 1],
                     shrink_axis_mask=[0, 0],
                     new_axis_mask=[0],
                     ellipsis_mask=[1, 0])).create_node(
                         [dft_node, begin_id, end_id])

            squeeze_axis = Const(graph, {'value': -1}).create_node()
            res = Squeeze(graph,
                          dict(name=node.name + '/squeeze')).create_node(
                              [real, squeeze_axis])

            return [res.id]
        else:
            zero = Const(graph, {'value': 0.0}).create_node()
            imag = Mul(graph, dict(name=node.name + '/imag')).create_node(
                [node.in_node(0), zero])
            cmplx = PackOp(graph,
                           dict(name=node.name + '/complex',
                                axis=-1)).create_node([node.in_node(0), imag])

            axes = Const(graph, {
                'value': int64_array(range(2, node.module.num_axes))
            }).create_node()
            dft_node = DFT(graph,
                           dict(name=node.name,
                                in_ports_count=2)).create_node([cmplx, axes])
            return [dft_node.id]
Beispiel #15
0
    def replace_pattern(self, graph: Graph, match: dict):
        if match['rnn_layer']['op'] == 'LSTM':
            return

        rnn_layer = match['rnn_layer']

        # Build TensorIterator body first
        body = Graph(name=rnn_layer.name + '/sub_graph')
        body.graph = graph.graph

        # 1. Input squeeze Reshape
        inputs = [
            Op._create_data_node(
                body, rnn_layer.name + '/inport/' + str(inp), {
                    'shape':
                    rnn_layer.in_node(inp).shape.copy(),
                    'value':
                    rnn_layer.in_node(inp).value.copy()
                    if rnn_layer.in_node(inp).value is not None
                    and inp in [1, 2] else None
                }) for inp in [0, 4, 1, 2]
        ]  # X, h_init, WR, B

        inputs[0].shape[rnn_layer.sequence_dim] = 1
        input_squeeze = Squeeze(
            body,
            dict(name=rnn_layer.name + '/input_squeeze', internal_layer_id=0))
        input_squeeze_dim = Const(
            body,
            dict(name=rnn_layer.name + '/input_squeeze_dim',
                 value=rnn_layer.sequence_dim)).create_node_with_data()
        inputs[0] = input_squeeze.create_node_with_data(
            [inputs[0], input_squeeze_dim],
            edge_attrs=[{
                'internal_port_id': 0
            }])

        # 2. Output unsqueeze Reshape
        outputs = [
            Op._create_data_node(
                body, rnn_layer.name + '/outport/' + str(out), {
                    'shape':
                    rnn_layer.out_node(out).shape.copy()
                    if out in rnn_layer.out_nodes() else None
                }) for out in [0]
        ]
        for out in outputs:
            add_opoutput(body, out.id, 0, False)

        outputs[0].shape = np.delete(outputs[0].shape.copy(),
                                     rnn_layer.sequence_dim)
        output_unsqueeze_dim = Const(
            body,
            dict(name=rnn_layer.name + '/output_unsqueeze_dim',
                 value=rnn_layer.sequence_dim)).create_node_with_data()
        output_unsqueeze = Unsqueeze(
            body,
            dict(name=rnn_layer.name + '/output_unsqueeze/',
                 internal_layer_id=2))

        additional_attrs = dict(activations=rnn_layer.activations,
                                activation_alpha=rnn_layer.activation_alpha,
                                activation_beta=rnn_layer.activation_beta,
                                clip=rnn_layer.clip)
        if rnn_layer.op == 'GRU':
            additional_attrs[
                'linear_before_reset'] = rnn_layer.linear_before_reset

        # 3. ***Cell
        rnn_cell_op = self.get_rnn_cell(rnn_layer['op'])(
            body,
            dict(hidden_size=rnn_layer.hidden_size,
                 name=rnn_layer.name + '/{}Cell'.format(rnn_layer.op),
                 **additional_attrs,
                 internal_layer_id=1))

        gru_cell = rnn_cell_op.create_node_with_data(inputs,
                                                     data_nodes=outputs,
                                                     edge_attrs=[{}, {
                                                         'internal_port_id':
                                                         1
                                                     }, {
                                                         'internal_port_id':
                                                         2
                                                     }, {
                                                         'bin':
                                                         'weights'
                                                     }, {
                                                         'bin':
                                                         'biases'
                                                     }])

        # internal ports for outputs of cell
        gru_cell.in_node().out_edge(0)['internal_port_id'] = 4  # h_state

        gru_cell = output_unsqueeze.create_node_with_data(
            [gru_cell, output_unsqueeze_dim])
        gru_cell.in_node().out_edge(0)['internal_port_id'] = 3
        add_opoutput(body, gru_cell.id, 0, False)

        # 4. TensorIterator layer creating
        assert rnn_layer.direction in ['forward', 'reverse']
        if rnn_layer.direction == 'forward':
            stride = 1
            start = None
            end = None
        else:
            assert rnn_layer.direction == 'reverse'
            stride = -1
            start = -1
            end = 0

        # stacked h_state
        output_port_map = [{
            'external_port_id': 3,
            'internal_layer_id': 2,
            'internal_port_id': 3,
            'axis': rnn_layer.sequence_dim,
            'stride': stride,
            'start': start,
            'end': end,
            'part_size': 1,
        }]

        # Adding last h_state to outputs
        if len(rnn_layer.out_nodes()) == 2:
            output_port_map.extend([{
                'external_port_id': 4,
                'internal_layer_id': 1,
                'internal_port_id': 4,
            }])

        ti_op = TensorIterator(
            graph,
            {
                'name':
                rnn_layer.name + '/TensorIterator',
                'body':
                body,
                'in_ports_count':
                4,
                'out_ports_count':
                len(rnn_layer.out_nodes()),
                'input_port_map': [
                    {
                        'external_port_id': 0,
                        'internal_layer_id': 0,
                        'internal_port_id': 0,
                        'axis': rnn_layer.sequence_dim,
                        'stride': stride,
                        'start': start,
                        'end': end,
                        'part_size': 1,
                    },
                    {
                        'external_port_id': 1,
                        'internal_layer_id': 1,
                        'internal_port_id': 1,
                    },
                ],
                'output_port_map':
                output_port_map,
                # only for h state
                'back_edges': [
                    {
                        'from_layer': 1,
                        'from_port': 4,
                        'to_layer': 1,
                        'to_port': 1,
                    },
                ]
            })

        assert sorted(rnn_layer.out_nodes().keys()) == list(range(len(rnn_layer.out_nodes()))), \
            "There are gaps in output ports of GRUSequence operation. Node {}".format(rnn_layer.id)

        outs = ti_op.create_node_with_data(
            [rnn_layer.in_node(i) for i in [0, 4]],  # X, h_init
            data_nodes=[
                rnn_layer.out_node(i)
                for i in range(len(rnn_layer.out_nodes()))
            ],
            edge_attrs=[{
                'external_port_id': 0
            }, {
                'external_port_id': 1
            }])

        if not isinstance(outs, list):
            outs = list([outs])

        graph.remove_node(rnn_layer.id)
        outs[0].in_edge(0)['external_port_id'] = 3
        for i, out in enumerate(outs[1:]):
            external_port_id = 4 + i
            out.in_edge()['external_port_id'] = external_port_id

        ti = outs[0].in_node()
        TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti)
        TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti)
        TensorIterator.normalize_internal_ids(ti)
    def replace_pattern(self, graph: Graph, match: dict):
        lstm = match['lstm']

        # Build TensorIterator body first
        body = Graph(name=lstm.name + '/sub_graph')
        body.graph = graph.graph

        # 1. Input squeeze Reshape
        inputs = [
            Op._create_data_node(
                body, lstm.name + '/inport/' + str(inp), {
                    'shape':
                    lstm.in_node(inp).shape.copy(),
                    'value':
                    lstm.in_node(inp).value.copy() if lstm.in_node(inp).value
                    is not None and inp in [1, 2] else None
                }) for inp in [0, 4, 5, 1, 2]
        ]  # X, WR, B, h_init, c_init

        inputs[0].shape[lstm.sequence_dim] = 1
        input_squeeze = Squeeze(
            body, dict(name=lstm.name + '/input_squeeze', internal_layer_id=0))
        squeeze_dim_data = Const(body, {
            'name': lstm.name + '/input_squeeze_dim',
            'value': [lstm.sequence_dim]
        }).create_node_with_data()
        inputs[0] = input_squeeze.create_node_with_data(
            [inputs[0], squeeze_dim_data],
            edge_attrs=[{
                'internal_port_id': 0
            }])

        # 2. Output unsqueeze Reshape
        outputs = [
            Op._create_data_node(
                body, lstm.name + '/outport/' + str(out), {
                    'shape':
                    lstm.out_node(out).shape.copy() if out in lstm.out_nodes()
                    else lstm.in_node(4).shape.copy()
                }) for out in [0, 1]
        ]
        for out in outputs:
            add_opoutput(body, out.id, 0, False)

        outputs[0].shape = shape_delete(outputs[0].shape, lstm.sequence_dim)
        output_unsqueeze = Unsqueeze(
            body, dict(name=lstm.name + 'output_unsqueeze',
                       internal_layer_id=2))
        unsqueeze_dim_data = Const(
            body, {
                'name': lstm.name + '/output_unsqueeze_dim',
                'value': [lstm.sequence_dim]
            }).create_node_with_data()

        # 3. LSTMCell
        lstm_cell_op = LSTMCell(
            body,
            dict(hidden_size=lstm.hidden_size,
                 activations=lstm.activations,
                 activation_alpha=lstm.activation_alpha,
                 activation_beta=lstm.activation_beta,
                 clip=lstm.clip,
                 input_forget=lstm.input_forget,
                 name=lstm.name + '/LSTMCell',
                 internal_layer_id=1))
        lstm_cell_node = lstm_cell_op.create_node_with_data(
            inputs,
            data_nodes=outputs,
            edge_attrs=[{}, {
                'internal_port_id': 1
            }, {
                'internal_port_id': 2
            }, {
                'bin': 'weights'
            }, {
                'bin': 'biases'
            }])
        lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 4
        lstm_cell_node[0].in_node().out_edge(1)['internal_port_id'] = 5
        lstm_cell_node[0] = output_unsqueeze.create_node_with_data(
            [lstm_cell_node[0], unsqueeze_dim_data])
        lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 3
        add_opoutput(body, lstm_cell_node[0].id, 0, False)

        # 4. TensorIterator layer creating
        assert lstm.direction in ['forward', 'reverse']
        if lstm.direction == 'forward':
            stride = 1
            start = None
            end = None
        else:
            assert lstm.direction == 'reverse'
            stride = -1
            start = -1
            end = 0

        output_port_map = [{
            'external_port_id': 3,
            'internal_layer_id': 2,
            'internal_port_id': 3,
            'axis': lstm.sequence_dim,
            'stride': stride,
            'start': start,
            'end': end,
            'part_size': 1,
        }]

        # Adding h_state, c_state to outputs
        if len(lstm.out_nodes()) == 3:
            output_port_map.extend([{
                'external_port_id': 4,
                'internal_layer_id': 1,
                'internal_port_id': 4,
            }, {
                'external_port_id': 5,
                'internal_layer_id': 1,
                'internal_port_id': 5,
            }])

        ti_op = TensorIterator(
            graph, {
                'name':
                lstm.name + '/TensorIterator',
                'body':
                body,
                'in_ports_count':
                3,
                'out_ports_count':
                len(lstm.out_nodes()),
                'input_port_map': [
                    {
                        'external_port_id': 0,
                        'internal_layer_id': 0,
                        'internal_port_id': 0,
                        'axis': lstm.sequence_dim,
                        'stride': stride,
                        'start': start,
                        'end': end,
                        'part_size': 1,
                    },
                    {
                        'external_port_id': 1,
                        'internal_layer_id': 1,
                        'internal_port_id': 1,
                    },
                    {
                        'external_port_id': 2,
                        'internal_layer_id': 1,
                        'internal_port_id': 2,
                    },
                ],
                'output_port_map':
                output_port_map,
                'back_edges': [
                    {
                        'from_layer': 1,
                        'from_port': 4,
                        'to_layer': 1,
                        'to_port': 1,
                    },
                    {
                        'from_layer': 1,
                        'from_port': 5,
                        'to_layer': 1,
                        'to_port': 2,
                    },
                ]
            })

        assert sorted(lstm.out_nodes().keys()) == list(range(len(lstm.out_nodes()))), \
            "There are gaps in output ports of LSTMSequence operation. Node {}".format(lstm.id)

        outs = ti_op.create_node_with_data(
            [lstm.in_node(i) for i in [0, 4, 5]],  # X, h_init, c_init
            data_nodes=[
                lstm.out_node(i) for i in range(len(lstm.out_nodes()))
            ],
            edge_attrs=[{
                'external_port_id': 0
            }, {
                'external_port_id': 1
            }, {
                'external_port_id': 2
            }])

        if not isinstance(outs, list):
            outs = list([outs])

        graph.remove_node(lstm.id)
        outs[0].in_edge(0)['external_port_id'] = 3
        for i, out in enumerate(outs[1:]):
            external_port_id = 4 + i
            out.in_edge()['external_port_id'] = external_port_id

        ti = outs[0].in_node()
        TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti)
        TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti)
        TensorIterator.normalize_internal_ids(ti)
 def extract(cls, node: Node):
     Squeeze.update_node_stat(
         node,
         {'squeeze_dims': tf_int_list(node.pb.attr['squeeze_dims'].list)})
     return cls.enabled
Beispiel #18
0
    def replace_pattern(graph, match: dict):
        # Here we will found all parts of TI: condition, inputs/outputs, back edges, body and create TensorIterator Op
        # and make all checks needed for TensorIterator work
        cond_data = match['condition'].out_node(
            0) if not match['condition'].out_port(0).disconnected() else None
        time_data = match['condition'].out_node(1) if len(
            match['condition'].out_nodes()) >= 1 else None
        name = match['condition'].name

        back_edges = []
        inputs = []
        outputs = []

        if cond_data is not None:
            for node in cond_data.out_nodes():
                if node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorBackEdge':
                    back_edges.append(node.id)
                elif node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorInput':
                    inputs.append(node.id)
                elif node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorOutput':
                    outputs.append(node.id)

        if time_data is not None:
            for node in time_data.out_nodes():
                if node['kind'] == 'op' and node['op'] == 'TensorIteratorInput':
                    inputs.append(node.id)
                elif node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorOutput':
                    outputs.append(node.id)
                else:
                    # something goes wrong here
                    assert False
        condition = match['condition']
        tensor_sequence_length = condition.in_node(0)

        nodes_to_remove = [
            n.id
            for n in (condition, cond_data, time_data, tensor_sequence_length)
            if n is not None
        ]
        graph.remove_nodes_from(nodes_to_remove)

        body_nodes, extra_inputs = get_body(graph, inputs, outputs)

        if cond_data is not None:
            body_nodes = list(set(body_nodes) - set([cond_data]))

        inputs += extra_inputs

        assert all([node in graph.nodes() for node in body_nodes])

        inputs = [Node(graph, node) for node in inputs]
        outputs = [Node(graph, node) for node in outputs]
        back_edges = [Node(graph, node) for node in back_edges]

        external_inputs = [{
            'external_data_id':
            node.in_node(1 if node.has_valid('axis') else 0),
            'internal_data_id':
            node.out_node(0),
            'axis':
            node.axis,
            'start':
            node.start,
            'end':
            node.end,
            'stride':
            node.stride,
            'part_size':
            node.part_size
        } for node in inputs]

        external_outputs = [{
            'external_data_id':
            node.out_node(0),
            'internal_data_id':
            node.in_node(1 if node.has_valid('axis') else 0),
            'axis':
            node.axis,
            'start':
            node.start,
            'end':
            node.end,
            'stride':
            node.stride,
            'part_size':
            node.part_size
        } for node in outputs]

        back_edges_data = [{
            'from_data_id': node.in_node(1),
            'to_data_id': node.out_node(0),
            'init_data_id': node.in_node(0),
        } for node in back_edges]

        body = Graph(name='body')
        body.graph = graph.graph
        body.add_nodes_from([(node, graph.node[node]) for node in body_nodes])
        body.add_edges_from([
            (u, v, k, d) for u, v, k, d in graph.edges(data=True, keys=True)
            if u in body_nodes and v in body_nodes
        ])

        graph.remove_nodes_from(body_nodes + [match['condition'].id] +
                                [inp.id for inp in inputs] +
                                [out.id for out in outputs])
        internal_id_count = 0
        real_back_edges = []
        for edge in back_edges_data:
            assert edge['from_data_id'].id in body.nodes()
            assert edge['to_data_id'].id in body.nodes()
            assert edge['init_data_id'].id in body.nodes()
            edge['from_data_id'] = Node(body, edge['from_data_id'].id)
            edge['to_data_id'] = Node(body, edge['to_data_id'].id)
            edge['init_data_id'] = Node(body, edge['init_data_id'].id)
            add_opoutput(body, edge['from_data_id'].id, 0, False)

            # Assign/reuse ids for the back-edge start; it comes from from_data_id
            assert len(edge['from_data_id'].in_nodes()) == 1
            # layer id
            if not edge['from_data_id'].in_node().has_valid(
                    'internal_layer_id'):
                edge['from_data_id'].in_node(
                )['internal_layer_id'] = internal_id_count
                internal_id_count += 1
            edge['from_layer'] = edge['from_data_id'].in_node(
            )['internal_layer_id']

            # port id
            if 'internal_port_id' not in edge['from_data_id'].in_edge():
                edge['from_data_id'].in_edge(
                )['internal_port_id'] = internal_id_count
                internal_id_count += 1
            edge['from_port'] = edge['from_data_id'].in_edge(
            )['internal_port_id']

            # Look at all consumers for a data that ends a back-edge
            # For each such consumer, there will be a separate back-edge (and input)
            current_real_back_edges = []
            for _, consumer, key, edge_attrs in body.out_edges(
                    edge['to_data_id'].id, data=True, keys=True):

                real_edge = {}
                real_edge.update(
                    edge)  # all real back_edges have the same back-edge start

                consumer = Node(body, consumer)

                if real_edge['to_data_id'].in_node().has_valid(
                        'internal_layer_id'):
                    assert False
                    real_edge['to_data_id'].out_node()['internal_layer_id'] = \
                        real_edge['to_data_id'].in_node().internal_layer_id
                elif not consumer.has_valid('internal_layer_id'):
                    consumer['internal_layer_id'] = internal_id_count
                    internal_id_count += 1
                real_edge['to_layer'] = consumer['internal_layer_id']

                assert 'internal_port_id' not in edge_attrs
                assert len(real_edge['init_data_id'].out_edges()) == 1
                assert not 'internal_port_id' in real_edge[
                    'init_data_id'].out_edge()
                edge_attrs['internal_port_id'] = internal_id_count
                internal_id_count += 1
                real_edge['to_port'] = edge_attrs['internal_port_id']
                real_edge['consumer'] = consumer
                real_edge['consumer_key'] = key

                real_edge['attrs'] = deepcopy(edge_attrs)
                current_real_back_edges.append(real_edge)

            # connect initial data node with each consumer providing actual edge attributes
            body.add_edges_from([
                (real_edge['init_data_id'].id, real_edge['consumer'].id,
                 real_edge['consumer_key'], real_edge['attrs'])
                for real_edge in current_real_back_edges
            ])

            body.remove_nodes_from(
                [edge['to_data_id'].id, edge['to_data_id'].in_node().id])
            real_back_edges += current_real_back_edges

        real_external_inputs = []

        for ext_inp in external_inputs:
            assert ext_inp['external_data_id'].id not in body.nodes()
            assert ext_inp['internal_data_id'].id in body.nodes()
            ext_inp['internal_data_id'] = Node(body,
                                               ext_inp['internal_data_id'].id)

            if ext_inp['axis'] is not None:
                # Insert squeezing resize at input port that has partitioning
                shape = ext_inp['internal_data_id'].shape.copy()
                assert not ext_inp['internal_data_id'].has_valid('value')
                new_input_data = Op._create_data_node(
                    body,
                    ext_inp['internal_data_id'].name + '/UnsqueezedInput',
                    dict(shape=shape_insert(shape, ext_inp['axis'], 1)))

                reshape_op = Squeeze(
                    body,
                    dict(name=ext_inp['internal_data_id'].name +
                         '/InputSqueeze'))
                reshape_dim_data = Const(
                    body, {
                        'name':
                        ext_inp['internal_data_id'].name + '/ReshapeDim',
                        'value': ext_inp['axis']
                    }).create_node_with_data()
                reshape_op.create_node_with_data(
                    [new_input_data, reshape_dim_data],
                    data_nodes=[ext_inp['internal_data_id']])
                ext_inp['internal_data_id'] = new_input_data

            ext_inp['internal_data_id']['is_input'] = True
            assert len(ext_inp['internal_data_id'].in_nodes()) == 0
            ext_inp['external_port_id'] = internal_id_count
            internal_id_count += 1
            for _, consumer, edge_attrs in body.out_edges(
                    ext_inp['internal_data_id'].id, data=True):
                real_ext_inp = {}
                real_ext_inp.update(ext_inp)
                consumer = Node(body, consumer)
                if not consumer.has_valid('internal_layer_id'):
                    consumer['internal_layer_id'] = internal_id_count
                    internal_id_count += 1
                if not 'internal_port_id' in edge_attrs:
                    edge_attrs['internal_port_id'] = internal_id_count
                    internal_id_count += 1
                real_ext_inp['internal_layer_id'] = consumer[
                    'internal_layer_id']
                real_ext_inp['internal_port_id'] = edge_attrs[
                    'internal_port_id']
                real_external_inputs.append(real_ext_inp)

        for ext_out in external_outputs:
            assert ext_out['external_data_id'].id not in body.nodes()
            assert ext_out['internal_data_id'].id in body.nodes()
            ext_out['internal_data_id'] = Node(body,
                                               ext_out['internal_data_id'].id)

            if ext_out['axis'] is not None:
                # Insert unsqueezing resize at output port that has partitioning
                reshape_op = Unsqueeze(
                    body,
                    dict(name=ext_out['internal_data_id'].name +
                         '/OutputUnsqueeze'))
                reshape_dim_data = Const(
                    body, {
                        'name':
                        ext_out['internal_data_id'].name + '/ReshapeDim',
                        'value': ext_out['axis']
                    }).create_node_with_data()
                ext_out['internal_data_id'] = reshape_op.create_node_with_data(
                    [ext_out['internal_data_id'], reshape_dim_data])

            # TODO: add here working with simple outputs

            if not any([
                    out_node.soft_get('op', None) == 'Result'
                    for out_node in ext_out['internal_data_id'].out_nodes()
            ]):
                add_opoutput(body, ext_out['internal_data_id'].id, 0, False)

            # assert len(ext_out['internal_data_id'].out_nodes()) == 0
            assert len(ext_out['internal_data_id'].in_nodes()) == 1
            if not 'internal_layer_id' in ext_out['internal_data_id'].in_node(
            ):
                ext_out['internal_data_id'].in_node(
                )['internal_layer_id'] = internal_id_count
                internal_id_count += 1
            if not 'internal_port_id' in ext_out['internal_data_id'].in_edge():
                ext_out['internal_data_id'].in_edge(
                )['internal_port_id'] = internal_id_count
                internal_id_count += 1
            ext_out['internal_layer_id'] = ext_out['internal_data_id'].in_node(
            )['internal_layer_id']
            ext_out['internal_port_id'] = ext_out['internal_data_id'].in_edge(
            )['internal_port_id']
            ext_out['external_port_id'] = internal_id_count
            internal_id_count += 1

        # create TensorIterator layer with pre-computed components
        ti_op = TensorIterator(
            graph, {
                'name':
                name + '/TensorIterator',
                'body':
                body,
                'in_ports_count':
                len(external_inputs),
                'out_ports_count':
                len(external_outputs),
                'input_port_map': [{
                    field: external_input[field]
                    for field in [
                        'external_port_id', 'internal_layer_id',
                        'internal_port_id', 'axis', 'stride', 'part_size',
                        'start', 'end'
                    ]
                } for external_input in real_external_inputs],
                'output_port_map': [{
                    field: external_output[field]
                    for field in [
                        'external_port_id', 'internal_layer_id',
                        'internal_port_id', 'axis', 'stride', 'part_size',
                        'start', 'end'
                    ]
                } for external_output in external_outputs],
                'back_edges': [{
                    field: edge[field]
                    for field in
                    ['from_layer', 'from_port', 'to_layer', 'to_port']
                } for edge in real_back_edges],
            })

        ti_outs = ti_op.create_node_with_data(
            inputs=[inp['external_data_id'] for inp in external_inputs],
            edge_attrs=[{
                'external_port_id': inp['external_port_id']
            } for inp in external_inputs],
            data_nodes=[out['external_data_id'] for out in external_outputs])

        if not isinstance(ti_outs, list):
            ti_outs = [ti_outs]

        for i, out in enumerate(ti_outs):
            out.in_edge(
            )['external_port_id'] = external_outputs[i]['external_port_id']

        ti = ti_outs[0].in_node()
        TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti)
        TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti)
        TensorIterator.normalize_internal_ids(ti)