Beispiel #1
0
    def test_tf_concat_infer(self):
        graph = build_graph(
            nodes_attributes, [('node_1', 'ctc'), ('node_2', 'ctc'),
                               ('ctc', 'node_3')], {
                                   'node_3': {
                                       'is_output': True,
                                       'shape': None
                                   },
                                   'node_1': {
                                       'shape': np.array([88, 2, 71])
                                   },
                                   'node_2': {
                                       'shape': np.array([88, 2])
                                   },
                                   'ctc': {
                                       'ctc_merge_repeated': 1
                                   }
                               })

        ctc_node = Node(graph, 'ctc')
        CTCGreedyDecoderOp.ctc_greedy_decoder_infer(ctc_node)
        exp_shape = np.array([2, 88, 1, 1])
        res_shape = graph.node['node_3']['shape']
        for i in range(0, len(exp_shape)):
            self.assertEqual(exp_shape[i], res_shape[i])
Beispiel #2
0
    def extract(cls, node):
        proto_layer = node.pb
        param = proto_layer.ctc_decoder_param

        update_attrs = {'ctc_merge_repeated': (int)(param.ctc_merge_repeated)}

        mapping_rule = merge_attrs(param, update_attrs)

        mapping_rule.update(layout_attrs())

        # update the attributes of the node
        CTCGreedyDecoderOp.update_node_stat(node, mapping_rule)
        return cls.enabled
    def test_infer1(self):
        graph = build_graph(nodes_attributes, edges1, inputs1)
        ctcgreedydecoder_node = Node(graph, 'ctcgreedydecoder_node')
        CTCGreedyDecoderOp.infer(ctcgreedydecoder_node)

        # prepare reference results
        ref_output_shape = int64_array([4, 100, 1, 1])

        # get the result
        res_output_shape = graph.node['output']['shape']

        self.assertTrue(
            np.array_equal(ref_output_shape, res_output_shape),
            'shapes do not match expected: {} and given: {}'.format(
                ref_output_shape, res_output_shape))
Beispiel #4
0
 def extract(node):
     attrs = {
         'ctc_merge_repeated': int(node.pb.attr['merge_repeated'].b),
     }
     CTCGreedyDecoderOp.update_node_stat(node, attrs)
     return __class__.enabled
Beispiel #5
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        seq_len_tf = match['seq_len']
        transpose_tf = match['transpose']
        ctc_greedy_decoder_tf = match['ctc_greedy_decoder']
        cast_tf = match['cast']
        ctc_loss_tf = match['ctc_loss']
        sparse_to_dense_tf = match['sparse_to_dense']

        output_sparse_to_dense_name = sparse_to_dense_tf.soft_get(
            'name', sparse_to_dense_tf.id)
        output_ctc_loss_name = ctc_loss_tf.soft_get('name', ctc_loss_tf.id)
        ctc_greedy_decoder_tf_name = ctc_greedy_decoder_tf.soft_get(
            'name', ctc_greedy_decoder_tf.id)

        log.debug(
            'Found CTCLossFrontReplacer pattern after {} with name {}'.format(
                ctc_greedy_decoder_tf.op, ctc_greedy_decoder_tf.name))

        # create sequence mask node, sub-graph for transforming into sequence length and connect with consumers
        seq_len_tf_shape = seq_len_tf.soft_get('shape', None)
        if seq_len_tf_shape is None or len(seq_len_tf_shape) != 2:
            raise Error(
                'The sequence length that is the second input to the CTCGreedyDecoder node "{}"'
                ' must be specified in a mask format.'.format(
                    ctc_greedy_decoder_tf_name))
        log.error(
            'The format of input sequence length has been changed to a mask format',
            extra={'is_warning': True})
        seq_len_tf_type = seq_len_tf.soft_get('data_type', None)
        seq_len_tf_name = seq_len_tf.soft_get('name', seq_len_tf.id)
        seq_mask_placeholder = Parameter(
            graph, {
                'name': seq_len_tf_name,
                'shape': seq_len_tf_shape,
                'data_type': seq_len_tf_type
            }).create_node()
        reduce_to_seq_len_node = create_op_with_const_inputs(
            graph, ReduceSum, {1: np.array(1, dtype=np.int32)}, {
                'name': seq_len_tf_name + '/ReduceToSeqLen',
                'keep_dims': False
            })
        reduce_to_seq_len_node.in_port(0).connect(
            seq_mask_placeholder.out_port(0))
        seq_len_tf.out_port(0).get_connection().set_source(
            reduce_to_seq_len_node.out_port(0))

        cast_fp_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)
        casted_seq_mask_node = Cast(graph, {
            'name': seq_len_tf_name + '/CastToFP32',
            'dst_type': cast_fp_type
        }).create_node()
        casted_seq_mask_node.in_port(0).connect(
            seq_mask_placeholder.out_port(0))
        permuted_casted_seq_mask = create_op_with_const_inputs(
            graph, Transpose, {1: int64_array([1, 0])},
            {'name': seq_len_tf_name + '/Permute'})
        permuted_casted_seq_mask.in_port(0).connect(
            casted_seq_mask_node.out_port(0))
        rename_nodes([(seq_len_tf, seq_len_tf_name + '/AbandonedName'),
                      (seq_mask_placeholder, seq_len_tf_name)])

        # create CTCGreedyDecoder node and set mask node
        ctc_merge_repeated_i = ctc_greedy_decoder_tf.soft_get(
            'ctc_merge_repeated', ctc_greedy_decoder_tf.id)
        ctc_greedy_decoder = CTCGreedyDecoderOp(
            graph, {
                'name': output_sparse_to_dense_name,
                'ctc_merge_repeated': ctc_merge_repeated_i
            }).create_node()
        ctc_greedy_decoder.in_port(1).connect(
            permuted_casted_seq_mask.out_port(0))
        rename_nodes([(sparse_to_dense_tf,
                       output_sparse_to_dense_name + '/AbandonedName'),
                      (ctc_greedy_decoder, output_sparse_to_dense_name)])

        # create CTCLoss node and set attributes
        assert ctc_loss_tf.has_valid('preprocess_collapse_repeated'), \
            'The CTCLoss node "{}" misses "preprocess_collapse_repeated" attribute'.format(output_ctc_loss_name)
        assert ctc_loss_tf.has_valid('ctc_merge_repeated'), \
            'The CTCLoss node "{}" misses "ctc_merge_repeated" attribute'.format(output_ctc_loss_name)
        assert ctc_loss_tf.has_valid('unique'), \
            'The CTCLoss node "{}" misses "unique" attribute'.format(output_ctc_loss_name)
        preprocess_collapse_repeated = ctc_loss_tf.preprocess_collapse_repeated
        ctc_merge_repeated = ctc_loss_tf.ctc_merge_repeated
        unique = ctc_loss_tf.unique
        ctc_loss = CTCLoss(
            graph, {
                'name': output_ctc_loss_name,
                'preprocess_collapse_repeated': preprocess_collapse_repeated,
                'ctc_merge_repeated': ctc_merge_repeated,
                'unique': unique
            }).create_node()
        rename_nodes([(ctc_loss_tf, output_ctc_loss_name + '/AbandonedName'),
                      (ctc_loss, output_ctc_loss_name)])

        # connect logits
        ctc_greedy_decoder_tf.in_port(0).get_connection().set_destination(
            ctc_greedy_decoder.in_port(0))
        ctc_loss.in_port(0).disconnect()
        transpose_tf.in_port(0).get_connection().add_destination(
            ctc_loss.in_port(0))

        # connect logit lengths
        ctc_greedy_decoder_tf.in_port(1).disconnect()
        ctc_loss.in_port(1).connect(reduce_to_seq_len_node.out_port(0))

        # connect labels to ctc_loss
        squeeze_op = create_op_with_const_inputs(graph, Squeeze,
                                                 {1: int64_array([2, 3])})
        cast_labels_op = Cast(
            graph, {
                'name': output_sparse_to_dense_name + '/CastLabels',
                'dst_type': np.int32
            }).create_node()
        squeeze_op.in_port(0).connect(ctc_greedy_decoder.out_port(0))
        cast_labels_op.in_port(0).connect(squeeze_op.out_port(0))
        ctc_loss.in_port(2).connect(cast_labels_op.out_port(0))

        # connect label lengths
        equal_op = create_op_with_const_inputs(
            graph, Equal, {1: np.array([-1], dtype=np.int32)},
            {'name': output_sparse_to_dense_name + '/Equal'})
        equal_op.in_port(0).connect(cast_labels_op.out_port(0))
        labels_shape_op = Shape(
            graph, {
                'name': output_sparse_to_dense_name + '/ShapeOf'
            }).create_node()
        labels_shape_op.in_port(0).connect(equal_op.out_port(0))
        broadcast_one = create_op_with_const_inputs(
            graph, Broadcast, {0: np.array([1], dtype=np.int32)}, {
                'mode': 'numpy',
                'name': output_sparse_to_dense_name + '/One'
            })
        broadcast_one.in_port(1).connect(labels_shape_op.out_port(0))
        broadcast_zero = create_op_with_const_inputs(
            graph, Broadcast, {0: np.array([0], dtype=np.int32)}, {
                'mode': 'numpy',
                'name': output_sparse_to_dense_name + '/Zero'
            })
        broadcast_zero.in_port(1).connect(labels_shape_op.out_port(0))

        select_node = Select(graph, {
            'name': output_sparse_to_dense_name + '/Select'
        }).create_node()
        select_node.in_port(0).connect(equal_op.out_port(0))
        select_node.in_port(1).connect(broadcast_zero.out_port(0))
        select_node.in_port(2).connect(broadcast_one.out_port(0))
        label_length_node = create_op_with_const_inputs(
            graph,
            ReduceSum, {1: int64_array([1])},
            op_attrs={
                'name': output_sparse_to_dense_name + '/LabelLength',
                'keep_dims': False
            })
        label_length_node.in_port(0).connect(select_node.out_port(0))
        ctc_loss.in_port(3).connect(label_length_node.out_port(0))

        # set source for output of new sub-graph and remove old nodes
        ctc_loss_tf.out_port(0).get_connection().set_source(
            ctc_loss.out_port(0))
        graph.remove_nodes_from([
            ctc_greedy_decoder_tf.id, ctc_loss_tf.id, cast_tf.id,
            sparse_to_dense_tf.id
        ])