Пример #1
0
    def replace_pattern(self, graph: Graph, match: dict):
        condition = match['cond_data']
        true_value = match['Switch_2_input']
        false_value = match['identity_data']

        select = Select(
            graph, dict(name=match['Merge'].name + '/Select/',
                        format='tf')).create_node(
                            inputs=[condition, true_value, false_value])

        match['Merge'].out_port(0).get_connection().set_source(
            select.out_port(0))

        # Reconnect inputs to some_op
        op = match['some_op']
        assert 1 in op.in_ports() and 0 in op.in_ports()

        op.in_port(0).disconnect()
        op.in_port(1).disconnect()
        match['Switch'].in_port(0).get_connection().set_destination(
            op.in_port(0))
        match['Switch_1'].in_port(0).get_connection().set_destination(
            op.in_port(1))

        graph.remove_nodes_from(nodes=[
            match['Switch_1'].id, match['Switch'].id, match['Switch_2'].id,
            match['Merge'].id
        ])
    def replace_sub_graph(self, graph: Graph, match: dict):
        tf_slice_node = match['op']
        slice_name = tf_slice_node.soft_get('name', tf_slice_node.id)
        slice_node = Slice(graph).create_node()
        rename_nodes([(tf_slice_node, slice_name + '/to_be_removed'),
                      (slice_node, slice_name)])
        ends_node = Add(graph, {'name': slice_name + '/ends'}).create_node()

        # reconnect input, begin, and size from TFSlice to the subgraph with Slice
        tf_slice_node.in_port(0).get_connection().set_destination(
            slice_node.in_port(0))
        tf_slice_node.in_port(1).get_connection().set_destination(
            slice_node.in_port(1))
        tf_slice_node.in_port(2).get_connection().set_destination(
            ends_node.in_port(0))
        slice_node.in_port(1).get_connection().add_destination(
            ends_node.in_port(1))

        max_ends = Shape(graph, {
            'name': slice_name + '/ShapeOf'
        }).create_node()
        slice_node.in_port(0).get_connection().add_destination(
            max_ends.in_port(0))

        # check if size[i] == -1, will be applied elementwisely: len(size) = len(begin) = input_rank
        where_max_ends_is_needed = create_op_with_const_inputs(
            graph, Equal, {0: int64_array(-1)},
            {'name': slice_name + '/where_max_ends_is_needed'})
        ends_node.in_port(0).get_connection().add_destination(
            where_max_ends_is_needed.in_port(1))
        # select requires equal dtypes, need to convert ends to I64
        ends_casted_to_i64 = Cast(graph, {
            'name': slice_name + '/CastToI64',
            'dst_type': np.int64
        }).create_node([ends_node])
        # if size[i] == 1 then take max_ends values
        correct_ends = Select(graph, {
            'name': slice_name + '/chosen_ends'
        }).create_node(
            [where_max_ends_is_needed, max_ends, ends_casted_to_i64])
        correct_ends.out_port(0).connect(slice_node.in_port(2))

        tf_slice_node.out_port(0).get_connection().set_source(
            slice_node.out_port(0))
Пример #3
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['op']
        slice_name = node.soft_get('name', node.id)
        slice_node = Slice(graph).create_node()
        rename_nodes([(node, slice_name + '/to_be_removed'), (slice_node, slice_name)])

        eq_node = Equal(graph, {'name': slice_name + '/equal'}).create_node()
        minus_one_node = Const(graph, {'name': slice_name + '/minus_one', 'value': np.array(-1)}).create_node()
        int32_max_node = Const(graph, {'name': slice_name + '/int32_max', 'value': np.iinfo(np.int32).max}).create_node()
        select_node = Select(graph, {'name': slice_name + '/select'}).create_node()

        # node to convert sizes to ends
        sum_node = Add(graph, {'name': slice_name + '/end_const'}).create_node()

        # reconnect input from tfslice to slice
        node.in_port(0).get_source().connect(slice_node.in_port(0))
        node.in_port(0).disconnect()
        # reconnect begin of tfslice to start of slice
        node.in_port(1).get_source().connect(slice_node.in_port(1))
        node.in_port(1).disconnect()

        # (size -> ends) reconnect begins and sizes to sum to evaluate ends for Slice
        # connects begins to slice
        slice_node.in_port(1).get_source().connect(sum_node.in_port(0))
        node.in_port(2).get_source().connect(sum_node.in_port(1))
        node.in_port(2).disconnect()

        # if size[i] == -1 when take int32_max as end[i]
        sum_node.in_port(1).get_source().connect(eq_node.in_port(0))
        minus_one_node.out_port(0).connect(eq_node.in_port(1))
        # from equal to 0 port of select
        eq_node.out_port(0).connect(select_node.in_port(0))
        # from int32_max to 1 of select
        int32_max_node.out_port(0).connect(select_node.in_port(1))
        # from sum to 2nd of select
        sum_node.out_port(0).connect(select_node.in_port(2))
        # out of select to end (2nd of slice)
        select_node.out_port(0).connect(slice_node.in_port(2))

        cast = Cast(graph, dict(name=sum_node.name + '/CastToI64', dst_type=np.int64)).create_node()
        select_node.in_port(2).get_connection().insert_node(cast)

        node.out_port(0).get_connection().set_source(slice_node.out_port(0))
    def find_and_replace_pattern(self, graph: Graph):
        for merge in graph.get_op_nodes(op='Merge'):
            for merge_switch_in_port in range(2):
                if merge.in_port(merge_switch_in_port).disconnected() or \
                        merge.in_port(merge_switch_in_port).get_source().node.op != 'Switch':
                    continue
                switch_2 = merge.in_port(
                    merge_switch_in_port).get_source().node

                if merge.in_port(1 - merge_switch_in_port).disconnected() or \
                        merge.in_port(1 - merge_switch_in_port).get_source().node.op != 'Identity':
                    continue
                false_value_port = merge.in_port(
                    1 - merge_switch_in_port).get_source()

                true_value_port = switch_2.in_port(0).get_source()
                op = false_value_port.node.in_port(0).get_source().node

                if op.in_port(0).disconnected(
                ) or op.in_port(0).get_source().node.op != 'Switch':
                    continue
                switch = op.in_port(0).get_source().node

                if op.in_port(1).disconnected(
                ) or op.in_port(1).get_source().node.op != 'Switch':
                    continue
                switch_1 = op.in_port(1).get_source().node

                if switch.in_port(1).get_source() == switch_1.in_port(1).get_source() and \
                        switch.in_port(1).get_source() == switch_2.in_port(1).get_source():
                    select = Select(
                        graph,
                        dict(name=merge.soft_get('name') + '/Select/',
                             format='tf')).create_node()
                    select.in_port(0).connect(switch.in_port(1).get_source())
                    select.in_port(1).connect(true_value_port)
                    select.in_port(2).connect(false_value_port)

                    merge.out_port(0).get_connection().set_source(
                        select.out_port(0))

                    assert 1 in op.in_ports() and 0 in op.in_ports()

                    op.in_port(0).disconnect()
                    op.in_port(1).disconnect()

                    switch.in_port(0).get_connection().set_destination(
                        op.in_port(0))
                    switch_1.in_port(0).get_connection().set_destination(
                        op.in_port(1))

                    graph.remove_nodes_from(
                        nodes=[switch_1.id, switch.id, switch_2.id, merge.id])
                    # need to exit from the inner for loop because the Merge op has been removed
                    break
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']

        if node.name == 'iteration_number_out':
            return

        # calculate length of context when state of inference becomes meaningful
        inputs = []
        for n in graph.get_op_nodes(**{'op': 'Parameter'}):
            inputs.append(n)

        in_nodes = []
        for inp in inputs:
            for ins in inp.out_port(0).get_destinations():
                in_nodes.append(ins.node.name)

        context_len = 1
        try:
            subgraph = invert_sub_graph_between_nodes(
                graph, [node.in_port(0).get_source().node.name], in_nodes)
        except Error:
            return

        for n in subgraph:
            n_node = Node(graph, n)
            if n_node.kind == 'op' and n_node.op == 'Splice':
                context_len += len(n_node.context) - 1

        if context_len == 1:
            return

        in_node_port = node.in_port(0).get_source()
        in_node_shape = node.in_port(0).data.get_shape()
        node.in_port(0).disconnect()

        # add Select before saving state to avoid saving garbage
        select_node = Select(graph, {
            'name': 'select_' + node.name
        }).create_node()
        zero_else = Const(graph, {
            'name': 'zero_else',
            'value': np.zeros(in_node_shape)
        }).create_node()
        select_node.in_port(1).connect(in_node_port)
        select_node.in_port(2).connect(zero_else.out_port(0))

        # check if we have already appropriate iteration counter
        existing_counters = find_pattern_matches(
            graph,
            nodes=[('mem_in', dict(op='ReadValue')),
                   ('mem_in_data', dict(shape=int64_array([context_len]))),
                   ('crop_mem_in',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([1]),
                         dim=int64_array([context_len - 1]))),
                   ('crop_mem_in_data', dict()),
                   ('concat', dict(op='Concat', axis=1)),
                   ('concat_data', dict()), ('const_1', dict(op='Const')),
                   ('const_1_data', dict()), ('mem_out', dict(op='Assign')),
                   ('crop_out',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([0]),
                         dim=int64_array([1]))), ('crop_out_data', dict()),
                   ('select', dict(op='Select'))],
            edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'),
                   ('crop_mem_in', 'crop_mem_in_data'),
                   ('crop_mem_in_data', 'concat', {
                       'in': 0
                   }), ('const_1', 'const_1_data'),
                   ('const_1_data', 'concat', {
                       'in': 1
                   }), ('concat', 'concat_data'), ('concat_data', 'mem_out'),
                   ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
                   ('crop_out_data', 'select')])
        counter_match = next(existing_counters, None)
        if counter_match is not None:
            ones = Node(graph, inverse_dict(counter_match)['const_1'])
            input_port = Node(
                graph,
                inverse_dict(counter_match)['crop_out']).out_port(0)
        else:
            init_value_mem_out = create_zero_value_with_batch_from_input(
                in_node_port, context_len, np.int32)
            mem_out = ReadValue(
                graph, {
                    'name': 'iteration_number',
                    'variable_id': 'iteration_' + node.name
                }).create_node()
            mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
            cut_first = Crop(
                graph, {
                    'name': 'cut_first',
                    'axis': int64_array([1]),
                    'offset': int64_array([1]),
                    'dim': int64_array([context_len - 1])
                }).create_node()
            cut_first.in_port(0).connect(mem_out.out_port(0))
            ones = Const(graph, {
                'name': 'ones',
                'value': np.ones([1, 1], dtype=np.int32)
            }).create_node()
            concat = Concat(graph, {
                'name': 'concat_ones',
                'in_ports_count': 2,
                'axis': 1
            }).create_node()
            concat.in_port(0).connect(cut_first.out_port(0))
            concat.in_port(1).connect(ones.out_port(0))
            mem_in = Assign(
                graph, {
                    'name': 'iteration_number_out',
                    'variable_id': 'iteration_' + node.name
                }).create_node()
            mem_in.in_port(0).connect(concat.out_port(0))
            res = Result(graph, {}).create_node()
            mem_in.out_port(0).connect(res.in_port(0))
            cut_last = Crop(
                graph, {
                    'name': 'cut_last',
                    'axis': int64_array([1]),
                    'offset': int64_array([0]),
                    'dim': int64_array([1])
                }).create_node()
            cut_last.in_port(0).connect(concat.out_port(0))
            input_port = cut_last.out_port(0)

        # Check if data from memory is 1
        # if it is True, we have correct data and should proceed with saving it to memory
        # else we have not gathered context and have garbage here, shouldn't change initial state of memory
        cast_in = Equal(graph, {
            'name': input_port.node.name + '/cast_to_bool'
        }).create_node()
        cast_in.in_port(0).connect(ones.out_port(0))
        cast_in.in_port(1).connect(input_port)
        select_node.in_port(0).connect(cast_in.out_port(0))
        select_node.out_port(0).connect(node.in_port(0))
        select_node.out_port(0).data.set_shape(in_node_shape)
Пример #6
0
    def find_and_replace_pattern(self, graph: Graph):
        for embedding_segments_mean in graph.get_op_nodes(
                op='EmbeddingSegmentsMean'):
            embedding_segments_mean_name = embedding_segments_mean.soft_get(
                'name', embedding_segments_mean.id)
            embedding_table_input = embedding_segments_mean.in_port(0)
            segment_ids_input = embedding_segments_mean.in_port(2)
            num_segments_input = embedding_segments_mean.in_port(3)

            # TODO: support EmbeddingSegmentsMean with specified weights vector.
            # now this case has not appeared in models so far so EmbeddingSegmentsOperation fusion
            # transformations do not handle it either
            if embedding_segments_mean.is_in_port_connected(5):
                return

            # 1. compute indices membership matrix, i.e. which indices belong to some object
            # the shape of this matrix is [num_segments, num_indices]
            non_norm_range_1_to_num_segments = create_op_with_const_inputs(
                graph, Range, {
                    0: int64_array(0),
                    2: int64_array(1)
                }, {
                    'name':
                    embedding_segments_mean_name + '/Range1ToNumSegments',
                    'output_type': np.int64
                })
            num_segments_input.get_connection().add_destination(
                non_norm_range_1_to_num_segments.in_port(1))

            range_1_to_num_segments = ConvertLike(graph, {
                'name':
                embedding_segments_mean_name + '/Range1ToNumSegmentsNorm'
            }).create_node()
            range_1_to_num_segments.in_port(0).connect(
                non_norm_range_1_to_num_segments.out_port(0))
            num_segments_input.get_connection().add_destination(
                range_1_to_num_segments.in_port(1))

            unsqueeze_range_1_to_num_segments = create_op_with_const_inputs(
                graph, Unsqueeze, {1: int64_array(1)}, {
                    'name':
                    embedding_segments_mean_name +
                    '/Range1ToNumSegmentsUnsqueeze'
                })
            unsqueeze_range_1_to_num_segments.in_port(0).connect(
                range_1_to_num_segments.out_port(0))
            unsqueeze_segment_ids = create_op_with_const_inputs(
                graph, Unsqueeze, {1: int64_array(0)}, {
                    'name':
                    embedding_segments_mean_name + '/SegmentIdsUnsqueeze'
                })
            segment_ids_input.get_connection().add_destination(
                unsqueeze_segment_ids.in_port(0))
            boolean_membership_matrix = Equal(graph, {
                'name':
                embedding_segments_mean_name + '/BooleanMembershipMatrix'
            }).create_node()
            boolean_membership_matrix.in_port(0).connect(
                unsqueeze_range_1_to_num_segments.out_port(0))
            boolean_membership_matrix.in_port(1).connect(
                unsqueeze_segment_ids.out_port(0))
            shape_of_membership_matrix = Shape(graph, {
                'name':
                embedding_segments_mean_name + '/ShapeOfMembershipMatrix'
            }).create_node([boolean_membership_matrix])
            one_scalar_constant = Const(
                graph, {
                    'name': embedding_segments_mean_name + '/OneScalar',
                    'value': int64_array([1])
                }).create_node()
            one_constant = Broadcast(graph, {
                'name':
                embedding_segments_mean_name + '/One'
            }).create_node([one_scalar_constant, shape_of_membership_matrix])
            zero_constant = Const(
                graph, {
                    'name': embedding_segments_mean_name + '/Zero',
                    'value': int64_array(0)
                }).create_node()
            membership_matrix = Select(
                graph, {
                    'name': embedding_segments_mean_name + '/MembershipMatrix',
                    'auto_broadcast': 'numpy'
                }).create_node(
                    [boolean_membership_matrix, one_constant, zero_constant])

            # 2. compute a number of indices belong to each object from the batch
            # it computes the normalization coefficients
            num_indices_per_object = create_op_with_const_inputs(
                graph, ReduceSum, {1: int64_array(1)}, {
                    'name':
                    embedding_segments_mean_name + '/NumIndicesPerObject'
                })
            num_indices_per_object.in_port(0).connect(
                membership_matrix.out_port(0))

            # 3. replace zero coefficient (zero number of indices belong to an object) with one
            # because for such object the single default embedding vector is used
            where_zero_number = Equal(graph, {
                'name':
                embedding_segments_mean_name + '/WhereZeroIndicesNumber'
            }).create_node([num_indices_per_object, zero_constant])
            normalized_num_indices_per_object = Select(
                graph, {
                    'name':
                    embedding_segments_mean_name + '/NormNumIndicesPerObject',
                    'auto_broadcast': 'numpy'
                }).create_node([
                    where_zero_number, one_scalar_constant,
                    num_indices_per_object
                ])

            # 4. cast normalized_num_indices_per_object to the same type as embedding vector table
            norm_coefficients = ConvertLike(
                graph, {
                    'name': embedding_segments_mean_name + '/NormCoefficients'
                }).create_node()
            norm_coefficients.in_port(0).connect(
                normalized_num_indices_per_object.out_port(0))
            embedding_table_input.get_connection().add_destination(
                norm_coefficients.in_port(1))

            # 5. replace EmbeddingSegmentMean with EmbeddingSegmentSum
            embedding_segments_sum = EmbeddingSegmentsSum(
                graph, {
                    'name':
                    embedding_segments_mean_name + '/EmbeddingSegmentsSum'
                }).create_node()
            for in_port in embedding_segments_mean.in_ports():
                if embedding_segments_mean.is_in_port_connected(in_port):
                    embedding_segments_mean.in_port(
                        in_port).get_connection().set_destination(
                            embedding_segments_sum.in_port(in_port))

            # 6. normalize EmbeddingSegmentSum results by computed coefficients
            result_node = Div(graph, {
                'name': embedding_segments_mean_name + '/Div'
            }).create_node([embedding_segments_sum, norm_coefficients])
            embedding_segments_mean.out_port(0).get_connection().set_source(
                result_node.out_port(0))

            rename_nodes([(embedding_segments_mean,
                           embedding_segments_mean_name + '/AbandonedName'),
                          (result_node, embedding_segments_mean_name)])
            graph.remove_nodes_from([embedding_segments_mean.id])
Пример #7
0
    def replace(node: Node, const: Node):
        graph = node.graph
        shape = const.shape
        const_name = const.soft_get('name', const.id)

        non_one_dims = np.argwhere(shape != 1).flatten()
        one_dims = np.argwhere(shape == 1).flatten()

        if not (non_one_dims.size == 1 and 5 < np.prod(shape) < 500):
            # (5;500) range is deduced to affect less models
            return

        value = const.value
        if not np.array_equal(
                np.arange(0, np.prod(shape), 1).reshape(shape), value):
            return

        positive_idx = non_one_dims.item(0)
        negative_idx = positive_idx - len(shape)

        node_name = node.soft_get('name', node.id)
        gather = create_op_with_const_inputs(
            graph, Gather, {
                1: int64_array(negative_idx),
                2: int64_array(0)
            }, {'name': node_name + '/BroadcastingDim'})
        gather_for_const = create_op_with_const_inputs(
            graph, Gather, {
                1: int64_array(negative_idx),
                2: int64_array(0)
            }, {'name': const_name + '/BroadcastingDim'})
        shapeof_node = Shape(graph, {
            'name': const_name + '/ShapeOf'
        }).create_node()
        shapeof_node.out_port(0).connect(gather_for_const.in_port(0))

        equal_node = create_op_with_const_inputs(
            graph, Equal, {1: int64_array(1)},
            {'name': node_name + '/ConstOne'})
        gather.out_port(0).connect(equal_node.in_port(0))

        select_node = Select(graph, {
            'name': node_name + '/Select',
            'auto_broadcast': 'numpy'
        }).create_node([equal_node, gather_for_const, gather])

        const.out_port(0).connect(shapeof_node.in_port(0))

        range_node = create_op_with_const_inputs(
            graph, Range, {
                0: np.array(0, dtype=value.dtype),
                2: np.array(1, dtype=value.dtype)
            }, {
                'name': const_name + '/Range',
                'dtype': value.dtype
            })
        select_node.out_port(0).connect(range_node.in_port(1))

        node.in_port(1).get_connection().add_destination(gather.in_port(0))

        node.in_port(0).get_connection().set_source(range_node.out_port(0))

        if one_dims.size:
            unsqueeze = create_op_node_with_second_input(
                graph, Unsqueeze, one_dims,
                {'name': const_name + '/KeepShape'})
            range_node.out_port(0).get_connection().insert_node(unsqueeze)
            rename_nodes([(const, const_name + '/ToBeDeleted'),
                          (unsqueeze, const_name)])
        else:
            rename_nodes([(const, const_name + '/ToBeDeleted'),
                          (range_node, const_name)])
Пример #8
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']

        if node.name == 'iteration_number_out':
            return

        # calculate length of context when state of inference becomes meaningful
        inputs = []
        for n in graph.get_op_nodes(**{'op': 'Parameter'}):
            inputs.append(n)

        in_nodes = []
        for inp in inputs:
            for ins in inp.out_port(0).get_destinations():
                in_nodes.append(ins.node.name)

        context_len = 1
        try:
            subgraph = invert_sub_graph_between_nodes(
                graph, [node.in_port(0).get_source().node.name], in_nodes)
        except Error:
            return

        for n in subgraph:
            n_node = Node(graph, n)
            if n_node.kind == 'op' and n_node.op == 'Splice':
                context_len += len(n_node.context) - 1

        if context_len == 1:
            return

        in_node_port = node.in_port(0).get_source()
        in_node_shape = node.in_port(0).data.get_shape()
        node.in_port(0).disconnect()

        # add Select before saving state to avoid saving garbage
        select_node = Select(graph, {
            'name': 'select_' + node.name
        }).create_node()
        zero_else = Const(graph, {
            'name': 'zero_else',
            'value': np.zeros(in_node_shape)
        }).create_node()
        select_node.in_port(1).connect(in_node_port)
        select_node.in_port(2).connect(zero_else.out_port(0))

        # check if we have already appropriate iteration counter
        existing_counters = find_pattern_matches(
            graph,
            nodes=[('mem_in',
                    dict(op='Memory',
                         index=1,
                         shape=int64_array([context_len]))),
                   ('mem_in_data', dict()),
                   ('crop_mem_in',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([1]),
                         dim=int64_array([context_len - 1]))),
                   ('crop_mem_in_data', dict()),
                   ('concat', dict(op='Concat', axis=1)),
                   ('concat_data', dict()), ('const_1', dict(op='Const')),
                   ('const_1_data', dict()),
                   ('mem_out',
                    dict(op='Memory',
                         index=0,
                         shape=int64_array([context_len]))),
                   ('crop_out',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([0]),
                         dim=int64_array([1]))), ('crop_out_data', dict()),
                   ('select', dict(op='Select'))],
            edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'),
                   ('crop_mem_in', 'crop_mem_in_data'),
                   ('crop_mem_in_data', 'concat', {
                       'in': 0
                   }), ('const_1', 'const_1_data'),
                   ('const_1_data', 'concat', {
                       'in': 1
                   }), ('concat', 'concat_data'), ('concat_data', 'mem_out'),
                   ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
                   ('crop_out_data', 'select')])
        counter_match = next(existing_counters, None)
        if counter_match is not None:
            input_port = Node(
                graph,
                inverse_dict(counter_match)['crop_out']).out_port(0)
        else:
            mem_out = Memory(
                graph, {
                    'name': 'iteration_number',
                    'size': 2,
                    'index': 1,
                    'id': 'iteration_' + node.name,
                    'shape': int64_array([context_len]),
                    'dst_type': np.int32
                }).create_node()
            cut_first = Crop(
                graph, {
                    'name': 'cut_first',
                    'axis': int64_array([1]),
                    'offset': int64_array([1]),
                    'dim': int64_array([context_len - 1])
                }).create_node()
            cut_first.in_port(0).connect(mem_out.out_port(0))
            ones = Const(graph, {
                'name': 'ones',
                'value': np.ones([1, 1], dtype=np.int32)
            }).create_node()
            concat = Concat(graph, {
                'name': 'concat_ones',
                'in_ports_count': 2,
                'axis': 1
            }).create_node()
            concat.in_port(0).connect(cut_first.out_port(0))
            concat.in_port(1).connect(ones.out_port(0))
            mem_in = Memory(
                graph, {
                    'name': 'iteration_number_out',
                    'size': 2,
                    'index': 0,
                    'id': 'iteration_' + node.name,
                    'shape': int64_array([context_len])
                }).create_node()
            mem_in.in_port(0).connect(concat.out_port(0))
            res = Result(graph, {}).create_node()
            mem_in.out_port(0).connect(res.in_port(0))
            cut_last = Crop(
                graph, {
                    'name': 'cut_last',
                    'axis': int64_array([1]),
                    'offset': int64_array([0]),
                    'dim': int64_array([1])
                }).create_node()
            cut_last.in_port(0).connect(concat.out_port(0))
            input_port = cut_last.out_port(0)

        select_node.in_port(0).connect(input_port)
        select_node.out_port(0).connect(node.in_port(0))
        select_node.out_port(0).data.set_shape(in_node_shape)
Пример #9
0
    def replace_sub_graph(self, graph: Graph, match: Dict[str, Node]):
        node = match['op']
        name = node.name

        # Zero Point Nudging : Scale counting
        f_min = node.in_port(1).get_source()
        node.in_port(1).disconnect()
        f_max = node.in_port(2).get_source()
        node.in_port(2).disconnect()

        f_diff = Sub(graph, {'name': name + '/float_range'}).create_node()
        f_max.connect(f_diff.in_port(0))
        f_min.connect(f_diff.in_port(1))

        quant_min_value = int(node.narrow_range)
        quant_max_value = 2 ** node.num_bits - 1
        i_diff = Const(graph, dict(name=name + '/int_range', value=quant_max_value - quant_min_value)).create_node()

        scale = Div(graph, {'name': name + '/scale'}).create_node()
        f_diff.out_port(0).connect(scale.in_port(0))
        i_diff.out_port(0).connect(scale.in_port(1))

        # Zero Point Nudging : ZP from min counting
        descaled_min = Div(graph, {'name': name + '/descaled_min'}).create_node()
        f_min.connect(descaled_min.in_port(0))
        scale.out_port(0).connect(descaled_min.in_port(1))

        zero_point_from_min = Sub(graph, {'name': name + '/zero_point_from_min'}).create_node()
        quant_min = Const(graph, {'value': quant_min_value, 'name': name + '/quant_min'}).create_node()
        quant_min.out_port(0).connect(zero_point_from_min.in_port(0))
        descaled_min.out_port(0).connect(zero_point_from_min.in_port(1))

        # Zero Point Nudging : Nudged Zero Point counting
        zp_lesser_q_mi = Less(graph, {'name': name + '/zero_point_from_min_less_quant_min'}).create_node()
        zero_point_from_min.out_port(0).connect(zp_lesser_q_mi.in_port(0))
        quant_min.out_port(0).connect(zp_lesser_q_mi.in_port(1))

        zp_greater_q_ma = Greater(graph, {'name': name + '/zero_point_from_min_greater_quant_max'}).create_node()
        zero_point_from_min.out_port(0).connect(zp_greater_q_ma.in_port(0))
        quant_max = Const(graph, {'value': quant_max_value, 'name': name + '/quant_max'}).create_node()
        quant_max.out_port(0).connect(zp_greater_q_ma.in_port(1))

        rounded_zero_point_from_min = Round(graph, {'name': name + '/zero_point_from_min_rounding'}).create_node()
        zero_point_from_min.out_port(0).connect(rounded_zero_point_from_min.in_port(0))

        nudged_zero_point = Select(graph, {'name': name + '/nudging_zp_1_select_less_condition'}).create_node()
        greater_condition = Select(graph, {'name': name + '/nudging_zp_2_select_greater_condition'}).create_node()

        greater_condition.in_port(0).connect(zp_greater_q_ma.out_port(0))
        greater_condition.in_port(1).connect(quant_max.out_port(0))
        greater_condition.in_port(2).connect(rounded_zero_point_from_min.out_port(0))

        nudged_zero_point.in_port(0).connect(zp_lesser_q_mi.out_port(0))
        nudged_zero_point.in_port(1).connect(quant_max.out_port(0))
        nudged_zero_point.in_port(2).connect(greater_condition.out_port(0))

        nudged_i_min = Sub(graph, {'name': name + '/nudged_i_min'}).create_node()
        quant_min.out_port(0).connect(nudged_i_min.in_port(0))
        nudged_zero_point.out_port(0).connect(nudged_i_min.in_port(1))

        nudged_i_max = Sub(graph, {'name': name + '/nudged_i_max'}).create_node()
        quant_max.out_port(0).connect(nudged_i_max.in_port(0))
        nudged_zero_point.out_port(0).connect(nudged_i_max.in_port(1))

        nudged_min = Mul(graph, {'name': name + '/nudged_min'}).create_node()
        nudged_i_min.out_port(0).connect(nudged_min.in_port(0))
        scale.out_port(0).connect(nudged_min.in_port(1))

        nudged_max = Mul(graph, {'name': name + '/nudged_max'}).create_node()
        nudged_i_max.out_port(0).connect(nudged_max.in_port(0))
        scale.out_port(0).connect(nudged_max.in_port(1))

        nudged_min.out_port(0).connect(node.in_port(1))
        nudged_max.out_port(0).connect(node.in_port(2))

        # FakeQuantize operation has 5 inputs instead of 3 inputs in TensorFlow
        node.add_input_port(3, skip_if_exist=True)
        node.add_input_port(4, skip_if_exist=True)

        node.in_port(3).connect(nudged_min.out_port(0))
        node.in_port(4).connect(nudged_max.out_port(0))

        FakeQuantize.update_node_stat(node, {'levels': node['levels']})
Пример #10
0
    def dequantize_data(fake_quantize: Node, dst_type: type,
                        quantized_type: type) -> Node:
        graph = fake_quantize.graph
        quantized_data = fake_quantize.in_port(0).get_source().node
        name = fake_quantize.soft_get('name', fake_quantize.id)

        assert quantized_data.soft_get('type') == 'Convert' and quantized_data.dst_type == quantized_type, \
            'Weights aren`t compressed as expected for node {}'.format(fake_quantize.soft_get('name', fake_quantize.id))

        dequantizing_cast = Cast(
            graph,
            dict(name=quantized_data.name +
                 "/to_{}".format(np_data_type_to_destination_type(dst_type)),
                 dst_type=dst_type,
                 stop_value_propagation=True)).create_node()
        fake_quantize.in_port(0).get_connection().set_destination(
            dequantizing_cast.in_port(0))

        # limits of dequantize
        in_low = fake_quantize.in_port(1).get_source()
        in_high = fake_quantize.in_port(2).get_source()
        out_low = fake_quantize.in_port(3).get_source()
        out_high = fake_quantize.in_port(4).get_source()

        # scale calculation
        output_range = Sub(graph, {
            'name': name + '/output_range'
        }).create_node()
        output_range.in_port(0).connect(out_high)
        output_range.in_port(1).connect(out_low)

        input_range = Sub(graph, {'name': name + '/input_range'}).create_node()
        input_range.in_port(0).connect(in_high)
        input_range.in_port(1).connect(in_low)

        scale = Div(graph, {'name': name + '/scale'}).create_node()
        scale.in_port(0).connect(output_range.out_port(0))
        scale.in_port(1).connect(input_range.out_port(0))

        # shift calculation
        descaled_output_low = Div(graph, {
            'name': name + '/descaled_output_low'
        }).create_node()
        descaled_output_low.in_port(0).connect(out_low)
        descaled_output_low.in_port(1).connect(scale.out_port(0))

        shift = Sub(graph, {'name': name + '/shift'}).create_node()
        shift.in_port(0).connect(in_low)
        shift.in_port(1).connect(descaled_output_low.out_port(0))

        zero = Const(graph, {
            'name': name + '/zero',
            'value': np.array(0, dtype=dst_type)
        }).create_node()
        scale_eq_zero = Equal(graph, {
            'name': name + '/scale_eq_zero'
        }).create_node()
        scale_eq_zero.in_port(0).connect(scale.out_port(0))
        scale_eq_zero.in_port(1).connect(zero.out_port(0))

        zero_point = Select(graph, {
            'name': name + '/zero_point'
        }).create_node()
        zero_point.in_port(0).connect(scale_eq_zero.out_port(0))
        zero_point.in_port(1).connect(zero.out_port(0))
        zero_point.in_port(2).connect(shift.out_port(0))

        # DeQuantize(x) == Mul(Sub(x, zero_point), scale)
        sub_zp = Sub(graph, {'name': name + '/minus_zp'}).create_node()
        sub_zp.in_port(0).connect(dequantizing_cast.out_port(0))
        sub_zp.in_port(1).connect(zero_point.out_port(0))

        mul_scale = Mul(graph, {
            'name': name + '/mulpiply_by_scale'
        }).create_node()
        mul_scale.in_port(0).connect(sub_zp.out_port(0))
        mul_scale.in_port(1).connect(scale.out_port(0))

        fake_quantize.out_port(0).get_connection().set_source(
            mul_scale.out_port(0))

        graph.remove_nodes_from([fake_quantize.id, fake_quantize.out_node(0)])
Пример #11
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        seq_len_tf = match['seq_len']
        transpose_tf = match['transpose']
        ctc_greedy_decoder_tf = match['ctc_greedy_decoder']
        cast_tf = match['cast']
        ctc_loss_tf = match['ctc_loss']
        sparse_to_dense_tf = match['sparse_to_dense']

        output_sparse_to_dense_name = sparse_to_dense_tf.soft_get(
            'name', sparse_to_dense_tf.id)
        output_ctc_loss_name = ctc_loss_tf.soft_get('name', ctc_loss_tf.id)
        ctc_greedy_decoder_tf_name = ctc_greedy_decoder_tf.soft_get(
            'name', ctc_greedy_decoder_tf.id)

        log.debug(
            'Found CTCLossFrontReplacer pattern after {} with name {}'.format(
                ctc_greedy_decoder_tf.op, ctc_greedy_decoder_tf.name))

        # create sequence mask node, sub-graph for transforming into sequence length and connect with consumers
        seq_len_tf_shape = seq_len_tf.soft_get('shape', None)
        if seq_len_tf_shape is None or len(seq_len_tf_shape) != 2:
            raise Error(
                'The sequence length that is the second input to the CTCGreedyDecoder node "{}"'
                ' must be specified in a mask format.'.format(
                    ctc_greedy_decoder_tf_name))
        log.error(
            'The format of input sequence length has been changed to a mask format',
            extra={'is_warning': True})
        seq_len_tf_type = seq_len_tf.soft_get('data_type', None)
        seq_len_tf_name = seq_len_tf.soft_get('name', seq_len_tf.id)
        seq_mask_placeholder = Parameter(
            graph, {
                'name': seq_len_tf_name,
                'shape': seq_len_tf_shape,
                'data_type': seq_len_tf_type
            }).create_node()
        reduce_to_seq_len_node = create_op_with_const_inputs(
            graph, ReduceSum, {1: np.array(1, dtype=np.int32)}, {
                'name': seq_len_tf_name + '/ReduceToSeqLen',
                'keep_dims': False
            })
        reduce_to_seq_len_node.in_port(0).connect(
            seq_mask_placeholder.out_port(0))
        seq_len_tf.out_port(0).get_connection().set_source(
            reduce_to_seq_len_node.out_port(0))

        cast_fp_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)
        casted_seq_mask_node = Cast(graph, {
            'name': seq_len_tf_name + '/CastToFP32',
            'dst_type': cast_fp_type
        }).create_node()
        casted_seq_mask_node.in_port(0).connect(
            seq_mask_placeholder.out_port(0))
        permuted_casted_seq_mask = create_op_with_const_inputs(
            graph, Transpose, {1: int64_array([1, 0])},
            {'name': seq_len_tf_name + '/Permute'})
        permuted_casted_seq_mask.in_port(0).connect(
            casted_seq_mask_node.out_port(0))
        rename_nodes([(seq_len_tf, seq_len_tf_name + '/AbandonedName'),
                      (seq_mask_placeholder, seq_len_tf_name)])

        # create CTCGreedyDecoder node and set mask node
        ctc_merge_repeated_i = ctc_greedy_decoder_tf.soft_get(
            'ctc_merge_repeated', ctc_greedy_decoder_tf.id)
        ctc_greedy_decoder = CTCGreedyDecoderOp(
            graph, {
                'name': output_sparse_to_dense_name,
                'ctc_merge_repeated': ctc_merge_repeated_i
            }).create_node()
        ctc_greedy_decoder.in_port(1).connect(
            permuted_casted_seq_mask.out_port(0))
        rename_nodes([(sparse_to_dense_tf,
                       output_sparse_to_dense_name + '/AbandonedName'),
                      (ctc_greedy_decoder, output_sparse_to_dense_name)])

        # create CTCLoss node and set attributes
        assert ctc_loss_tf.has_valid('preprocess_collapse_repeated'), \
            'The CTCLoss node "{}" misses "preprocess_collapse_repeated" attribute'.format(output_ctc_loss_name)
        assert ctc_loss_tf.has_valid('ctc_merge_repeated'), \
            'The CTCLoss node "{}" misses "ctc_merge_repeated" attribute'.format(output_ctc_loss_name)
        assert ctc_loss_tf.has_valid('unique'), \
            'The CTCLoss node "{}" misses "unique" attribute'.format(output_ctc_loss_name)
        preprocess_collapse_repeated = ctc_loss_tf.preprocess_collapse_repeated
        ctc_merge_repeated = ctc_loss_tf.ctc_merge_repeated
        unique = ctc_loss_tf.unique
        ctc_loss = CTCLoss(
            graph, {
                'name': output_ctc_loss_name,
                'preprocess_collapse_repeated': preprocess_collapse_repeated,
                'ctc_merge_repeated': ctc_merge_repeated,
                'unique': unique
            }).create_node()
        rename_nodes([(ctc_loss_tf, output_ctc_loss_name + '/AbandonedName'),
                      (ctc_loss, output_ctc_loss_name)])

        # connect logits
        ctc_greedy_decoder_tf.in_port(0).get_connection().set_destination(
            ctc_greedy_decoder.in_port(0))
        ctc_loss.in_port(0).disconnect()
        transpose_tf.in_port(0).get_connection().add_destination(
            ctc_loss.in_port(0))

        # connect logit lengths
        ctc_greedy_decoder_tf.in_port(1).disconnect()
        ctc_loss.in_port(1).connect(reduce_to_seq_len_node.out_port(0))

        # connect labels to ctc_loss
        squeeze_op = create_op_with_const_inputs(graph, Squeeze,
                                                 {1: int64_array([2, 3])})
        cast_labels_op = Cast(
            graph, {
                'name': output_sparse_to_dense_name + '/CastLabels',
                'dst_type': np.int32
            }).create_node()
        squeeze_op.in_port(0).connect(ctc_greedy_decoder.out_port(0))
        cast_labels_op.in_port(0).connect(squeeze_op.out_port(0))
        ctc_loss.in_port(2).connect(cast_labels_op.out_port(0))

        # connect label lengths
        equal_op = create_op_with_const_inputs(
            graph, Equal, {1: np.array([-1], dtype=np.int32)},
            {'name': output_sparse_to_dense_name + '/Equal'})
        equal_op.in_port(0).connect(cast_labels_op.out_port(0))
        labels_shape_op = Shape(
            graph, {
                'name': output_sparse_to_dense_name + '/ShapeOf'
            }).create_node()
        labels_shape_op.in_port(0).connect(equal_op.out_port(0))
        broadcast_one = create_op_with_const_inputs(
            graph, Broadcast, {0: np.array([1], dtype=np.int32)}, {
                'mode': 'numpy',
                'name': output_sparse_to_dense_name + '/One'
            })
        broadcast_one.in_port(1).connect(labels_shape_op.out_port(0))
        broadcast_zero = create_op_with_const_inputs(
            graph, Broadcast, {0: np.array([0], dtype=np.int32)}, {
                'mode': 'numpy',
                'name': output_sparse_to_dense_name + '/Zero'
            })
        broadcast_zero.in_port(1).connect(labels_shape_op.out_port(0))

        select_node = Select(graph, {
            'name': output_sparse_to_dense_name + '/Select'
        }).create_node()
        select_node.in_port(0).connect(equal_op.out_port(0))
        select_node.in_port(1).connect(broadcast_zero.out_port(0))
        select_node.in_port(2).connect(broadcast_one.out_port(0))
        label_length_node = create_op_with_const_inputs(
            graph,
            ReduceSum, {1: int64_array([1])},
            op_attrs={
                'name': output_sparse_to_dense_name + '/LabelLength',
                'keep_dims': False
            })
        label_length_node.in_port(0).connect(select_node.out_port(0))
        ctc_loss.in_port(3).connect(label_length_node.out_port(0))

        # set source for output of new sub-graph and remove old nodes
        ctc_loss_tf.out_port(0).get_connection().set_source(
            ctc_loss.out_port(0))
        graph.remove_nodes_from([
            ctc_greedy_decoder_tf.id, ctc_loss_tf.id, cast_tf.id,
            sparse_to_dense_tf.id
        ])
Пример #12
0
    def insert_select(graph: Graph, node: Node):
        context_len = node.frame_time + 1

        if context_len == 1:
            return

        in_node_port = node.in_port(0).get_source()
        in_node_shape = node.in_port(0).data.get_shape()
        node.in_port(0).disconnect()

        # add Select before saving state to avoid saving garbage
        select_node = Select(graph, {'name': 'select_' + node.name}).create_node()
        zero_else = create_const_with_batch_from_input(in_node_port, in_node_shape[1])
        select_node.in_port(1).connect(in_node_port)
        select_node.in_port(2).connect(zero_else.out_port(0))

        # check if we have already appropriate iteration counter
        existing_counters = find_pattern_matches(graph, nodes=[('mem_in', dict(op='ReadValue')),
                                                               ('mem_in_data', dict(shape=int64_array([context_len]))),
                                                               ('crop_mem_in', dict(op='Crop', axis=int64_array([1]),
                                                                                    offset=int64_array([1]),
                                                                                    dim=int64_array([context_len - 1]))),
                                                               ('crop_mem_in_data', dict()),
                                                               ('concat', dict(op='Concat', axis=1)),
                                                               ('concat_data', dict()),
                                                               ('const_1', dict(op='Const')),
                                                               ('const_1_data', dict()),
                                                               ('mem_out', dict(op='Assign')),
                                                               ('crop_out', dict(op='Crop', axis=int64_array([1]),
                                                                                 offset=int64_array([0]),
                                                                                 dim=int64_array([1]))),
                                                               ('crop_out_data', dict()),
                                                               ('select', dict(op='Select'))
                                                               ],
                                                 edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'),
                                                        ('crop_mem_in', 'crop_mem_in_data'),
                                                        ('crop_mem_in_data', 'concat', {'in': 0}),
                                                        ('const_1', 'const_1_data'),
                                                        ('const_1_data', 'concat', {'in': 1}),
                                                        ('concat', 'concat_data'), ('concat_data', 'mem_out'),
                                                        ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
                                                        ('crop_out_data', 'select')])
        counter_match = next(existing_counters, None)
        if counter_match is not None:
            ones = Node(graph, inverse_dict(counter_match)['const_1'])
            input_port = Node(graph, inverse_dict(counter_match)['crop_out']).out_port(0)
        else:
            init_value_mem_out = create_const_with_batch_from_input(in_node_port, context_len, precision=np.int32)
            mem_out = ReadValue(graph, {'name': 'iteration_number',
                                        'variable_id': 'iteration_' + node.name}).create_node()
            mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
            cut_first = Crop(graph, {'name': 'cut_first', 'axis': int64_array([1]),
                                     'offset': int64_array([1]), 'dim': int64_array([context_len - 1])}).create_node()
            cut_first.in_port(0).connect(mem_out.out_port(0))
            ones = create_const_with_batch_from_input(in_node_port, 1, 1, np.int32)
            concat = Concat(graph, {'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1}).create_node()
            concat.in_port(0).connect(cut_first.out_port(0))
            concat.in_port(1).connect(ones.out_port(0))
            mem_in = Assign(graph, {'name': 'iteration_number_out',
                                    'variable_id': 'iteration_' + node.name}).create_node()
            mem_in.in_port(0).connect(concat.out_port(0))
            res = Result(graph, {}).create_node()
            mem_in.out_port(0).connect(res.in_port(0))
            cut_last = Crop(graph, {'name': 'cut_last', 'axis': int64_array([1]),
                                    'offset': int64_array([0]), 'dim': int64_array([1])}).create_node()
            cut_last.in_port(0).connect(concat.out_port(0))
            input_port = cut_last.out_port(0)

        # Check if data from memory is 1
        # if it is True, we have correct data and should proceed with saving it to memory
        # else we have not gathered context and have garbage here, shouldn't change initial state of memory
        cast_in = Equal(graph, {'name': input_port.node.name + '/cast_to_bool'}).create_node()
        cast_in.in_port(0).connect(ones.out_port(0))
        cast_in.in_port(1).connect(input_port)
        select_node.in_port(0).connect(cast_in.out_port(0))
        select_node.out_port(0).connect(node.in_port(0))
        select_node.out_port(0).data.set_shape(in_node_shape)