def setUpClass(cls): export_config = get_config(config_file, section='export') export_config['dataset'] = get_config(config_file, section='eval')['dataset'] cls.config = export_config cls.config.update({'expected_outputs': expected_outputs}) cls.model_path = os.path.join(mkdtemp(), os.path.split(cls.config.get('model_path'))[1]) cls.res_model_name = os.path.join(os.path.dirname(cls.model_path), cls.config.get('res_model_name')) cls.config['res_model_name'] = cls.res_model_name cls.config['model_path'] = cls.model_path if not os.path.exists(cls.model_path): download_checkpoint(cls.model_path, cls.config.get('model_url')) cls.exporter = Exporter(cls.config)
import argparse from text_recognition.utils.get_config import get_config from text_recognition.utils.exporter import Exporter def parse_args(): args = argparse.ArgumentParser() args.add_argument('--config') return args.parse_args() if __name__ == '__main__': arguments = parse_args() export_config = get_config(arguments.config, section='export') head_type = export_config.get('head').get('type') exporter = Exporter(export_config) if head_type == 'AttentionBasedLSTM': exporter.export_encoder() exporter.export_decoder() elif head_type == 'LSTMEncoderDecoder': exporter.export_complete_model() print('Model succesfully exported to ONNX') if export_config.get('export_ir'): if head_type == 'AttentionBasedLSTM': exporter.export_encoder_ir() exporter.export_decoder_ir() elif head_type == 'LSTMEncoderDecoder': exporter.export_complete_model_ir() print('Model succesfully exported to OpenVINO IR')