def extract(cls, node): attrs = { 'merge_repeated': bool(onnx_attr(node, 'merge_repeated', 'i', default=1)), } CTCGreedyDecoderSeqLenOp.update_node_stat(node, attrs) return cls.enabled
def extract(cls, node): attrs = { 'merge_repeated': bool(node.pb.attr['merge_repeated'].b), 'output_sparse_format': True, # Special argument for TF CTCGreedyDecoder replacement transformations } CTCGreedyDecoderSeqLenOp.update_node_stat(node, attrs) return cls.enabled
def test_infer1(self): graph = build_graph(nodes_attributes, edges1, inputs1) ctcgreedydecoder_node = Node(graph, 'ctcgreedydecoder_node') CTCGreedyDecoderSeqLenOp.infer(ctcgreedydecoder_node) # prepare reference results ref_output1_shape = int64_array([4, 100]) # get the result res_output1_shape = graph.node['output1']['shape'] self.assertTrue(np.array_equal(ref_output1_shape, res_output1_shape), 'shapes do not match expected: {} and given: {}'.format(ref_output1_shape, res_output1_shape))
def replace_ctc_greedy_decoder(graph: Graph, match: dict): ctc_greedy_decoder_tf = match['decoder'] cast = match['cast'] sparse_to_dense = match['sparse_to_dense'] sparse_to_dense_name = sparse_to_dense.soft_get('name', sparse_to_dense.id) ctc_greedy_decoder_tf_name = ctc_greedy_decoder_tf.soft_get('name', ctc_greedy_decoder_tf.id) # for normalizing input chanel need to transpose input data from [T, N, C] to [N, T, C] # which supported CTCGreedyDecoderSeqLen op. ctc_data_permute = create_op_with_const_inputs(graph, Transpose, {1: int64_array([1, 0, 2])}, {'name': ctc_greedy_decoder_tf_name + '/ctc_data_permute'}) assert ctc_greedy_decoder_tf.has_valid('merge_repeated'), \ 'The CTCGreedyDecoderSeqLen node "{}" misses "merge_repeated" attribute'.format(ctc_greedy_decoder_tf_name) ctc_greedy_decoder_tf.in_port(0).get_source().connect(ctc_data_permute.in_port(0)) merge_repeated_tf = ctc_greedy_decoder_tf.merge_repeated ctc_greedy_decoder = CTCGreedyDecoderSeqLenOp(graph, {'name': sparse_to_dense_name, 'merge_repeated': merge_repeated_tf}).create_node() rename_nodes( [(sparse_to_dense, sparse_to_dense_name + '/AbandonedName'), (ctc_greedy_decoder, sparse_to_dense_name)]) ctc_greedy_decoder.in_port(0).connect(ctc_data_permute.out_port(0)) ctc_greedy_decoder_tf.in_port(1).get_source().connect(ctc_greedy_decoder.in_port(1)) # set output of the new sub-graph as a source for SparseToDense consumer sparse_to_dense.out_port(0).get_connection().set_source(ctc_greedy_decoder.out_port(0)) # remove no longer needed nodes graph.remove_nodes_from([sparse_to_dense.id, cast.id, ctc_greedy_decoder_tf.id])
def replace_sub_graph(self, graph: Graph, match: dict): transpose_tf = match['transpose'] ctc_greedy_decoder_tf = match['ctc_greedy_decoder'] cast_tf = match['cast'] ctc_loss_tf = match['ctc_loss'] sparse_to_dense_tf = match['sparse_to_dense'] output_sparse_to_dense_name = sparse_to_dense_tf.soft_get('name', sparse_to_dense_tf.id) ctc_data_permute = create_op_with_const_inputs(graph, Transpose, {1: int64_array([1, 0, 2])}, {'name': ctc_greedy_decoder_tf.name + '/ctc_data_permute'}) ctc_data_permute.in_port(0).connect(transpose_tf.out_port(0)) ctc_greedy_decoder_tf_name = ctc_greedy_decoder_tf.soft_get('name', ctc_greedy_decoder_tf.id) assert ctc_greedy_decoder_tf.has_valid('merge_repeated'), \ 'The CTCGreedyDecoderSeqLen node "{}" misses "merge_repeated" attribute'.format(ctc_greedy_decoder_tf_name) merge_repeated_tf = ctc_greedy_decoder_tf.merge_repeated ctc_greedy_decoder = CTCGreedyDecoderSeqLenOp(graph, {'name': output_sparse_to_dense_name, 'merge_repeated': merge_repeated_tf}).create_node() rename_nodes([(sparse_to_dense_tf, output_sparse_to_dense_name + '/AbandonedName'), (ctc_greedy_decoder, output_sparse_to_dense_name)]) ctc_greedy_decoder.in_port(0).connect(ctc_data_permute.out_port(0)) ctc_greedy_decoder.in_port(1).connect(ctc_greedy_decoder_tf.in_port(1).get_connection().get_source()) # set output of the new sub-graph as a source for SparseToDense consumer output_ctc_loss_name = ctc_loss_tf.soft_get('name', ctc_loss_tf.id) assert ctc_loss_tf.has_valid('preprocess_collapse_repeated'), \ 'The CTCLoss node "{}" misses "preprocess_collapse_repeated" attribute'.format(output_ctc_loss_name) assert ctc_loss_tf.has_valid('ctc_merge_repeated'), \ 'The CTCLoss node "{}" misses "ctc_merge_repeated" attribute'.format(output_ctc_loss_name) assert ctc_loss_tf.has_valid('unique'), \ 'The CTCLoss node "{}" misses "unique" attribute'.format(output_ctc_loss_name) preprocess_collapse_repeated = ctc_loss_tf.preprocess_collapse_repeated ctc_merge_repeated = ctc_loss_tf.ctc_merge_repeated unique = ctc_loss_tf.unique ctc_loss = CTCLoss(graph, {'name': output_ctc_loss_name, 'preprocess_collapse_repeated': preprocess_collapse_repeated, 'ctc_merge_repeated': ctc_merge_repeated, 'unique': unique}).create_node() rename_nodes([(ctc_loss_tf, output_ctc_loss_name + '/AbandonedName'), (ctc_loss, output_ctc_loss_name)]) ctc_loss_tf.out_port(0).get_connection().set_source(ctc_loss.out_port(0)) if ctc_loss_tf.logits_time_major: ctc_loss.in_port(0).connect(ctc_data_permute.out_port(0)) else: ctc_loss.in_port(0).connect(transpose_tf.out_port(0)) ctc_loss.in_port(1).connect(ctc_greedy_decoder_tf.in_port(1).get_connection().get_source()) ctc_loss.in_port(2).connect(ctc_greedy_decoder.out_port(0)) ctc_loss.in_port(3).connect(ctc_greedy_decoder.out_port(1)) # remove no longer needed nodes graph.remove_nodes_from([sparse_to_dense_tf.id, cast_tf.id, ctc_loss_tf.id, ctc_greedy_decoder_tf.id])
def extract(cls, node): attrs = { 'merge_repeated': bool(node.pb.attr['merge_repeated'].b), } CTCGreedyDecoderSeqLenOp.update_node_stat(node, attrs) return cls.enabled