Example #1
0
    def concat_outputs(bi_rnn, forward_outputs, reverse_outputs, final_outputs):
        """ Concatenates two set of outputs from bidirectiondl RNNSequence nodes """
        concat_ops = [
            Concat(bi_rnn.graph, {
                'name': bi_rnn.name + '/FinalConcat/Data',
                'axis': 1,
                'in_ports_count': 2,
            }),
            Concat(bi_rnn.graph, {
                'name': bi_rnn.name + '/FinalConcat/HiddenState',
                'axis': 0,
                'in_ports_count': 2,
            }),
            Concat(bi_rnn.graph, {
                'name': bi_rnn.name + '/FinalConcat/CellState',
                'axis': 0,
                'in_ports_count': 2,
            })
        ]

        bi_rnn.graph.remove_node(bi_rnn.id)

        for i in final_outputs:
            concat_ops[i].create_node_with_data(
                [forward_outputs[i], reverse_outputs[i]],
                data_nodes=[final_outputs[i]]
            )
Example #2
0
 def extract(cls, node):
     pb = node.pb
     mapping_rule = {
        'axis': pb.concat_param.axis,
     }
     Concat.update_node_stat(node, mapping_rule)
     return cls.enabled
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['mxreshape']

        input_index = 0
        reshape_index = 0
        shape_node = Shape(graph, dict(name=node.id +
                                       '/ShapeMXReshape')).create_node()
        shape_node.in_port(0).connect(node.in_port(0).get_source())
        output_dims_nodes = []
        for d in node.dim:
            if reshape_index < len(node.dim):
                input_index, reshape_index, output_dims_nodes = self.resolve(
                    input_index, reshape_index, node.dim, shape_node,
                    output_dims_nodes)

        concat_node = Concat(
            shape_node.graph,
            dict(name=shape_node.id + '/ConcatMXReshape_',
                 axis=0,
                 in_ports_count=len(output_dims_nodes))).create_node()

        for in_port_index, dim_node in enumerate(output_dims_nodes):
            concat_node.in_port(in_port_index).connect(dim_node.out_port(0))

        reshape_node = Reshape(graph,
                               dict(name=node.id + '/Reshape_')).create_node()
        reshape_node.in_port(1).connect(concat_node.out_port(0))
        node.in_port(0).get_connection().set_destination(
            reshape_node.in_port(0))
        node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
    def replace_with_split_concat(node):
        graph = node.graph

        name = node.soft_get('name', node.id)
        axis = node.axis
        order = node.order

        split = create_op_with_const_inputs(graph, Split,
                                            {1: int64_array(axis)}, {
                                                'name': name + '/Split',
                                                'num_splits': order.size
                                            })
        concat = Concat(graph, {
            'name': name + '/Concat',
            'axis': axis,
            'in_ports_count': order.size
        }).create_node()

        for out_port_idx, in_port_idx in enumerate(order):
            split.out_port(out_port_idx).connect(concat.in_port(in_port_idx))

        node.out_port(0).get_connection().set_source(concat.out_port(0))
        node.in_port(0).get_connection().set_destination(split.in_port(0))

        graph.remove_node(node.id)
    def concat_output_states(graph: Graph, match: dict, new_states: list):
        """ Concatenates output states from multilayer layer. """
        rnn_layer = match['rnn_layer']
        original_states = [
            rnn_layer.out_node(i) if i in rnn_layer.out_nodes() else None
            for i in [1, 2]
        ]

        concat_ops = [
            Concat(
                rnn_layer.graph, {
                    'name':
                    rnn_layer.name + '/FinalLayerSplitConcat/HiddenState',
                    'axis': -1
                }),
            Concat(
                rnn_layer.graph, {
                    'name':
                    rnn_layer.name + '/FinalLayerSplitConcat/CellState',
                    'axis': -1
                })
        ]

        for i in range(len(original_states)):  # [0] or [0, 1]
            if original_states[i] is None:
                continue
            concat_ops[i].attrs.update({'in_ports_count': len(new_states[i])})
            concat_ops[i].create_node_with_data(
                inputs=new_states[i], data_nodes=[original_states[i]])
Example #6
0
    def extract(cls, node):
        attrs = get_mxnet_layer_attrs(node.symbol_dict)
        data = {
            'axis': attrs.int("dim", 1),
        }

        # update the attributes of the node
        Concat.update_node_stat(node, data)
        return cls.enabled
Example #7
0
    def replace_op(self, graph: Graph, node: Node):
        out_node = Concat(graph, {'axis': node.axis, 'in_ports_count': len(node.in_ports())}).create_node()
        pack_name = node.soft_get('name', node.id)

        for ind in node.in_ports():
            unsqueeze_node = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array([node.axis])},
                                                         {'name': node.soft_get('name', node.id) + '/Unsqueeze'})
            node.in_port(ind).get_connection().set_destination(unsqueeze_node.in_port(0))
            unsqueeze_node.out_port(0).connect(out_node.in_port(ind))

        rename_nodes([(node, pack_name + '/TBR'), (out_node, pack_name)])
        return [out_node.id]
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        pair_node = Node(graph, node.pair_name)

        if node.t >= 0:
            raise Error('Does not support IfDefined with t > 0')

        if node.in_port(0).get_source() is not None:
            input_port = node.in_port(0).get_source()
            op_output_id = node.out_port(0).get_destination().node.id
            out_port = pair_node.out_port(0)
            node_name = node.name
            pair_name = pair_node.name
        else:
            input_port = pair_node.in_port(0).get_source()
            op_output_id = pair_node.out_port(0).get_destination().node.id
            out_port = node.out_port(0)
            node_name = pair_node.name
            pair_name = node.name

        in_shape = input_port.data.get_shape()
        node_t = abs(node.t)

        init_value_memory_out = Const(graph, {'name': 'init_value_' + pair_name,
                                              'value': np.zeros(int64_array([in_shape[0], in_shape[1]*node_t]), dtype=np.float32),
                                              'shape': int64_array([in_shape[0], in_shape[1]*node_t])}).create_node()
        memory_out = ReadValue(graph, {'name': pair_name, 'variable_id': node_name+pair_name}).create_node()
        init_value_memory_out.out_port(0).connect(memory_out.in_port(0))

        if node_t > 1:
            crop_concat = Crop(graph, {'name': 'Memory_crop', 'dim': mo_array([in_shape[1]*(node_t-1)]),
                                       'offset': mo_array([in_shape[1]]), 'axis': mo_array([1])}).create_node()
            memory_out.out_port(0).connect(crop_concat.in_port(0))
            concat = Concat(graph, {'name': 'Memory_concat'}).create_node()
            concat.add_sequence_of_ports('in', range(2))
            crop_concat.out_port(0).connect(concat.in_port(0))
            concat.in_port(1).connect(input_port)

            memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node()
            concat.out_port(0).connect(memory_in.in_port(0))
            out = Result(graph, {'name': 'Memory_output'}).create_node()
            memory_in.out_port(0).connect(out.in_port(0))

            crop_out = Crop(graph, {'name': 'Memory_crop_out', 'dim': mo_array([in_shape[1]]),
                                    'offset': mo_array([0]), 'axis': mo_array([1])}).create_node()
            memory_out.out_port(0).connect(crop_out.in_port(0))
            out_port.get_connection().set_source(crop_out.out_port(0))
        else:
            memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node()
            memory_in.in_port(0).connect(input_port)
            out = Result(graph, {'name': 'Memory_output'}).create_node()
            memory_in.out_port(0).connect(out.in_port(0))
            out_port.get_connection().set_source(memory_out.out_port(0))

        graph.remove_node(op_output_id)
        graph.remove_node(node.id)
        graph.remove_node(pair_node.id)
def create_ss_interval_border(graph: Graph, slice_border_port: Port,
                              shape: np.ndarray, axes: np.ndarray,
                              node_name: str):
    """
    This function creates "begin"/"end" parameters for the StridedSlice based on Slice's "starts"/"ends"

    :param graph: graph to operate on.
    :param slice_border_port: node output port that provides "starts"/"ends" values for the Slice.
    :param shape: input shape of the Slice
    :param axes: axes that "starts" and "ends" apply to
    :param node_name: Slice node name
    :return: Concat node that forms "begin"/"end" values for the StridedSlice
    """
    # the value for 'starts' or 'ends' might be maximum/minimum possible value of int64. This
    # value must be converted to maximum/minimum of int32 because such big values do not fit into the int32 which is
    # supported by the StridedSlice layer
    clamp = create_op_with_const_inputs(graph,
                                        Clamp,
                                        port_value_dict={
                                            1: np.iinfo(np.int32).min,
                                            2: np.iinfo(np.int32).max
                                        },
                                        op_attrs=dict(name=node_name +
                                                      '/Clamp'))
    clamp.in_port(0).connect(slice_border_port)
    # we have to convert "starts"/"ends" values from the network to one data type with constant values that are created
    # here to prevent type errors in Concat node
    cast = Cast(graph, dict(name=node_name + '/CastToI64',
                            dst_type=np.int64)).create_node()
    cast.in_port(0).connect(clamp.out_port(0))
    concat = Concat(graph, dict(name=node_name + '/Concat',
                                axis=0)).create_node()
    for value_idx, port_idx in enumerate(axes):
        concat.add_input_port(port_idx)
        # "axes" may not be sorted, so we need to split "starts"/"ends" values and connect each value to the correct
        # Concat input port
        value = create_op_with_const_inputs(
            graph,
            Gather,
            port_value_dict={
                1: int64_array([value_idx]),
                2: int64_array(0)
            },
            op_attrs={'name': node_name + '/Gather'})
        cast.out_port(0).connect(value.in_port(0))
        value.out_port(0).connect(concat.in_port(port_idx))
    for port_idx in range(len(shape)):
        if not concat.is_in_port_connected(port_idx):
            concat.add_input_port(port_idx)
            # This border value would be ignored in StridedSlice because of the begin_mask\end_mask
            const = Const(
                graph, dict(name=node_name + '/Const',
                            value=int64_array([0]))).create_node()
            const.out_port(0).connect(concat.in_port(port_idx))

    return concat
Example #10
0
 def test_concat_op(self):
     graph = build_graph(self.nodes_attributes, [('node_1', 'concat_node'),
                                                 ('concat_node', 'node_3')])
     concat_node = Concat(graph,
                          self.nodes_attributes['concat_node']).add_node()
     self.assertEqual(concat_node.type, 'Concat')
     self.assertEqual(concat_node.op, 'Concat')
     self.assertEqual(concat_node.infer, concat_infer)
Example #11
0
def add_fake_background_loc(graph: Graph, input_node: Node):
    r"""
    DetectionOutput layer expects that box coordinates contains coordinates of boxes for the "background" class also,
    but in the TensorFlow\* Object Detection API the tensor contains information about real object classes only.
    The function copies a slice of the output data of the node 'input_node' and then concats it to the beginning of the
    data. The data in this slice is not used by the Detection Output layer so the actual values are not important. This
    approach allows the model to be reshape-able and does not introduce many layers.
    "background" class box coordinates.
    :param graph: graph to operate on.
    :param input_node: node producing the boxes coordinates.
    :return convolution node that adds slice of data for the "background" class.
    """
    crop_op = Crop(graph, dict(axis=mo_array([1]), offset=mo_array([0]), dim=mo_array([1]), nchw_layout=True))
    crop_node = crop_op.create_node([input_node], dict(name='crop_locs'))

    concat_op = Concat(graph, dict(axis=1, in_ports_count=2, nchw_layout=True))
    return concat_op.create_node([crop_node, input_node], dict(name=input_node.id + '/locs_with_fake_background'))
    def replace_op(self, graph: Graph, node: Node):
        if node.has_and_set('inputs_preprocessed'):
            log.debug('Node "{}" has already been preprocessed'.format(
                node.soft_get('name')))
            return []
        # reshape tensor with batch indices to 2d
        unsqueeze_node = create_op_node_with_second_input(
            graph, Unsqueeze, int64_array([1]),
            {'name': node.name + '/Unsqueeze'}, node.in_node(2))

        convert_node = Cast(
            graph, {
                'name':
                unsqueeze_node.name + '/ToFloat',
                'dst_type':
                data_type_str_to_np(graph.graph['cmd_params'].data_type)
            }).create_node()

        convert_node.in_port(0).connect(unsqueeze_node.out_port(0))

        concat_op = Concat(
            graph, {
                'axis': 1,
                'name': node.name + '/concat_batch_indices_and_boxes',
                'in_ports_count': 2
            })
        concat_node = concat_op.create_node([convert_node, node.in_node(1)])

        # do not remove edge with crop_size because it is needed in the partial infer
        graph.remove_edge(node.in_node(1).id, node.id)

        # input to the CropAndResize contains boxes coordinates in YXYX layout. But IE layer ROIPooling expects
        # coordinates in the XYXY layout, so convolution is added here to swap coordinates
        swapped_box_coordinates_node = add_convolution_to_swap_xy_coordinates(
            graph, concat_node, 5)

        # reshape locations tensor to 2D so it could be passed to Eltwise which will be converted to ScaleShift
        reshape_2d_node = create_op_node_with_second_input(
            graph, Reshape, int64_array([-1, 5]),
            dict(name=swapped_box_coordinates_node.id + '/reshape_2d_'),
            swapped_box_coordinates_node)
        graph.create_edge(reshape_2d_node, node, 0, 1)

        # do not replace any output edge
        return []
Example #13
0
    def replace_pattern(self, graph: Graph, match: dict):
        concat_node = match['concat']
        sources_of_ports = [concat_node.in_port(i).get_connection().get_source() for i in concat_node.in_ports()]
        # If 'concat' is ConcatV2 layer from TF, then this layer initially had input 'axis' as the last input.
        # But then this input was deleted and the attribute 'axis' was added. Hence, the last port source can
        # be None in such case.
        sources_of_ports = [s for s in sources_of_ports if s is not None]

        input_nodes = [s.node for s in sources_of_ports]
        if not all(n.has_valid('type') for n in input_nodes):
            return

        saved_ports = []
        disconnected_ports = []

        for port_num, node in enumerate(input_nodes):
            if node.soft_get('type') == 'Const' and len(node.shape) > 1 and any(i == 0 for i in node.shape):
                disconnected_ports.append(port_num)
            else:
                saved_ports.append(port_num)

        if not saved_ports or not disconnected_ports:
            return

        if len(saved_ports) == 1:
            before_concat = concat_node.in_port(saved_ports[0]).get_connection().get_source()
            concat_node.out_port(0).get_connection().set_source(before_concat)
            return

        new_concat_attrs = concat_node.attrs().copy()
        new_concat_attrs['name'] = concat_node.name + '/Concat_'
        new_concat_attrs['in_ports_count'] = len(saved_ports)
        new_concat_node = Concat(graph, attrs=new_concat_attrs).create_node()

        for new_port_num, old_port_num in enumerate(saved_ports):
            concat_node.in_port(old_port_num).get_connection().set_destination(new_concat_node.in_port(new_port_num))

        for p in disconnected_ports:
            concat_node.in_port(p).disconnect()

        concat_node.out_port(0).get_connection().set_source(new_concat_node.out_port(0))
    def replace_tdnn(self, graph: Graph, tdnn_node: Node):
        tdnn_name = tdnn_node.soft_get('name', tdnn_node.id)

        concat_node = Concat(graph, {'axis': 1}).create_node()
        rename_nodes([(tdnn_node, tdnn_name + '/to_be_removed'),
                      (concat_node, tdnn_name)])

        for offset_ind, t in enumerate(tdnn_node['time_offsets']):
            concat_node.add_input_port(offset_ind)
            if t != 0:
                memory_name = tdnn_name + '/MemoryOffset/' + str(abs(t))
                memoryoffset_node = MemoryOffset(
                    graph, {
                        'name': memory_name,
                        't': t,
                        'pair_name': memory_name + '_out',
                        'has_default': False,
                        'splitted': False
                    }).create_node()

                tdnn_node.in_port(0).get_source().connect(
                    memoryoffset_node.in_port(0))
                memoryoffset_node.out_port(0).connect(
                    concat_node.in_port(offset_ind))
            else:
                # 0 time delay is not allowed in IE, it's meaningless
                # if time offset is 0 then connect input of tdnncomponent directly to Concat without memoryoffset
                tdnn_node.in_port(0).get_source().connect(
                    concat_node.in_port(offset_ind))

        weights = tdnn_node['weights']
        fc_inputs = {1: weights}

        bias_term = False
        if tdnn_node.has_valid('biases'):
            assert len(tdnn_node['biases']) == weights.shape[0]
            fc_inputs.update({2: tdnn_node['biases']})
            bias_term = True

        fc_node = create_op_with_const_inputs(
            graph, FullyConnected, fc_inputs, {
                'name': tdnn_name + '/FC',
                'out-size': weights.shape[0],
                'transpose_weights': True,
                'bias_term': bias_term
            })

        concat_node.out_port(0).connect(fc_node.in_port(0))
        tdnn_node.in_port(0).disconnect()
        tdnn_node.out_port(0).get_connection().set_source(fc_node.out_port(0))
Example #15
0
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['node']
        node_name = node.soft_get('name', node.id)

        connected_ports = [port for port in node.in_ports().values() if not port.disconnected()]
        if len(connected_ports) == 2:
            axis = node.in_port(1).data.get_value()
        else:
            axis = node.axis

        assert axis is not None, 'The "axis" should be defined for node "{}"'.format(node_name)
        assert node.has_and_set('output_type'), 'The data type is not set for node "{}"'.format(node_name)

        topk_mode = 'max' if node.op == 'ArgMax' else 'min'
        topk_node = TopK(graph, {'axis': axis, 'mode': topk_mode, 'sort': 'index',
                                 'remove_values_output': node.has_and_set('remove_values_output'),
                                 'index_element_type': node.output_type}).create_node()
        node.in_port(0).get_connection().set_destination(topk_node.in_port(0))
        if node.has_and_set('out_max_val'):  # in this mode the ArgMax produces tuples (max_ind, max_value)
            concat_node = Concat(graph, {'axis': 1, 'name': node.name + '/Concat'}).create_node()
            concat_node.add_input_port(0, skip_if_exist=True)
            concat_node.add_input_port(1, skip_if_exist=True)
            topk_node.out_port(0).connect(concat_node.in_port(1))  # indices
            topk_node.out_port(1).connect(concat_node.in_port(0))  # values
            if not node.out_port(0).disconnected():
                node.out_port(0).get_connection().set_source(concat_node.out_port(0))
        else:
            if not node.out_port(0).disconnected():
                node.out_port(0).get_connection().set_source(topk_node.out_port(1))

        topk_node.in_port(1).connect(Const(graph, {'name': node.soft_get('name') + '/TopK',
                                                   'value': node.top_k}).create_node().out_port(0))

        graph.remove_nodes_from([node.id, node.out_node(0).id])
Example #16
0
    def fuse_reduces(first_reduce, second_reduce):
        first_reduce_name = first_reduce.soft_get('name', first_reduce.id)
        second_reduce_name = second_reduce.soft_get('name', second_reduce.id)
        reduce_type = first_reduce.type

        assert first_reduce.type == second_reduce.type

        if len(first_reduce.out_port(0).get_destinations()) != 1:
            # data dependency
            return

        if first_reduce.keep_dims != second_reduce.keep_dims:
            return

        first_axes = first_reduce.in_port(1).data.get_value()
        second_axes = second_reduce.in_port(1).data.get_value()
        if first_axes is None or second_axes is None:
            # dynamic axes merging is not supported
            return

        if not first_reduce.keep_dims:
            if not np.all(first_axes > second_axes):
                # indexing of upper reduce input dimensions changed
                return

        graph = second_reduce.graph

        new_axes = Concat(
            graph, {
                'name': second_reduce_name + '/Axes',
                'axis': int64_array(0),
                'in_ports_count': 2,
                'override_output_shape': True
            }).create_node()
        new_axes.in_port(0).connect(first_reduce.in_port(1).get_source())
        new_axes.in_port(1).connect(second_reduce.in_port(1).get_source())

        first_reduce.in_port(
            0).get_source().node['need_shape_inference'] = True
        first_reduce.in_port(
            0).get_source().node['override_output_shape'] = True

        second_reduce.in_port(1).get_connection().set_source(
            new_axes.out_port(0))

        first_reduce.out_port(0).get_connection().set_source(
            first_reduce.in_port(0).get_connection().get_source())
        first_reduce.in_port(1).disconnect()
        graph.remove_node(first_reduce.id)

        log.debug(
            '{0} nodes {1} and {2} were fused to a single {2} node with updated axes input'
            ''.format(reduce_type, first_reduce_name, second_reduce_name))
Example #17
0
    def placeholder_scales(self, placeholder: Node):
        """
        Helper function to get scales for prior boxes out of input image size:
                [1 / im_width, 1 / im_height, 1 / im_width, 1 / im_height]
        """
        graph = placeholder.graph
        name = placeholder.soft_get('name', placeholder.id)

        shape_value = placeholder.soft_get('shape', None)
        assert shape_value is not None, \
            "[ {} replacer ] Placeholder `{}` should have shape attribute".format(self.replacement_id, name)
        assert isinstance(shape_value, np.ndarray), \
            "[ {} replacer ] Placeholder `{}` shape attribute should be np.ndarray".format(self.replacement_id, name)
        assert shape_value.size == 4, \
            "[ {} replacer ] Placeholder `{}` should be 4D. Shape: {}".format(self.replacement_id, name, shape_value)

        shape = Shape(graph, {'name': 'input_image_shape'}).create_node()
        shape.in_port(0).connect(placeholder.out_port(0))

        begin = Const(graph, {'value': int64_array([1])}).create_node()
        end = Const(graph, {'value': int64_array([3])}).create_node()
        stride = Const(graph, {'value': int64_array([1])}).create_node()
        spatial = StridedSlice(graph, {'name': name + '/get_h_w', 'begin_mask': int64_array([1]),
                                       'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]),
                                       'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0])}).create_node()

        spatial.in_port(0).connect(shape.out_port(0))
        spatial.in_port(1).connect(begin.out_port(0))
        spatial.in_port(2).connect(end.out_port(0))
        spatial.in_port(3).connect(stride.out_port(0))

        power = Const(graph, {'value': float32_array([-1.])}).create_node()
        spatial_scale = Pow(graph, {}).create_node()

        spatial_scale.in_port(0).connect(spatial.out_port(0))
        spatial_scale.in_port(1).connect(power.out_port(0))

        # Power `type_infer` requires inputs to have equal data type
        convert_to_fp32 = Cast(graph, {'dst_type': np.float32}).create_node()
        spatial_scale.in_port(0).get_connection().insert_node(convert_to_fp32)

        order = Const(graph, {'value': int64_array([1, 0])}).create_node()
        axis_const = Const(graph, {'value': int64_array(0)}).create_node()
        reverse = Gather(graph, {}).create_node()

        reverse.in_port(0).connect(spatial_scale.out_port(0))
        reverse.in_port(1).connect(order.out_port(0))
        axis_const.out_port(0).connect(reverse.in_port(2))

        priors_scale_node = Concat(graph, {'axis': 0, 'in_ports_count': 2}).create_node()
        priors_scale_node.add_input_port(0, skip_if_exist=True)
        priors_scale_node.add_input_port(1, skip_if_exist=True)

        priors_scale_node.in_port(0).connect(reverse.out_port(0))
        priors_scale_node.in_port(1).connect(reverse.out_port(0))
        return priors_scale_node
Example #18
0
    def extend_inputs(node: Node, num_insertions: int):
        graph = node.graph
        node_name = node.soft_get('name', node.id)

        for i, input_name in [(1, 'begin'), (2, 'end'), (3, 'strides')]:
            if i == 3 and not node.is_in_port_connected(3):
                continue  # no need to extend strides if they are not connected

            blank_values_arr = np.zeros(
                num_insertions) if input_name != 'strides' else np.ones(
                    num_insertions)
            blank_values_node = Const(
                graph, {
                    'name': node_name + '/extend_{}_const'.format(input_name),
                    'value': int64_array(blank_values_arr)
                }).create_node()

            if node.in_port(i).get_source().node.soft_get('type') == 'Concat':
                # concat already exists
                concat = node.in_port(i).get_source().node
                # because output data node shape will be changed
                # while shapes will be reinferred no need to check consistency
                concat['override_output_shape'] = True

                last_in_port = max(concat.in_ports().keys())
                assert not concat.in_port(last_in_port).disconnected(), 'The last in_port of Concat node {} ' \
                                                                        'should be connected'. \
                    format(concat.soft_get('name', node.id))

                concat.add_input_port(last_in_port + 1)
                concat.in_port(last_in_port + 1).connect(
                    blank_values_node.out_port(0))
            else:
                # have to create concat
                concat = Concat(
                    graph, {
                        'axis': 0,
                        'name': node_name + '/concat_{}'.format(input_name),
                        'in_ports_count': 2
                    }).create_node()
                node.in_port(i).get_connection().set_destination(
                    concat.in_port(0))
                concat.in_port(1).connect(blank_values_node.out_port(0))
                concat.out_port(0).get_connection().set_destination(
                    node.in_port(i))
Example #19
0
    def append_variances(priors_scale_node: Node, variance: list):
        graph = priors_scale_node.graph
        name = priors_scale_node.name

        sp_shape = Shape(graph, {'name': name + '/shape'}).create_node()
        priors_scale_node.out_port(0).connect(sp_shape.in_port(0))

        begin = Const(graph, {'value': int64_array([-2])}).create_node()
        end = Const(graph, {'value': int64_array([-1])}).create_node()
        stride = Const(graph, {'value': int64_array([1])}).create_node()
        shape_part_for_tiling = StridedSlice(graph, {'name': name + '/get_-2_dim', 'begin_mask': int64_array([1]),
                                                     'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]),
                                                     'shrink_axis_mask': int64_array([0]),
                                                     'ellipsis_mask': int64_array([0])}).create_node()

        sp_shape.out_port(0).connect(shape_part_for_tiling.in_port(0))
        begin.out_port(0).connect(shape_part_for_tiling.in_port(1))
        end.out_port(0).connect(shape_part_for_tiling.in_port(2))
        stride.out_port(0).connect(shape_part_for_tiling.in_port(3))

        shape_concat = create_op_node_with_second_input(graph, Concat, int64_array([4]),
                                                        {'name': name + '/shape_for_tiling', 'in_ports_count': 2,
                                                         'axis': int64_array(0)},
                                                        shape_part_for_tiling)

        variance = Const(graph, {'name': name + '/variance', 'value': float32_array(variance)}).create_node()
        tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node()
        variance.out_port(0).connect(tile.in_port(0))
        shape_concat.out_port(0).connect(tile.in_port(1))

        reshape_dim = Const(graph, {'value': int64_array([-1, 4])}).create_node()
        sp_reshape = Reshape(graph, {'name': name + '/reshape'}).create_node()
        sp_reshape.in_port(0).connect(priors_scale_node.out_port(0))
        sp_reshape.in_port(1).connect(reshape_dim.out_port(0))

        concat = Concat(graph,
                        {'name': name + '/priors_concat', 'axis': int64_array(0), 'in_ports_count': 2}).create_node()
        sp_reshape.out_port(0).connect(concat.in_port(0))
        tile.out_port(0).connect(concat.in_port(1))

        output_dims = Const(graph, {'value': int64_array([1, 2, -1])}).create_node()
        output_node = Reshape(graph, {'name': name + '/3D_priors_wth_variances'}).create_node()
        concat.out_port(0).connect(output_node.in_port(0))
        output_dims.out_port(0).connect(output_node.in_port(1))

        return output_node
Example #20
0
def new_shape_node_from_shape_nodes(input_shape_nodes: list):
    """
    The function returns a node producing 1D tensor with concatenated shapes produced by nodes from "input_shape_nodes"
    :param input_shape_nodes: list of nodes producing 1D tensors
    :return: the node producing concatenated values of nodes from the "input_shape_nodes"
    """
    assert len(input_shape_nodes
               ) > 0, 'The list of input shape nodes should be non-empty'
    new_shape_node = Concat(
        input_shape_nodes[0].graph, {
            'axis':
            0,
            'name':
            input_shape_nodes[0].soft_get('name', input_shape_nodes[0].id) +
            '/shapes_concat'
        }).create_node()

    for ind, input_node in enumerate(input_shape_nodes):
        new_shape_node.add_input_port(ind)
        new_shape_node.in_port(ind).connect(input_node.out_port(0))
    return new_shape_node
Example #21
0
    def unroll_ellipsis_for_inputs(graph: Graph, node: Node,
                                   ellipsis_start: int, num_insertions: int):
        node_name = node.soft_get('name', node.id)

        for i, input_name in [(1, 'begin'), (2, 'end'), (3, 'strides')]:
            if i == 3 and not node.is_in_port_connected(3):
                continue  # no need to extend strides if they are not connected

            blank_values_arr = np.zeros(
                num_insertions) if input_name != 'strides' else np.ones(
                    num_insertions)
            blank_values_node = Const(
                graph, {
                    'name':
                    node_name +
                    '/const_to_unroll_{}_ellipsis'.format(input_name),
                    'value':
                    int64_array(blank_values_arr)
                }).create_node()

            concat_in_ports_count = 3 if ellipsis_start != 0 else 2
            concat = Concat(
                graph, {
                    'axis': 0,
                    'name': node_name + '/concat_{}'.format(input_name),
                    'in_ports_count': concat_in_ports_count
                }).create_node()

            if ellipsis_start != 0:
                split = create_op_with_const_inputs(graph, VariadicSplit, {
                    1:
                    int64_array(0),
                    2:
                    int64_array([ellipsis_start, -1])
                }, {
                    'name':
                    node_name + '/split_for_{}_ellipsis'.format(input_name),
                    'out_ports_count':
                    2
                })
                node.in_port(i).get_connection().set_destination(
                    split.in_port(0))

                concat.in_port(0).connect(split.out_port(0))
                concat.in_port(1).connect(blank_values_node.out_port(0))
                concat.in_port(2).connect(split.out_port(1))
            else:
                concat.in_port(0).connect(blank_values_node.out_port(0))
                node.in_port(i).get_connection().set_destination(
                    concat.in_port(1))

            concat.out_port(0).get_connection().set_destination(
                node.in_port(i))
Example #22
0
    def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
        reshape_classes_node = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]),
                                                                dict(name='do_reshape_classes'),
                                                                match.single_input_node(1)[0])

        initial_priors_node = match.single_input_node(2)[0]
        priors_name = initial_priors_node.soft_get('name', initial_priors_node.id)
        # model calculates identical prior boxes for each batch, so we take first slice of them
        begin = Const(graph, {'value': mo_array([0, 0, 0], dtype=np.int32)}).create_node()
        end = Const(graph, {'value': mo_array([1, 0, 0], dtype=np.int32)}).create_node()
        stride = Const(graph, {'value': mo_array([1, 1, 1], dtype=np.int32)}).create_node()

        priors_node = StridedSlice(graph, {'name': priors_name + '/0_batch_slice',
                                           'begin_mask': int64_array([1, 1, 1]),
                                           'end_mask': int64_array([1, 0, 0]),
                                           'new_axis_mask': int64_array([0]),
                                           'shrink_axis_mask': int64_array([0]),
                                           'ellipsis_mask': int64_array([0])}).create_node()

        initial_priors_node.out_port(0).connect(priors_node.in_port(0))
        begin.out_port(0).connect(priors_node.in_port(1))
        end.out_port(0).connect(priors_node.in_port(2))
        stride.out_port(0).connect(priors_node.in_port(3))

        placeholders = graph.get_op_nodes(type='Parameter')
        assert len(placeholders) == 1, "{} replacer requires model to have one Placeholder, but current model has " \
                                       "{} placeholders".format(self.replacement_id, len(placeholders))
        placeholder = placeholders[0]

        # scale prior boxes to the [0, 1] interval
        node_with_scales_for_prior_boxes = self.placeholder_scales(placeholder)
        priors_scale_node = Mul(graph, {'name': 'scale_priors'}).create_node()

        broadcast = Broadcast(graph, {'name': 'scales_broadcast'}).create_node()
        shape_of_priors = Shape(graph, {'name': 'priors_shape'}).create_node()
        priors_node.out_port(0).connect(shape_of_priors.in_port(0))
        broadcast.in_port(1).connect(shape_of_priors.out_port(0))
        broadcast.in_port(0).connect(node_with_scales_for_prior_boxes.out_port(0))

        priors_scale_node.in_port(0).connect(priors_node.out_port(0))
        priors_scale_node.in_port(1).connect(broadcast.out_port(0))

        try:
            variance = match.custom_replacement_desc.custom_attributes['variance']
        except:
            raise Error('There is no variance attribute in {} replacement config file `custom_attributes`'
                        ''.format(self.replacement_id))

        priors = self.append_variances(priors_scale_node, variance)

        # calculate prior boxes widths and heights
        split_node = create_op_with_const_inputs(
            graph, VariadicSplit, {1: int64_array(2), 2: int64_array([1, 1, 1, 1])}, {'out_ports_count': 4},
            priors_scale_node)

        priors_width_node = Sub(graph, dict(name=split_node.name + '/sub_2-0_')
                                ).create_node([(split_node, 2), (split_node, 0)])
        priors_height_node = Sub(graph, dict(name=split_node.name + '/sub_3-1_')
                                 ).create_node([(split_node, 3), (split_node, 1)])

        # concat weights and heights into a single tensor and multiple with the box coordinates regression values
        # WA with 3 Concats instead of 1 for keeping model reshapable
        # concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height', 'axis': -1,
        #                                           'in_ports_count': 4}).create_node(
        # [priors_width_node, priors_height_node, priors_width_node, priors_height_node])

        concat_1 = Concat(graph, {'name': 'concat_width_height',
                                  'axis': -1, 'in_ports_count': 2}).create_node([priors_width_node, priors_height_node])
        concat_2 = Concat(graph, {'name': 'concat_width_height_width',
                                  'axis': -1, 'in_ports_count': 2}).create_node([concat_1, priors_width_node])
        concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height', 'axis': -1, 'in_ports_count': 2}
                                          ).create_node([concat_2, priors_height_node])

        applied_width_height_regressions_node = Mul(graph, {'name': 'final_regressions'}).create_node(
            [concat_width_height_node, match.single_input_node(0)[0]])

        # reshape to 2D tensor as Inference Engine Detection Output layer expects
        reshape_regression_node = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]),
                                                                   dict(name='reshape_regression'),
                                                                   applied_width_height_regressions_node)

        detection_output_op = DetectionOutput(graph, match.custom_replacement_desc.custom_attributes)
        # get nms from the original network
        iou_threshold = None
        nms_nodes = graph.get_op_nodes(op='NonMaxSuppression')
        if len(nms_nodes) > 0:
            # it is highly unlikely that for different classes NMS has different
            # moreover DetectionOutput accepts only scalar values for iou_threshold (nms_threshold)
            iou_threshold = nms_nodes[0].in_node(3).value
        if iou_threshold is None:
            raise Error('During {} `iou_threshold` was not retrieved from RetinaNet graph'.format(self.replacement_id))

        detection_output_node = detection_output_op.create_node(
            [reshape_regression_node, reshape_classes_node, priors],
            dict(name=detection_output_op.attrs['type'], nms_threshold=iou_threshold, clip_after_nms=1, normalized=1,
                 variance_encoded_in_target=0, background_label_id=1000))

        # As outputs are replaced with a postprocessing node, outgoing tensor names are no longer
        # correspond to original tensors and should be removed from output->Result edges
        out_nodes = []
        for out in range(match.outputs_count()):
            out_nodes.append(match.output_node(out)[0])
        clear_tensor_names_info(out_nodes)

        return {'detection_output_node': detection_output_node}
Example #23
0
    def replace_timeheightconv(self, graph: Graph, node: Node):
        req_time_offsets = node.soft_get('time_offsets')
        offsets = node.soft_get("offsets", [[]])
        all_time_offsets = list(set(offsets[:, 0]))
        all_time_offsets.sort()
        in_name = node.soft_get('name', node.id)
        rename_node(node, in_name + '/to_delete')

        # create memoryoffsets for context gathering
        # we need concat if time offsets more than 1
        concat = Concat(graph,
                        attrs={
                            'name': in_name + '/Concat',
                            'in_ports_count': len(all_time_offsets)
                        }).create_node()
        i = 0
        for t in all_time_offsets:
            # if time offset included in required_time_offsets we don't need default value
            has_default = t not in req_time_offsets
            memoff = MemoryOffset(graph,
                                  attrs={
                                      'name':
                                      in_name + '/MemoryOffset_' + str(i),
                                      't':
                                      t,
                                      'has_default':
                                      has_default,
                                      'splitted':
                                      False,
                                      'pair_name':
                                      in_name + '/MemoryOffset_pair_' + str(i)
                                  }).create_node()
            concat.in_port(i).connect(memoff.out_port(0))
            memoff.in_port(0).connect(node.in_port(0).get_source())
            i = i + 1

        stride = node.soft_get("height_subsample", 1)

        kernel = int64_array([0, 0])
        kernel[0] = len(set(offsets[:, 0]))
        kernel[1] = len(set(offsets[:, 1]))

        pad_h = int64_array([0, 0])
        pad_h[0] = -min(offsets[:, 1]) if min(offsets[:, 1]) < 0 else 0
        pad_h[1] = stride * node.height_out - (node.height_in -
                                               max([max(offsets[:, 1]), 0]))

        dilation_t = (max(offsets[:, 0]) - min(offsets[:, 0])) / (
            kernel[0] - 1) if kernel[0] > 1 else 1
        dilation_h = (max(offsets[:, 1]) - min(offsets[:, 1])) / (
            kernel[1] - 1) if kernel[0] > 1 else 1

        conv_attrs = {
            'name':
            in_name,
            'output':
            node['out_channels'],
            'height_in':
            node.height_in,
            'bias_term':
            None,
            'pad':
            int64_array([[0, 0], [0, 0], [0, 0], pad_h]),
            'pad_spatial_shape':
            int64_array([[0, 0], pad_h]),
            'dilation':
            int64_array([1, 1, dilation_t, dilation_h]),
            'kernel':
            int64_array(
                [node.out_channels, node.in_channels, kernel[0], kernel[1]]),
            'stride':
            int64_array([1, 1, 1, stride]),
            'kernel_spatial':
            kernel,
            'input_feature_channel':
            1,
            'output_feature_channel':
            0,
            'channel_dims':
            int64_array([1]),
            'spatial_dims':
            int64_array([2, 3]),
            'batch_dims':
            int64_array([0]),
            'kernel_spatial_idx':
            int64_array([2, 3]),
            'group':
            1,
            'reshape_kernel':
            True,
            'bias_addable':
            True,
        }
        conv = Convolution(graph, attrs=conv_attrs).create_node()
        conv.in_port(0).connect(concat.out_port(0))
        conv.in_port(1).connect(node.in_port(1).get_source())

        # change layout for weights from OHWI to OIHW
        # in future should be replaced by common Permute mechanics
        weights = conv.in_port(1).get_source().node.value
        weights = weights.reshape(
            int64_array([node.out_channels, -1, node.in_channels]))
        weights = weights.transpose(int64_array([0, 2, 1]))
        weights = weights.flatten()
        conv.in_port(1).get_source().node.value = weights

        conv.in_port(2).connect(node.in_port(2).get_source())
        node.out_port(0).get_connection().set_source(conv.out_port(0))
        graph.remove_node(node.id)
Example #24
0
    def transform_graph(self, graph: Graph, replacement_descriptions: dict):
        parameter_node = graph.get_op_nodes(op='Parameter')[0]
        parameter_node['data_type'] = data_type_str_to_np(
            parameter_node.graph.graph['cmd_params'].data_type)

        # remove existing Result operations to remove unsupported sub-graph
        graph.remove_nodes_from(
            [node.id
             for node in graph.get_op_nodes(op='Result')] + ['detections'])

        # determine if the op which is a input/final result of mean value and scale applying to the input tensor
        # then connect it to the input of the first convolution of the model, so we remove the image pre-processing
        # which includes padding and resizing from the model
        preprocessing_input_node_id = replacement_descriptions[
            'preprocessing_input_node']
        assert preprocessing_input_node_id in graph.nodes, 'The node with name "{}" is not found in the graph. This ' \
                                                           'should be a last node before image normalization and is specified' \
                                                           ' in the json file.'.format(preprocessing_input_node_id)
        preprocessing_input_node = Node(graph, preprocessing_input_node_id)
        consumer_node = preprocessing_input_node.out_port(
            0).get_connection().get_destination().node
        consumer_node.in_port(0).get_connection().set_source(
            parameter_node.out_port(0))

        preprocessing_output_node_id = replacement_descriptions[
            'preprocessing_output_node']
        assert preprocessing_output_node_id in graph.nodes, 'The node with name "{}" is not found in the graph. This ' \
                                                            'node should provide scaled image output and is specified' \
                                                            ' in the json file.'.format(preprocessing_output_node_id)
        preprocessing_output_node = Node(graph, preprocessing_output_node_id)
        preprocessing_output_node.out_port(0).disconnect()

        convolution_nodes = [
            n for n in graph.pseudo_topological_sort()
            if n.soft_get('type') == 'Convolution'
        ]
        convolution_nodes[0].in_port(0).get_connection().set_source(
            preprocessing_output_node.out_port(0))

        # create prior boxes (anchors) generator
        aspect_ratios = replacement_descriptions['aspect_ratios']
        assert len(aspect_ratios) % 2 == 0
        aspect_ratios = list(zip(aspect_ratios[::2], aspect_ratios[1::2]))
        priors_generator = self.AnchorGenerator(
            min_level=int(replacement_descriptions['min_level']),
            aspect_ratios=aspect_ratios,
            num_scales=int(replacement_descriptions['num_scales']),
            anchor_scale=replacement_descriptions['anchor_scale'])

        prior_boxes = []
        for i in range(100):
            inp_name = 'box_net/box-predict{}/BiasAdd'.format('_%d' %
                                                              i if i else '')
            if inp_name not in graph:
                break
            widths, heights = priors_generator.get(i)
            prior_box_op = PriorBoxClusteredOp(
                graph, {
                    'width': mo_array(widths),
                    'height': mo_array(heights),
                    'clip': 0,
                    'flip': 0,
                    'variance': replacement_descriptions['variance'],
                    'offset': 0.5
                })
            prior_boxes.append(
                prior_box_op.create_node(
                    [Node(graph, inp_name), parameter_node]))

        # concatenate prior box operations
        concat_prior_boxes = Concat(graph, {'axis': -1}).create_node()
        for idx, node in enumerate(prior_boxes):
            concat_prior_boxes.add_input_port(idx)
            concat_prior_boxes.in_port(idx).connect(node.out_port(0))

        conf = Sigmoid(graph, dict(name='concat/sigmoid')).create_node(
            [Node(graph, 'concat')])
        reshape_size_node = Const(graph, {
            'value': int64_array([0, -1])
        }).create_node([])
        logits = Reshape(graph, dict(name=conf.name + '/Flatten')).create_node(
            [conf, reshape_size_node])
        deltas = Reshape(graph, dict(name='concat_1/Flatten')).create_node(
            [Node(graph, 'concat_1'), reshape_size_node])

        # revert convolution boxes prediction weights from yxYX to xyXY (convolutions share weights and bias)
        weights = Node(graph, 'box_net/box-predict/pointwise_kernel')
        weights.value = weights.value.reshape(-1, 4)[:, [1, 0, 3, 2]].reshape(
            weights.shape)
        bias = Node(graph, 'box_net/box-predict/bias')
        bias.value = bias.value.reshape(-1,
                                        4)[:, [1, 0, 3, 2]].reshape(bias.shape)

        detection_output_node = DetectionOutput(
            graph,
            dict(
                name='detections',
                share_location=1,
                background_label_id=int(
                    replacement_descriptions['num_classes']) + 1,
                nms_threshold=replacement_descriptions['nms_threshold'],
                confidence_threshold=replacement_descriptions[
                    'confidence_threshold'],
                top_k=100,
                keep_top_k=100,
                code_type='caffe.PriorBoxParameter.CENTER_SIZE',
            )).create_node([deltas, logits, concat_prior_boxes])

        output_op = Result(graph, dict(name='output'))
        output_op.create_node([detection_output_node])
Example #25
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='ATen', operator='embedding_bag'):
            assert node.soft_get('mode') == 0, 'ATen::embedding_bag has unsupported mode, only "sum" ' \
                                               'mode is supported for node {}.'.format(node.id)
            node_name = node.soft_get('name', node.id)
            rename_node(node, node_name + '/TBR')
            is_packed = False
            if len(node.in_ports()) < 3 or node.in_port(2).disconnected():
                is_packed = True
                embedding_bag = EmbeddingBagPackedSum(graph, {
                    'name': node_name
                }).create_node()
            else:
                embedding_bag = EmbeddingBagOffsetsSum(graph, {
                    'name': node_name
                }).create_node()
                node.in_port(2).get_connection().set_destination(
                    embedding_bag.in_port(2))
            rename_node(embedding_bag, node_name)
            node.in_port(0).get_connection().set_destination(
                embedding_bag.in_port(0))
            node.in_port(1).get_connection().set_destination(
                embedding_bag.in_port(1))
            node.out_port(0).get_connection().set_source(
                embedding_bag.out_port(0))
            if len(node.in_ports()
                   ) == 4 and not node.in_port(3).disconnected():
                if is_packed:
                    node.in_port(3).get_connection().set_destination(
                        embedding_bag.in_port(2))
                else:
                    # connect per_sample_weights
                    node.in_port(3).get_connection().set_destination(
                        embedding_bag.in_port(4))

                    weights_shape_node = Shape(
                        graph, {
                            'name': node_name + '/WeightsShape'
                        }).create_node()

                    weights_rank_node = Rank(graph, {
                        'name': node_name + '/WeightsRank'
                    }).create_node()
                    last_dim_node = get_canonical_axis_index_node(
                        weights_rank_node, -1)
                    weights_last_dim = get_shape_values_by_indices_node(
                        weights_shape_node, last_dim_node)

                    weights_first_dim = node_to_get_shape_value_of_indices(
                        weights_shape_node, [0])

                    zero_col_node = create_op_with_const_inputs(
                        graph, Broadcast, {0: int64_array([0])},
                        {'name': node_name + '/Broadcast'})
                    zero_col_node.in_port(1).connect(
                        weights_last_dim.out_port(0))

                    default_embeddings_node = create_op_with_const_inputs(
                        graph, Unsqueeze, {1: int64_array(0)},
                        {'name': node_name + '/Unsqueeze'})
                    default_embeddings_node.in_port(0).connect(
                        zero_col_node.out_port(0))

                    # expand embedding table with zeros
                    weights_concat = Concat(
                        graph, {
                            'axis': 0,
                            'in_ports_count': 2,
                            'name': node_name + '/Concat'
                        }).create_node()
                    embedding_bag.in_port(0).get_connection().set_destination(
                        weights_concat.in_port(0))
                    weights_concat.in_port(0).get_connection().add_destination(
                        weights_shape_node.in_port(0))
                    weights_concat.in_port(0).get_connection().add_destination(
                        weights_rank_node.in_port(0))
                    weights_concat.in_port(1).connect(
                        default_embeddings_node.out_port(0))
                    weights_concat.out_port(0).connect(
                        embedding_bag.in_port(0))

                    # point default index to expanded part of embedding table
                    weights_first_dim.out_port(0).connect(
                        embedding_bag.in_port(3))
Example #26
0
 def extract(cls, node):
     mapping_rule = {'axis': onnx_attr(node, 'axis', 'i', default=0)}
     Concat.update_node_stat(node, mapping_rule)
     return cls.enabled
Example #27
0
 def extract(cls, node):
     mapping_rule = {'axis': 1}
     Concat.update_node_stat(node, mapping_rule)
     return cls.enabled
Example #28
0
    def insert_select(graph: Graph, node: Node):
        context_len = node.frame_time + 1

        if context_len == 1:
            return

        in_node_port = node.in_port(0).get_source()
        in_node_shape = node.in_port(0).data.get_shape()
        node.in_port(0).disconnect()

        # add Select before saving state to avoid saving garbage
        select_node = Select(graph, {
            'name': 'select_' + node.name
        }).create_node()
        zero_else = create_const_with_batch_from_input(in_node_port,
                                                       in_node_shape[1])
        select_node.in_port(1).connect(in_node_port)
        select_node.in_port(2).connect(zero_else.out_port(0))

        # check if we have already appropriate iteration counter
        existing_counters = find_pattern_matches(
            graph,
            nodes=[('mem_in', dict(op='ReadValue')),
                   ('mem_in_data', dict(shape=int64_array([context_len]))),
                   ('crop_mem_in',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([1]),
                         dim=int64_array([context_len - 1]))),
                   ('crop_mem_in_data', dict()),
                   ('concat', dict(op='Concat', axis=1)),
                   ('concat_data', dict()), ('const_1', dict(op='Const')),
                   ('const_1_data', dict()), ('mem_out', dict(op='Assign')),
                   ('crop_out',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([0]),
                         dim=int64_array([1]))), ('crop_out_data', dict()),
                   ('select', dict(op='Select'))],
            edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'),
                   ('crop_mem_in', 'crop_mem_in_data'),
                   ('crop_mem_in_data', 'concat', {
                       'in': 0
                   }), ('const_1', 'const_1_data'),
                   ('const_1_data', 'concat', {
                       'in': 1
                   }), ('concat', 'concat_data'), ('concat_data', 'mem_out'),
                   ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
                   ('crop_out_data', 'select')])
        counter_match = next(existing_counters, None)
        if counter_match is not None:
            ones = Node(graph, inverse_dict(counter_match)['const_1'])
            input_port = Node(
                graph,
                inverse_dict(counter_match)['crop_out']).out_port(0)
        else:
            init_value_mem_out = create_const_with_batch_from_input(
                in_node_port, context_len, precision=np.int32)
            mem_out = ReadValue(
                graph, {
                    'name': 'iteration_number',
                    'variable_id': 'iteration_' + node.name
                }).create_node()
            mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
            cut_first = Crop(
                graph, {
                    'name': 'cut_first',
                    'axis': int64_array([1]),
                    'offset': int64_array([1]),
                    'dim': int64_array([context_len - 1])
                }).create_node()
            cut_first.in_port(0).connect(mem_out.out_port(0))
            ones = create_const_with_batch_from_input(in_node_port, 1, 1,
                                                      np.int32)
            concat = Concat(graph, {
                'name': 'concat_ones',
                'in_ports_count': 2,
                'axis': 1
            }).create_node()
            concat.in_port(0).connect(cut_first.out_port(0))
            concat.in_port(1).connect(ones.out_port(0))
            mem_in = Assign(
                graph, {
                    'name': 'iteration_number_out',
                    'variable_id': 'iteration_' + node.name
                }).create_node()
            mem_in.in_port(0).connect(concat.out_port(0))
            res = Result(graph, {}).create_node()
            mem_in.out_port(0).connect(res.in_port(0))
            cut_last = Crop(
                graph, {
                    'name': 'cut_last',
                    'axis': int64_array([1]),
                    'offset': int64_array([0]),
                    'dim': int64_array([1])
                }).create_node()
            cut_last.in_port(0).connect(concat.out_port(0))
            input_port = cut_last.out_port(0)

        # Check if data from memory is 1
        # if it is True, we have correct data and should proceed with saving it to memory
        # else we have not gathered context and have garbage here, shouldn't change initial state of memory
        cast_in = Equal(graph, {
            'name': input_port.node.name + '/cast_to_bool'
        }).create_node()
        cast_in.in_port(0).connect(ones.out_port(0))
        cast_in.in_port(1).connect(input_port)
        select_node.in_port(0).connect(cast_in.out_port(0))
        select_node.out_port(0).connect(node.in_port(0))
        select_node.out_port(0).data.set_shape(in_node_shape)
Example #29
0
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        concat_node = match['concat']
        concat_node['axis'] = 1
        concat_name = concat_node.soft_get('name', concat_node.id)

        concat_reshape = create_op_node_with_second_input(
            graph,
            Reshape,
            int64_array([1, 2, -1]),
            op_attrs=dict(name=concat_name + '/Reshape'))
        split_node = create_op_node_with_second_input(
            graph,
            Split,
            int64_array(1),
            op_attrs=dict(name=concat_name + '/Split', num_splits=2),
            input_node=concat_reshape)
        split_node_reshape = create_op_node_with_second_input(
            graph,
            Reshape,
            int64_array([-1, 4]),
            op_attrs=dict(name=split_node.name + '/Reshape'))
        split_node.out_port(0).connect(split_node_reshape.in_port(0))
        value = create_op_node_with_second_input(
            graph,
            Split,
            int64_array(1),
            op_attrs=dict(name=split_node_reshape.name + '/Split',
                          num_splits=4),
            input_node=split_node_reshape)

        xmin, xmax = calculate_prior_box_value(value,
                                               value_to_div=value.out_port(2),
                                               value_to_add=value.out_port(0))
        ymin, ymax = calculate_prior_box_value(value,
                                               value_to_div=value.out_port(3),
                                               value_to_add=value.out_port(1))

        concat_slice_value = Concat(
            graph, dict(name=value.name + '/Concat', in_ports_count=4,
                        axis=1)).create_node()
        for ind, node in enumerate([xmin, ymin, xmax, ymax]):
            concat_slice_value.in_port(ind).connect(node.out_port(0))

        reshape_concat_values = create_op_node_with_second_input(
            graph,
            Reshape,
            int64_array([1, 1, -1]),
            op_attrs=dict(name=concat_slice_value.name + '/Reshape'),
            input_node=concat_slice_value)
        concat = Concat(
            graph,
            dict(name=reshape_concat_values.name + '/Concat',
                 in_ports_count=2,
                 axis=1)).create_node()
        concat.in_port(0).connect(reshape_concat_values.out_port(0))
        concat.in_port(1).connect(split_node.out_port(1))

        match['detection_output'].in_port(2).get_connection().set_source(
            concat.out_port(0))
        concat_node.out_port(0).get_connection().set_destination(
            concat_reshape.in_port(0))
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        in_shape = node.in_port(0).data.get_shape().copy()
        memory_element = in_shape[1] - node.const_dim
        memory_size = memory_element * len(node.context)

        memory_pair_id = unique_id('id')
        # Memory(in)
        input_memory = ReadValue(graph, {
            'name': 'prev_splice_memory',
            'variable_id': memory_pair_id
        }).create_node()

        # Memory(in)  \
        #             Crop
        # Input(temp) /
        crop = Crop(
            graph, {
                'name': 'Splice_Crop',
                'axis': int64_array([1]),
                'offset': int64_array([memory_element]),
                'dim': int64_array([memory_size - memory_element])
            }).create_node()
        crop.in_port(0).connect(input_memory.out_port(0))

        # Crop   \
        #         Concat
        # Input  /
        concat_node = Concat(graph, {
            'name': 'Splice_Concat',
            'in_ports_count': 2,
            'axis': 1
        }).create_node()
        concat_node.in_port(0).connect(crop.out_port(0))

        # Concat -> Memory(out)
        mem_out = Assign(graph, {
            'name': 'out_splice_memory',
            'variable_id': memory_pair_id
        }).create_node()
        mem_out.in_port(0).connect(concat_node.out_port(0))
        Result(graph).create_node().in_port(0).connect(mem_out.out_port(0))

        if node.const_dim != 0:
            memory_element_constdim = node.const_dim
            memory_size_constdim = memory_element_constdim * len(node.context)

            split = create_op_with_const_inputs(
                graph, VariadicSplit, {
                    1: int64_array(1),
                    2: int64_array([memory_element, memory_element_constdim])
                }, {
                    'name': node.id + '_split_const',
                    'out_ports_count': 2
                })

            split.out_port(0).connect(concat_node.in_port(1))

            # create separate splice construction for const_dim
            memory_pair_id = unique_id('memory_for_const_dim')
            init_value_input_memory_const_dim = Const(
                graph, {
                    'name':
                    'init_value_const_dim_in_memory',
                    'value':
                    np.zeros(int64_array([in_shape[0], memory_size_constdim]),
                             dtype=np.float32),
                    'shape':
                    int64_array([in_shape[0], memory_size_constdim])
                }).create_node()
            input_memory_const_dim = ReadValue(graph, {
                'name': 'const_dim_in_memory',
                'variable_id': memory_pair_id
            }).create_node()
            init_value_input_memory_const_dim.out_port(0).connect(
                input_memory_const_dim.in_port(0))

            crop_const_dim = Crop(
                graph, {
                    'name':
                    'const_dim_crop',
                    'axis':
                    int64_array([1]),
                    'offset':
                    int64_array([memory_element_constdim]),
                    'dim':
                    int64_array(
                        [memory_size_constdim - memory_element_constdim])
                }).create_node()
            crop_const_dim.in_port(0).connect(
                input_memory_const_dim.out_port(0))

            concat_node_const_dim = Concat(graph, {
                'name': 'const_dim_concat',
                'in_ports_count': 2,
                'axis': 1
            }).create_node()
            concat_node_const_dim.in_port(0).connect(
                crop_const_dim.out_port(0))

            mem_out_const_dim = Assign(graph, {
                'name': 'const_dim_out_memory',
                'variable_id': memory_pair_id
            }).create_node()
            mem_out_const_dim.in_port(0).connect(
                concat_node_const_dim.out_port(0))
            Result(graph).create_node().in_port(0).connect(
                mem_out_const_dim.out_port(0))

            # connect splice to Split as begin and Concat as the end
            split.out_port(1).connect(concat_node_const_dim.in_port(1))
            crop_first = Crop(
                graph, {
                    'name': 'const_dim_crop_first',
                    'axis': int64_array([1]),
                    'offset': int64_array([0]),
                    'dim': int64_array([memory_element_constdim])
                }).create_node()
            crop_first.in_port(0).connect(concat_node_const_dim.out_port(0))

            concat_const = Concat(graph, {
                'name': node.id + '_concat_const',
                'axis': 1,
                'in_ports_count': 2
            }).create_node()
            concat_const.in_port(1).connect(crop_first.out_port(0))
            concat_const.in_port(0).connect(concat_node.out_port(0))

            init_value_input_memory = Const(
                graph, {
                    'name':
                    'init_value_' + node.name,
                    'value':
                    np.zeros(int64_array([in_shape[0], memory_size]),
                             dtype=np.float32),
                    'shape':
                    int64_array([in_shape[0], memory_size])
                }).create_node()
            init_value_input_memory.out_port(0).connect(
                input_memory.in_port(0))
            node.in_port(0).get_connection().set_destination(split.in_port(0))
            node.out_port(0).get_connection().set_source(
                concat_const.out_port(0))
        else:
            init_value_input_memory = Const(
                graph, {
                    'name':
                    'init_value_' + node.name,
                    'value':
                    np.zeros(int64_array([in_shape[0], memory_size]),
                             dtype=np.float32),
                    'shape':
                    int64_array([in_shape[0], memory_size])
                }).create_node()
            init_value_input_memory.out_port(0).connect(
                input_memory.in_port(0))
            node.in_port(0).get_connection().set_destination(
                concat_node.in_port(1))
            node.out_port(0).get_connection().set_source(
                concat_node.out_port(0))

        # to avoid re-inference of shape and touching in next replacements
        graph.remove_node(node.id)