Пример #1
0
    def replace(node: Node, const: Node):
        graph = node.graph
        shape = const.shape
        const_name = const.soft_get('name', const.id)

        non_one_dims = np.argwhere(shape != 1).flatten()
        one_dims = np.argwhere(shape == 1).flatten()

        if not (non_one_dims.size == 1 and 5 < np.prod(shape) < 500):
            # (5;500) range is deduced to affect less models
            return

        value = const.value
        if not np.array_equal(np.arange(0, np.prod(shape), 1).reshape(shape), value):
            return

        positive_idx = non_one_dims.item(0)
        negative_idx = positive_idx - len(shape)

        node_name = node.soft_get('name', node.id)
        gather = create_op_with_const_inputs(graph, Gather, {1: int64_array(negative_idx), 2: int64_array(0)},
                                             {'name': node_name + '/BroadcastingDim'})
        gather_for_const = create_op_with_const_inputs(graph, Gather, {1: int64_array(negative_idx), 2: int64_array(0)},
                                                       {'name': const_name + '/BroadcastingDim'})
        shapeof_node = Shape(graph, {'name': const_name + '/ShapeOf'}).create_node()
        shapeof_node.out_port(0).connect(gather_for_const.in_port(0))

        equal_node = create_op_with_const_inputs(graph, Equal, {1: int64_array(1)}, {'name': node_name + '/ConstOne'})
        gather.out_port(0).connect(equal_node.in_port(0))

        select_node = Select(graph, {'name': node_name + '/Select',
                                      'auto_broadcast': 'numpy'}).create_node([equal_node, gather_for_const, gather])

        const.out_port(0).connect(shapeof_node.in_port(0))

        range_node = create_op_with_const_inputs(graph, Range,
                                                 {0: mo_array(0, dtype=value.dtype),
                                                  2: mo_array(1, dtype=value.dtype)},
                                                 {'name': const_name + '/Range', 'dtype': value.dtype})
        select_node.out_port(0).connect(range_node.in_port(1))

        node.in_port(1).get_connection().add_destination(gather.in_port(0))

        node.in_port(0).get_connection().set_source(range_node.out_port(0))

        if one_dims.size:
            unsqueeze = create_op_node_with_second_input(graph, Unsqueeze, one_dims,
                                                         {'name': const_name + '/KeepShape'})
            range_node.out_port(0).get_connection().insert_node(unsqueeze)
            rename_nodes([(const, const_name + '/ToBeDeleted'), (unsqueeze, const_name)])
        else:
            rename_nodes([(const, const_name + '/ToBeDeleted'), (range_node, const_name)])
Пример #2
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        tf_slice_node = match['op']
        slice_name = tf_slice_node.soft_get('name', tf_slice_node.id)
        slice_node = Slice(graph).create_node()
        rename_nodes([(tf_slice_node, slice_name + '/to_be_removed'),
                      (slice_node, slice_name)])
        ends_node = Add(graph, {'name': slice_name + '/ends'}).create_node()

        # reconnect input, begin, and size from TFSlice to the subgraph with Slice
        tf_slice_node.in_port(0).get_connection().set_destination(
            slice_node.in_port(0))
        tf_slice_node.in_port(1).get_connection().set_destination(
            slice_node.in_port(1))
        tf_slice_node.in_port(2).get_connection().set_destination(
            ends_node.in_port(0))
        slice_node.in_port(1).get_connection().add_destination(
            ends_node.in_port(1))

        max_ends = Shape(graph, {
            'name': slice_name + '/ShapeOf'
        }).create_node()
        slice_node.in_port(0).get_connection().add_destination(
            max_ends.in_port(0))

        # check if size[i] == -1, will be applied elementwisely: len(size) = len(begin) = input_rank
        where_max_ends_is_needed = create_op_with_const_inputs(
            graph, Equal, {0: int64_array(-1)},
            {'name': slice_name + '/where_max_ends_is_needed'})
        ends_node.in_port(0).get_connection().add_destination(
            where_max_ends_is_needed.in_port(1))
        # select requires equal dtypes, need to convert ends to I64
        ends_casted_to_i64 = Cast(graph, {
            'name': slice_name + '/CastToI64',
            'dst_type': np.int64
        }).create_node([ends_node])
        # if size[i] == 1 then take max_ends values
        correct_ends = Select(graph, {
            'name': slice_name + '/chosen_ends'
        }).create_node(
            [where_max_ends_is_needed, max_ends, ends_casted_to_i64])
        correct_ends.out_port(0).connect(slice_node.in_port(2))

        tf_slice_node.out_port(0).get_connection().set_source(
            slice_node.out_port(0))
Пример #3
0
    def find_and_replace_pattern(self, graph: Graph):
        for merge in graph.get_op_nodes(op='Merge'):
            for merge_switch_in_port in range(2):
                if merge.in_port(merge_switch_in_port).disconnected() or \
                        merge.in_port(merge_switch_in_port).get_source().node.op != 'Switch':
                    continue
                switch_2 = merge.in_port(
                    merge_switch_in_port).get_source().node

                if merge.in_port(1 - merge_switch_in_port).disconnected() or \
                        merge.in_port(1 - merge_switch_in_port).get_source().node.op != 'Identity':
                    continue
                false_value_port = merge.in_port(
                    1 - merge_switch_in_port).get_source()

                true_value_port = switch_2.in_port(0).get_source()
                op = false_value_port.node.in_port(0).get_source().node

                if op.in_port(0).disconnected(
                ) or op.in_port(0).get_source().node.op != 'Switch':
                    continue
                switch = op.in_port(0).get_source().node

                if op.in_port(1).disconnected(
                ) or op.in_port(1).get_source().node.op != 'Switch':
                    continue
                switch_1 = op.in_port(1).get_source().node

                if switch.in_port(1).get_source() == switch_1.in_port(1).get_source() and \
                        switch.in_port(1).get_source() == switch_2.in_port(1).get_source():
                    select = Select(
                        graph,
                        dict(name=merge.soft_get('name') + '/Select/',
                             format='tf')).create_node()
                    select.in_port(0).connect(switch.in_port(1).get_source())
                    select.in_port(1).connect(true_value_port)
                    select.in_port(2).connect(false_value_port)

                    merge.out_port(0).get_connection().set_source(
                        select.out_port(0))

                    assert 1 in op.in_ports() and 0 in op.in_ports()

                    op.in_port(0).disconnect()
                    op.in_port(1).disconnect()

                    switch.in_port(0).get_connection().set_destination(
                        op.in_port(0))
                    switch_1.in_port(0).get_connection().set_destination(
                        op.in_port(1))

                    graph.remove_nodes_from(
                        nodes=[switch_1.id, switch.id, switch_2.id, merge.id])
                    # need to exit from the inner for loop because the Merge op has been removed
                    break
    def find_and_replace_pattern(self, graph: Graph):
        for embedding_segments_mean in graph.get_op_nodes(
                op='EmbeddingSegmentsMean'):
            embedding_segments_mean_name = embedding_segments_mean.soft_get(
                'name', embedding_segments_mean.id)
            embedding_table_input = embedding_segments_mean.in_port(0)
            segment_ids_input = embedding_segments_mean.in_port(2)
            num_segments_input = embedding_segments_mean.in_port(3)

            # TODO: support EmbeddingSegmentsMean with specified weights vector.
            # now this case has not appeared in models so far so EmbeddingSegmentsOperation fusion
            # transformations do not handle it either
            if embedding_segments_mean.is_in_port_connected(5):
                return

            # 1. compute indices membership matrix, i.e. which indices belong to some object
            # the shape of this matrix is [num_segments, num_indices]
            non_norm_range_1_to_num_segments = create_op_with_const_inputs(
                graph, Range, {
                    0: int64_array(0),
                    2: int64_array(1)
                }, {
                    'name':
                    embedding_segments_mean_name + '/Range1ToNumSegments',
                    'output_type': np.int64
                })
            num_segments_input.get_connection().add_destination(
                non_norm_range_1_to_num_segments.in_port(1))

            range_1_to_num_segments = ConvertLike(graph, {
                'name':
                embedding_segments_mean_name + '/Range1ToNumSegmentsNorm'
            }).create_node()
            range_1_to_num_segments.in_port(0).connect(
                non_norm_range_1_to_num_segments.out_port(0))
            num_segments_input.get_connection().add_destination(
                range_1_to_num_segments.in_port(1))

            unsqueeze_range_1_to_num_segments = create_op_with_const_inputs(
                graph, Unsqueeze, {1: int64_array(1)}, {
                    'name':
                    embedding_segments_mean_name +
                    '/Range1ToNumSegmentsUnsqueeze'
                })
            unsqueeze_range_1_to_num_segments.in_port(0).connect(
                range_1_to_num_segments.out_port(0))
            unsqueeze_segment_ids = create_op_with_const_inputs(
                graph, Unsqueeze, {1: int64_array(0)}, {
                    'name':
                    embedding_segments_mean_name + '/SegmentIdsUnsqueeze'
                })
            segment_ids_input.get_connection().add_destination(
                unsqueeze_segment_ids.in_port(0))
            boolean_membership_matrix = Equal(graph, {
                'name':
                embedding_segments_mean_name + '/BooleanMembershipMatrix'
            }).create_node()
            boolean_membership_matrix.in_port(0).connect(
                unsqueeze_range_1_to_num_segments.out_port(0))
            boolean_membership_matrix.in_port(1).connect(
                unsqueeze_segment_ids.out_port(0))
            shape_of_membership_matrix = Shape(graph, {
                'name':
                embedding_segments_mean_name + '/ShapeOfMembershipMatrix'
            }).create_node([boolean_membership_matrix])
            one_scalar_constant = Const(
                graph, {
                    'name': embedding_segments_mean_name + '/OneScalar',
                    'value': int64_array([1])
                }).create_node()
            one_constant = Broadcast(graph, {
                'name':
                embedding_segments_mean_name + '/One'
            }).create_node([one_scalar_constant, shape_of_membership_matrix])
            zero_constant = Const(
                graph, {
                    'name': embedding_segments_mean_name + '/Zero',
                    'value': int64_array(0)
                }).create_node()
            membership_matrix = Select(
                graph, {
                    'name': embedding_segments_mean_name + '/MembershipMatrix',
                    'auto_broadcast': 'numpy'
                }).create_node(
                    [boolean_membership_matrix, one_constant, zero_constant])

            # 2. compute a number of indices belong to each object from the batch
            # it computes the normalization coefficients
            num_indices_per_object = create_op_with_const_inputs(
                graph, ReduceSum, {1: int64_array(1)}, {
                    'name':
                    embedding_segments_mean_name + '/NumIndicesPerObject'
                })
            num_indices_per_object.in_port(0).connect(
                membership_matrix.out_port(0))

            # 3. replace zero coefficient (zero number of indices belong to an object) with one
            # because for such object the single default embedding vector is used
            where_zero_number = Equal(graph, {
                'name':
                embedding_segments_mean_name + '/WhereZeroIndicesNumber'
            }).create_node([num_indices_per_object, zero_constant])
            normalized_num_indices_per_object = Select(
                graph, {
                    'name':
                    embedding_segments_mean_name + '/NormNumIndicesPerObject',
                    'auto_broadcast': 'numpy'
                }).create_node([
                    where_zero_number, one_scalar_constant,
                    num_indices_per_object
                ])

            # 4. cast normalized_num_indices_per_object to the same type as embedding vector table
            norm_coefficients = ConvertLike(
                graph, {
                    'name': embedding_segments_mean_name + '/NormCoefficients'
                }).create_node()
            norm_coefficients.in_port(0).connect(
                normalized_num_indices_per_object.out_port(0))
            embedding_table_input.get_connection().add_destination(
                norm_coefficients.in_port(1))

            # 5. replace EmbeddingSegmentMean with EmbeddingSegmentSum
            embedding_segments_sum = EmbeddingSegmentsSum(
                graph, {
                    'name':
                    embedding_segments_mean_name + '/EmbeddingSegmentsSum'
                }).create_node()
            for in_port in embedding_segments_mean.in_ports():
                if embedding_segments_mean.is_in_port_connected(in_port):
                    embedding_segments_mean.in_port(
                        in_port).get_connection().set_destination(
                            embedding_segments_sum.in_port(in_port))

            # 6. normalize EmbeddingSegmentSum results by computed coefficients
            result_node = Div(graph, {
                'name': embedding_segments_mean_name + '/Div'
            }).create_node([embedding_segments_sum, norm_coefficients])
            embedding_segments_mean.out_port(0).get_connection().set_source(
                result_node.out_port(0))

            rename_nodes([(embedding_segments_mean,
                           embedding_segments_mean_name + '/AbandonedName'),
                          (result_node, embedding_segments_mean_name)])
            graph.remove_nodes_from([embedding_segments_mean.id])
Пример #5
0
    def insert_select(graph: Graph, node: Node):
        context_len = node.frame_time + 1

        if context_len == 1:
            return

        in_node_port = node.in_port(0).get_source()
        in_node_shape = node.in_port(0).data.get_shape()
        node.in_port(0).disconnect()

        # add Select before saving state to avoid saving garbage
        select_node = Select(graph, {
            'name': 'select_' + node.name
        }).create_node()
        zero_else = create_const_with_batch_from_input(in_node_port,
                                                       in_node_shape[1])
        select_node.in_port(1).connect(in_node_port)
        select_node.in_port(2).connect(zero_else.out_port(0))

        # check if we have already appropriate iteration counter
        existing_counters = find_pattern_matches(
            graph,
            nodes=[('mem_in', dict(op='ReadValue')),
                   ('mem_in_data', dict(shape=int64_array([context_len]))),
                   ('crop_mem_in',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([1]),
                         dim=int64_array([context_len - 1]))),
                   ('crop_mem_in_data', dict()),
                   ('concat', dict(op='Concat', axis=1)),
                   ('concat_data', dict()), ('const_1', dict(op='Const')),
                   ('const_1_data', dict()), ('mem_out', dict(op='Assign')),
                   ('crop_out',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([0]),
                         dim=int64_array([1]))), ('crop_out_data', dict()),
                   ('select', dict(op='Select'))],
            edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'),
                   ('crop_mem_in', 'crop_mem_in_data'),
                   ('crop_mem_in_data', 'concat', {
                       'in': 0
                   }), ('const_1', 'const_1_data'),
                   ('const_1_data', 'concat', {
                       'in': 1
                   }), ('concat', 'concat_data'), ('concat_data', 'mem_out'),
                   ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
                   ('crop_out_data', 'select')])
        counter_match = next(existing_counters, None)
        if counter_match is not None:
            ones = Node(graph, inverse_dict(counter_match)['const_1'])
            input_port = Node(
                graph,
                inverse_dict(counter_match)['crop_out']).out_port(0)
        else:
            init_value_mem_out = create_const_with_batch_from_input(
                in_node_port, context_len, precision=np.int32)
            mem_out = ReadValue(
                graph, {
                    'name': 'iteration_number',
                    'variable_id': 'iteration_' + node.name
                }).create_node()
            mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
            cut_first = Crop(
                graph, {
                    'name': 'cut_first',
                    'axis': int64_array([1]),
                    'offset': int64_array([1]),
                    'dim': int64_array([context_len - 1])
                }).create_node()
            cut_first.in_port(0).connect(mem_out.out_port(0))
            ones = create_const_with_batch_from_input(in_node_port, 1, 1,
                                                      np.int32)
            concat = Concat(graph, {
                'name': 'concat_ones',
                'in_ports_count': 2,
                'axis': 1
            }).create_node()
            concat.in_port(0).connect(cut_first.out_port(0))
            concat.in_port(1).connect(ones.out_port(0))
            mem_in = Assign(
                graph, {
                    'name': 'iteration_number_out',
                    'variable_id': 'iteration_' + node.name
                }).create_node()
            mem_in.in_port(0).connect(concat.out_port(0))
            res = Result(graph, {}).create_node()
            mem_in.out_port(0).connect(res.in_port(0))
            cut_last = Crop(
                graph, {
                    'name': 'cut_last',
                    'axis': int64_array([1]),
                    'offset': int64_array([0]),
                    'dim': int64_array([1])
                }).create_node()
            cut_last.in_port(0).connect(concat.out_port(0))
            input_port = cut_last.out_port(0)

        # Check if data from memory is 1
        # if it is True, we have correct data and should proceed with saving it to memory
        # else we have not gathered context and have garbage here, shouldn't change initial state of memory
        cast_in = Equal(graph, {
            'name': input_port.node.name + '/cast_to_bool'
        }).create_node()
        cast_in.in_port(0).connect(ones.out_port(0))
        cast_in.in_port(1).connect(input_port)
        select_node.in_port(0).connect(cast_in.out_port(0))
        select_node.out_port(0).connect(node.in_port(0))
        select_node.out_port(0).data.set_shape(in_node_shape)
Пример #6
0
    def dequantize_data(fake_quantize: Node, dst_type: type,
                        quantized_type: type) -> Node:
        graph = fake_quantize.graph
        quantized_data = fake_quantize.in_port(0).get_source().node
        name = fake_quantize.soft_get('name', fake_quantize.id)

        assert quantized_data.soft_get('type') == 'Convert' and quantized_data.dst_type == quantized_type, \
            'Weights aren`t compressed as expected for node {}'.format(fake_quantize.soft_get('name', fake_quantize.id))

        dequantizing_cast = Cast(
            graph,
            dict(name=quantized_data.name +
                 "/to_{}".format(np_data_type_to_destination_type(dst_type)),
                 dst_type=dst_type,
                 stop_value_propagation=True)).create_node()
        fake_quantize.in_port(0).get_connection().set_destination(
            dequantizing_cast.in_port(0))

        # limits of dequantize
        in_low = fake_quantize.in_port(1).get_source()
        in_high = fake_quantize.in_port(2).get_source()
        out_low = fake_quantize.in_port(3).get_source()
        out_high = fake_quantize.in_port(4).get_source()

        # scale calculation
        output_range = Sub(graph, {
            'name': name + '/output_range'
        }).create_node()
        output_range.in_port(0).connect(out_high)
        output_range.in_port(1).connect(out_low)

        input_range = Sub(graph, {'name': name + '/input_range'}).create_node()
        input_range.in_port(0).connect(in_high)
        input_range.in_port(1).connect(in_low)

        scale = Div(graph, {'name': name + '/scale'}).create_node()
        scale.in_port(0).connect(output_range.out_port(0))
        scale.in_port(1).connect(input_range.out_port(0))

        # shift calculation
        descaled_output_low = Div(graph, {
            'name': name + '/descaled_output_low'
        }).create_node()
        descaled_output_low.in_port(0).connect(out_low)
        descaled_output_low.in_port(1).connect(scale.out_port(0))

        shift = Sub(graph, {'name': name + '/shift'}).create_node()
        shift.in_port(0).connect(in_low)
        shift.in_port(1).connect(descaled_output_low.out_port(0))

        zero = Const(graph, {
            'name': name + '/zero',
            'value': mo_array(0, dtype=dst_type)
        }).create_node()
        scale_eq_zero = Equal(graph, {
            'name': name + '/scale_eq_zero'
        }).create_node()
        scale_eq_zero.in_port(0).connect(scale.out_port(0))
        scale_eq_zero.in_port(1).connect(zero.out_port(0))

        zero_point = Select(graph, {
            'name': name + '/zero_point'
        }).create_node()
        zero_point.in_port(0).connect(scale_eq_zero.out_port(0))
        zero_point.in_port(1).connect(zero.out_port(0))
        zero_point.in_port(2).connect(shift.out_port(0))

        # DeQuantize(x) == Mul(Sub(x, zero_point), scale)
        sub_zp = Sub(graph, {'name': name + '/minus_zp'}).create_node()
        sub_zp.in_port(0).connect(dequantizing_cast.out_port(0))
        sub_zp.in_port(1).connect(zero_point.out_port(0))

        mul_scale = Mul(graph, {
            'name': name + '/mulpiply_by_scale'
        }).create_node()
        mul_scale.in_port(0).connect(sub_zp.out_port(0))
        mul_scale.in_port(1).connect(scale.out_port(0))

        fake_quantize.out_port(0).get_connection().set_source(
            mul_scale.out_port(0))

        graph.remove_nodes_from([fake_quantize.id, fake_quantize.out_node(0)])