コード例 #1
0
 def __init__(self, tfrecords_path: str, weights_path: str,
              config: GlobalConfig):
     self._log = LogFactory.get_logger()
     self._tfrecords_path = tfrecords_path
     self._weights_path = weights_path
     self._batch_size = config.get_test_config().batch_size
     self._merge_repeated = config.get_test_config().merge_repeated_chars
     self._gpu_config = config.get_gpu_config()
     self._decoder = TextFeatureIO().reader
     self._recognition_time = None
コード例 #2
0
ファイル: test.py プロジェクト: HubertLegec/CRNN_Tensorflow
def test_crnn(dataset_dir: str, weights_path: str, config: GlobalConfig):
    log = LogFactory.get_logger()
    is_recursive = config.get_test_config().is_recursive
    tfrecords_path = ops.join(dataset_dir, 'test_feature.tfrecords')
    if is_recursive:
        tester = RecursiveCrnnTester(tfrecords_path, weights_path, config)
    else:
        tester = BasicCrnnTester(tfrecords_path, weights_path, config)
    accuracy, distance, avg_time = tester.run()
    log.info(
        '\n* Mean test accuracy is {:.3f}\n* Mean Levenshtein edit distance is {:.3f}\n* Mean detection time for batch is {:.3f} s'
        .format(accuracy, distance, avg_time))
コード例 #3
0
 def __init__(self, tfrecords_path: str, weights_path: str,
              config: GlobalConfig):
     super().__init__(tfrecords_path, weights_path, config)
     self._show_plot = config.get_test_config().show_plot()