def replace_op(self, graph: Graph, node: Node):
        node_name = node.soft_get('name', node.id)

        # broadcast default value to required shape
        broadcast_node = Broadcast(graph, {
            'name': node_name + '/Broadcast_'
        }).create_node()
        node.in_port(1).get_connection().set_destination(
            broadcast_node.in_port(1))
        if not node.in_port(3).disconnected():
            node.in_port(3).get_connection().set_destination(
                broadcast_node.in_port(0))
        else:
            broadcast_node.in_port(0).connect(
                Const(
                    graph, {
                        'name': broadcast_node.name + '/FillValue_',
                        'value': np.float32(0)
                    }).create_node().out_port(0))

        # update broadcasted tensor with required values at required locations
        scatternd_node = ScatterNDUpdate(
            graph, {
                'name': node_name + '/ScatterNDUpdate_'
            }).create_node()
        scatternd_node.in_port(0).connect(broadcast_node.out_port(0))
        node.in_port(0).get_connection().set_destination(
            scatternd_node.in_port(1))
        node.in_port(2).get_connection().set_destination(
            scatternd_node.in_port(2))

        rename_nodes([(node, node_name + "/AbandonedName"),
                      (scatternd_node, node_name)])

        return [scatternd_node.id]
 def find_and_replace_pattern(self, graph: Graph):
     for const_of_shape_node in graph.get_op_nodes(op='ConstantOfShape'):
         broadcast_node = Broadcast(graph, {'name': const_of_shape_node.name + '/Broadcast'}).create_node()
         const_of_shape_node.in_port(0).get_connection().set_destination(broadcast_node.in_port(1))
         broadcast_node.in_port(0).connect(Const(graph, {'name': broadcast_node.name + '/FillValue',
                                                         'value': const_of_shape_node.fill_value}
                                                 ).create_node().out_port(0))
         const_of_shape_node.out_port(0).get_connection().set_source(broadcast_node.out_port(0))
Example #3
0
    def find_and_replace_pattern(self, graph: Graph):
        reverse_nodes = graph.get_op_nodes(op='Reverse')
        for reverse in reverse_nodes:
            reverse_name = reverse.soft_get('name', reverse.id)

            assert reverse.in_port(1).disconnected()
            assert reverse.has_valid('axis')

            in_shape_rank = len(reverse.in_port(0).data.get_shape())
            # 1. Add new dimension as batch for rank = 1 to have batch != seq_axis
            if in_shape_rank == 1:
                unsq_node = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]),
                                                             {'name': reverse_name+"/Unsqueeze"})
                reverse.in_port(0).get_source().connect(unsq_node.in_port(0))
                new_in = unsq_node.out_port(0)
                batch_axis = 0
                seq_axis = 1
            else:
                new_in = reverse.in_port(0).get_source()
                seq_axis = reverse['axis']
                batch_axis = 0 if seq_axis != 0 else 1

            # 2. For ReverseSequence 1-port input is seq_lengths => create this input node as
            # shape[seq_axis] broadcasted to shape[batch_axis]
            # in ---> ShapeOf ----> Gather(seq_axis)  ----> Broadcast----->
            #            |                                      |
            #            | -------> Gather(batch_axis)----------|
            shape_node = Shape(graph, {'name': reverse_name + "/Shape"}).create_node()
            new_in.connect(shape_node.in_port(0))
            seq_axis_node = node_to_get_shape_value_of_indices(shape_node, [seq_axis])
            batch_node = node_to_get_shape_value_of_indices(shape_node, [batch_axis])
            broadcast_node = Broadcast(graph, {'name': reverse_name + "/Broadcast"}).create_node()
            broadcast_node.in_port(0).connect(seq_axis_node.out_port(0))
            broadcast_node.in_port(1).connect(batch_node.out_port(0))

            # 3. Create new ReverseSequence node and reconnect all inputs/outputs to it
            rename_node(reverse, reverse_name + '/to_delete')
            reverse_sequence = ReverseSequence(graph, {'name':  reverse_name, 'seq_axis': seq_axis,
                                                       'batch_axis': batch_axis}).create_node()
            reverse_sequence.in_port(0).connect(new_in)
            reverse_sequence.in_port(1).connect(broadcast_node.out_port(0))

            # 4. remove added dimension for rank = 1
            if in_shape_rank == 1:
                rename_node(reverse_sequence, reverse_name + '/ReverseSequence')
                squeeze_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]),
                                                                {'name': reverse_name})
                squeeze_node.in_port(0).connect(reverse_sequence.out_port(0))
                reverse.out_port(0).get_connection().set_source(squeeze_node.out_port(0))
            else:
                reverse.out_port(0).get_connection().set_source(reverse_sequence.out_port(0))

        # 5. Delete old Reverse node
        graph.remove_nodes_from([reverse.id for reverse in reverse_nodes])
Example #4
0
    def append_variances(priors_scale_node: Node, variance: list):
        graph = priors_scale_node.graph
        name = priors_scale_node.name

        sp_shape = Shape(graph, {'name': name + '/shape'}).create_node()
        priors_scale_node.out_port(0).connect(sp_shape.in_port(0))

        begin = Const(graph, {'value': int64_array([-2])}).create_node()
        end = Const(graph, {'value': int64_array([-1])}).create_node()
        stride = Const(graph, {'value': int64_array([1])}).create_node()
        shape_part_for_tiling = StridedSlice(graph, {'name': name + '/get_-2_dim', 'begin_mask': int64_array([1]),
                                                     'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]),
                                                     'shrink_axis_mask': int64_array([0]),
                                                     'ellipsis_mask': int64_array([0])}).create_node()

        sp_shape.out_port(0).connect(shape_part_for_tiling.in_port(0))
        begin.out_port(0).connect(shape_part_for_tiling.in_port(1))
        end.out_port(0).connect(shape_part_for_tiling.in_port(2))
        stride.out_port(0).connect(shape_part_for_tiling.in_port(3))

        shape_concat = create_op_node_with_second_input(graph, Concat, int64_array([4]),
                                                        {'name': name + '/shape_for_tiling', 'in_ports_count': 2,
                                                         'axis': int64_array(0)},
                                                        shape_part_for_tiling)

        variance = Const(graph, {'name': name + '/variance', 'value': float32_array(variance)}).create_node()
        tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node()
        variance.out_port(0).connect(tile.in_port(0))
        shape_concat.out_port(0).connect(tile.in_port(1))

        reshape_dim = Const(graph, {'value': int64_array([-1, 4])}).create_node()
        sp_reshape = Reshape(graph, {'name': name + '/reshape'}).create_node()
        sp_reshape.in_port(0).connect(priors_scale_node.out_port(0))
        sp_reshape.in_port(1).connect(reshape_dim.out_port(0))

        concat = Concat(graph,
                        {'name': name + '/priors_concat', 'axis': int64_array(0), 'in_ports_count': 2}).create_node()
        sp_reshape.out_port(0).connect(concat.in_port(0))
        tile.out_port(0).connect(concat.in_port(1))

        output_dims = Const(graph, {'value': int64_array([1, 2, -1])}).create_node()
        output_node = Reshape(graph, {'name': name + '/3D_priors_wth_variances'}).create_node()
        concat.out_port(0).connect(output_node.in_port(0))
        output_dims.out_port(0).connect(output_node.in_port(1))

        return output_node
Example #5
0
    def find_and_replace_pattern(self, graph: Graph):
        for tf_scatter_nd in graph.get_op_nodes(op='TFScatterND'):
            if not tf_scatter_nd.is_in_port_connected(0) or not tf_scatter_nd.is_in_port_connected(1) \
                    or not tf_scatter_nd.is_in_port_connected(2):
                continue
            name = tf_scatter_nd.soft_get('name', tf_scatter_nd.soft_get('id'))
            indices_port = tf_scatter_nd.in_port(0).get_source()
            updates_port = tf_scatter_nd.in_port(1).get_source()
            shape_port = tf_scatter_nd.in_port(2).get_source()
            # need get type of  const type
            zero_const = Const(graph, {
                'value': int64_array(0.0),
                'name': name + '/zero_const'
            }).create_node()

            # Convert zero value to type of updates node
            convert_to_type = ConvertLike(graph, {
                'name': name + '/convert_like'
            }).create_node()
            convert_to_type.in_port(0).connect(zero_const.out_port(0))
            convert_to_type.in_port(1).connect(updates_port)

            broad_cast_node = Broadcast(graph, {
                'name': name + '/broadcast'
            }).create_node()
            broad_cast_node.in_port(0).connect(convert_to_type.out_port(0))
            broad_cast_node.in_port(1).connect(shape_port)

            scatter_nd_node = ScatterNDUpdate(graph, {
                'name': name + '/replaced'
            }).create_node()
            scatter_nd_node.in_port(0).connect(broad_cast_node.out_port(0))
            scatter_nd_node.in_port(1).connect(indices_port)
            scatter_nd_node.in_port(2).connect(updates_port)

            rename_nodes([(tf_scatter_nd, name + '/TBD'),
                          (scatter_nd_node, name)])

            tf_scatter_nd.out_port(0).get_connection().set_source(
                scatter_nd_node.out_port(0))
            tf_scatter_nd.in_port(0).disconnect()
            tf_scatter_nd.in_port(1).disconnect()
            tf_scatter_nd.in_port(2).disconnect()
    def find_and_replace_pattern(self, graph: Graph):
        for fill_node in graph.get_op_nodes(op='Fill'):
            name = fill_node.soft_get('name', fill_node.id)

            broadcast_node = Broadcast(graph, {'name': name + '/Broadcast'}).create_node()
            fill_node.in_port(0).get_connection().set_destination(broadcast_node.in_port(1))
            fill_node.in_port(1).get_connection().set_destination(broadcast_node.in_port(0))
            fill_node.out_port(0).get_connection().set_source(broadcast_node.out_port(0))

        for fill_node in graph.get_op_nodes(op='ConstantFill'):
            name = fill_node.soft_get('name', fill_node.id)

            assert fill_node.has_valid('fill_value')
            assert fill_node.has_and_set('input_as_shape')

            const = Const(graph, {'value': mo_array(fill_node.fill_value), 'name': name + '/value'}).create_node()
            broadcast_node = Broadcast(graph, {'name': name + '/Broadcast'}).create_node()
            fill_node.in_port(0).get_connection().set_destination(broadcast_node.in_port(1))
            const.out_port(0).connect(broadcast_node.in_port(0))
            fill_node.out_port(0).get_connection().set_source(broadcast_node.out_port(0))
Example #7
0
    def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
        reshape_classes_node = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]),
                                                                dict(name='do_reshape_classes'),
                                                                match.single_input_node(1)[0])

        initial_priors_node = match.single_input_node(2)[0]
        priors_name = initial_priors_node.soft_get('name', initial_priors_node.id)
        # model calculates identical prior boxes for each batch, so we take first slice of them
        begin = Const(graph, {'value': mo_array([0, 0, 0], dtype=np.int32)}).create_node()
        end = Const(graph, {'value': mo_array([1, 0, 0], dtype=np.int32)}).create_node()
        stride = Const(graph, {'value': mo_array([1, 1, 1], dtype=np.int32)}).create_node()

        priors_node = StridedSlice(graph, {'name': priors_name + '/0_batch_slice',
                                           'begin_mask': int64_array([1, 1, 1]),
                                           'end_mask': int64_array([1, 0, 0]),
                                           'new_axis_mask': int64_array([0]),
                                           'shrink_axis_mask': int64_array([0]),
                                           'ellipsis_mask': int64_array([0])}).create_node()

        initial_priors_node.out_port(0).connect(priors_node.in_port(0))
        begin.out_port(0).connect(priors_node.in_port(1))
        end.out_port(0).connect(priors_node.in_port(2))
        stride.out_port(0).connect(priors_node.in_port(3))

        placeholders = graph.get_op_nodes(type='Parameter')
        assert len(placeholders) == 1, "{} replacer requires model to have one Placeholder, but current model has " \
                                       "{} placeholders".format(self.replacement_id, len(placeholders))
        placeholder = placeholders[0]

        # scale prior boxes to the [0, 1] interval
        node_with_scales_for_prior_boxes = self.placeholder_scales(placeholder)
        priors_scale_node = Mul(graph, {'name': 'scale_priors'}).create_node()

        broadcast = Broadcast(graph, {'name': 'scales_broadcast'}).create_node()
        shape_of_priors = Shape(graph, {'name': 'priors_shape'}).create_node()
        priors_node.out_port(0).connect(shape_of_priors.in_port(0))
        broadcast.in_port(1).connect(shape_of_priors.out_port(0))
        broadcast.in_port(0).connect(node_with_scales_for_prior_boxes.out_port(0))

        priors_scale_node.in_port(0).connect(priors_node.out_port(0))
        priors_scale_node.in_port(1).connect(broadcast.out_port(0))

        try:
            variance = match.custom_replacement_desc.custom_attributes['variance']
        except:
            raise Error('There is no variance attribute in {} replacement config file `custom_attributes`'
                        ''.format(self.replacement_id))

        priors = self.append_variances(priors_scale_node, variance)

        # calculate prior boxes widths and heights
        split_node = create_op_with_const_inputs(
            graph, VariadicSplit, {1: int64_array(2), 2: int64_array([1, 1, 1, 1])}, {'out_ports_count': 4},
            priors_scale_node)

        priors_width_node = Sub(graph, dict(name=split_node.name + '/sub_2-0_')
                                ).create_node([(split_node, 2), (split_node, 0)])
        priors_height_node = Sub(graph, dict(name=split_node.name + '/sub_3-1_')
                                 ).create_node([(split_node, 3), (split_node, 1)])

        # concat weights and heights into a single tensor and multiple with the box coordinates regression values
        # WA with 3 Concats instead of 1 for keeping model reshapable
        # concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height', 'axis': -1,
        #                                           'in_ports_count': 4}).create_node(
        # [priors_width_node, priors_height_node, priors_width_node, priors_height_node])

        concat_1 = Concat(graph, {'name': 'concat_width_height',
                                  'axis': -1, 'in_ports_count': 2}).create_node([priors_width_node, priors_height_node])
        concat_2 = Concat(graph, {'name': 'concat_width_height_width',
                                  'axis': -1, 'in_ports_count': 2}).create_node([concat_1, priors_width_node])
        concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height', 'axis': -1, 'in_ports_count': 2}
                                          ).create_node([concat_2, priors_height_node])

        applied_width_height_regressions_node = Mul(graph, {'name': 'final_regressions'}).create_node(
            [concat_width_height_node, match.single_input_node(0)[0]])

        # reshape to 2D tensor as Inference Engine Detection Output layer expects
        reshape_regression_node = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]),
                                                                   dict(name='reshape_regression'),
                                                                   applied_width_height_regressions_node)

        detection_output_op = DetectionOutput(graph, match.custom_replacement_desc.custom_attributes)
        # get nms from the original network
        iou_threshold = None
        nms_nodes = graph.get_op_nodes(op='NonMaxSuppression')
        if len(nms_nodes) > 0:
            # it is highly unlikely that for different classes NMS has different
            # moreover DetectionOutput accepts only scalar values for iou_threshold (nms_threshold)
            iou_threshold = nms_nodes[0].in_node(3).value
        if iou_threshold is None:
            raise Error('During {} `iou_threshold` was not retrieved from RetinaNet graph'.format(self.replacement_id))

        detection_output_node = detection_output_op.create_node(
            [reshape_regression_node, reshape_classes_node, priors],
            dict(name=detection_output_op.attrs['type'], nms_threshold=iou_threshold, clip_after_nms=1, normalized=1,
                 variance_encoded_in_target=0, background_label_id=1000))

        # As outputs are replaced with a postprocessing node, outgoing tensor names are no longer
        # correspond to original tensors and should be removed from output->Result edges
        out_nodes = []
        for out in range(match.outputs_count()):
            out_nodes.append(match.output_node(out)[0])
        clear_tensor_names_info(out_nodes)

        return {'detection_output_node': detection_output_node}
Example #8
0
    def mxrepeat_decomposition(node: Node):
        graph = node.graph
        name = node.soft_get('name', node.id)

        rename_node(node, name + '/to_be_removed')

        # Unqueeze
        input_rank = Rank(graph, {'name': name + '/Rank'}).create_node()
        node.in_port(0).get_source().connect(input_rank.in_port(0))

        axis = get_canonical_axis_index_node(input_rank, node.axis)
        unsqueeze_axis = create_op_node_with_second_input(
            graph,
            Add,
            int64_array([1]), {'name': name + '/Unsqueeze/Axis'},
            input_node=axis)

        unsqueeze = Unsqueeze(graph, {
            'name': name + '/Unsqueeze'
        }).create_node()
        unsqueeze.in_port(1).connect(unsqueeze_axis.out_port(0))

        # Tile (1, 1, ..., repeats, ..., 1)
        # we generate tile array according to the following table:

        # parts:       |      first      |  repeats |  second     |
        # i:           | 0, 1, ..., axis,| axis + 1,| ..., rank+1 |
        # tile_array:  | 1, 1, ...,  1  ,| repeats ,| ...,   1    |

        one = Const(graph, {
            'name': name + '/Broadcast/One',
            'value': int64_array([1])
        }).create_node()
        first_ones = Broadcast(graph, {
            'name': name + '/Broadcast/Ones_first_part'
        }).create_node()
        first_ones.in_port(0).connect(one.out_port(0))
        first_ones.in_port(1).connect(unsqueeze_axis.out_port(0))

        repeats = Const(graph, {
            'name': name + '/repeats',
            'value': int64_array([node.repeats])
        }).create_node()

        second_ones = Broadcast(graph, {
            'name': name + '/Broadcast/Ones_second_part'
        }).create_node()
        second_part_broadcast_shape = Sub(
            graph, {
                'name': name + '/Broadcast/Shape/second_part'
            }).create_node()
        second_part_broadcast_shape.in_port(0).connect(input_rank.out_port(0))
        second_part_broadcast_shape.in_port(1).connect(
            unsqueeze_axis.out_port(0))
        second_ones.in_port(0).connect(one.out_port(0))
        second_ones.in_port(1).connect(second_part_broadcast_shape.out_port(0))

        tile_repeats = new_shape_node_from_shape_nodes(
            [first_ones, repeats, second_ones])
        tile = Tile(graph, {'name': name + '/Tile'}).create_node()
        tile.in_port(1).connect(tile_repeats.out_port(0))

        # Reshape (input_shape[:axis], input_shape[axis] * repeats, input_shape[axis+1:])
        # we generate reshape dim array according to the following table:

        # parts:       |    first   |                rep           |  second   |
        # i:           | 0, 1, ... ,|               axis,          | ..., rank |
        # dim_array:   | inp_sh[i] ,| input_shape[axis] * repeats ,| inp_sh[i] |

        input_shape = Shape(graph, {'name': name + '/Shape'}).create_node()
        node.in_port(0).get_source().connect(input_shape.in_port(0))

        first_input_shape_part = get_shape_values_by_range_idxs(
            input_shape,
            input_rank,
            begin=0,
            end=node.axis,
            include_begin=True,
            include_end=False)

        original_axis_dim = create_op_with_const_inputs(
            graph,
            Gather, {2: int64_array(0)}, {'name': name + '/OriginalDim'},
            input_node=input_shape)
        original_axis_dim.in_port(1).connect(axis.out_port(0))

        repeated_dimention = Mul(graph, {
            'name': name + '/RepeatedDim'
        }).create_node()
        repeated_dimention.in_port(0).connect(original_axis_dim.out_port(0))
        repeated_dimention.in_port(1).connect(repeats.out_port(0))

        second_input_shape_part = get_shape_values_by_range_idxs(
            input_shape,
            input_rank,
            begin=node.axis,
            end=-1,
            include_begin=False,
            include_end=True)

        output_shape = new_shape_node_from_shape_nodes([
            first_input_shape_part, repeated_dimention, second_input_shape_part
        ])

        reshape = Reshape(graph, {'name': name}).create_node()
        rename_node(reshape, name)
        reshape.in_port(1).connect(output_shape.out_port(0))

        # Final connections
        node.in_port(0).get_connection().set_destination(unsqueeze.in_port(0))
        tile.in_port(0).connect(unsqueeze.out_port(0))
        reshape.in_port(0).connect(tile.out_port(0))
        node.out_port(0).get_connection().set_source(reshape.out_port(0))