Example #1
0
    def test1(self):
        nodes_attributes = {
            'logits': {
                'shape': int64_array([2, 6, 100]),
                'type': 'Parameter',
                'kind': 'op',
                'op': 'Parameter'
            },
            'seq_mask': {
                'shape': int64_array([2, 100]),
                'data_type': np.int32,
                'kind': 'op',
                'op': 'Parameter'
            },
            'reduce_seq_mask': {
                'kind': 'op',
                'op': 'ReduceSum'
            },
            's_cast_seq_mask': {
                'kind': 'op',
                'op': 'Cast'
            },
            'transpose_cast_seq_mask': {
                'kind': 'op',
                'op': 'Transpose'
            },
            'transpose': {
                'kind': 'op',
                'op': 'Transpose'
            },
            'ctc_greedy_decoder': {
                'kind': 'op',
                'op': 'CTCGreedyDecoder'
            },
            'cast': {
                'kind': 'op',
                'op': 'Cast'
            },
            'sparse_to_dense': {
                'kind': 'op',
                'op': 'SparseToDense'
            },
            'const': {
                'kind': 'op',
                'op': 'Const'
            },
            'ctc_loss': {
                'kind': 'op',
                'op': 'CTCLoss',
                'preprocess_collapse_repeated': False,
                'ctc_merge_repeated': True,
                'unique': False
            },
            'equal_op': {
                'kind': 'op',
                'op': 'Equal'
            },
            'ctc_greedy_decoder_op': {
                'kind': 'op',
                'op': 'CTCGreedyDecoder'
            },
            'ctc_loss_op': {
                'kind': 'op',
                'op': 'CTCLoss'
            },
            'squeeze_op': {
                'kind': 'op',
                'op': 'Squeeze'
            },
            'cast_labels_op': {
                'kind': 'op',
                'op': 'Cast',
                'type': 'Convert'
            },
            'labels_shape_op': {
                'kind': 'op',
                'op': 'ShapeOf'
            },
            'broadcast_one_op': {
                'kind': 'op',
                'op': 'Broadcast'
            },
            'broadcast_zero_op': {
                'kind': 'op',
                'op': 'Broadcast'
            },
            'select_op': {
                'kind': 'op',
                'op': 'Select'
            },
            'label_length_op': {
                'kind': 'op',
                'op': 'ReduceSum'
            },
            **const('reduce_indices', int64_array(1)),
            **const('permute_order', int64_array([1, 0])),
            **const('default_value', int64_array(-1)),
            **const('squeeze_axis', int64_array([2, 3])),
            **const('minus_one', np.array([-1], dtype=np.int32)),
            **const('one', np.array([1], dtype=np.int32)),
            **const('zero', np.array([0], dtype=np.int32)),
            **const('reduce_sum_axis', int64_array([1])),
            'last': {
                'type': None,
                'value': None,
                'kind': 'op',
                'op': 'Result'
            },
        }

        graph = build_graph(nodes_attributes, [
            ('logits', 'transpose', {
                'out': 0,
                'in': 0
            }),
            ('transpose', 'ctc_greedy_decoder', {
                'out': 0,
                'in': 0
            }),
            ('seq_mask', 'ctc_greedy_decoder', {
                'out': 0,
                'in': 1
            }),
            ('transpose', 'ctc_loss', {
                'out': 0,
                'in': 0
            }),
            ('seq_mask', 'ctc_loss', {
                'out': 0,
                'in': 3
            }),
            ('ctc_greedy_decoder', 'sparse_to_dense', {
                'out': 0,
                'in': 0
            }),
            ('ctc_greedy_decoder', 'sparse_to_dense', {
                'out': 2,
                'in': 1
            }),
            ('ctc_greedy_decoder', 'sparse_to_dense', {
                'out': 1,
                'in': 2
            }),
            ('default_value', 'sparse_to_dense', {
                'out': 0,
                'in': 3
            }),
            ('ctc_greedy_decoder', 'cast', {
                'out': 1,
                'in': 0
            }),
            ('ctc_greedy_decoder', 'ctc_loss', {
                'out': 0,
                'in': 1
            }),
            ('cast', 'ctc_loss', {
                'out': 0,
                'in': 2
            }),
            ('ctc_loss', 'last', {
                'out': 0,
                'in': 0
            }),
        ],
                            nodes_with_edges_only=True)
        graph.graph['cmd_params'] = Namespace(data_type='FP32')
        graph.stage = 'front'
        CTCLossReplacement().find_and_replace_pattern(graph)

        graph_ref = build_graph(
            nodes_attributes,
            [('seq_mask', 'reduce_seq_mask', {
                'out': 0,
                'in': 0
            }), ('reduce_indices', 'reduce_seq_mask', {
                'out': 0,
                'in': 1
            }), ('seq_mask', 's_cast_seq_mask', {
                'out': 0,
                'in': 0
            }),
             ('s_cast_seq_mask', 'transpose_cast_seq_mask', {
                 'out': 0,
                 'in': 0
             }),
             ('permute_order', 'transpose_cast_seq_mask', {
                 'out': 0,
                 'in': 1
             }), ('logits', 'transpose', {
                 'out': 0,
                 'in': 0
             }), ('transpose', 'ctc_greedy_decoder_op', {
                 'out': 0,
                 'in': 0
             }),
             ('transpose_cast_seq_mask', 'ctc_greedy_decoder_op', {
                 'out': 0,
                 'in': 1
             }), ('ctc_greedy_decoder_op', 'squeeze_op', {
                 'out': 0,
                 'in': 0
             }), ('squeeze_axis', 'squeeze_op', {
                 'out': 0,
                 'in': 1
             }), ('squeeze_op', 'cast_labels_op', {
                 'in': 0
             }), ('minus_one', 'equal_op', {
                 'out': 0,
                 'in': 1
             }), ('equal_op', 'labels_shape_op', {
                 'out': 0,
                 'in': 0
             }), ('one', 'broadcast_one_op', {
                 'out': 0,
                 'in': 0
             }), ('labels_shape_op', 'broadcast_one_op', {
                 'out': 0,
                 'in': 1
             }), ('zero', 'broadcast_zero_op', {
                 'out': 0,
                 'in': 0
             }), ('labels_shape_op', 'broadcast_zero_op', {
                 'out': 0,
                 'in': 1
             }), ('equal_op', 'select_op', {
                 'out': 0,
                 'in': 0
             }), ('broadcast_zero_op', 'select_op', {
                 'out': 0,
                 'in': 1
             }), ('broadcast_one_op', 'select_op', {
                 'out': 0,
                 'in': 2
             }), ('select_op', 'label_length_op', {
                 'out': 0,
                 'in': 0
             }), ('reduce_sum_axis', 'label_length_op', {
                 'out': 0,
                 'in': 1
             }), ('logits', 'ctc_loss_op', {
                 'out': 0,
                 'in': 0
             }), ('reduce_seq_mask', 'ctc_loss_op', {
                 'out': 0,
                 'in': 1
             }), ('cast_labels_op', 'ctc_loss_op', {
                 'out': 0,
                 'in': 2
             }), ('label_length_op', 'ctc_loss_op', {
                 'out': 0,
                 'in': 3
             }), ('cast_labels_op', 'equal_op', {
                 'out': 0,
                 'in': 0
             }), ('ctc_loss_op', 'last', {
                 'out': 0,
                 'in': 0
             })],
            nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'last',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
Example #2
0
    def test1(self):
        nodes_attributes = {
            'logits': {
                'shape': int64_array([2, 6, 100]),
                'type': 'Parameter',
                'kind': 'op',
                'op': 'Parameter'
            },
            'seq_mask': {
                'shape': int64_array([2]),
                'data_type': np.int32,
                'kind': 'op',
                'op': 'Parameter'
            },
            'transpose': {
                'kind': 'op',
                'op': 'Transpose'
            },
            'ctc_greedy_decoder': {
                'kind': 'op',
                'op': 'CTCGreedyDecoderSeqLen',
                'merge_repeated': True
            },
            'cast': {
                'kind': 'op',
                'op': 'Cast'
            },
            'sparse_to_dense': {
                'kind': 'op',
                'op': 'SparseToDense'
            },
            'tf_ctc_loss': {
                'kind': 'op',
                'op': 'CTCLoss',
                'preprocess_collapse_repeated': False,
                'ctc_merge_repeated': True,
                'unique': False,
                'logits_time_major': True
            },
            'ctc_loss': {
                'kind': 'op',
                'op': 'CTCLoss',
                'preprocess_collapse_repeated': False,
                'ctc_merge_repeated': True,
                'unique': False
            },
            **const('default_value', int64_array(-1)),
            'last': {
                'type': None,
                'value': None,
                'kind': 'op',
                'op': 'Result'
            },
            'transpose2': {
                'kind': 'op',
                'op': 'Transpose'
            },
            **const('transpose2_axis', int64_array([1, 0, 2])),
        }
        graph = build_graph(nodes_attributes, [('logits', 'transpose', {
            'out': 0,
            'in': 0
        }), ('transpose', 'ctc_greedy_decoder', {
            'out': 0,
            'in': 0
        }), ('seq_mask', 'ctc_greedy_decoder', {
            'out': 0,
            'in': 1
        }), ('transpose', 'tf_ctc_loss', {
            'out': 0,
            'in': 0
        }), ('seq_mask', 'tf_ctc_loss', {
            'out': 0,
            'in': 3
        }), ('ctc_greedy_decoder', 'sparse_to_dense', {
            'out': 0,
            'in': 0
        }), ('ctc_greedy_decoder', 'sparse_to_dense', {
            'out': 2,
            'in': 1
        }), ('ctc_greedy_decoder', 'sparse_to_dense', {
            'out': 1,
            'in': 2
        }), ('default_value', 'sparse_to_dense', {
            'out': 0,
            'in': 3
        }), ('ctc_greedy_decoder', 'cast', {
            'out': 1,
            'in': 0
        }), ('ctc_greedy_decoder', 'tf_ctc_loss', {
            'out': 0,
            'in': 1
        }), ('cast', 'tf_ctc_loss', {
            'out': 0,
            'in': 2
        }), ('tf_ctc_loss', 'last', {
            'out': 0,
            'in': 0
        })],
                            nodes_with_edges_only=True)
        graph.graph['cmd_params'] = Namespace(data_type='FP32')
        graph.stage = 'front'
        CTCLossReplacement().find_and_replace_pattern(graph)

        graph_ref = build_graph(nodes_attributes, [('logits', 'transpose', {
            'out': 0,
            'in': 0
        }), ('transpose', 'transpose2', {
            'out': 0,
            'in': 0
        }), ('transpose2_axis', 'transpose2', {
            'out': 0,
            'in': 1
        }), ('transpose2', 'ctc_greedy_decoder', {
            'out': 0,
            'in': 0
        }), ('seq_mask', 'ctc_greedy_decoder', {
            'out': 0,
            'in': 1
        }), ('transpose2', 'ctc_loss', {
            'out': 0,
            'in': 0
        }), ('ctc_greedy_decoder', 'ctc_loss', {
            'out': 0,
            'in': 2
        }), ('ctc_greedy_decoder', 'ctc_loss', {
            'out': 1,
            'in': 3
        }), ('seq_mask', 'ctc_loss', {
            'out': 0,
            'in': 1
        }), ('ctc_loss', 'last', {
            'out': 0,
            'in': 0
        })],
                                nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'last',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
Example #3
0
    def CTCLossReplacement_test_true_logits(self):
        graph = build_graph(
            self.nodes_attributes,
            [('logits', 'transpose', {
                'out': 0,
                'in': 0
            }), ('transpose', 'ctc_greedy_decoder', {
                'out': 0,
                'in': 0
            }), ('seq_mask', 'ctc_greedy_decoder', {
                'out': 0,
                'in': 1
            }),
             ('transpose', 'tf_ctc_loss_true_logits', {
                 'out': 0,
                 'in': 0
             }), ('seq_mask', 'tf_ctc_loss_true_logits', {
                 'out': 0,
                 'in': 3
             }), ('ctc_greedy_decoder', 'sparse_to_dense', {
                 'out': 0,
                 'in': 0
             }), ('ctc_greedy_decoder', 'sparse_to_dense', {
                 'out': 2,
                 'in': 1
             }), ('ctc_greedy_decoder', 'sparse_to_dense', {
                 'out': 1,
                 'in': 2
             }), ('default_value', 'sparse_to_dense', {
                 'out': 0,
                 'in': 3
             }), ('ctc_greedy_decoder', 'cast', {
                 'out': 1,
                 'in': 0
             }),
             ('ctc_greedy_decoder', 'tf_ctc_loss_true_logits', {
                 'out': 0,
                 'in': 1
             }), ('cast', 'tf_ctc_loss_true_logits', {
                 'out': 0,
                 'in': 2
             }), ('tf_ctc_loss_true_logits', 'last', {
                 'out': 0,
                 'in': 0
             })],
            nodes_with_edges_only=True)
        graph.graph['cmd_params'] = Namespace(data_type='FP32')
        graph.stage = 'front'
        CTCLossReplacement().find_and_replace_pattern(graph)

        graph_ref = build_graph(self.nodes_attributes,
                                [('logits', 'transpose', {
                                    'out': 0,
                                    'in': 0
                                }),
                                 ('transpose', 'transpose2', {
                                     'out': 0,
                                     'in': 0
                                 }),
                                 ('transpose2_axis', 'transpose2', {
                                     'out': 0,
                                     'in': 1
                                 }),
                                 ('transpose2', 'new_ctc_greedy_decoder', {
                                     'out': 0,
                                     'in': 0
                                 }),
                                 ('seq_mask', 'new_ctc_greedy_decoder', {
                                     'out': 0,
                                     'in': 1
                                 }),
                                 ('transpose2', 'ctc_loss', {
                                     'out': 0,
                                     'in': 0
                                 }),
                                 ('new_ctc_greedy_decoder', 'ctc_loss', {
                                     'out': 0,
                                     'in': 2
                                 }),
                                 ('new_ctc_greedy_decoder', 'ctc_loss', {
                                     'out': 1,
                                     'in': 3
                                 }),
                                 ('seq_mask', 'ctc_loss', {
                                     'out': 0,
                                     'in': 1
                                 }), ('ctc_loss', 'last', {
                                     'out': 0,
                                     'in': 0
                                 })],
                                nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'last',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)