Exemplo n.º 1
0
 def __init__(self, config: AppConfig):
     self.config = config
     self.download_manager = DownloadManager(
         download_dir=get_download_dir(config))
     self.pdfalto_wrapper = PdfAltoWrapper(
         self.download_manager.download_if_url(config['pdfalto']['path']))
     self.pdfalto_wrapper.ensure_executable()
     self.app_context = AppContext(
         app_config=config,
         download_manager=self.download_manager,
         lazy_wapiti_binary_wrapper=LazyWapitiBinaryWrapper(
             install_url=config.get('wapiti', {}).get('install_source'),
             download_manager=self.download_manager))
     self.fulltext_processor_config = FullTextProcessorConfig.from_app_config(
         app_config=config)
     self.fulltext_models = load_models(
         config,
         app_context=self.app_context,
         fulltext_processor_config=self.fulltext_processor_config)
     if config.get('preload_on_startup'):
         self.fulltext_models.preload()
     self.app_features_context = load_app_features_context(
         config, download_manager=self.download_manager)
     tei_to_jats_config = config.get('xslt', {}).get('tei_to_jats', {})
     self.tei_to_jats_xslt_transformer = XsltTransformerWrapper.from_template_file(
         TEI_TO_JATS_XSLT_FILE,
         xslt_template_parameters=tei_to_jats_config.get('parameters', {}))
     self.doc_to_pdf_enabled = config.get('doc_to_pdf',
                                          {}).get('enabled', True)
     self.doc_to_pdf_convert_parameters = config.get('doc_to_pdf',
                                                     {}).get('convert', {})
     self.doc_converter_wrapper = DocConverterWrapper(
         **config.get('doc_to_pdf', {}).get('listener', {}))
 def load_from(model_path: str,
               download_manager: DownloadManager,
               wapiti_binary_path: str = None) -> 'WapitiModelAdapter':
     model_file_path = os.path.join(model_path, 'model.wapiti.gz')
     local_model_file_path = None
     try:
         local_model_file_path = download_manager.download_if_url(
             model_file_path)
     except FileNotFoundError:
         pass
     if not local_model_file_path or not os.path.isfile(
             str(local_model_file_path)):
         model_file_path = os.path.splitext(model_file_path)[0]
         local_model_file_path = download_manager.download_if_url(
             model_file_path)
     LOGGER.debug('local_model_file_path: %s', local_model_file_path)
     if local_model_file_path.endswith('.gz'):
         local_uncompressed_file_path = os.path.splitext(
             local_model_file_path)[0]
         copy_file(local_model_file_path,
                   local_uncompressed_file_path,
                   overwrite=False)
         local_model_file_path = local_uncompressed_file_path
     return WapitiModelAdapter(
         WapitiWrapper(wapiti_binary_path=wapiti_binary_path),
         model_file_path=local_model_file_path,
         model_path=model_path)
Exemplo n.º 3
0
def download_if_url_from_alternatives(
        download_manager: DownloadManager,
        alternative_file_url_or_path_list: Sequence[str]) -> str:
    for file_url_or_path in alternative_file_url_or_path_list:
        if not is_external_location(file_url_or_path):
            if os.path.exists(file_url_or_path):
                return file_url_or_path
            LOGGER.debug('local file doesnt exist: %r', file_url_or_path)
            continue
        local_file = download_manager.get_local_file(file_url_or_path)
        if os.path.exists(local_file):
            return local_file
    LOGGER.debug('no existing local files found, downloading: %r',
                 alternative_file_url_or_path_list)
    for file_url_or_path in alternative_file_url_or_path_list:
        try:
            local_file = download_manager.download_if_url(file_url_or_path)
            if os.path.exists(local_file):
                return local_file
            LOGGER.debug('local file for %r not found: %r', file_url_or_path,
                         local_file)
        except FileNotFoundError:
            LOGGER.debug('remote file not found: %r', file_url_or_path)
    raise FileNotFoundError('no file found for %r' %
                            alternative_file_url_or_path_list)
Exemplo n.º 4
0
 def test_should_download_using_passed_in_local_file(
         self, copy_file_mock: MagicMock, download_dir: Path):
     download_file = str(download_dir.joinpath('custom.file'))
     download_manager = DownloadManager(download_dir=str(download_dir))
     assert download_manager.download(
         EXTERNAL_TXT_URL_1, local_file=download_file) == download_file
     copy_file_mock.assert_called_with(EXTERNAL_TXT_URL_1, download_file)
Exemplo n.º 5
0
def _pdfalto_wrapper(sciencebeam_parser_config: dict) -> PdfAltoWrapper:
    download_manager = DownloadManager(
        download_dir=get_download_dir(sciencebeam_parser_config))
    pdfalto_wrapper = PdfAltoWrapper(
        download_manager.download_if_url(
            sciencebeam_parser_config['pdfalto']['path']))
    pdfalto_wrapper.ensure_executable()
    return pdfalto_wrapper
 def do_run(self, args: argparse.Namespace):
     LOGGER.info('train')
     download_manager = DownloadManager()
     train_input_paths = _flatten_input_paths(args.train_input)
     train_input_texts, train_input_labels, list_classes = load_input_data(
         train_input_paths,
         download_manager=download_manager,
         limit=args.train_input_limit
     )
     LOGGER.info('list_classes: %s', list_classes)
     embedding_name = self.preload_and_validate_embedding(
         args.embeddings,
         use_word_embeddings=True
     )
     train(
         app_config=self.app_config,
         model_config=ModelConfig(
             embeddings_name=embedding_name,
             model_type=args.architecture,
             list_classes=list_classes
         ),
         training_config=TrainingConfig(
             batch_size=args.batch_size,
             max_epoch=args.max_epoch,
             log_dir=args.checkpoint
         ),
         train_input_texts=train_input_texts,
         train_input_labels=train_input_labels,
         model_path=args.model_path
     )
Exemplo n.º 7
0
def get_downloaded_input_paths(
        input_paths: List[str],
        download_manager: DownloadManager) -> List[str]:
    return [
        download_manager.download_if_url(input_path)
        for input_path in input_paths
    ]
Exemplo n.º 8
0
    def run(self, args: argparse.Namespace):
        if args.save_input_to_and_exit:
            save_input_to(args.input, args.save_input_to_and_exit)
            return

        self.download_manager = DownloadManager()
        self.embedding_manager = EmbeddingManager(
            download_manager=self.download_manager)
        if args.no_use_lmdb:
            self.embedding_manager.disable_embedding_lmdb_cache()

        set_random_seeds(args.random_seed)

        self.do_run(args)

        # see https://github.com/tensorflow/tensorflow/issues/3388
        K.clear_session()
Exemplo n.º 9
0
 def __init__(
         self,
         download_manager: DownloadManager = None,
         embedding_registry_path: str = None,
         embedding_manager: EmbeddingManager = None,
         **kwargs):
     self.embedding_registry_path = embedding_registry_path or DEFAULT_EMBEDDINGS_PATH
     if download_manager is None:
         download_manager = DownloadManager()
     if embedding_manager is None:
         embedding_manager = EmbeddingManager(
             path=self.embedding_registry_path,
             download_manager=download_manager
         )
     self.download_manager = download_manager
     self.embedding_manager = embedding_manager
     super().__init__(**kwargs)
def load_data_and_labels(input_paths: List[str] = None,
                         limit: int = None,
                         shuffle_input: bool = False,
                         clean_features: bool = True,
                         random_seed: int = DEFAULT_RANDOM_SEED,
                         download_manager: DownloadManager = None):
    assert download_manager
    assert input_paths
    LOGGER.info('loading data from: %s', input_paths)
    downloaded_input_paths = [
        download_manager.download_if_url(input_path)
        for input_path in input_paths
    ]
    x_all, y_all, f_all = _load_data_and_labels_crf_files(
        downloaded_input_paths, limit=limit)
    if shuffle_input:
        shuffle_arrays([x_all, y_all, f_all], random_seed=random_seed)
    log_data_info(x_all, y_all, f_all)
    if clean_features:
        (x_all, y_all, f_all) = get_clean_x_y_features(x_all, y_all, f_all)
    return x_all, y_all, f_all
    def run(self, args: argparse.Namespace):
        self.download_manager = DownloadManager()
        self.embedding_manager = EmbeddingManager(
            download_manager=self.download_manager
        )
        self.app_config = AppConfig(
            download_manager=self.download_manager,
            embedding_manager=self.embedding_manager
        )
        if args.no_use_lmdb:
            self.embedding_manager.disable_embedding_lmdb_cache()

        if args.preload_embedding:
            self.preload_and_validate_embedding(
                args.preload_embedding,
                use_word_embeddings=True
            )

        self.do_run(args)

        # see https://github.com/tensorflow/tensorflow/issues/3388
        K.clear_session()
 def do_run(self, args: argparse.Namespace):
     LOGGER.info('eval')
     download_manager = DownloadManager()
     eval_input_paths = _flatten_input_paths(args.eval_input)
     eval_label_input_paths = _flatten_input_paths(args.eval_label_input)
     eval_input_texts, eval_input_labels, list_classes = load_input_data(
         eval_input_paths,
         download_manager=download_manager,
         limit=args.eval_input_limit
     )
     if eval_label_input_paths:
         eval_input_labels, _ = load_label_data(
             eval_label_input_paths,
             download_manager=download_manager,
             limit=args.eval_input_limit
         )
     LOGGER.info('list_classes: %s', list_classes)
     result = evaluate(
         app_config=self.app_config,
         eval_input_texts=eval_input_texts,
         eval_input_labels=eval_input_labels,
         model_path=args.model_path
     )
     print(result.text_formatted_report)
 def do_run(self, args: argparse.Namespace):
     LOGGER.info('train')
     download_manager = DownloadManager()
     predict_input_paths = _flatten_input_paths(args.predict_input)
     predict_df = load_input_data_frame(
         predict_input_paths,
         download_manager=download_manager,
         limit=args.predict_input_limit
     )
     predict_input_texts, _, _ = get_texts_and_classes_from_data_frame(
         predict_df
     )
     result = predict(
         app_config=self.app_config,
         eval_input_texts=predict_input_texts,
         model_path=args.model_path
     )
     list_classes = result['labels']
     prediction = result['prediction']
     LOGGER.info('list_classes: %s', list_classes)
     result_df = pd.concat([
         predict_df[predict_df.columns[:2]],
         pd.DataFrame(
             prediction,
             columns=list_classes,
             index=predict_df.index
         )
     ], axis=1)
     if args.predict_output:
         LOGGER.info('writing output to: %s', args.predict_output)
         save_data_frame(result_df, args.predict_output)
     else:
         print(json.dumps(
             result_df.to_dict(orient='records'),
             indent=2
         ))
Exemplo n.º 14
0
def install_wapiti_and_get_path_or_none(
        install_url: Optional[str],
        download_manager: DownloadManager) -> Optional[str]:
    if not install_url:
        return None
    if not install_url.endswith(TAR_GZ_EXT):
        raise ValueError('only supporting %s' % TAR_GZ_EXT)
    local_file = download_manager.download_if_url(install_url,
                                                  auto_uncompress=False)
    extracted_directory = local_file[:-len(TAR_GZ_EXT)]
    LOGGER.debug('local_file: %s', local_file)
    LOGGER.debug('extracting to: %s', extracted_directory)
    with tarfile.open(local_file, mode='r') as tar:
        tar.extractall(extracted_directory)
    extracted_files = os.listdir(extracted_directory)
    wapiti_source_directory = extracted_directory
    if len(extracted_files) == 1:
        wapiti_source_directory = os.path.join(extracted_directory,
                                               extracted_files[0])
    LOGGER.info('running make in %s', wapiti_source_directory)
    subprocess.check_output('make', cwd=wapiti_source_directory)
    wapiti_binary = os.path.join(wapiti_source_directory, 'wapiti')
    LOGGER.info('done, binary: %s', wapiti_binary)
    return wapiti_binary
Exemplo n.º 15
0
 def test_should_unzip_embedding(self, copy_file_mock: MagicMock,
                                 download_manager: DownloadManager):
     download_file = download_manager.download(EXTERNAL_TXT_GZ_URL_1)
     copy_file_mock.assert_called_with(EXTERNAL_TXT_GZ_URL_1, download_file)
     assert str(download_file).endswith(
         os.path.basename(EXTERNAL_TXT_URL_1))
Exemplo n.º 16
0
 def __init__(self, download_manager: DownloadManager = None):
     if download_manager is None:
         download_manager = DownloadManager()
     self.download_manager = download_manager
Exemplo n.º 17
0
 def test_get_local_file_should_return_different_path_for_different_url_paths(
         self, download_manager: DownloadManager):
     assert (download_manager.get_local_file('http://host1/file.txt') !=
             download_manager.get_local_file('http://host2/file.txt'))
Exemplo n.º 18
0
 def test_get_local_file_should_return_same_path_for_same_urls(
         self, download_manager: DownloadManager):
     assert (download_manager.get_local_file(EXTERNAL_TXT_URL_1) ==
             download_manager.get_local_file(EXTERNAL_TXT_URL_1))
Exemplo n.º 19
0
def _download_manager(download_dir: Path):
    return DownloadManager(download_dir=str(download_dir))
 def __init__(
         self, *args,
         use_features: bool = False,
         features_indices: List[int] = None,
         features_embedding_size: int = None,
         multiprocessing: bool = False,
         embedding_registry_path: str = None,
         embedding_manager: EmbeddingManager = None,
         config_props: dict = None,
         training_props: dict = None,
         max_sequence_length: int = None,
         input_window_stride: int = None,
         eval_max_sequence_length: int = None,
         eval_input_window_stride: int = None,
         batch_size: int = None,
         eval_batch_size: int = None,
         stateful: bool = None,
         transfer_learning_config: TransferLearningConfig = None,
         tag_transformed: bool = False,
         **kwargs):
     # initialise logging if not already initialised
     logging.basicConfig(level='INFO')
     LOGGER.debug('Sequence, args=%s, kwargs=%s', args, kwargs)
     self.embedding_registry_path = embedding_registry_path or DEFAULT_EMBEDDINGS_PATH
     if embedding_manager is None:
         embedding_manager = EmbeddingManager(
             path=self.embedding_registry_path,
             download_manager=DownloadManager()
         )
     self.download_manager = embedding_manager.download_manager
     self.embedding_manager = embedding_manager
     self.embeddings: Optional[Embeddings] = None
     if not batch_size:
         batch_size = get_default_batch_size()
     if not max_sequence_length:
         max_sequence_length = get_default_max_sequence_length()
     self.max_sequence_length = max_sequence_length
     if not input_window_stride:
         input_window_stride = get_default_input_window_stride()
     self.input_window_stride = input_window_stride
     self.eval_max_sequence_length = eval_max_sequence_length
     self.eval_input_window_stride = eval_input_window_stride
     self.eval_batch_size = eval_batch_size
     self.model_path: Optional[str] = None
     if stateful is None:
         # use a stateful model, if supported
         stateful = get_default_stateful()
     self.stateful = stateful
     self.transfer_learning_config = transfer_learning_config
     self.dataset_transformer_factory = DummyDatasetTransformer
     self.tag_transformed = tag_transformed
     super().__init__(
         *args,
         max_sequence_length=max_sequence_length,
         batch_size=batch_size,
         **kwargs
     )
     LOGGER.debug('use_features=%s', use_features)
     self.model_config: ModelConfig = ModelConfig(
         **{  # type: ignore
             **vars(self.model_config),
             **(config_props or {}),
             'features_indices': features_indices,
             'features_embedding_size': features_embedding_size
         },
         use_features=use_features
     )
     self.update_model_config_word_embedding_size()
     updated_implicit_model_config_props(self.model_config)
     self.update_dataset_transformer_factor()
     self.training_config: TrainingConfig = TrainingConfig(
         **vars(cast(DelftTrainingConfig, self.training_config)),
         **(training_props or {})
     )
     LOGGER.info('training_config: %s', vars(self.training_config))
     self.multiprocessing = multiprocessing
     self.tag_debug_reporter = get_tag_debug_reporter_if_enabled()
     self._load_exception = None
     self.p: Optional[WordPreprocessor] = None
     self.model: Optional[BaseModel] = None
     self.models: List[BaseModel] = []