Example #1
0
 def test_should_not_download_if_not_downloaded_but_has_lmdb_cache(
         self, download_manager: MagicMock,
         embedding_manager: EmbeddingManager):
     _create_dummy_lmdb_cache_file(embedding_manager, EMBEDDING_NAME_1)
     assert embedding_manager.ensure_available(
         EXTERNAL_TXT_URL_1) == EMBEDDING_NAME_1
     download_manager.download_if_url.assert_not_called()
     embedding_manager.validate_embedding(EMBEDDING_NAME_1)
Example #2
0
        def test_should_download_if_config_exists_but_not_downloaded(
                self, download_manager: MagicMock,
                embedding_manager: EmbeddingManager):
            embedding_manager.download_and_install_embedding(
                EXTERNAL_TXT_URL_1)
            download_manager.reset_mock()

            assert embedding_manager.ensure_available(
                EXTERNAL_TXT_URL_1) == EMBEDDING_NAME_1
            download_manager.download_if_url.assert_called_with(
                EXTERNAL_TXT_URL_1)
Example #3
0
 def test_should_download_registered_embedding(
         self, download_manager: MagicMock,
         embedding_manager: EmbeddingManager, download_path_1: Path):
     embedding_manager.add_embedding_config({
         'name': EMBEDDING_NAME_1,
         'path': str(download_path_1),
         'url': EXTERNAL_TXT_URL_1
     })
     assert embedding_manager.ensure_available(
         EMBEDDING_NAME_1) == EMBEDDING_NAME_1
     download_manager.download.assert_called_with(
         EXTERNAL_TXT_URL_1, local_file=str(download_path_1))
Example #4
0
 def test_should_skip_if_lmdb_cache_is_disabled(
         self, embedding_manager: EmbeddingManager):
     embedding_manager.add_embedding_config(EMBEDDING_1)
     embedding_manager.disable_embedding_lmdb_cache()
     assert embedding_manager.ensure_available(
         EMBEDDING_NAME_1) == EMBEDDING_NAME_1
     assert not embedding_manager.has_lmdb_cache(EMBEDDING_NAME_1)
     embedding_manager.ensure_lmdb_cache_if_enabled(EMBEDDING_NAME_1)
     assert not embedding_manager.has_lmdb_cache(EMBEDDING_NAME_1)
Example #5
0
        def test_should_download_and_install_embedding(
                self, download_manager: MagicMock,
                embedding_manager: EmbeddingManager, download_path_1: Path):
            embedding_manager.download_and_install_embedding(
                EXTERNAL_TXT_URL_1)
            download_manager.download_if_url.assert_called_with(
                EXTERNAL_TXT_URL_1)

            embedding_config = embedding_manager.get_embedding_config(
                EMBEDDING_NAME_1)
            assert embedding_config
            assert embedding_config['name'] == EMBEDDING_NAME_1
            assert embedding_config['path'] == str(download_path_1)
Example #6
0
 def test_should_download_registered_mdb_embedding(
         self, download_manager: MagicMock,
         embedding_manager: EmbeddingManager):
     embedding_manager.add_embedding_config({
         'name': EMBEDDING_NAME_1,
         'url': EXTERNAL_MDB_GZ_URL_1
     })
     assert embedding_manager.ensure_available(
         EMBEDDING_NAME_1) == EMBEDDING_NAME_1
     download_manager.download.assert_called_with(
         EXTERNAL_MDB_GZ_URL_1,
         local_file=str(
             embedding_manager.get_embedding_lmdb_cache_data_path(
                 EMBEDDING_NAME_1)))
Example #7
0
        def test_should_not_download_if_already_downloaded(
                self, download_manager: MagicMock,
                embedding_manager: EmbeddingManager):
            embedding_manager.download_and_install_embedding(
                EXTERNAL_TXT_URL_1)
            download_manager.reset_mock()
            embedding_config = embedding_manager.get_embedding_config(
                EMBEDDING_NAME_1)
            assert embedding_config
            Path(embedding_config['path']).touch()

            assert embedding_manager.ensure_available(
                EXTERNAL_TXT_URL_1) == EMBEDDING_NAME_1
            download_manager.download_if_url.assert_not_called()
Example #8
0
        def test_should_unzip_mdb_embedding(
                self, download_manager: MagicMock,
                embedding_manager: EmbeddingManager):
            embedding_manager.download_and_install_embedding(
                EXTERNAL_MDB_GZ_URL_1)
            download_manager.download.assert_called_with(
                EXTERNAL_MDB_GZ_URL_1,
                local_file=str(
                    embedding_manager.get_embedding_lmdb_cache_data_path(
                        EMBEDDING_NAME_1)))

            embedding_config = embedding_manager.get_embedding_config(
                EMBEDDING_NAME_1)
            assert embedding_config
            assert embedding_config['name'] == EMBEDDING_NAME_1
Example #9
0
def _create_dummy_lmdb_cache_file(embedding_manager: EmbeddingManager,
                                  embedding_name: str):
    embedding_cache_file = Path(
        embedding_manager.get_embedding_lmdb_path()).joinpath(
            embedding_name).joinpath('data.mdb')
    embedding_cache_file.parent.mkdir(parents=True, exist_ok=True)
    embedding_cache_file.touch()
Example #10
0
    def run(self, args: argparse.Namespace):
        if args.save_input_to_and_exit:
            save_input_to(args.input, args.save_input_to_and_exit)
            return

        self.download_manager = DownloadManager()
        self.embedding_manager = EmbeddingManager(
            download_manager=self.download_manager)
        if args.no_use_lmdb:
            self.embedding_manager.disable_embedding_lmdb_cache()

        set_random_seeds(args.random_seed)

        self.do_run(args)

        # see https://github.com/tensorflow/tensorflow/issues/3388
        K.clear_session()
Example #11
0
def _embedding_manager(download_manager: MagicMock,
                       embedding_registry_path: Path,
                       embedding_lmdb_path: Path):
    embedding_manager = EmbeddingManager(
        str(embedding_registry_path),
        download_manager=download_manager,
        default_embedding_lmdb_path=str(embedding_lmdb_path),
        min_lmdb_cache_size=0)
    return embedding_manager
class BaseSubCommand(SubCommand):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.download_manager = None
        self.embedding_manager = None
        self.app_config = None

    @abstractmethod
    def do_run(self, args: argparse.Namespace):
        pass

    def preload_and_validate_embedding(
            self,
            embedding_name: str,
            use_word_embeddings: bool = True) -> Optional[str]:
        if not use_word_embeddings:
            return None
        embedding_name = self.embedding_manager.ensure_available(embedding_name)
        LOGGER.info('embedding_name: %s', embedding_name)
        self.embedding_manager.validate_embedding(embedding_name)
        return embedding_name

    def run(self, args: argparse.Namespace):
        self.download_manager = DownloadManager()
        self.embedding_manager = EmbeddingManager(
            download_manager=self.download_manager
        )
        self.app_config = AppConfig(
            download_manager=self.download_manager,
            embedding_manager=self.embedding_manager
        )
        if args.no_use_lmdb:
            self.embedding_manager.disable_embedding_lmdb_cache()

        if args.preload_embedding:
            self.preload_and_validate_embedding(
                args.preload_embedding,
                use_word_embeddings=True
            )

        self.do_run(args)

        # see https://github.com/tensorflow/tensorflow/issues/3388
        K.clear_session()
    def run(self, args: argparse.Namespace):
        self.download_manager = DownloadManager()
        self.embedding_manager = EmbeddingManager(
            download_manager=self.download_manager
        )
        self.app_config = AppConfig(
            download_manager=self.download_manager,
            embedding_manager=self.embedding_manager
        )
        if args.no_use_lmdb:
            self.embedding_manager.disable_embedding_lmdb_cache()

        if args.preload_embedding:
            self.preload_and_validate_embedding(
                args.preload_embedding,
                use_word_embeddings=True
            )

        self.do_run(args)

        # see https://github.com/tensorflow/tensorflow/issues/3388
        K.clear_session()
Example #14
0
 def test_should_disable_lmdb_cache(
         self, embedding_manager: EmbeddingManager):
     embedding_manager.disable_embedding_lmdb_cache()
     assert embedding_manager.get_embedding_lmdb_path() is None
Example #15
0
 def test_should_resolve_embedding_alias(
         self, embedding_manager: EmbeddingManager):
     embedding_manager.set_embedding_aliases(
         {EMBEDDING_ALIAS_1: EMBEDDING_NAME_1})
     assert embedding_manager.resolve_alias(
         EMBEDDING_ALIAS_1) == EMBEDDING_NAME_1
Example #16
0
 def test_should_return_passed_in_embedding_name_by_default(
         self, embedding_manager: EmbeddingManager):
     assert embedding_manager.resolve_alias(
         EMBEDDING_NAME_1) == EMBEDDING_NAME_1
 def __init__(
         self, *args,
         use_features: bool = False,
         features_indices: List[int] = None,
         features_embedding_size: int = None,
         multiprocessing: bool = False,
         embedding_registry_path: str = None,
         embedding_manager: EmbeddingManager = None,
         config_props: dict = None,
         training_props: dict = None,
         max_sequence_length: int = None,
         input_window_stride: int = None,
         eval_max_sequence_length: int = None,
         eval_input_window_stride: int = None,
         batch_size: int = None,
         eval_batch_size: int = None,
         stateful: bool = None,
         transfer_learning_config: TransferLearningConfig = None,
         tag_transformed: bool = False,
         **kwargs):
     # initialise logging if not already initialised
     logging.basicConfig(level='INFO')
     LOGGER.debug('Sequence, args=%s, kwargs=%s', args, kwargs)
     self.embedding_registry_path = embedding_registry_path or DEFAULT_EMBEDDINGS_PATH
     if embedding_manager is None:
         embedding_manager = EmbeddingManager(
             path=self.embedding_registry_path,
             download_manager=DownloadManager()
         )
     self.download_manager = embedding_manager.download_manager
     self.embedding_manager = embedding_manager
     self.embeddings: Optional[Embeddings] = None
     if not batch_size:
         batch_size = get_default_batch_size()
     if not max_sequence_length:
         max_sequence_length = get_default_max_sequence_length()
     self.max_sequence_length = max_sequence_length
     if not input_window_stride:
         input_window_stride = get_default_input_window_stride()
     self.input_window_stride = input_window_stride
     self.eval_max_sequence_length = eval_max_sequence_length
     self.eval_input_window_stride = eval_input_window_stride
     self.eval_batch_size = eval_batch_size
     self.model_path: Optional[str] = None
     if stateful is None:
         # use a stateful model, if supported
         stateful = get_default_stateful()
     self.stateful = stateful
     self.transfer_learning_config = transfer_learning_config
     self.dataset_transformer_factory = DummyDatasetTransformer
     self.tag_transformed = tag_transformed
     super().__init__(
         *args,
         max_sequence_length=max_sequence_length,
         batch_size=batch_size,
         **kwargs
     )
     LOGGER.debug('use_features=%s', use_features)
     self.model_config: ModelConfig = ModelConfig(
         **{  # type: ignore
             **vars(self.model_config),
             **(config_props or {}),
             'features_indices': features_indices,
             'features_embedding_size': features_embedding_size
         },
         use_features=use_features
     )
     self.update_model_config_word_embedding_size()
     updated_implicit_model_config_props(self.model_config)
     self.update_dataset_transformer_factor()
     self.training_config: TrainingConfig = TrainingConfig(
         **vars(cast(DelftTrainingConfig, self.training_config)),
         **(training_props or {})
     )
     LOGGER.info('training_config: %s', vars(self.training_config))
     self.multiprocessing = multiprocessing
     self.tag_debug_reporter = get_tag_debug_reporter_if_enabled()
     self._load_exception = None
     self.p: Optional[WordPreprocessor] = None
     self.model: Optional[BaseModel] = None
     self.models: List[BaseModel] = []
Example #18
0
class GrobidTrainerSubCommand(SubCommand):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.download_manager = None
        self.embedding_manager = None

    @abstractmethod
    def do_run(self, args: argparse.Namespace):
        pass

    def preload_and_validate_embedding(
            self,
            embedding_name: str,
            use_word_embeddings: bool = True) -> Optional[str]:
        if not use_word_embeddings:
            return None
        embedding_name = self.embedding_manager.ensure_available(
            embedding_name)
        LOGGER.info('embedding_name: %s', embedding_name)
        self.embedding_manager.validate_embedding(embedding_name)
        return embedding_name

    def get_common_args(self, args: argparse.Namespace) -> dict:
        return dict(model_name=args.model,
                    input_paths=args.input,
                    limit=args.limit,
                    shuffle_input=args.shuffle_input,
                    random_seed=args.random_seed,
                    batch_size=args.batch_size,
                    max_sequence_length=args.max_sequence_length,
                    multiprocessing=args.multiprocessing,
                    embedding_manager=self.embedding_manager,
                    download_manager=self.download_manager)

    def get_train_args(self, args: argparse.Namespace) -> dict:
        return dict(
            architecture=args.architecture,
            use_ELMo=args.use_ELMo,
            output_path=args.output,
            log_dir=args.checkpoint,
            char_emb_size=args.char_embedding_size,
            char_lstm_units=args.char_lstm_units,
            word_lstm_units=args.word_lstm_units,
            dropout=args.dropout,
            recurrent_dropout=args.recurrent_dropout,
            max_epoch=args.max_epoch,
            use_features=args.use_features,
            features_indices=args.features_indices,
            features_embedding_size=args.features_embedding_size,
            patience=args.early_stopping_patience,
            config_props=dict(
                max_char_length=args.max_char_length,
                char_input_mask_zero=args.char_input_mask_zero,
                char_input_dropout=args.char_input_dropout,
                char_lstm_dropout=args.char_lstm_dropout,
                additional_token_feature_indices=args.
                additional_token_feature_indices,
                text_feature_indices=args.text_feature_indices,
                unroll_text_feature_index=args.unroll_text_feature_index,
                concatenated_embeddings_token_count=args.
                concatenated_embeddings_token_count,
                use_word_embeddings=args.use_word_embeddings,
                use_features_indices_input=args.use_features_indices_input,
                continuous_features_indices=args.continuous_features_indices,
                features_lstm_units=args.features_lstm_units,
                stateful=args.stateful),
            training_props=dict(
                initial_epoch=args.initial_epoch,
                input_window_stride=args.input_window_stride,
                checkpoint_epoch_interval=args.checkpoint_epoch_interval),
            resume_train_model_path=args.resume_train_model_path,
            auto_resume=args.auto_resume,
            transfer_learning_config=
            get_transfer_learning_config_for_parsed_args(args),
            train_notification_manager=get_train_notification_manager(args),
            **self.get_common_args(args))

    def run(self, args: argparse.Namespace):
        if args.save_input_to_and_exit:
            save_input_to(args.input, args.save_input_to_and_exit)
            return

        self.download_manager = DownloadManager()
        self.embedding_manager = EmbeddingManager(
            download_manager=self.download_manager)
        if args.no_use_lmdb:
            self.embedding_manager.disable_embedding_lmdb_cache()

        set_random_seeds(args.random_seed)

        self.do_run(args)

        # see https://github.com/tensorflow/tensorflow/issues/3388
        K.clear_session()