Exemplo n.º 1
0
 def setUpClass(cls):
     export_config = get_config(config_file, section='export')
     export_config['dataset'] = get_config(config_file, section='eval')['dataset']
     cls.config = export_config
     cls.config.update({'expected_outputs': expected_outputs})
     cls.model_path = os.path.join(mkdtemp(), os.path.split(cls.config.get('model_path'))[1])
     cls.res_model_name = os.path.join(os.path.dirname(cls.model_path), cls.config.get('res_model_name'))
     cls.config['res_model_name'] = cls.res_model_name
     cls.config['model_path'] = cls.model_path
     if not os.path.exists(cls.model_path):
         download_checkpoint(cls.model_path, cls.config.get('model_url'))
     cls.exporter = Exporter(cls.config)
Exemplo n.º 2
0
 def setUpClass(cls):
     train_config = get_config(config_file, section='train')
     cls.config = train_config
     cls.config['epochs'] = 1
     cls.config['_test_steps'] = 40
     cls.work_dir = mkdtemp()
     cls.trainer = Trainer(work_dir=cls.work_dir, config=cls.config)
Exemplo n.º 3
0
 def setUpClass(cls):
     test_config = get_config(config_file, section='eval')
     cls.config = test_config
     cls.config.update({'expected_outputs': expected_outputs})
     if not os.path.exists(cls.config.get("model_path")):
         download_checkpoint(cls.config.get("model_path"),
                             cls.config.get("model_url"))
     cls.validator = Evaluator(config=cls.config)
Exemplo n.º 4
0
 def setUpClass(cls):
     train_config = get_config(config_file, section='train')
     cls.config = train_config
     cls.config['epochs'] = 1
     # workaround for training test without downloading language model (~4 Gb)
     if cls.config['head'].get('use_semantics'):
         cls.config['head']['use_semantics'] = False
     cls.config['_test_steps'] = 40
     cls.work_dir = mkdtemp()
     cls.trainer = Trainer(work_dir=cls.work_dir, config=cls.config)
 def setUpClass(cls):
     train_config = get_config(config_file, section='train')
     cls.config = train_config
     cls.config['epochs'] = 1
     # workaround for training test without downloading language model (~4 Gb)
     if cls.config['head'].get('use_semantics'):
         cls.config['head']['use_semantics'] = False
     # workaround for training test without running it via `python -m torch.distributed.launch`
     if cls.config.get('multi_gpu'):
         cls.config['multi_gpu'] = False
     cls.config['_test_steps'] = 40
     cls.config['batch_size'] = 2  # only for this test
     cls.work_dir = mkdtemp()
     cls.trainer = Trainer(work_dir=cls.work_dir, config=cls.config)
Exemplo n.º 6
0
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

import argparse
import sys

from text_recognition.utils.get_config import get_config
from text_recognition.utils.trainer import Trainer


def parse_args():
    args = argparse.ArgumentParser()
    args.add_argument('--config')
    args.add_argument('--work_dir')
    return args.parse_args()


if __name__ == '__main__':
    assert sys.version_info[0] == 3
    arguments = parse_args()

    train_config = get_config(arguments.config, section='train')

    experiment = Trainer(work_dir=arguments.work_dir, config=train_config)
    experiment.train()
Exemplo n.º 7
0
            pred = torch.nn.functional.log_softmax(logits.detach(), dim=2)
            pred = ctc_greedy_search(pred, 0)

        return self.vocab.construct_phrase(pred[0], ignore_end_token=self.use_ctc)


def parse_args():
    args = argparse.ArgumentParser()
    args.add_argument('--config')
    args.add_argument('-i', '--input', help='Path to a folder with images or path to an image files', required=True)
    return args.parse_args()


if __name__ == '__main__':
    arguments = parse_args()
    demo_config = get_config(arguments.config, section='demo')
    demo = TextRecognitionDemo(demo_config)
    try:
        check_environment()
    except EnvironmentError:
        print('Warning: cannot render image because some render tools are not installed')
        print('Check that pdflatex, ghostscript and imagemagick are installed')
        print('For details, please, refer to README.md')
        render_images = False
    else:
        render_images = True
    if os.path.isdir(arguments.input):
        inputs = sorted(os.path.join(arguments.input, inp)
                        for inp in os.listdir(arguments.input))
    else:
        inputs = [arguments.input]
Exemplo n.º 8
0
import argparse

from text_recognition.utils.get_config import get_config
from text_recognition.utils.exporter import Exporter


def parse_args():
    args = argparse.ArgumentParser()
    args.add_argument('--config')
    return args.parse_args()


if __name__ == '__main__':
    arguments = parse_args()
    export_config = get_config(arguments.config, section='export')
    head_type = export_config.get('head').get('type')
    exporter = Exporter(export_config)
    if head_type == 'AttentionBasedLSTM':
        exporter.export_encoder()
        exporter.export_decoder()
    elif head_type == 'LSTMEncoderDecoder':
        exporter.export_complete_model()
    print('Model succesfully exported to ONNX')
    if export_config.get('export_ir'):
        if head_type == 'AttentionBasedLSTM':
            exporter.export_encoder_ir()
            exporter.export_decoder_ir()
        elif head_type == 'LSTMEncoderDecoder':
            exporter.export_complete_model_ir()
        print('Model succesfully exported to OpenVINO IR')
Exemplo n.º 9
0
import argparse
import os

from text_recognition.utils.get_config import get_config
from text_recognition.utils.evaluator import Evaluator


def parse_args():
    args = argparse.ArgumentParser()
    args.add_argument('--config')
    return args.parse_args()


if __name__ == '__main__':
    arguments = parse_args()
    test_config = get_config(arguments.config, section='eval')
    validator = Evaluator(test_config)
    if 'model_folder' in test_config.keys():
        model_folder = test_config.get('model_folder')
        best_model, best_result = None, 0
        for model in os.listdir(model_folder):
            validator.runner.reload_model(os.path.join(model_folder, model))
            result = validator.validate()
            if result > best_result:
                best_result = result
                best_model = os.path.join(model_folder, model)
        print('model = {}'.format(best_model))
        result = best_result
    else:
        result = validator.validate()
    print('Result metric is: {}'.format(result))