Пример #1
0
    def extract(cls, node):
        # borders: leftBorder, topBorder, rightBorder, bottomBordes
        borders = onnx_attr(node,
                            'border',
                            'ints',
                            default=None,
                            dst_type=int64_array)
        scale = onnx_attr(node,
                          'scale',
                          'ints',
                          default=None,
                          dst_type=int64_array)

        # Crop reference: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Crop
        if len(borders) != 4:
            log.error(
                'ONNX Crop layer {} should take exactly 4 borders instead of {}'
                .format(node.name, len(borders)))
            return False

        attrs = {'axis': int64_array([2, 3])}
        if scale is not None:
            attrs.update({
                'dim': scale,
                'offset': int64_array([borders[1], borders[0]])
            })
        else:
            attrs.update({
                'crop_begin': int64_array([borders[1], borders[0]]),
                'crop_end': int64_array([borders[3], borders[2]])
            })

        Crop.update_node_stat(node, attrs)
        return CropFrontExtractor.enabled
Пример #2
0
    def test_crop_type3_infer_neg3(self):
        graph = self._create_graph_type3()

        crop_node = Node(graph, 'crop_node')
        crop_node['offset'] = None

        with self.assertRaisesRegex(Error, "offset attribute is missing.*"):
            Crop.infer(crop_node)
Пример #3
0
    def test_crop_type3_infer_neg2(self):
        graph = self._create_graph_type3()

        crop_node = Node(graph, 'crop_node')
        crop_node['axis'] = None

        with self.assertRaisesRegex(Error, "axis attribute is missing for .*"):
            Crop.infer(crop_node)
Пример #4
0
    def test_crop_type2_infer_neg1(self):
        graph = self._create_graph_type2()

        crop_node = Node(graph, 'crop_node')
        crop_node['dim'] = int64_array([1, 2, 3])

        with self.assertRaisesRegex(Error, "Number of axis.*"):
            Crop.infer(crop_node)
Пример #5
0
    def test_crop_type1_infer_neg2(self):
        graph = self._create_graph_type1()

        crop_node = Node(graph, 'crop_node')
        crop_node['crop_begin'] = int64_array([1, 2, 3])

        with self.assertRaisesRegex(Error, "number of crop_begin.*"):
            Crop.infer(crop_node)
Пример #6
0
    def test_crop_type3_infer_neg4(self):
        graph = self._create_graph_type3()

        crop_node = Node(graph, 'crop_node')
        crop_input2 = Node(graph, 'crop_input2')
        crop_input2.shape = int64_array([1, 4, 423, 563])

        with self.assertRaisesRegex(
                Error, "The crop for dimension is out of bounds.*"):
            Crop.infer(crop_node)
Пример #7
0
    def test_crop_type3_infer_neg1(self):
        graph = self._create_graph_type3()

        crop_node = Node(graph, 'crop_node')
        crop_input2 = Node(graph, 'crop_input2')
        crop_input2.shape = None

        with self.assertRaisesRegex(Error,
                                    "Not all input shapes were defined.*"):
            Crop.infer(crop_node)
Пример #8
0
    def test_crop_type2_infer_neg2(self):
        graph = self._create_graph_type2()

        crop_node = Node(graph, 'crop_node')
        crop_node['dim'] = None
        crop_node['crop_begin'] = None

        with self.assertRaisesRegex(
                Error, "Crop node crop_node should have either.*"):
            Crop.infer(crop_node)
Пример #9
0
 def extract(cls, node):
     attrs = get_mxnet_layer_attrs(node.symbol_dict)
     offset = attrs.tuple("offset", int, ())
     axis = attrs.int("num_args", 0)
     node_attrs = {
         'axis': axis,
         'offset': list(offset),
         'dim': None,
     }
     Crop.update_node_stat(node, node_attrs)
     return cls.enabled
Пример #10
0
    def extract(cls, node):
        pb = node.parameters

        mapping_rule = {
            'dim': pb['dim'],
            'offset': pb['offset'],
            'axis': pb['axis'],
            'layout': 'NCHW'
        }

        Crop.update_node_stat(node, attrs=mapping_rule)
        return cls.enabled
Пример #11
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        node_id = node['variable_id']

        out_node_port = node.out_port(0).get_destination()
        in_node_port = node.in_port(0).get_source()
        node.in_port(0).disconnect()
        node.out_port(0).disconnect()
        crop = Crop(graph, {'name': 'Result_for_'+node_id, 'dim': np.array([1]), 'offset': np.array([0]),
                            'axis': np.array([0])}).create_node()
        in_node_port.connect(crop.in_port(0))
        crop.out_port(0).connect(out_node_port)
Пример #12
0
 def extract(cls, node):
     proto_layer = node.pb
     param = proto_layer.crop_param
     mapping_rule = {
         'axis': param.axis,
         'offset': param.offset,
         'dim': None,  # set in infer
         'infer': crop_infer
     }
     # update the attributes of the node
     Crop.update_node_stat(node, mapping_rule)
     return cls.enabled
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        pair_node = Node(graph, node.pair_name)

        if node.t >= 0:
            raise Error('Does not support IfDefined with t > 0')

        if node.in_port(0).get_source() is not None:
            input_port = node.in_port(0).get_source()
            op_output_id = node.out_port(0).get_destination().node.id
            out_port = pair_node.out_port(0)
            node_name = node.name
            pair_name = pair_node.name
        else:
            input_port = pair_node.in_port(0).get_source()
            op_output_id = pair_node.out_port(0).get_destination().node.id
            out_port = node.out_port(0)
            node_name = pair_node.name
            pair_name = node.name

        in_shape = input_port.data.get_shape()
        node_t = abs(node.t)

        init_value_memory_out = Const(graph, {'name': 'init_value_' + pair_name,
                                              'value': np.zeros(int64_array([in_shape[0], in_shape[1]*node_t]), dtype=np.float32),
                                              'shape': int64_array([in_shape[0], in_shape[1]*node_t])}).create_node()
        memory_out = ReadValue(graph, {'name': pair_name, 'variable_id': node_name+pair_name}).create_node()
        init_value_memory_out.out_port(0).connect(memory_out.in_port(0))

        if node_t > 1:
            crop_concat = Crop(graph, {'name': 'Memory_crop', 'dim': mo_array([in_shape[1]*(node_t-1)]),
                                       'offset': mo_array([in_shape[1]]), 'axis': mo_array([1])}).create_node()
            memory_out.out_port(0).connect(crop_concat.in_port(0))
            concat = Concat(graph, {'name': 'Memory_concat'}).create_node()
            concat.add_sequence_of_ports('in', range(2))
            crop_concat.out_port(0).connect(concat.in_port(0))
            concat.in_port(1).connect(input_port)

            memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node()
            concat.out_port(0).connect(memory_in.in_port(0))
            out = Result(graph, {'name': 'Memory_output'}).create_node()
            memory_in.out_port(0).connect(out.in_port(0))

            crop_out = Crop(graph, {'name': 'Memory_crop_out', 'dim': mo_array([in_shape[1]]),
                                    'offset': mo_array([0]), 'axis': mo_array([1])}).create_node()
            memory_out.out_port(0).connect(crop_out.in_port(0))
            out_port.get_connection().set_source(crop_out.out_port(0))
        else:
            memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node()
            memory_in.in_port(0).connect(input_port)
            out = Result(graph, {'name': 'Memory_output'}).create_node()
            memory_in.out_port(0).connect(out.in_port(0))
            out_port.get_connection().set_source(memory_out.out_port(0))

        graph.remove_node(op_output_id)
        graph.remove_node(node.id)
        graph.remove_node(pair_node.id)
Пример #14
0
    def test_crop_type3_infer(self):
        graph = self._create_graph_type3()

        crop_node = Node(graph, 'crop_node')
        Crop.infer(crop_node)

        exp_shape = int64_array([1, 3, 100, 150])
        res_shape = graph.node['crop_output']['shape']

        self.assertTrue(
            np.array_equal(exp_shape, res_shape),
            'shapes do not match expected: {} and given: {}'.format(
                exp_shape, res_shape))
Пример #15
0
def add_fake_background_loc(graph: Graph, input_node: Node):
    r"""
    DetectionOutput layer expects that box coordinates contains coordinates of boxes for the "background" class also,
    but in the TensorFlow\* Object Detection API the tensor contains information about real object classes only.
    The function copies a slice of the output data of the node 'input_node' and then concats it to the beginning of the
    data. The data in this slice is not used by the Detection Output layer so the actual values are not important. This
    approach allows the model to be reshape-able and does not introduce many layers.
    "background" class box coordinates.
    :param graph: graph to operate on.
    :param input_node: node producing the boxes coordinates.
    :return convolution node that adds slice of data for the "background" class.
    """
    crop_op = Crop(graph, dict(axis=mo_array([1]), offset=mo_array([0]), dim=mo_array([1]), nchw_layout=True))
    crop_node = crop_op.create_node([input_node], dict(name='crop_locs'))

    concat_op = Concat(graph, dict(axis=1, in_ports_count=2, nchw_layout=True))
    return concat_op.create_node([crop_node, input_node], dict(name=input_node.id + '/locs_with_fake_background'))
    def find_and_replace_pattern(self, graph: Graph):
        for nms in graph.get_op_nodes(op='NonMaxSuppression'):
            # prepare inputs to the NonMaximumSuppression Node
            unsqueeze_boxes = create_op_node_with_second_input(
                graph, Unsqueeze, int64_array([0]),
                {'name': nms.soft_get('name') + '/Unsqueeze_0'})
            nms.in_port(0).get_connection().insert_node(unsqueeze_boxes)

            unsqueeze_box_scores = create_op_node_with_second_input(
                graph, Reshape, int64_array([1, 1, -1]),
                {'name': nms.soft_get('name') + '/Unsqueeze_1'})
            nms.in_port(1).get_connection().insert_node(unsqueeze_box_scores)

            nms_name = nms.soft_get('name', nms.id)

            # prepare output #0
            crop_box_indices_name = nms_name + '/Crop_boxes_'
            crop_box_indices = Crop(
                graph, {
                    'name': crop_box_indices_name,
                    'axis': int64_array([1]),
                    'offset': int64_array([2]),
                    'dim': int64_array([1])
                }).create_node()
            nms.out_port(0).get_connection().insert_node(crop_box_indices)
            squeeze_output_boxes = create_op_node_with_second_input(
                graph, Squeeze, int64_array([1]),
                {'name': crop_box_indices_name + '/Squeeze'})
            crop_box_indices.out_port(0).get_connection().insert_node(
                squeeze_output_boxes)

            num_of_outputs = len([
                port for port in nms.out_ports().values()
                if not port.disconnected()
            ])

            if num_of_outputs == 1:
                continue

            # prepare output #1
            crop_score_indices_name = nms_name + '/Crop_scores_'
            crop_score_indices = Crop(
                graph, {
                    'name': crop_score_indices_name,
                    'axis': int64_array([1]),
                    'offset': int64_array([2]),
                    'dim': int64_array([1])
                }).create_node()
            nms.out_port(1).get_connection().insert_node(crop_score_indices)
            squeeze_output_scores = create_op_node_with_second_input(
                graph, Squeeze, int64_array([1]),
                {'name': crop_score_indices_name + '/Squeeze'})
            crop_score_indices.out_port(0).get_connection().insert_node(
                squeeze_output_scores)
Пример #17
0
    def replace_pattern(graph: Graph, match: dict):
        mem = match['op']
        mem_shape = mem.in_port(0).data.get_shape()
        mem_parent = mem.in_port(0).get_source()
        context = mem['context']

        for child_port in mem_parent.get_destinations():
            child = child_port.node
            # check if we find Splice containing context 'context'
            if child['op'] == 'Splice' and child.id != mem.id and set(
                    child['context']).issubset(set(context)):
                left_cont_out = child['context'][0]
                left_cont = context[0]

                for child_of_child in child.out_port(0).get_destinations():
                    out_transfer = child_of_child.node
                    out_transfer_port = child_of_child
                    if out_transfer['op'] == 'Crop':
                        # modify existing Crop to get right data from larger Splice
                        out_transfer['offset'] = out_transfer['offset'] + (
                            left_cont_out - left_cont) * mem_shape[-1]
                    else:
                        # insert Crop if we have not one
                        child_of_child.disconnect()
                        crop_node = Crop(
                            graph, {
                                'name':
                                graph.unique_id(prefix='Splice_crop_'),
                                'offset':
                                (left_cont_out - left_cont) * mem_shape[-1],
                                'dim':
                                mo_array(
                                    [len(child['context']) * mem_shape[-1]]),
                                'axis':
                                mo_array([-1])
                            }).create_node()
                        child.out_port(0).connect(crop_node.in_port(0))
                        crop_node.out_port(0).connect(child_of_child)
                        crop_node.out_port(0).data.set_shape(
                            child.out_port(0).data.get_shape())

                        out_transfer_port = crop_node.in_port(0)

                    # move edge to child from old Splice to larger
                    out_transfer_port.disconnect()
                    mem.out_port(0).connect(out_transfer_port)

                graph.remove_node(child.id)
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        pair_node = Node(graph, node.pair_name)

        if pair_node.has_default:
            return

        if node.in_port(0).get_source() is not None:
            input_node_out_port = node.in_port(0).get_source()
            op_output_id = node.out_port(0).get_destination().node.id
            out_node_in_ports = pair_node.out_port(0).get_destinations()
        else:
            input_node_out_port = pair_node.in_port(0).get_source()
            op_output_id = pair_node.out_port(0).get_destination().node.id
            out_node_in_ports = node.out_port(0).get_destinations()

        in_shape = input_node_out_port.data.get_shape().copy()

        node_id = node.id
        node_name = node.name
        node_t = node.t

        splice = Splice(graph, {'name': node_name,
                                'id': node_id,
                                'context': int64_array(range(node_t, 1))
                                if node_t < 0 else int64_array(range(0, node_t+1))}).create_node()
        splice.in_port(0).connect(input_node_out_port)

        # offset of Crop will be 0 (first element) if node_t < 0 and in_shape[1]*node_t (last element) if node_t > 0
        crop = Crop(graph, {'name': 'Splice_Crop',
                            'axis': int64_array([1]),
                            'offset': int64_array([max(0, in_shape[1] * node_t)]),
                            'dim': int64_array([in_shape[1]])}).create_node()

        splice.out_port(0).connect(crop.in_port(0))
        splice.out_port(0).data.set_shape(int64_array([in_shape[0], (abs(node_t) + 1) * in_shape[1]]))

        outs = input_node_out_port.get_destinations()
        for in_port in outs:
            out_ = in_port.node
            if out_.op == 'Concat' and out_ == out_node_in_ports[0].node:
                crop_input = Crop(graph, {'name': 'Splice_Crop',
                                          'axis': int64_array([1]),
                                          'offset': int64_array([-min(0, in_shape[1] * node_t)]),
                                          'dim': int64_array([in_shape[1]])}).create_node()
                splice.out_port(0).connect(crop_input.in_port(0))

                in_port.disconnect()
                crop_input.out_port(0).connect(in_port)
                crop_input.out_port(0).data.set_shape(in_shape)

        for dest_port in out_node_in_ports:
            dest_port.connect(crop.out_port(0))

        graph.remove_node(op_output_id)
        graph.remove_node(node.id)
        graph.remove_node(pair_node.id)
Пример #19
0
    def insert_select(graph: Graph, node: Node):
        context_len = node.frame_time + 1

        if context_len == 1:
            return

        in_node_port = node.in_port(0).get_source()
        in_node_shape = node.in_port(0).data.get_shape()
        node.in_port(0).disconnect()

        # add Select before saving state to avoid saving garbage
        select_node = Select(graph, {
            'name': 'select_' + node.name
        }).create_node()
        zero_else = create_const_with_batch_from_input(in_node_port,
                                                       in_node_shape[1])
        select_node.in_port(1).connect(in_node_port)
        select_node.in_port(2).connect(zero_else.out_port(0))

        # check if we have already appropriate iteration counter
        existing_counters = find_pattern_matches(
            graph,
            nodes=[('mem_in', dict(op='ReadValue')),
                   ('mem_in_data', dict(shape=int64_array([context_len]))),
                   ('crop_mem_in',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([1]),
                         dim=int64_array([context_len - 1]))),
                   ('crop_mem_in_data', dict()),
                   ('concat', dict(op='Concat', axis=1)),
                   ('concat_data', dict()), ('const_1', dict(op='Const')),
                   ('const_1_data', dict()), ('mem_out', dict(op='Assign')),
                   ('crop_out',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([0]),
                         dim=int64_array([1]))), ('crop_out_data', dict()),
                   ('select', dict(op='Select'))],
            edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'),
                   ('crop_mem_in', 'crop_mem_in_data'),
                   ('crop_mem_in_data', 'concat', {
                       'in': 0
                   }), ('const_1', 'const_1_data'),
                   ('const_1_data', 'concat', {
                       'in': 1
                   }), ('concat', 'concat_data'), ('concat_data', 'mem_out'),
                   ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
                   ('crop_out_data', 'select')])
        counter_match = next(existing_counters, None)
        if counter_match is not None:
            ones = Node(graph, inverse_dict(counter_match)['const_1'])
            input_port = Node(
                graph,
                inverse_dict(counter_match)['crop_out']).out_port(0)
        else:
            init_value_mem_out = create_const_with_batch_from_input(
                in_node_port, context_len, precision=np.int32)
            mem_out = ReadValue(
                graph, {
                    'name': 'iteration_number',
                    'variable_id': 'iteration_' + node.name
                }).create_node()
            mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
            cut_first = Crop(
                graph, {
                    'name': 'cut_first',
                    'axis': int64_array([1]),
                    'offset': int64_array([1]),
                    'dim': int64_array([context_len - 1])
                }).create_node()
            cut_first.in_port(0).connect(mem_out.out_port(0))
            ones = create_const_with_batch_from_input(in_node_port, 1, 1,
                                                      np.int32)
            concat = Concat(graph, {
                'name': 'concat_ones',
                'in_ports_count': 2,
                'axis': 1
            }).create_node()
            concat.in_port(0).connect(cut_first.out_port(0))
            concat.in_port(1).connect(ones.out_port(0))
            mem_in = Assign(
                graph, {
                    'name': 'iteration_number_out',
                    'variable_id': 'iteration_' + node.name
                }).create_node()
            mem_in.in_port(0).connect(concat.out_port(0))
            res = Result(graph, {}).create_node()
            mem_in.out_port(0).connect(res.in_port(0))
            cut_last = Crop(
                graph, {
                    'name': 'cut_last',
                    'axis': int64_array([1]),
                    'offset': int64_array([0]),
                    'dim': int64_array([1])
                }).create_node()
            cut_last.in_port(0).connect(concat.out_port(0))
            input_port = cut_last.out_port(0)

        # Check if data from memory is 1
        # if it is True, we have correct data and should proceed with saving it to memory
        # else we have not gathered context and have garbage here, shouldn't change initial state of memory
        cast_in = Equal(graph, {
            'name': input_port.node.name + '/cast_to_bool'
        }).create_node()
        cast_in.in_port(0).connect(ones.out_port(0))
        cast_in.in_port(1).connect(input_port)
        select_node.in_port(0).connect(cast_in.out_port(0))
        select_node.out_port(0).connect(node.in_port(0))
        select_node.out_port(0).data.set_shape(in_node_shape)
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        in_shape = node.in_port(0).data.get_shape().copy()
        memory_element = in_shape[1] - node.const_dim
        memory_size = memory_element * len(node.context)

        memory_pair_id = unique_id('id')
        # Memory(in)
        input_memory = ReadValue(graph, {
            'name': 'prev_splice_memory',
            'variable_id': memory_pair_id
        }).create_node()

        # Memory(in)  \
        #             Crop
        # Input(temp) /
        crop = Crop(
            graph, {
                'name': 'Splice_Crop',
                'axis': int64_array([1]),
                'offset': int64_array([memory_element]),
                'dim': int64_array([memory_size - memory_element])
            }).create_node()
        crop.in_port(0).connect(input_memory.out_port(0))

        # Crop   \
        #         Concat
        # Input  /
        concat_node = Concat(graph, {
            'name': 'Splice_Concat',
            'in_ports_count': 2,
            'axis': 1
        }).create_node()
        concat_node.in_port(0).connect(crop.out_port(0))

        # Concat -> Memory(out)
        mem_out = Assign(graph, {
            'name': 'out_splice_memory',
            'variable_id': memory_pair_id
        }).create_node()
        mem_out.in_port(0).connect(concat_node.out_port(0))
        Result(graph).create_node().in_port(0).connect(mem_out.out_port(0))

        if node.const_dim != 0:
            memory_element_constdim = node.const_dim
            memory_size_constdim = memory_element_constdim * len(node.context)

            split = create_op_with_const_inputs(
                graph, VariadicSplit, {
                    1: int64_array(1),
                    2: int64_array([memory_element, memory_element_constdim])
                }, {
                    'name': node.id + '_split_const',
                    'out_ports_count': 2
                })

            split.out_port(0).connect(concat_node.in_port(1))

            # create separate splice construction for const_dim
            memory_pair_id = unique_id('memory_for_const_dim')
            init_value_input_memory_const_dim = Const(
                graph, {
                    'name':
                    'init_value_const_dim_in_memory',
                    'value':
                    np.zeros(int64_array([in_shape[0], memory_size_constdim]),
                             dtype=np.float32),
                    'shape':
                    int64_array([in_shape[0], memory_size_constdim])
                }).create_node()
            input_memory_const_dim = ReadValue(graph, {
                'name': 'const_dim_in_memory',
                'variable_id': memory_pair_id
            }).create_node()
            init_value_input_memory_const_dim.out_port(0).connect(
                input_memory_const_dim.in_port(0))

            crop_const_dim = Crop(
                graph, {
                    'name':
                    'const_dim_crop',
                    'axis':
                    int64_array([1]),
                    'offset':
                    int64_array([memory_element_constdim]),
                    'dim':
                    int64_array(
                        [memory_size_constdim - memory_element_constdim])
                }).create_node()
            crop_const_dim.in_port(0).connect(
                input_memory_const_dim.out_port(0))

            concat_node_const_dim = Concat(graph, {
                'name': 'const_dim_concat',
                'in_ports_count': 2,
                'axis': 1
            }).create_node()
            concat_node_const_dim.in_port(0).connect(
                crop_const_dim.out_port(0))

            mem_out_const_dim = Assign(graph, {
                'name': 'const_dim_out_memory',
                'variable_id': memory_pair_id
            }).create_node()
            mem_out_const_dim.in_port(0).connect(
                concat_node_const_dim.out_port(0))
            Result(graph).create_node().in_port(0).connect(
                mem_out_const_dim.out_port(0))

            # connect splice to Split as begin and Concat as the end
            split.out_port(1).connect(concat_node_const_dim.in_port(1))
            crop_first = Crop(
                graph, {
                    'name': 'const_dim_crop_first',
                    'axis': int64_array([1]),
                    'offset': int64_array([0]),
                    'dim': int64_array([memory_element_constdim])
                }).create_node()
            crop_first.in_port(0).connect(concat_node_const_dim.out_port(0))

            concat_const = Concat(graph, {
                'name': node.id + '_concat_const',
                'axis': 1,
                'in_ports_count': 2
            }).create_node()
            concat_const.in_port(1).connect(crop_first.out_port(0))
            concat_const.in_port(0).connect(concat_node.out_port(0))

            init_value_input_memory = Const(
                graph, {
                    'name':
                    'init_value_' + node.name,
                    'value':
                    np.zeros(int64_array([in_shape[0], memory_size]),
                             dtype=np.float32),
                    'shape':
                    int64_array([in_shape[0], memory_size])
                }).create_node()
            init_value_input_memory.out_port(0).connect(
                input_memory.in_port(0))
            node.in_port(0).get_connection().set_destination(split.in_port(0))
            node.out_port(0).get_connection().set_source(
                concat_const.out_port(0))
        else:
            init_value_input_memory = Const(
                graph, {
                    'name':
                    'init_value_' + node.name,
                    'value':
                    np.zeros(int64_array([in_shape[0], memory_size]),
                             dtype=np.float32),
                    'shape':
                    int64_array([in_shape[0], memory_size])
                }).create_node()
            init_value_input_memory.out_port(0).connect(
                input_memory.in_port(0))
            node.in_port(0).get_connection().set_destination(
                concat_node.in_port(1))
            node.out_port(0).get_connection().set_source(
                concat_node.out_port(0))

        # to avoid re-inference of shape and touching in next replacements
        graph.remove_node(node.id)
Пример #21
0
    def replace_pattern(graph: Graph, match: dict):
        mem = match['op']
        mem_shape = mem.in_port(0).data.get_shape()
        mem_parent = mem.in_port(0).get_source()
        context = mem['context']

        for child_port in mem_parent.get_destinations():
            child = child_port.node
            if child['op'] == 'Splice' and child.id != mem.id and \
               (child['context'][0] == context[-1] or child['context'][0] == context[-1]):

                new_context = list(context)
                new_context.extend(list(child['context']))
                new_context = list(set(new_context))
                new_context.sort()
                if child['context'][0] == context[-1]:
                    new_node = mem
                    rem_node = child
                else:
                    new_node = child
                    rem_node = mem

                # reset edges from rem_node to new_node
                for out_port_rem in rem_node.out_port(0).get_destinations():
                    out_transfer = out_port_rem.node
                    out_transfer_shape = out_port_rem.data.get_shape().copy()

                    out_port_rem.disconnect()

                    if out_transfer['op'] == 'Crop':
                        # modify existing Crop to get right data from larger Splice
                        out_transfer['offset'] = out_transfer['offset'] + (
                            len(new_context) -
                            len(rem_node.context)) * mem_shape[-1]
                        out_port_rem.connect(new_node.out_port(0))
                    else:
                        # insert Crop if we have not one
                        crop_node = Crop(
                            graph, {
                                'name':
                                graph.unique_id(prefix='Splice_crop_'),
                                'offset':
                                (len(new_context) - len(rem_node.context)) *
                                mem_shape[-1],
                                'dim':
                                mo_array([
                                    len(rem_node['context']) * mem_shape[-1]
                                ]),
                                'axis':
                                mo_array([-1])
                            }).create_node()
                        new_node.out_port(0).connect(crop_node.in_port(0))
                        crop_node.out_port(0).connect(out_port_rem)
                        crop_node.out_port(0).data.set_shape(
                            out_transfer_shape)

                for out_port_rem in new_node.out_port(0).get_destinations():
                    out_transfer = out_port_rem.node
                    out_transfer_shape = out_port_rem.data.get_shape().copy()

                    if out_transfer['op'] != 'Crop':
                        # insert Crop if we have not one
                        crop_node = Crop(
                            graph, {
                                'name':
                                graph.unique_id(prefix='Splice_crop_'),
                                'offset':
                                mo_array([0]),
                                'dim':
                                mo_array([
                                    len(new_node['context']) * mem_shape[-1]
                                ]),
                                'axis':
                                mo_array([-1])
                            }).create_node()
                        new_node.out_port(0).connect(crop_node.in_port(0))
                        out_port_rem.disconnect()
                        crop_node.out_port(0).connect(out_port_rem)
                        crop_node.out_port(0).data.set_shape(
                            out_transfer_shape)

                new_shape = new_node.out_port(0).data.get_shape()
                new_shape[1] += rem_node.out_port(0).data.get_shape(
                )[1] - rem_node.in_port(0).data.get_shape()[1]
                new_node.out_port(0).data.set_shape(new_shape)
                new_node.context = new_context

                graph.remove_node(rem_node.id)