Пример #1
0
 def extract(cls, node):
     attrs = {
         'merge_repeated':
         bool(onnx_attr(node, 'merge_repeated', 'i', default=1)),
     }
     CTCGreedyDecoderSeqLenOp.update_node_stat(node, attrs)
     return cls.enabled
Пример #2
0
 def extract(cls, node):
     attrs = {
         'merge_repeated': bool(node.pb.attr['merge_repeated'].b),
         'output_sparse_format':
         True,  # Special argument for TF CTCGreedyDecoder replacement transformations
     }
     CTCGreedyDecoderSeqLenOp.update_node_stat(node, attrs)
     return cls.enabled
    def test_infer1(self):
        graph = build_graph(nodes_attributes, edges1, inputs1)
        ctcgreedydecoder_node = Node(graph, 'ctcgreedydecoder_node')
        CTCGreedyDecoderSeqLenOp.infer(ctcgreedydecoder_node)

        # prepare reference results
        ref_output1_shape = int64_array([4, 100])

        # get the result
        res_output1_shape = graph.node['output1']['shape']

        self.assertTrue(np.array_equal(ref_output1_shape, res_output1_shape),
                        'shapes do not match expected: {} and given: {}'.format(ref_output1_shape, res_output1_shape))
def replace_ctc_greedy_decoder(graph: Graph, match: dict):
    ctc_greedy_decoder_tf = match['decoder']
    cast = match['cast']
    sparse_to_dense = match['sparse_to_dense']
    sparse_to_dense_name = sparse_to_dense.soft_get('name', sparse_to_dense.id)
    ctc_greedy_decoder_tf_name = ctc_greedy_decoder_tf.soft_get(
        'name', ctc_greedy_decoder_tf.id)

    # For normalizing input channel needs to transpose input data from [T, N, C] to [N, T, C]
    # which supported CTCGreedyDecoderSeqLen op.
    ctc_data_permute = create_op_with_const_inputs(
        graph, Transpose, {1: int64_array([1, 0, 2])},
        {'name': ctc_greedy_decoder_tf_name + '/ctc_data_permute'})

    assert ctc_greedy_decoder_tf.has_valid('merge_repeated'), \
        'The CTCGreedyDecoderSeqLen node "{}" misses "merge_repeated" attribute'.format(ctc_greedy_decoder_tf_name)

    ctc_greedy_decoder_tf.in_port(0).get_source().connect(
        ctc_data_permute.in_port(0))
    merge_repeated_tf = ctc_greedy_decoder_tf.merge_repeated
    ctc_greedy_decoder = CTCGreedyDecoderSeqLenOp(
        graph, {
            'name': sparse_to_dense_name,
            'merge_repeated': merge_repeated_tf
        }).create_node()
    rename_nodes([(sparse_to_dense, sparse_to_dense_name + '/AbandonedName'),
                  (ctc_greedy_decoder, sparse_to_dense_name)])
    ctc_greedy_decoder.in_port(0).connect(ctc_data_permute.out_port(0))
    ctc_greedy_decoder_tf.in_port(1).get_source().connect(
        ctc_greedy_decoder.in_port(1))

    # Set output of the new sub-graph as a source for SparseToDense consumer
    sparse_to_dense.out_port(0).get_connection().set_source(
        ctc_greedy_decoder.out_port(0))

    # Remove no longer needed nodes
    graph.remove_nodes_from(
        [sparse_to_dense.id, cast.id, ctc_greedy_decoder_tf.id])
Пример #5
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        transpose_tf = match['transpose']
        ctc_greedy_decoder_tf = match['ctc_greedy_decoder']
        cast_tf = match['cast']
        ctc_loss_tf = match['ctc_loss']
        sparse_to_dense_tf = match['sparse_to_dense']
        output_sparse_to_dense_name = sparse_to_dense_tf.soft_get(
            'name', sparse_to_dense_tf.id)
        ctc_data_permute = create_op_with_const_inputs(
            graph, Transpose, {1: int64_array([1, 0, 2])},
            {'name': ctc_greedy_decoder_tf.name + '/ctc_data_permute'})
        ctc_data_permute.in_port(0).connect(transpose_tf.out_port(0))

        ctc_greedy_decoder_tf_name = ctc_greedy_decoder_tf.soft_get(
            'name', ctc_greedy_decoder_tf.id)
        assert ctc_greedy_decoder_tf.has_valid('merge_repeated'), \
            'The CTCGreedyDecoderSeqLen node "{}" misses "merge_repeated" attribute'.format(ctc_greedy_decoder_tf_name)
        merge_repeated_tf = ctc_greedy_decoder_tf.merge_repeated
        ctc_greedy_decoder = CTCGreedyDecoderSeqLenOp(
            graph, {
                'name': output_sparse_to_dense_name,
                'merge_repeated': merge_repeated_tf
            }).create_node()
        rename_nodes([(sparse_to_dense_tf,
                       output_sparse_to_dense_name + '/AbandonedName'),
                      (ctc_greedy_decoder, output_sparse_to_dense_name)])
        ctc_greedy_decoder.in_port(0).connect(ctc_data_permute.out_port(0))
        ctc_greedy_decoder.in_port(1).connect(
            ctc_greedy_decoder_tf.in_port(1).get_connection().get_source())

        # set output of the new sub-graph as a source for SparseToDense consumer
        output_ctc_loss_name = ctc_loss_tf.soft_get('name', ctc_loss_tf.id)
        assert ctc_loss_tf.has_valid('preprocess_collapse_repeated'), \
            'The CTCLoss node "{}" misses "preprocess_collapse_repeated" attribute'.format(output_ctc_loss_name)
        assert ctc_loss_tf.has_valid('ctc_merge_repeated'), \
            'The CTCLoss node "{}" misses "ctc_merge_repeated" attribute'.format(output_ctc_loss_name)
        assert ctc_loss_tf.has_valid('unique'), \
            'The CTCLoss node "{}" misses "unique" attribute'.format(output_ctc_loss_name)
        preprocess_collapse_repeated = ctc_loss_tf.preprocess_collapse_repeated
        ctc_merge_repeated = ctc_loss_tf.ctc_merge_repeated
        unique = ctc_loss_tf.unique
        ctc_loss = CTCLoss(
            graph, {
                'name': output_ctc_loss_name,
                'preprocess_collapse_repeated': preprocess_collapse_repeated,
                'ctc_merge_repeated': ctc_merge_repeated,
                'unique': unique
            }).create_node()
        rename_nodes([(ctc_loss_tf, output_ctc_loss_name + '/AbandonedName'),
                      (ctc_loss, output_ctc_loss_name)])
        ctc_loss_tf.out_port(0).get_connection().set_source(
            ctc_loss.out_port(0))
        if ctc_loss_tf.logits_time_major:
            ctc_loss.in_port(0).connect(ctc_data_permute.out_port(0))
        else:
            ctc_loss.in_port(0).connect(transpose_tf.out_port(0))
        ctc_loss.in_port(1).connect(
            ctc_greedy_decoder_tf.in_port(1).get_connection().get_source())
        ctc_loss.in_port(2).connect(ctc_greedy_decoder.out_port(0))
        ctc_loss.in_port(3).connect(ctc_greedy_decoder.out_port(1))

        # remove no longer needed nodes
        graph.remove_nodes_from([
            sparse_to_dense_tf.id, cast_tf.id, ctc_loss_tf.id,
            ctc_greedy_decoder_tf.id
        ])