def replace_sub_graph(self, graph: Graph, match: dict):
        transpose_tf = match['transpose']
        ctc_greedy_decoder_tf = match['ctc_greedy_decoder']
        cast_tf = match['cast']
        ctc_loss_tf = match['ctc_loss']
        sparse_to_dense_tf = match['sparse_to_dense']
        output_sparse_to_dense_name = sparse_to_dense_tf.soft_get('name', sparse_to_dense_tf.id)
        ctc_data_permute = create_op_with_const_inputs(graph, Transpose, {1: int64_array([1, 0, 2])},
                                                       {'name': ctc_greedy_decoder_tf.name + '/ctc_data_permute'})
        ctc_data_permute.in_port(0).connect(transpose_tf.out_port(0))

        ctc_greedy_decoder_tf_name = ctc_greedy_decoder_tf.soft_get('name', ctc_greedy_decoder_tf.id)
        assert ctc_greedy_decoder_tf.has_valid('merge_repeated'), \
            'The CTCGreedyDecoderSeqLen node "{}" misses "merge_repeated" attribute'.format(ctc_greedy_decoder_tf_name)
        merge_repeated_tf = ctc_greedy_decoder_tf.merge_repeated
        ctc_greedy_decoder = CTCGreedyDecoderSeqLenOp(graph, {'name': output_sparse_to_dense_name,
                                                              'merge_repeated': merge_repeated_tf}).create_node()
        rename_nodes([(sparse_to_dense_tf, output_sparse_to_dense_name + '/AbandonedName'),
                      (ctc_greedy_decoder, output_sparse_to_dense_name)])
        ctc_greedy_decoder.in_port(0).connect(ctc_data_permute.out_port(0))
        ctc_greedy_decoder.in_port(1).connect(ctc_greedy_decoder_tf.in_port(1).get_connection().get_source())

        # set output of the new sub-graph as a source for SparseToDense consumer
        output_ctc_loss_name = ctc_loss_tf.soft_get('name', ctc_loss_tf.id)
        assert ctc_loss_tf.has_valid('preprocess_collapse_repeated'), \
            'The CTCLoss node "{}" misses "preprocess_collapse_repeated" attribute'.format(output_ctc_loss_name)
        assert ctc_loss_tf.has_valid('ctc_merge_repeated'), \
            'The CTCLoss node "{}" misses "ctc_merge_repeated" attribute'.format(output_ctc_loss_name)
        assert ctc_loss_tf.has_valid('unique'), \
            'The CTCLoss node "{}" misses "unique" attribute'.format(output_ctc_loss_name)
        preprocess_collapse_repeated = ctc_loss_tf.preprocess_collapse_repeated
        ctc_merge_repeated = ctc_loss_tf.ctc_merge_repeated
        unique = ctc_loss_tf.unique
        ctc_loss = CTCLoss(graph, {'name': output_ctc_loss_name,
                                   'preprocess_collapse_repeated': preprocess_collapse_repeated,
                                   'ctc_merge_repeated': ctc_merge_repeated,
                                   'unique': unique}).create_node()
        rename_nodes([(ctc_loss_tf, output_ctc_loss_name + '/AbandonedName'), (ctc_loss, output_ctc_loss_name)])
        ctc_loss_tf.out_port(0).get_connection().set_source(ctc_loss.out_port(0))
        if ctc_loss_tf.logits_time_major:
            ctc_loss.in_port(0).connect(ctc_data_permute.out_port(0))
        else:
            ctc_loss.in_port(0).connect(transpose_tf.out_port(0))
        ctc_loss.in_port(1).connect(ctc_greedy_decoder_tf.in_port(1).get_connection().get_source())
        ctc_loss.in_port(2).connect(ctc_greedy_decoder.out_port(0))
        ctc_loss.in_port(3).connect(ctc_greedy_decoder.out_port(1))

        # remove no longer needed nodes
        graph.remove_nodes_from([sparse_to_dense_tf.id, cast_tf.id, ctc_loss_tf.id, ctc_greedy_decoder_tf.id])
Exemplo n.º 2
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        seq_len_tf = match['seq_len']
        transpose_tf = match['transpose']
        ctc_greedy_decoder_tf = match['ctc_greedy_decoder']
        cast_tf = match['cast']
        ctc_loss_tf = match['ctc_loss']
        sparse_to_dense_tf = match['sparse_to_dense']

        output_sparse_to_dense_name = sparse_to_dense_tf.soft_get(
            'name', sparse_to_dense_tf.id)
        output_ctc_loss_name = ctc_loss_tf.soft_get('name', ctc_loss_tf.id)
        ctc_greedy_decoder_tf_name = ctc_greedy_decoder_tf.soft_get(
            'name', ctc_greedy_decoder_tf.id)

        log.debug(
            'Found CTCLossFrontReplacer pattern after {} with name {}'.format(
                ctc_greedy_decoder_tf.op, ctc_greedy_decoder_tf.name))

        # create sequence mask node, sub-graph for transforming into sequence length and connect with consumers
        seq_len_tf_shape = seq_len_tf.soft_get('shape', None)
        if seq_len_tf_shape is None or len(seq_len_tf_shape) != 2:
            raise Error(
                'The sequence length that is the second input to the CTCGreedyDecoder node "{}"'
                ' must be specified in a mask format.'.format(
                    ctc_greedy_decoder_tf_name))
        log.error(
            'The format of input sequence length has been changed to a mask format',
            extra={'is_warning': True})
        seq_len_tf_type = seq_len_tf.soft_get('data_type', None)
        seq_len_tf_name = seq_len_tf.soft_get('name', seq_len_tf.id)
        seq_mask_placeholder = Parameter(
            graph, {
                'name': seq_len_tf_name,
                'shape': seq_len_tf_shape,
                'data_type': seq_len_tf_type
            }).create_node()
        reduce_to_seq_len_node = create_op_with_const_inputs(
            graph, ReduceSum, {1: np.array(1, dtype=np.int32)}, {
                'name': seq_len_tf_name + '/ReduceToSeqLen',
                'keep_dims': False
            })
        reduce_to_seq_len_node.in_port(0).connect(
            seq_mask_placeholder.out_port(0))
        seq_len_tf.out_port(0).get_connection().set_source(
            reduce_to_seq_len_node.out_port(0))

        cast_fp_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)
        casted_seq_mask_node = Cast(graph, {
            'name': seq_len_tf_name + '/CastToFP32',
            'dst_type': cast_fp_type
        }).create_node()
        casted_seq_mask_node.in_port(0).connect(
            seq_mask_placeholder.out_port(0))
        permuted_casted_seq_mask = create_op_with_const_inputs(
            graph, Transpose, {1: int64_array([1, 0])},
            {'name': seq_len_tf_name + '/Permute'})
        permuted_casted_seq_mask.in_port(0).connect(
            casted_seq_mask_node.out_port(0))
        rename_nodes([(seq_len_tf, seq_len_tf_name + '/AbandonedName'),
                      (seq_mask_placeholder, seq_len_tf_name)])

        # create CTCGreedyDecoder node and set mask node
        ctc_merge_repeated_i = ctc_greedy_decoder_tf.soft_get(
            'ctc_merge_repeated', ctc_greedy_decoder_tf.id)
        ctc_greedy_decoder = CTCGreedyDecoderOp(
            graph, {
                'name': output_sparse_to_dense_name,
                'ctc_merge_repeated': ctc_merge_repeated_i
            }).create_node()
        ctc_greedy_decoder.in_port(1).connect(
            permuted_casted_seq_mask.out_port(0))
        rename_nodes([(sparse_to_dense_tf,
                       output_sparse_to_dense_name + '/AbandonedName'),
                      (ctc_greedy_decoder, output_sparse_to_dense_name)])

        # create CTCLoss node and set attributes
        assert ctc_loss_tf.has_valid('preprocess_collapse_repeated'), \
            'The CTCLoss node "{}" misses "preprocess_collapse_repeated" attribute'.format(output_ctc_loss_name)
        assert ctc_loss_tf.has_valid('ctc_merge_repeated'), \
            'The CTCLoss node "{}" misses "ctc_merge_repeated" attribute'.format(output_ctc_loss_name)
        assert ctc_loss_tf.has_valid('unique'), \
            'The CTCLoss node "{}" misses "unique" attribute'.format(output_ctc_loss_name)
        preprocess_collapse_repeated = ctc_loss_tf.preprocess_collapse_repeated
        ctc_merge_repeated = ctc_loss_tf.ctc_merge_repeated
        unique = ctc_loss_tf.unique
        ctc_loss = CTCLoss(
            graph, {
                'name': output_ctc_loss_name,
                'preprocess_collapse_repeated': preprocess_collapse_repeated,
                'ctc_merge_repeated': ctc_merge_repeated,
                'unique': unique
            }).create_node()
        rename_nodes([(ctc_loss_tf, output_ctc_loss_name + '/AbandonedName'),
                      (ctc_loss, output_ctc_loss_name)])

        # connect logits
        ctc_greedy_decoder_tf.in_port(0).get_connection().set_destination(
            ctc_greedy_decoder.in_port(0))
        ctc_loss.in_port(0).disconnect()
        transpose_tf.in_port(0).get_connection().add_destination(
            ctc_loss.in_port(0))

        # connect logit lengths
        ctc_greedy_decoder_tf.in_port(1).disconnect()
        ctc_loss.in_port(1).connect(reduce_to_seq_len_node.out_port(0))

        # connect labels to ctc_loss
        squeeze_op = create_op_with_const_inputs(graph, Squeeze,
                                                 {1: int64_array([2, 3])})
        cast_labels_op = Cast(
            graph, {
                'name': output_sparse_to_dense_name + '/CastLabels',
                'dst_type': np.int32
            }).create_node()
        squeeze_op.in_port(0).connect(ctc_greedy_decoder.out_port(0))
        cast_labels_op.in_port(0).connect(squeeze_op.out_port(0))
        ctc_loss.in_port(2).connect(cast_labels_op.out_port(0))

        # connect label lengths
        equal_op = create_op_with_const_inputs(
            graph, Equal, {1: np.array([-1], dtype=np.int32)},
            {'name': output_sparse_to_dense_name + '/Equal'})
        equal_op.in_port(0).connect(cast_labels_op.out_port(0))
        labels_shape_op = Shape(
            graph, {
                'name': output_sparse_to_dense_name + '/ShapeOf'
            }).create_node()
        labels_shape_op.in_port(0).connect(equal_op.out_port(0))
        broadcast_one = create_op_with_const_inputs(
            graph, Broadcast, {0: np.array([1], dtype=np.int32)}, {
                'mode': 'numpy',
                'name': output_sparse_to_dense_name + '/One'
            })
        broadcast_one.in_port(1).connect(labels_shape_op.out_port(0))
        broadcast_zero = create_op_with_const_inputs(
            graph, Broadcast, {0: np.array([0], dtype=np.int32)}, {
                'mode': 'numpy',
                'name': output_sparse_to_dense_name + '/Zero'
            })
        broadcast_zero.in_port(1).connect(labels_shape_op.out_port(0))

        select_node = Select(graph, {
            'name': output_sparse_to_dense_name + '/Select'
        }).create_node()
        select_node.in_port(0).connect(equal_op.out_port(0))
        select_node.in_port(1).connect(broadcast_zero.out_port(0))
        select_node.in_port(2).connect(broadcast_one.out_port(0))
        label_length_node = create_op_with_const_inputs(
            graph,
            ReduceSum, {1: int64_array([1])},
            op_attrs={
                'name': output_sparse_to_dense_name + '/LabelLength',
                'keep_dims': False
            })
        label_length_node.in_port(0).connect(select_node.out_port(0))
        ctc_loss.in_port(3).connect(label_length_node.out_port(0))

        # set source for output of new sub-graph and remove old nodes
        ctc_loss_tf.out_port(0).get_connection().set_source(
            ctc_loss.out_port(0))
        graph.remove_nodes_from([
            ctc_greedy_decoder_tf.id, ctc_loss_tf.id, cast_tf.id,
            sparse_to_dense_tf.id
        ])