Beispiel #1
0
 def extract(cls, node):
     attrs = {
         'merge_repeated':
         bool(onnx_attr(node, 'merge_repeated', 'i', default=1)),
     }
     CTCGreedyDecoderSeqLenOp.update_node_stat(node, attrs)
     return cls.enabled
 def extract(cls, node):
     attrs = {
         'merge_repeated': bool(node.pb.attr['merge_repeated'].b),
         'output_sparse_format':
         True,  # Special argument for TF CTCGreedyDecoder replacement transformations
     }
     CTCGreedyDecoderSeqLenOp.update_node_stat(node, attrs)
     return cls.enabled
    def test_infer1(self):
        graph = build_graph(nodes_attributes, edges1, inputs1)
        ctcgreedydecoder_node = Node(graph, 'ctcgreedydecoder_node')
        CTCGreedyDecoderSeqLenOp.infer(ctcgreedydecoder_node)

        # prepare reference results
        ref_output1_shape = int64_array([4, 100])

        # get the result
        res_output1_shape = graph.node['output1']['shape']

        self.assertTrue(np.array_equal(ref_output1_shape, res_output1_shape),
                        'shapes do not match expected: {} and given: {}'.format(ref_output1_shape, res_output1_shape))
Beispiel #4
0
def replace_ctc_greedy_decoder(graph: Graph, match: dict):
    ctc_greedy_decoder_tf = match['decoder']
    cast = match['cast']
    sparse_to_dense = match['sparse_to_dense']
    sparse_to_dense_name = sparse_to_dense.soft_get('name', sparse_to_dense.id)
    ctc_greedy_decoder_tf_name = ctc_greedy_decoder_tf.soft_get('name', ctc_greedy_decoder_tf.id)

    # for normalizing input chanel need to transpose input data from [T, N, C] to [N, T, C]
    # which supported CTCGreedyDecoderSeqLen op.
    ctc_data_permute = create_op_with_const_inputs(graph, Transpose, {1: int64_array([1, 0, 2])},
                                                   {'name': ctc_greedy_decoder_tf_name + '/ctc_data_permute'})

    assert ctc_greedy_decoder_tf.has_valid('merge_repeated'), \
        'The CTCGreedyDecoderSeqLen node "{}" misses "merge_repeated" attribute'.format(ctc_greedy_decoder_tf_name)

    ctc_greedy_decoder_tf.in_port(0).get_source().connect(ctc_data_permute.in_port(0))
    merge_repeated_tf = ctc_greedy_decoder_tf.merge_repeated
    ctc_greedy_decoder = CTCGreedyDecoderSeqLenOp(graph, {'name': sparse_to_dense_name,
                                                          'merge_repeated': merge_repeated_tf}).create_node()
    rename_nodes(
        [(sparse_to_dense, sparse_to_dense_name + '/AbandonedName'), (ctc_greedy_decoder, sparse_to_dense_name)])
    ctc_greedy_decoder.in_port(0).connect(ctc_data_permute.out_port(0))
    ctc_greedy_decoder_tf.in_port(1).get_source().connect(ctc_greedy_decoder.in_port(1))

    # set output of the new sub-graph as a source for SparseToDense consumer
    sparse_to_dense.out_port(0).get_connection().set_source(ctc_greedy_decoder.out_port(0))

    # remove no longer needed nodes
    graph.remove_nodes_from([sparse_to_dense.id, cast.id, ctc_greedy_decoder_tf.id])
    def replace_sub_graph(self, graph: Graph, match: dict):
        transpose_tf = match['transpose']
        ctc_greedy_decoder_tf = match['ctc_greedy_decoder']
        cast_tf = match['cast']
        ctc_loss_tf = match['ctc_loss']
        sparse_to_dense_tf = match['sparse_to_dense']
        output_sparse_to_dense_name = sparse_to_dense_tf.soft_get('name', sparse_to_dense_tf.id)
        ctc_data_permute = create_op_with_const_inputs(graph, Transpose, {1: int64_array([1, 0, 2])},
                                                       {'name': ctc_greedy_decoder_tf.name + '/ctc_data_permute'})
        ctc_data_permute.in_port(0).connect(transpose_tf.out_port(0))

        ctc_greedy_decoder_tf_name = ctc_greedy_decoder_tf.soft_get('name', ctc_greedy_decoder_tf.id)
        assert ctc_greedy_decoder_tf.has_valid('merge_repeated'), \
            'The CTCGreedyDecoderSeqLen node "{}" misses "merge_repeated" attribute'.format(ctc_greedy_decoder_tf_name)
        merge_repeated_tf = ctc_greedy_decoder_tf.merge_repeated
        ctc_greedy_decoder = CTCGreedyDecoderSeqLenOp(graph, {'name': output_sparse_to_dense_name,
                                                              'merge_repeated': merge_repeated_tf}).create_node()
        rename_nodes([(sparse_to_dense_tf, output_sparse_to_dense_name + '/AbandonedName'),
                      (ctc_greedy_decoder, output_sparse_to_dense_name)])
        ctc_greedy_decoder.in_port(0).connect(ctc_data_permute.out_port(0))
        ctc_greedy_decoder.in_port(1).connect(ctc_greedy_decoder_tf.in_port(1).get_connection().get_source())

        # set output of the new sub-graph as a source for SparseToDense consumer
        output_ctc_loss_name = ctc_loss_tf.soft_get('name', ctc_loss_tf.id)
        assert ctc_loss_tf.has_valid('preprocess_collapse_repeated'), \
            'The CTCLoss node "{}" misses "preprocess_collapse_repeated" attribute'.format(output_ctc_loss_name)
        assert ctc_loss_tf.has_valid('ctc_merge_repeated'), \
            'The CTCLoss node "{}" misses "ctc_merge_repeated" attribute'.format(output_ctc_loss_name)
        assert ctc_loss_tf.has_valid('unique'), \
            'The CTCLoss node "{}" misses "unique" attribute'.format(output_ctc_loss_name)
        preprocess_collapse_repeated = ctc_loss_tf.preprocess_collapse_repeated
        ctc_merge_repeated = ctc_loss_tf.ctc_merge_repeated
        unique = ctc_loss_tf.unique
        ctc_loss = CTCLoss(graph, {'name': output_ctc_loss_name,
                                   'preprocess_collapse_repeated': preprocess_collapse_repeated,
                                   'ctc_merge_repeated': ctc_merge_repeated,
                                   'unique': unique}).create_node()
        rename_nodes([(ctc_loss_tf, output_ctc_loss_name + '/AbandonedName'), (ctc_loss, output_ctc_loss_name)])
        ctc_loss_tf.out_port(0).get_connection().set_source(ctc_loss.out_port(0))
        if ctc_loss_tf.logits_time_major:
            ctc_loss.in_port(0).connect(ctc_data_permute.out_port(0))
        else:
            ctc_loss.in_port(0).connect(transpose_tf.out_port(0))
        ctc_loss.in_port(1).connect(ctc_greedy_decoder_tf.in_port(1).get_connection().get_source())
        ctc_loss.in_port(2).connect(ctc_greedy_decoder.out_port(0))
        ctc_loss.in_port(3).connect(ctc_greedy_decoder.out_port(1))

        # remove no longer needed nodes
        graph.remove_nodes_from([sparse_to_dense_tf.id, cast_tf.id, ctc_loss_tf.id, ctc_greedy_decoder_tf.id])
 def extract(cls, node):
     attrs = {
         'merge_repeated': bool(node.pb.attr['merge_repeated'].b),
     }
     CTCGreedyDecoderSeqLenOp.update_node_stat(node, attrs)
     return cls.enabled