Esempio n. 1
0
    def test_broadcast(self, data, target_shape, axes_mapping=None, mode='numpy', ref_out=None, test_raising=False):
        if ref_out is not None:
            input = valued_const_with_data('data', int64_array(data))
        else:
            input = shaped_data('data', int64_array(data))

        nodes = {
            **input,
            **valued_const_with_data('target_shape', int64_array(target_shape)),
            **regular_op_with_empty_data('broadcast', {'op': 'Broadcast', 'mode': mode}),
        }

        edges = [('data', 'broadcast'),
                 ('target_shape', 'broadcast'),
                 ('broadcast', 'broadcast_d')]

        if axes_mapping is not None:
            nodes.update(**valued_const_with_data('axes_mapping', int64_array(axes_mapping)))
            edges.append(('axes_mapping', 'broadcast'))
        graph = build_graph(nodes, edges)

        broadcast_node = Node(graph, 'broadcast')
        if test_raising:
            self.assertRaises(AssertionError, Broadcast.infer, broadcast_node)
            return

        Broadcast.infer(broadcast_node)
        if ref_out is not None:
            self.assertTrue(np.array_equal(broadcast_node.out_node().value, np.array(ref_out)))
        else:
            self.assertTrue(np.array_equal(broadcast_node.out_node().shape, np.array(target_shape)))
Esempio n. 2
0
    def find_and_replace_pattern(self, graph: Graph):
        for fill_node in graph.get_op_nodes(op='Fill'):
            name = fill_node.soft_get('name', fill_node.id)

            broadcast_node = Broadcast(graph, {
                'name': name + '/Broadcast'
            }).create_node()
            fill_node.in_port(0).get_connection().set_destination(
                broadcast_node.in_port(1))
            fill_node.in_port(1).get_connection().set_destination(
                broadcast_node.in_port(0))
            fill_node.out_port(0).get_connection().set_source(
                broadcast_node.out_port(0))

        for fill_node in graph.get_op_nodes(op='ConstantFill'):
            name = fill_node.soft_get('name', fill_node.id)

            assert fill_node.has_valid('fill_value')
            assert fill_node.has_and_set('input_as_shape')

            const = Const(graph, {
                'value': np.array(fill_node.fill_value),
                'name': name + '/value'
            }).create_node()
            broadcast_node = Broadcast(graph, {
                'name': name + '/Broadcast'
            }).create_node()
            fill_node.in_port(0).get_connection().set_destination(
                broadcast_node.in_port(1))
            const.out_port(0).connect(broadcast_node.in_port(0))
            fill_node.out_port(0).get_connection().set_source(
                broadcast_node.out_port(0))
 def find_and_replace_pattern(self, graph: Graph):
     for fill_node in graph.get_op_nodes(op='Fill'):
         broadcast_node = Broadcast(graph, {}).create_node()
         fill_node.in_port(0).get_connection().set_destination(
             broadcast_node.in_port(1))
         fill_node.in_port(1).get_connection().set_destination(
             broadcast_node.in_port(0))
         fill_node.out_port(0).get_connection().set_source(
             broadcast_node.out_port(0))
Esempio n. 4
0
    def replace_op(self, graph: Graph, node: Node):
        node_name = node.soft_get('name', node.id)

        # broadcast default value to required shape
        broadcast_node = Broadcast(graph, {'name': node_name + '/Broadcast_'}).create_node()
        node.in_port(1).get_connection().set_destination(broadcast_node.in_port(1))
        if not node.in_port(3).disconnected():
            # TODO: remove casting once we start to support I64 model input
            # cast default value to I32 due limitation about I64 input support
            # so that input parameter and default value will be of the same I32 type as required ScatterNDUpdate
            cast_default_value = Cast(graph, {'name': node_name + '/CastDefaultValue', 'dst_type': np.int32}).create_node()
            node.in_port(3).get_connection().set_destination(cast_default_value.in_port(0))
            broadcast_node.in_port(0).connect(cast_default_value.out_port(0))
        else:
            broadcast_node.in_port(0).connect(Const(graph, {'name': broadcast_node.name + '/FillValue_',
                                                            'value': np.float32(0)}
                                                    ).create_node().out_port(0))

        # update broadcasted tensor with required values at required locations
        scatternd_node = ScatterNDUpdate(graph, {'name': node_name + '/ScatterNDUpdate_'}).create_node()
        scatternd_node.in_port(0).connect(broadcast_node.out_port(0))
        node.in_port(0).get_connection().set_destination(scatternd_node.in_port(1))
        node.in_port(2).get_connection().set_destination(scatternd_node.in_port(2))

        rename_nodes([(node, node_name + "/AbandonedName"), (scatternd_node, node_name)])

        return [scatternd_node.id]
 def find_and_replace_pattern(self, graph: Graph):
     for const_of_shape_node in graph.get_op_nodes(op='ConstantOfShape'):
         broadcast_node = Broadcast(graph, {'name': const_of_shape_node.name + '/Broadcast'}).create_node()
         const_of_shape_node.in_port(0).get_connection().set_destination(broadcast_node.in_port(1))
         broadcast_node.in_port(0).connect(Const(graph, {'name': broadcast_node.name + '/FillValue',
                                                         'value': const_of_shape_node.fill_value}
                                                 ).create_node().out_port(0))
         const_of_shape_node.out_port(0).get_connection().set_source(broadcast_node.out_port(0))
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        shapes = [in_node.shape for _, in_node in node.in_nodes().items()]
        out_shape = node.out_node().shape
        broadcast_name = node.name + '/Broadcast/'

        for i, shape in enumerate(shapes):
            if not np.array_equal(shape, out_shape):
                # Add Broadcast op for this input
                # Need to create additional Const op for shape
                new_shape = Const(graph, {'name': broadcast_name + 'Shape', 'value': out_shape.copy()}).create_node()
                broadcast_axis = Const(graph, {
                    'name': broadcast_name + 'Axis',
                    'value': np.array(range(len(out_shape)), dtype=np.int64)}
                ).create_node()
                broadcast = Broadcast(graph, {'name': broadcast_name}).create_node()
                node.in_port(i).get_connection().set_destination(broadcast.in_port(0))
                broadcast.in_port(1).connect(new_shape.out_port(0))
                broadcast.in_port(2).connect(broadcast_axis.out_port(0))
                broadcast.out_port(0).connect(node.in_port(i))
Esempio n. 7
0
def create_zero_value_with_batch_from_input(input_out_port: Port,
                                            second_dim,
                                            precision=np.float):
    # create init_graph connected to ReadValue
    graph = input_out_port.node.graph
    input_name = input_out_port.node.name
    shape_of_input = Shape(graph, {
        'name': 'shape/' + input_name
    }).create_node()
    shape_of_input.in_port(0).connect(input_out_port)
    dim_for_get_batch = Const(
        graph, {
            'name': 'dim/crop_batch/' + shape_of_input.name,
            'value': int64_array([1]),
            'shape': int64_array([1])
        }).create_node()
    get_batch = Crop(
        graph, {
            'name': 'crop_batch/' + shape_of_input.name,
            'axis': int64_array([0]),
            'offset': int64_array([0])
        }).create_node()
    get_batch.in_port(0).connect(shape_of_input.out_port(0))
    get_batch.in_port(1).connect(dim_for_get_batch.out_port(0))
    mem_shape_2nd_dim = Const(
        graph, {
            'name': 'gifo_r_weights_shape/' + input_name,
            'value': int64_array([second_dim]),
            'shape': int64_array([1])
        }).create_node()
    mem_shape = Concat(
        graph, {
            'name': 'gather_memory_shape/' + input_name,
            'axis': 0,
            'in_ports_count': 2
        }).create_node()
    mem_shape.in_port(0).connect(get_batch.out_port(0))
    mem_shape.in_port(1).connect(mem_shape_2nd_dim.out_port(0))
    fill_value = Const(
        graph, {
            'name': 'fill_value/' + input_name,
            'value': np.array([0.0], precision),
            'shape': int64_array([1])
        }).create_node()
    init_value_prev_lstm_output = Broadcast(graph, {
        'name': 'init_value/' + input_name,
    }).create_node()
    init_value_prev_lstm_output.in_port(0).connect(fill_value.out_port(0))
    init_value_prev_lstm_output.in_port(1).connect(mem_shape.out_port(0))
    return init_value_prev_lstm_output
    def append_variances(priors_scale_node: Node, variance: list):
        graph = priors_scale_node.graph
        name = priors_scale_node.name

        sp_shape = Shape(graph, {'name': name + '/shape'}).create_node()
        priors_scale_node.out_port(0).connect(sp_shape.in_port(0))

        begin = Const(graph, {'value': np.array([-2])}).create_node()
        end = Const(graph, {'value': np.array([-1])}).create_node()
        stride = Const(graph, {'value': np.array([1])}).create_node()
        shape_part_for_tiling = StridedSlice(graph, {'name': name + '/get_-2_dim', 'begin_mask': np.array([1]),
                                                     'end_mask': np.array([1]), 'new_axis_mask': np.array([0]),
                                                     'shrink_axis_mask': np.array([0]),
                                                     'ellipsis_mask': np.array([0])}).create_node()

        sp_shape.out_port(0).connect(shape_part_for_tiling.in_port(0))
        begin.out_port(0).connect(shape_part_for_tiling.in_port(1))
        end.out_port(0).connect(shape_part_for_tiling.in_port(2))
        stride.out_port(0).connect(shape_part_for_tiling.in_port(3))

        concat_value = Const(graph, {'value': np.array([4])}).create_node()
        shape_concat = Concat(graph, {'name': name + '/shape_for_tiling', 'in_ports_count': 2,
                                      'axis': np.array(0)}).create_node()
        shape_part_for_tiling.out_port(0).connect(shape_concat.in_port(0))
        concat_value.out_port(0).connect(shape_concat.in_port(1))

        variance = Const(graph, {'name': name + '/variance', 'value': np.array(variance)}).create_node()
        tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node()
        variance.out_port(0).connect(tile.in_port(0))
        shape_concat.out_port(0).connect(tile.in_port(1))

        reshape_dim = Const(graph, {'value': int64_array([-1, 4])}).create_node()
        sp_reshape = Reshape(graph, {'name': name + '/reshape'}).create_node()
        sp_reshape.in_port(0).connect(priors_scale_node.out_port(0))
        sp_reshape.in_port(1).connect(reshape_dim.out_port(0))

        concat = Concat(graph,
                        {'name': name + '/priors_concat', 'axis': np.array(0), 'in_ports_count': 2}).create_node()
        sp_reshape.out_port(0).connect(concat.in_port(0))
        tile.out_port(0).connect(concat.in_port(1))

        output_dims = Const(graph, {'value': int64_array([1, 2, -1])}).create_node()
        output_node = Reshape(graph, {'name': name + '/3D_priors_wth_variances'}).create_node()
        concat.out_port(0).connect(output_node.in_port(0))
        output_dims.out_port(0).connect(output_node.in_port(1))

        return output_node
    def replace_op(self, graph: Graph, node: Node):
        node_name = node.soft_get('name', node.id)

        # broadcast default value to required shape
        broadcast_node = Broadcast(graph, {'name': node_name + '/Broadcast_'}).create_node()
        node.in_port(1).get_connection().set_destination(broadcast_node.in_port(1))
        if not node.in_port(3).disconnected():
            node.in_port(3).get_connection().set_destination(broadcast_node.in_port(0))
        else:
            broadcast_node.in_port(0).connect(Const(graph, {'name': broadcast_node.name + '/FillValue_',
                                                            'value': np.float32(0)}
                                                    ).create_node().out_port(0))

        # update broadcasted tensor with required values at required locations
        scatternd_node = ScatterNDUpdate(graph, {'name': node_name + '/ScatterNDUpdate_'}).create_node()
        scatternd_node.in_port(0).connect(broadcast_node.out_port(0))
        node.in_port(0).get_connection().set_destination(scatternd_node.in_port(1))
        node.in_port(2).get_connection().set_destination(scatternd_node.in_port(2))

        rename_nodes([(node, node_name + "/AbandonedName"), (scatternd_node, node_name)])

        return [scatternd_node.id]
Esempio n. 10
0
    def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
        reshape_classes_node = create_op_node_with_second_input(
            graph, Reshape, int64_array([0, -1]),
            dict(name='do_reshape_classes'),
            match.single_input_node(1)[0])

        initial_priors_node = match.single_input_node(2)[0]
        priors_name = initial_priors_node.soft_get('name',
                                                   initial_priors_node.id)
        # model calculates identical prior boxes for each batch, so we take first slice of them
        begin = Const(graph, {
            'value': np.array([0, 0, 0], dtype=np.int32)
        }).create_node()
        end = Const(graph, {
            'value': np.array([1, 0, 0], dtype=np.int32)
        }).create_node()
        stride = Const(graph, {
            'value': np.array([1, 1, 1], dtype=np.int32)
        }).create_node()

        priors_node = StridedSlice(
            graph, {
                'name': priors_name + '/0_batch_slice',
                'begin_mask': np.array([1, 1, 1], dtype=np.int32),
                'end_mask': np.array([1, 0, 0], dtype=np.int32),
                'new_axis_mask': np.array([0], dtype=np.int32),
                'shrink_axis_mask': np.array([0], dtype=np.int32),
                'ellipsis_mask': np.array([0], dtype=np.int32)
            }).create_node()

        initial_priors_node.out_port(0).connect(priors_node.in_port(0))
        begin.out_port(0).connect(priors_node.in_port(1))
        end.out_port(0).connect(priors_node.in_port(2))
        stride.out_port(0).connect(priors_node.in_port(3))

        placeholders = graph.get_op_nodes(type='Parameter')
        assert len(placeholders) == 1, "{} replacer requires model to have one Placeholder, but current model has " \
                                       "{} placeholders".format(self.replacement_id, len(placeholders))
        placeholder = placeholders[0]

        # scale prior boxes to the [0, 1] interval
        node_with_scales_for_prior_boxes = self.placeholder_scales(placeholder)
        priors_scale_node = Mul(graph, {'name': 'scale_priors'}).create_node()

        broadcast = Broadcast(graph, {
            'name': 'scales_broadcast'
        }).create_node()
        shape_of_priors = Shape(graph, {'name': 'priors_shape'}).create_node()
        priors_node.out_port(0).connect(shape_of_priors.in_port(0))
        broadcast.in_port(1).connect(shape_of_priors.out_port(0))
        broadcast.in_port(0).connect(
            node_with_scales_for_prior_boxes.out_port(0))

        priors_scale_node.in_port(0).connect(priors_node.out_port(0))
        priors_scale_node.in_port(1).connect(broadcast.out_port(0))

        try:
            variance = match.custom_replacement_desc.custom_attributes[
                'variance']
        except:
            raise Error(
                'There is no variance attribute in {} replacement config file `custom_attributes`'
                ''.format(self.replacement_id))

        priors = self.append_variances(priors_scale_node, variance)

        # calculate prior boxes widths and heights
        split_node = create_op_with_const_inputs(graph, VariadicSplit, {
            1: int64_array(2),
            2: int64_array([1, 1, 1, 1])
        }, {'out_ports_count': 4}, priors_scale_node)

        priors_width_node = Sub(
            graph, dict(name=split_node.name + '/sub_2-0_')).create_node([
                (split_node, 2), (split_node, 0)
            ])
        priors_height_node = Sub(graph, dict(name=split_node.name +
                                             '/sub_3-1_')).create_node([
                                                 (split_node, 3),
                                                 (split_node, 1)
                                             ])

        # concat weights and heights into a single tensor and multiple with the box coordinates regression values
        # WA with 3 Concats instead of 1 for keeping model reshapable
        # concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height', 'axis': -1,
        #                                           'in_ports_count': 4}).create_node(
        # [priors_width_node, priors_height_node, priors_width_node, priors_height_node])

        concat_1 = Concat(graph, {
            'name': 'concat_width_height',
            'axis': -1,
            'in_ports_count': 2
        }).create_node([priors_width_node, priors_height_node])
        concat_2 = Concat(graph, {
            'name': 'concat_width_height_width',
            'axis': -1,
            'in_ports_count': 2
        }).create_node([concat_1, priors_width_node])
        concat_width_height_node = Concat(graph, {
            'name': 'concat_priors_width_height',
            'axis': -1,
            'in_ports_count': 2
        }).create_node([concat_2, priors_height_node])

        applied_width_height_regressions_node = Mul(graph, {
            'name': 'final_regressions'
        }).create_node(
            [concat_width_height_node,
             match.single_input_node(0)[0]])

        # reshape to 2D tensor as Inference Engine Detection Output layer expects
        reshape_regression_node = create_op_node_with_second_input(
            graph, Reshape, int64_array([0, -1]),
            dict(name='reshape_regression'),
            applied_width_height_regressions_node)

        detection_output_op = DetectionOutput(
            graph, match.custom_replacement_desc.custom_attributes)
        # get nms from the original network
        iou_threshold = None
        nms_nodes = graph.get_op_nodes(op='NonMaxSuppression')
        if len(nms_nodes) > 0:
            # it is highly unlikely that for different classes NMS has different
            # moreover DetectionOutput accepts only scalar values for iou_threshold (nms_threshold)
            iou_threshold = nms_nodes[0].in_node(3).value
        if iou_threshold is None:
            raise Error(
                'During {} `iou_threshold` was not retrieved from RetinaNet graph'
                .format(self.replacement_id))

        detection_output_node = detection_output_op.create_node(
            [reshape_regression_node, reshape_classes_node, priors],
            dict(name=detection_output_op.attrs['type'],
                 nms_threshold=iou_threshold,
                 clip_after_nms=1,
                 normalized=1,
                 variance_encoded_in_target=0,
                 background_label_id=1000))

        return {'detection_output_node': detection_output_node}
Esempio n. 11
0
    def mxrepeat_decomposition(node: Node):
        graph = node.graph
        name = node.soft_get('name', node.id)

        rename_node(node, name + '/to_be_removed')

        # Unqueeze
        input_rank = Rank(graph, {'name': name + '/Rank'}).create_node()
        node.in_port(0).get_source().connect(input_rank.in_port(0))

        axis = get_canonical_axis_index_node(input_rank, node.axis)
        unsqueeze_axis = create_op_node_with_second_input(
            graph,
            Add,
            int64_array([1]), {'name': name + '/Unsqueeze/Axis'},
            input_node=axis)

        unsqueeze = Unsqueeze(graph, {
            'name': name + '/Unsqueeze'
        }).create_node()
        unsqueeze.in_port(1).connect(unsqueeze_axis.out_port(0))

        # Tile (1, 1, ..., repeats, ..., 1)
        # we generate tile array according to the following table:

        # parts:       |      first      |  repeats |  second     |
        # i:           | 0, 1, ..., axis,| axis + 1,| ..., rank+1 |
        # tile_array:  | 1, 1, ...,  1  ,| repeats ,| ...,   1    |

        one = Const(graph, {
            'name': name + '/Broadcast/One',
            'value': int64_array([1])
        }).create_node()
        first_ones = Broadcast(graph, {
            'name': name + '/Broadcast/Ones_first_part'
        }).create_node()
        first_ones.in_port(0).connect(one.out_port(0))
        first_ones.in_port(1).connect(unsqueeze_axis.out_port(0))

        repeats = Const(graph, {
            'name': name + '/repeats',
            'value': int64_array([node.repeats])
        }).create_node()

        second_ones = Broadcast(graph, {
            'name': name + '/Broadcast/Ones_second_part'
        }).create_node()
        second_part_broadcast_shape = Sub(
            graph, {
                'name': name + '/Broadcast/Shape/second_part'
            }).create_node()
        second_part_broadcast_shape.in_port(0).connect(input_rank.out_port(0))
        second_part_broadcast_shape.in_port(1).connect(
            unsqueeze_axis.out_port(0))
        second_ones.in_port(0).connect(one.out_port(0))
        second_ones.in_port(1).connect(second_part_broadcast_shape.out_port(0))

        tile_repeats = new_shape_node_from_shape_nodes(
            [first_ones, repeats, second_ones])
        tile = Tile(graph, {'name': name + '/Tile'}).create_node()
        tile.in_port(1).connect(tile_repeats.out_port(0))

        # Reshape (input_shape[:axis], input_shape[axis] * repeats, input_shape[axis+1:])
        # we generate reshape dim array according to the following table:

        # parts:       |    first   |                rep           |  second   |
        # i:           | 0, 1, ... ,|               axis,          | ..., rank |
        # dim_array:   | inp_sh[i] ,| input_shape[axis] * repeats ,| inp_sh[i] |

        input_shape = Shape(graph, {'name': name + '/Shape'}).create_node()
        node.in_port(0).get_source().connect(input_shape.in_port(0))

        first_input_shape_part = get_shape_values_by_range_idxs(
            input_shape,
            input_rank,
            begin=0,
            end=node.axis,
            include_begin=True,
            include_end=False)

        original_axis_dim = create_op_with_const_inputs(
            graph,
            Gather, {2: int64_array(0)}, {'name': name + '/OriginalDim'},
            input_node=input_shape)
        original_axis_dim.in_port(1).connect(axis.out_port(0))

        repeated_dimention = Mul(graph, {
            'name': name + '/RepeatedDim'
        }).create_node()
        repeated_dimention.in_port(0).connect(original_axis_dim.out_port(0))
        repeated_dimention.in_port(1).connect(repeats.out_port(0))

        second_input_shape_part = get_shape_values_by_range_idxs(
            input_shape,
            input_rank,
            begin=node.axis,
            end=-1,
            include_begin=False,
            include_end=True)

        output_shape = new_shape_node_from_shape_nodes([
            first_input_shape_part, repeated_dimention, second_input_shape_part
        ])

        reshape = Reshape(graph, {'name': name}).create_node()
        rename_node(reshape, name)
        reshape.in_port(1).connect(output_shape.out_port(0))

        # Final connections
        node.in_port(0).get_connection().set_destination(unsqueeze.in_port(0))
        tile.in_port(0).connect(unsqueeze.out_port(0))
        reshape.in_port(0).connect(tile.out_port(0))
        node.out_port(0).get_connection().set_source(reshape.out_port(0))
 def extract(cls, node):
     Broadcast.update_node_stat(node, {'mode': 'bidirectional'})
     return cls.enabled
Esempio n. 13
0
 def extract(node):
     Broadcast.update_node_stat(node)
     return __class__.enabled
Esempio n. 14
0
 def extract(cls, node: Node):
     Broadcast.update_node_stat(node, attrs={'mode': 'numpy'})
     return cls.enabled
Esempio n. 15
0
    def find_and_replace_pattern(self, graph: Graph):
        for embedding_segments_mean in graph.get_op_nodes(
                op='EmbeddingSegmentsMean'):
            embedding_segments_mean_name = embedding_segments_mean.soft_get(
                'name', embedding_segments_mean.id)
            embedding_table_input = embedding_segments_mean.in_port(0)
            segment_ids_input = embedding_segments_mean.in_port(2)
            num_segments_input = embedding_segments_mean.in_port(3)

            # TODO: support EmbeddingSegmentsMean with specified weights vector.
            # now this case has not appeared in models so far so EmbeddingSegmentsOperation fusion
            # transformations do not handle it either
            if embedding_segments_mean.is_in_port_connected(5):
                return

            # 1. compute indices membership matrix, i.e. which indices belong to some object
            # the shape of this matrix is [num_segments, num_indices]
            non_norm_range_1_to_num_segments = create_op_with_const_inputs(
                graph, Range, {
                    0: int64_array(0),
                    2: int64_array(1)
                }, {
                    'name':
                    embedding_segments_mean_name + '/Range1ToNumSegments',
                    'output_type': np.int64
                })
            num_segments_input.get_connection().add_destination(
                non_norm_range_1_to_num_segments.in_port(1))

            range_1_to_num_segments = ConvertLike(graph, {
                'name':
                embedding_segments_mean_name + '/Range1ToNumSegmentsNorm'
            }).create_node()
            range_1_to_num_segments.in_port(0).connect(
                non_norm_range_1_to_num_segments.out_port(0))
            num_segments_input.get_connection().add_destination(
                range_1_to_num_segments.in_port(1))

            unsqueeze_range_1_to_num_segments = create_op_with_const_inputs(
                graph, Unsqueeze, {1: int64_array(1)}, {
                    'name':
                    embedding_segments_mean_name +
                    '/Range1ToNumSegmentsUnsqueeze'
                })
            unsqueeze_range_1_to_num_segments.in_port(0).connect(
                range_1_to_num_segments.out_port(0))
            unsqueeze_segment_ids = create_op_with_const_inputs(
                graph, Unsqueeze, {1: int64_array(0)}, {
                    'name':
                    embedding_segments_mean_name + '/SegmentIdsUnsqueeze'
                })
            segment_ids_input.get_connection().add_destination(
                unsqueeze_segment_ids.in_port(0))
            boolean_membership_matrix = Equal(graph, {
                'name':
                embedding_segments_mean_name + '/BooleanMembershipMatrix'
            }).create_node()
            boolean_membership_matrix.in_port(0).connect(
                unsqueeze_range_1_to_num_segments.out_port(0))
            boolean_membership_matrix.in_port(1).connect(
                unsqueeze_segment_ids.out_port(0))
            shape_of_membership_matrix = Shape(graph, {
                'name':
                embedding_segments_mean_name + '/ShapeOfMembershipMatrix'
            }).create_node([boolean_membership_matrix])
            one_scalar_constant = Const(
                graph, {
                    'name': embedding_segments_mean_name + '/OneScalar',
                    'value': int64_array([1])
                }).create_node()
            one_constant = Broadcast(graph, {
                'name':
                embedding_segments_mean_name + '/One'
            }).create_node([one_scalar_constant, shape_of_membership_matrix])
            zero_constant = Const(
                graph, {
                    'name': embedding_segments_mean_name + '/Zero',
                    'value': int64_array(0)
                }).create_node()
            membership_matrix = Select(
                graph, {
                    'name': embedding_segments_mean_name + '/MembershipMatrix',
                    'auto_broadcast': 'numpy'
                }).create_node(
                    [boolean_membership_matrix, one_constant, zero_constant])

            # 2. compute a number of indices belong to each object from the batch
            # it computes the normalization coefficients
            num_indices_per_object = create_op_with_const_inputs(
                graph, ReduceSum, {1: int64_array(1)}, {
                    'name':
                    embedding_segments_mean_name + '/NumIndicesPerObject'
                })
            num_indices_per_object.in_port(0).connect(
                membership_matrix.out_port(0))

            # 3. replace zero coefficient (zero number of indices belong to an object) with one
            # because for such object the single default embedding vector is used
            where_zero_number = Equal(graph, {
                'name':
                embedding_segments_mean_name + '/WhereZeroIndicesNumber'
            }).create_node([num_indices_per_object, zero_constant])
            normalized_num_indices_per_object = Select(
                graph, {
                    'name':
                    embedding_segments_mean_name + '/NormNumIndicesPerObject',
                    'auto_broadcast': 'numpy'
                }).create_node([
                    where_zero_number, one_scalar_constant,
                    num_indices_per_object
                ])

            # 4. cast normalized_num_indices_per_object to the same type as embedding vector table
            norm_coefficients = ConvertLike(
                graph, {
                    'name': embedding_segments_mean_name + '/NormCoefficients'
                }).create_node()
            norm_coefficients.in_port(0).connect(
                normalized_num_indices_per_object.out_port(0))
            embedding_table_input.get_connection().add_destination(
                norm_coefficients.in_port(1))

            # 5. replace EmbeddingSegmentMean with EmbeddingSegmentSum
            embedding_segments_sum = EmbeddingSegmentsSum(
                graph, {
                    'name':
                    embedding_segments_mean_name + '/EmbeddingSegmentsSum'
                }).create_node()
            for in_port in embedding_segments_mean.in_ports():
                if embedding_segments_mean.is_in_port_connected(in_port):
                    embedding_segments_mean.in_port(
                        in_port).get_connection().set_destination(
                            embedding_segments_sum.in_port(in_port))

            # 6. normalize EmbeddingSegmentSum results by computed coefficients
            result_node = Div(graph, {
                'name': embedding_segments_mean_name + '/Div'
            }).create_node([embedding_segments_sum, norm_coefficients])
            embedding_segments_mean.out_port(0).get_connection().set_source(
                result_node.out_port(0))

            rename_nodes([(embedding_segments_mean,
                           embedding_segments_mean_name + '/AbandonedName'),
                          (result_node, embedding_segments_mean_name)])
            graph.remove_nodes_from([embedding_segments_mean.id])
    def find_and_replace_pattern(self, graph: Graph):
        reverse_nodes = graph.get_op_nodes(op='Reverse')
        for reverse in reverse_nodes:
            reverse_name = reverse.soft_get('name', reverse.id)

            assert reverse.in_port(1).disconnected()
            assert reverse.has_valid('axis')

            in_shape_rank = len(reverse.in_port(0).data.get_shape())
            # 1. Add new dimension as batch for rank = 1 to have batch != seq_axis
            if in_shape_rank == 1:
                unsq_node = create_op_node_with_second_input(
                    graph, Unsqueeze, int64_array([0]),
                    {'name': reverse_name + "/Unsqueeze"})
                reverse.in_port(0).get_source().connect(unsq_node.in_port(0))
                new_in = unsq_node.out_port(0)
                batch_axis = 0
                seq_axis = 1
            else:
                new_in = reverse.in_port(0).get_source()
                seq_axis = reverse['axis']
                batch_axis = 0 if seq_axis != 0 else 1

            # 2. For ReverseSequence 1-port input is seq_lengths => create this input node as
            # shape[seq_axis] broadcasted to shape[batch_axis]
            # in ---> ShapeOf ----> Gather(seq_axis)  ----> Broadcast----->
            #            |                                      |
            #            | -------> Gather(batch_axis)----------|
            shape_node = Shape(graph, {
                'name': reverse_name + "/Shape"
            }).create_node()
            new_in.connect(shape_node.in_port(0))
            seq_axis_node = node_to_get_shape_value_of_indices(
                shape_node, [seq_axis])
            batch_node = node_to_get_shape_value_of_indices(
                shape_node, [batch_axis])
            broadcast_node = Broadcast(graph, {
                'name': reverse_name + "/Broadcast"
            }).create_node()
            broadcast_node.in_port(0).connect(seq_axis_node.out_port(0))
            broadcast_node.in_port(1).connect(batch_node.out_port(0))

            # 3. Create new ReverseSequence node and reconnect all inputs/outputs to it
            rename_node(reverse, reverse_name + '/to_delete')
            reverse_sequence = ReverseSequence(
                graph, {
                    'name': reverse_name,
                    'seq_axis': seq_axis,
                    'batch_axis': batch_axis
                }).create_node()
            reverse_sequence.in_port(0).connect(new_in)
            reverse_sequence.in_port(1).connect(broadcast_node.out_port(0))

            # 4. remove added dimension for rank = 1
            if in_shape_rank == 1:
                rename_node(reverse_sequence,
                            reverse_name + '/ReverseSequence')
                squeeze_node = create_op_node_with_second_input(
                    graph, Squeeze, int64_array([0]), {'name': reverse_name})
                squeeze_node.in_port(0).connect(reverse_sequence.out_port(0))
                reverse.out_port(0).get_connection().set_source(
                    squeeze_node.out_port(0))
            else:
                reverse.out_port(0).get_connection().set_source(
                    reverse_sequence.out_port(0))

        # 5. Delete old Reverse node
        graph.remove_nodes_from([reverse.id for reverse in reverse_nodes])
Esempio n. 17
0
 def extract(cls, node):
     Broadcast.update_node_stat(node)
     return cls.enabled