def test_embedding_segments_sum(self):
        graph = build_graph(nodes, [
            *connect('data', '0:embedding_segments'),
            *connect('indices1d', '1:embedding_segments'),
            *connect('segment_ids', '2:embedding_segments'),
            *connect('num_segments', '3:embedding_segments'),
            ('embedding_segments', 'embedding_segments_d', {'out': 0}),
            ('embedding_segments_d', 'output'),
        ], nodes_with_edges_only=True)
        eb_node = Node(graph, 'embedding_segments')
        EmbeddingSegmentsSum.infer(eb_node)

        self.assertTrue(np.array_equal(eb_node.out_port(0).data.get_shape(), int64_array([30, 8])))
Ejemplo n.º 2
0
    def find_and_replace_pattern(self, graph: Graph):
        for embedding_segments_mean in graph.get_op_nodes(
                op='EmbeddingSegmentsMean'):
            embedding_segments_mean_name = embedding_segments_mean.soft_get(
                'name', embedding_segments_mean.id)
            embedding_table_input = embedding_segments_mean.in_port(0)
            segment_ids_input = embedding_segments_mean.in_port(2)
            num_segments_input = embedding_segments_mean.in_port(3)

            # TODO: support EmbeddingSegmentsMean with specified weights vector.
            # now this case has not appeared in models so far so EmbeddingSegmentsOperation fusion
            # transformations do not handle it either
            if embedding_segments_mean.is_in_port_connected(5):
                return

            # 1. compute indices membership matrix, i.e. which indices belong to some object
            # the shape of this matrix is [num_segments, num_indices]
            non_norm_range_1_to_num_segments = create_op_with_const_inputs(
                graph, Range, {
                    0: int64_array(0),
                    2: int64_array(1)
                }, {
                    'name':
                    embedding_segments_mean_name + '/Range1ToNumSegments',
                    'output_type': np.int64
                })
            num_segments_input.get_connection().add_destination(
                non_norm_range_1_to_num_segments.in_port(1))

            range_1_to_num_segments = ConvertLike(graph, {
                'name':
                embedding_segments_mean_name + '/Range1ToNumSegmentsNorm'
            }).create_node()
            range_1_to_num_segments.in_port(0).connect(
                non_norm_range_1_to_num_segments.out_port(0))
            num_segments_input.get_connection().add_destination(
                range_1_to_num_segments.in_port(1))

            unsqueeze_range_1_to_num_segments = create_op_with_const_inputs(
                graph, Unsqueeze, {1: int64_array(1)}, {
                    'name':
                    embedding_segments_mean_name +
                    '/Range1ToNumSegmentsUnsqueeze'
                })
            unsqueeze_range_1_to_num_segments.in_port(0).connect(
                range_1_to_num_segments.out_port(0))
            unsqueeze_segment_ids = create_op_with_const_inputs(
                graph, Unsqueeze, {1: int64_array(0)}, {
                    'name':
                    embedding_segments_mean_name + '/SegmentIdsUnsqueeze'
                })
            segment_ids_input.get_connection().add_destination(
                unsqueeze_segment_ids.in_port(0))
            boolean_membership_matrix = Equal(graph, {
                'name':
                embedding_segments_mean_name + '/BooleanMembershipMatrix'
            }).create_node()
            boolean_membership_matrix.in_port(0).connect(
                unsqueeze_range_1_to_num_segments.out_port(0))
            boolean_membership_matrix.in_port(1).connect(
                unsqueeze_segment_ids.out_port(0))
            shape_of_membership_matrix = Shape(graph, {
                'name':
                embedding_segments_mean_name + '/ShapeOfMembershipMatrix'
            }).create_node([boolean_membership_matrix])
            one_scalar_constant = Const(
                graph, {
                    'name': embedding_segments_mean_name + '/OneScalar',
                    'value': int64_array([1])
                }).create_node()
            one_constant = Broadcast(graph, {
                'name':
                embedding_segments_mean_name + '/One'
            }).create_node([one_scalar_constant, shape_of_membership_matrix])
            zero_constant = Const(
                graph, {
                    'name': embedding_segments_mean_name + '/Zero',
                    'value': int64_array(0)
                }).create_node()
            membership_matrix = Select(
                graph, {
                    'name': embedding_segments_mean_name + '/MembershipMatrix',
                    'auto_broadcast': 'numpy'
                }).create_node(
                    [boolean_membership_matrix, one_constant, zero_constant])

            # 2. compute a number of indices belong to each object from the batch
            # it computes the normalization coefficients
            num_indices_per_object = create_op_with_const_inputs(
                graph, ReduceSum, {1: int64_array(1)}, {
                    'name':
                    embedding_segments_mean_name + '/NumIndicesPerObject'
                })
            num_indices_per_object.in_port(0).connect(
                membership_matrix.out_port(0))

            # 3. replace zero coefficient (zero number of indices belong to an object) with one
            # because for such object the single default embedding vector is used
            where_zero_number = Equal(graph, {
                'name':
                embedding_segments_mean_name + '/WhereZeroIndicesNumber'
            }).create_node([num_indices_per_object, zero_constant])
            normalized_num_indices_per_object = Select(
                graph, {
                    'name':
                    embedding_segments_mean_name + '/NormNumIndicesPerObject',
                    'auto_broadcast': 'numpy'
                }).create_node([
                    where_zero_number, one_scalar_constant,
                    num_indices_per_object
                ])

            # 4. cast normalized_num_indices_per_object to the same type as embedding vector table
            norm_coefficients = ConvertLike(
                graph, {
                    'name': embedding_segments_mean_name + '/NormCoefficients'
                }).create_node()
            norm_coefficients.in_port(0).connect(
                normalized_num_indices_per_object.out_port(0))
            embedding_table_input.get_connection().add_destination(
                norm_coefficients.in_port(1))

            # 5. replace EmbeddingSegmentMean with EmbeddingSegmentSum
            embedding_segments_sum = EmbeddingSegmentsSum(
                graph, {
                    'name':
                    embedding_segments_mean_name + '/EmbeddingSegmentsSum'
                }).create_node()
            for in_port in embedding_segments_mean.in_ports():
                if embedding_segments_mean.is_in_port_connected(in_port):
                    embedding_segments_mean.in_port(
                        in_port).get_connection().set_destination(
                            embedding_segments_sum.in_port(in_port))

            # 6. normalize EmbeddingSegmentSum results by computed coefficients
            result_node = Div(graph, {
                'name': embedding_segments_mean_name + '/Div'
            }).create_node([embedding_segments_sum, norm_coefficients])
            embedding_segments_mean.out_port(0).get_connection().set_source(
                result_node.out_port(0))

            rename_nodes([(embedding_segments_mean,
                           embedding_segments_mean_name + '/AbandonedName'),
                          (result_node, embedding_segments_mean_name)])
            graph.remove_nodes_from([embedding_segments_mean.id])
Ejemplo n.º 3
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        identity_spw = match['identity_spw']
        gather0_1 = match['gather0_1']
        gather0_2 = match['gather0_2']
        greaterequal0 = match['greaterequal0']
        sparse_fill_empty_rows = match['sparse_fill_empty_rows']
        gather = match['gather']
        select = match['select']
        where0 = match['where0']
        sparse_segment_op = match['sparse_segment_op']
        output_node_name = select.soft_get('name', select.id)

        log.debug('Found EmbeddingSparseSegmentsSingleFeature pattern after {} with name {}'.format(
            sparse_fill_empty_rows.op,
            sparse_fill_empty_rows.name))

        split_for_indices = create_op_with_const_inputs(graph, Split, {1: int64_array(1)}, {'num_splits': 2})
        squeeze_for_indices = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([1])})
        split_for_dense_shape = create_op_with_const_inputs(graph, Split, {1: int64_array(0)}, {'num_splits': 2})
        squeeze_to_scalar = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([0])})

        # TODO: remove Cast nodes once we start to support EmbeddingSegmentSum (new version) with segment_ids,
        #  indices, and num_segments of different integer type.
        #  Because the real cases show that it is possible to have it in TensorFlow
        cast_indices = Cast(graph, {'name': output_node_name + '/CastIndices', 'dst_type': np.int32}).create_node()
        cast_segment_ids = Cast(graph, {'name': output_node_name + '/CastSegmentIds',
                                        'dst_type': np.int32}).create_node()
        cast_default_value = Cast(graph, {'name': output_node_name + '/CastDefaultValue',
                                          'dst_type': np.int32}).create_node()
        cast_num_segments = Cast(graph, {'name': output_node_name + '/CastSegmentsNumber',
                                         'dst_type': np.int32}).create_node()
        if sparse_segment_op.op == 'SparseSegmentSum':
            embedding_segments_op = EmbeddingSegmentsSum(graph, {'name': output_node_name}).create_node()
        else:
            embedding_segments_op = EmbeddingSegmentsMean(graph, {'name': output_node_name}).create_node()
        rename_nodes([(select, output_node_name + '/AbandonedName'), (embedding_segments_op, output_node_name)])

        # connect parameters table
        gather.in_port(0).get_connection().set_destination(embedding_segments_op.in_port(0))
        # connect indices values
        greaterequal0.in_port(0).get_connection().set_destination(cast_indices.in_port(0))
        embedding_segments_op.in_port(1).connect(cast_indices.out_port(0))
        # split and connect segment ids
        gather0_1.in_port(0).get_connection().set_destination(split_for_indices.in_port(0))
        squeeze_for_indices.in_port(0).connect(split_for_indices.out_port(0))
        cast_segment_ids.in_port(0).connect(squeeze_for_indices.out_port(0))
        embedding_segments_op.in_port(2).connect(cast_segment_ids.out_port(0))
        # split and connect number of segments
        identity_spw.in_port(0).get_connection().set_destination(split_for_dense_shape.in_port(0))
        squeeze_to_scalar.in_port(0).connect(split_for_dense_shape.out_port(0))
        cast_num_segments.in_port(0).connect(squeeze_to_scalar.out_port(0))
        embedding_segments_op.in_port(3).connect(cast_num_segments.out_port(0))
        # connect default value
        sparse_fill_empty_rows.in_port(3).get_connection().set_destination(cast_default_value.in_port(0))
        embedding_segments_op.in_port(4).connect(cast_default_value.out_port(0))
        # no input port for per_sample_weight

        identity_spw.in_port(0).disconnect()
        gather0_1.in_port(0).disconnect()
        gather0_2.in_port(0).disconnect()
        greaterequal0.in_port(0).disconnect()
        sparse_fill_empty_rows.in_port(2).disconnect()
        gather.in_port(0).disconnect()

        select.out_port(0).get_connection().set_source(embedding_segments_op.out_port(0))
        graph.remove_nodes_from(
            [gather0_1.id, gather0_2.id, greaterequal0.id, sparse_fill_empty_rows.id, select.id, where0.id])
Ejemplo n.º 4
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        identity_spw = match['identity_spw']
        gather0_1 = match['gather0_1']
        gather0_2 = match['gather0_2']
        greaterequal0 = match['greaterequal0']
        sparse_fill_empty_rows = match['sparse_fill_empty_rows']
        gather = match['gather']
        select = match['select']
        where0 = match['where0']
        output_node_name = select.soft_get('name', select.id)

        log.debug(
            'Found EmbeddingSegmentsSum2 pattern after {} with name {}'.format(
                sparse_fill_empty_rows.op, sparse_fill_empty_rows.name))

        split_for_indices = create_op_with_const_inputs(
            graph, Split, {1: int64_array(1)}, {
                'num_splits': 2,
                'name': output_node_name + '/SplitForIndices'
            })
        squeeze_for_indices = create_op_with_const_inputs(
            graph, Squeeze, {1: int64_array([1])})
        split_for_dense_shape = create_op_with_const_inputs(
            graph, Split, {1: int64_array(0)}, {
                'num_splits': 2,
                'name': output_node_name + '/SplitForDenseShape'
            })
        squeeze_to_scalar = create_op_with_const_inputs(
            graph, Squeeze, {1: int64_array([0])})
        cast_segment_ids = Cast(graph, {
            'name': output_node_name + '/CastSegmentIds',
            'dst_type': np.int32
        }).create_node()
        cast_default_value = Cast(graph, {
            'name': output_node_name + '/CastDefaultValue',
            'dst_type': np.int32
        }).create_node()
        cast_num_segments = Cast(graph, {
            'name': output_node_name + '/CastSegmentsNumber',
            'dst_type': np.int32
        }).create_node()
        embedding_segments_sum = EmbeddingSegmentsSum(graph, {
            'name': output_node_name
        }).create_node()
        rename_nodes([(select, output_node_name + '/AbandonedName'),
                      (embedding_segments_sum, output_node_name)])

        # connect parameters table
        gather.in_port(0).get_connection().set_destination(
            embedding_segments_sum.in_port(0))
        # connect indices values
        greaterequal0.in_port(0).get_connection().set_destination(
            embedding_segments_sum.in_port(1))
        # split and connect segment ids
        gather0_1.in_port(0).get_connection().set_destination(
            split_for_indices.in_port(0))
        squeeze_for_indices.in_port(0).connect(split_for_indices.out_port(0))
        # TODO: remove casting once we start to support I64 model input
        cast_segment_ids.in_port(0).connect(squeeze_for_indices.out_port(0))
        embedding_segments_sum.in_port(2).connect(cast_segment_ids.out_port(0))
        # split and connect number of segments
        identity_spw.in_port(0).get_connection().set_destination(
            split_for_dense_shape.in_port(0))
        squeeze_to_scalar.in_port(0).connect(split_for_dense_shape.out_port(0))
        # TODO: remove casting once we start to support I64 model input
        cast_num_segments.in_port(0).connect(squeeze_to_scalar.out_port(0))
        embedding_segments_sum.in_port(3).connect(
            cast_num_segments.out_port(0))
        # connect default value
        # TODO: remove casting once we start to support I64 model input
        sparse_fill_empty_rows.in_port(3).get_connection().set_destination(
            cast_default_value.in_port(0))
        embedding_segments_sum.in_port(4).connect(
            cast_default_value.out_port(0))
        # no input port for per_sample_weight

        identity_spw.in_port(0).disconnect()
        gather0_1.in_port(0).disconnect()
        gather0_2.in_port(0).disconnect()
        greaterequal0.in_port(0).disconnect()
        sparse_fill_empty_rows.in_port(2).disconnect()
        gather.in_port(0).disconnect()

        select.out_port(0).get_connection().set_source(
            embedding_segments_sum.out_port(0))
        graph.remove_nodes_from([
            gather0_1.id, gather0_2.id, greaterequal0.id,
            sparse_fill_empty_rows.id, select.id, where0.id
        ])