def test_should_not_download_if_not_downloaded_but_has_lmdb_cache( self, download_manager: MagicMock, embedding_manager: EmbeddingManager): _create_dummy_lmdb_cache_file(embedding_manager, EMBEDDING_NAME_1) assert embedding_manager.ensure_available( EXTERNAL_TXT_URL_1) == EMBEDDING_NAME_1 download_manager.download_if_url.assert_not_called() embedding_manager.validate_embedding(EMBEDDING_NAME_1)
def test_should_download_if_config_exists_but_not_downloaded( self, download_manager: MagicMock, embedding_manager: EmbeddingManager): embedding_manager.download_and_install_embedding( EXTERNAL_TXT_URL_1) download_manager.reset_mock() assert embedding_manager.ensure_available( EXTERNAL_TXT_URL_1) == EMBEDDING_NAME_1 download_manager.download_if_url.assert_called_with( EXTERNAL_TXT_URL_1)
def test_should_download_registered_embedding( self, download_manager: MagicMock, embedding_manager: EmbeddingManager, download_path_1: Path): embedding_manager.add_embedding_config({ 'name': EMBEDDING_NAME_1, 'path': str(download_path_1), 'url': EXTERNAL_TXT_URL_1 }) assert embedding_manager.ensure_available( EMBEDDING_NAME_1) == EMBEDDING_NAME_1 download_manager.download.assert_called_with( EXTERNAL_TXT_URL_1, local_file=str(download_path_1))
def test_should_skip_if_lmdb_cache_is_disabled( self, embedding_manager: EmbeddingManager): embedding_manager.add_embedding_config(EMBEDDING_1) embedding_manager.disable_embedding_lmdb_cache() assert embedding_manager.ensure_available( EMBEDDING_NAME_1) == EMBEDDING_NAME_1 assert not embedding_manager.has_lmdb_cache(EMBEDDING_NAME_1) embedding_manager.ensure_lmdb_cache_if_enabled(EMBEDDING_NAME_1) assert not embedding_manager.has_lmdb_cache(EMBEDDING_NAME_1)
def test_should_download_and_install_embedding( self, download_manager: MagicMock, embedding_manager: EmbeddingManager, download_path_1: Path): embedding_manager.download_and_install_embedding( EXTERNAL_TXT_URL_1) download_manager.download_if_url.assert_called_with( EXTERNAL_TXT_URL_1) embedding_config = embedding_manager.get_embedding_config( EMBEDDING_NAME_1) assert embedding_config assert embedding_config['name'] == EMBEDDING_NAME_1 assert embedding_config['path'] == str(download_path_1)
def test_should_download_registered_mdb_embedding( self, download_manager: MagicMock, embedding_manager: EmbeddingManager): embedding_manager.add_embedding_config({ 'name': EMBEDDING_NAME_1, 'url': EXTERNAL_MDB_GZ_URL_1 }) assert embedding_manager.ensure_available( EMBEDDING_NAME_1) == EMBEDDING_NAME_1 download_manager.download.assert_called_with( EXTERNAL_MDB_GZ_URL_1, local_file=str( embedding_manager.get_embedding_lmdb_cache_data_path( EMBEDDING_NAME_1)))
def test_should_not_download_if_already_downloaded( self, download_manager: MagicMock, embedding_manager: EmbeddingManager): embedding_manager.download_and_install_embedding( EXTERNAL_TXT_URL_1) download_manager.reset_mock() embedding_config = embedding_manager.get_embedding_config( EMBEDDING_NAME_1) assert embedding_config Path(embedding_config['path']).touch() assert embedding_manager.ensure_available( EXTERNAL_TXT_URL_1) == EMBEDDING_NAME_1 download_manager.download_if_url.assert_not_called()
def test_should_unzip_mdb_embedding( self, download_manager: MagicMock, embedding_manager: EmbeddingManager): embedding_manager.download_and_install_embedding( EXTERNAL_MDB_GZ_URL_1) download_manager.download.assert_called_with( EXTERNAL_MDB_GZ_URL_1, local_file=str( embedding_manager.get_embedding_lmdb_cache_data_path( EMBEDDING_NAME_1))) embedding_config = embedding_manager.get_embedding_config( EMBEDDING_NAME_1) assert embedding_config assert embedding_config['name'] == EMBEDDING_NAME_1
def _create_dummy_lmdb_cache_file(embedding_manager: EmbeddingManager, embedding_name: str): embedding_cache_file = Path( embedding_manager.get_embedding_lmdb_path()).joinpath( embedding_name).joinpath('data.mdb') embedding_cache_file.parent.mkdir(parents=True, exist_ok=True) embedding_cache_file.touch()
def run(self, args: argparse.Namespace): if args.save_input_to_and_exit: save_input_to(args.input, args.save_input_to_and_exit) return self.download_manager = DownloadManager() self.embedding_manager = EmbeddingManager( download_manager=self.download_manager) if args.no_use_lmdb: self.embedding_manager.disable_embedding_lmdb_cache() set_random_seeds(args.random_seed) self.do_run(args) # see https://github.com/tensorflow/tensorflow/issues/3388 K.clear_session()
def _embedding_manager(download_manager: MagicMock, embedding_registry_path: Path, embedding_lmdb_path: Path): embedding_manager = EmbeddingManager( str(embedding_registry_path), download_manager=download_manager, default_embedding_lmdb_path=str(embedding_lmdb_path), min_lmdb_cache_size=0) return embedding_manager
class BaseSubCommand(SubCommand): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.download_manager = None self.embedding_manager = None self.app_config = None @abstractmethod def do_run(self, args: argparse.Namespace): pass def preload_and_validate_embedding( self, embedding_name: str, use_word_embeddings: bool = True) -> Optional[str]: if not use_word_embeddings: return None embedding_name = self.embedding_manager.ensure_available(embedding_name) LOGGER.info('embedding_name: %s', embedding_name) self.embedding_manager.validate_embedding(embedding_name) return embedding_name def run(self, args: argparse.Namespace): self.download_manager = DownloadManager() self.embedding_manager = EmbeddingManager( download_manager=self.download_manager ) self.app_config = AppConfig( download_manager=self.download_manager, embedding_manager=self.embedding_manager ) if args.no_use_lmdb: self.embedding_manager.disable_embedding_lmdb_cache() if args.preload_embedding: self.preload_and_validate_embedding( args.preload_embedding, use_word_embeddings=True ) self.do_run(args) # see https://github.com/tensorflow/tensorflow/issues/3388 K.clear_session()
def run(self, args: argparse.Namespace): self.download_manager = DownloadManager() self.embedding_manager = EmbeddingManager( download_manager=self.download_manager ) self.app_config = AppConfig( download_manager=self.download_manager, embedding_manager=self.embedding_manager ) if args.no_use_lmdb: self.embedding_manager.disable_embedding_lmdb_cache() if args.preload_embedding: self.preload_and_validate_embedding( args.preload_embedding, use_word_embeddings=True ) self.do_run(args) # see https://github.com/tensorflow/tensorflow/issues/3388 K.clear_session()
def test_should_disable_lmdb_cache( self, embedding_manager: EmbeddingManager): embedding_manager.disable_embedding_lmdb_cache() assert embedding_manager.get_embedding_lmdb_path() is None
def test_should_resolve_embedding_alias( self, embedding_manager: EmbeddingManager): embedding_manager.set_embedding_aliases( {EMBEDDING_ALIAS_1: EMBEDDING_NAME_1}) assert embedding_manager.resolve_alias( EMBEDDING_ALIAS_1) == EMBEDDING_NAME_1
def test_should_return_passed_in_embedding_name_by_default( self, embedding_manager: EmbeddingManager): assert embedding_manager.resolve_alias( EMBEDDING_NAME_1) == EMBEDDING_NAME_1
def __init__( self, *args, use_features: bool = False, features_indices: List[int] = None, features_embedding_size: int = None, multiprocessing: bool = False, embedding_registry_path: str = None, embedding_manager: EmbeddingManager = None, config_props: dict = None, training_props: dict = None, max_sequence_length: int = None, input_window_stride: int = None, eval_max_sequence_length: int = None, eval_input_window_stride: int = None, batch_size: int = None, eval_batch_size: int = None, stateful: bool = None, transfer_learning_config: TransferLearningConfig = None, tag_transformed: bool = False, **kwargs): # initialise logging if not already initialised logging.basicConfig(level='INFO') LOGGER.debug('Sequence, args=%s, kwargs=%s', args, kwargs) self.embedding_registry_path = embedding_registry_path or DEFAULT_EMBEDDINGS_PATH if embedding_manager is None: embedding_manager = EmbeddingManager( path=self.embedding_registry_path, download_manager=DownloadManager() ) self.download_manager = embedding_manager.download_manager self.embedding_manager = embedding_manager self.embeddings: Optional[Embeddings] = None if not batch_size: batch_size = get_default_batch_size() if not max_sequence_length: max_sequence_length = get_default_max_sequence_length() self.max_sequence_length = max_sequence_length if not input_window_stride: input_window_stride = get_default_input_window_stride() self.input_window_stride = input_window_stride self.eval_max_sequence_length = eval_max_sequence_length self.eval_input_window_stride = eval_input_window_stride self.eval_batch_size = eval_batch_size self.model_path: Optional[str] = None if stateful is None: # use a stateful model, if supported stateful = get_default_stateful() self.stateful = stateful self.transfer_learning_config = transfer_learning_config self.dataset_transformer_factory = DummyDatasetTransformer self.tag_transformed = tag_transformed super().__init__( *args, max_sequence_length=max_sequence_length, batch_size=batch_size, **kwargs ) LOGGER.debug('use_features=%s', use_features) self.model_config: ModelConfig = ModelConfig( **{ # type: ignore **vars(self.model_config), **(config_props or {}), 'features_indices': features_indices, 'features_embedding_size': features_embedding_size }, use_features=use_features ) self.update_model_config_word_embedding_size() updated_implicit_model_config_props(self.model_config) self.update_dataset_transformer_factor() self.training_config: TrainingConfig = TrainingConfig( **vars(cast(DelftTrainingConfig, self.training_config)), **(training_props or {}) ) LOGGER.info('training_config: %s', vars(self.training_config)) self.multiprocessing = multiprocessing self.tag_debug_reporter = get_tag_debug_reporter_if_enabled() self._load_exception = None self.p: Optional[WordPreprocessor] = None self.model: Optional[BaseModel] = None self.models: List[BaseModel] = []
class GrobidTrainerSubCommand(SubCommand): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.download_manager = None self.embedding_manager = None @abstractmethod def do_run(self, args: argparse.Namespace): pass def preload_and_validate_embedding( self, embedding_name: str, use_word_embeddings: bool = True) -> Optional[str]: if not use_word_embeddings: return None embedding_name = self.embedding_manager.ensure_available( embedding_name) LOGGER.info('embedding_name: %s', embedding_name) self.embedding_manager.validate_embedding(embedding_name) return embedding_name def get_common_args(self, args: argparse.Namespace) -> dict: return dict(model_name=args.model, input_paths=args.input, limit=args.limit, shuffle_input=args.shuffle_input, random_seed=args.random_seed, batch_size=args.batch_size, max_sequence_length=args.max_sequence_length, multiprocessing=args.multiprocessing, embedding_manager=self.embedding_manager, download_manager=self.download_manager) def get_train_args(self, args: argparse.Namespace) -> dict: return dict( architecture=args.architecture, use_ELMo=args.use_ELMo, output_path=args.output, log_dir=args.checkpoint, char_emb_size=args.char_embedding_size, char_lstm_units=args.char_lstm_units, word_lstm_units=args.word_lstm_units, dropout=args.dropout, recurrent_dropout=args.recurrent_dropout, max_epoch=args.max_epoch, use_features=args.use_features, features_indices=args.features_indices, features_embedding_size=args.features_embedding_size, patience=args.early_stopping_patience, config_props=dict( max_char_length=args.max_char_length, char_input_mask_zero=args.char_input_mask_zero, char_input_dropout=args.char_input_dropout, char_lstm_dropout=args.char_lstm_dropout, additional_token_feature_indices=args. additional_token_feature_indices, text_feature_indices=args.text_feature_indices, unroll_text_feature_index=args.unroll_text_feature_index, concatenated_embeddings_token_count=args. concatenated_embeddings_token_count, use_word_embeddings=args.use_word_embeddings, use_features_indices_input=args.use_features_indices_input, continuous_features_indices=args.continuous_features_indices, features_lstm_units=args.features_lstm_units, stateful=args.stateful), training_props=dict( initial_epoch=args.initial_epoch, input_window_stride=args.input_window_stride, checkpoint_epoch_interval=args.checkpoint_epoch_interval), resume_train_model_path=args.resume_train_model_path, auto_resume=args.auto_resume, transfer_learning_config= get_transfer_learning_config_for_parsed_args(args), train_notification_manager=get_train_notification_manager(args), **self.get_common_args(args)) def run(self, args: argparse.Namespace): if args.save_input_to_and_exit: save_input_to(args.input, args.save_input_to_and_exit) return self.download_manager = DownloadManager() self.embedding_manager = EmbeddingManager( download_manager=self.download_manager) if args.no_use_lmdb: self.embedding_manager.disable_embedding_lmdb_cache() set_random_seeds(args.random_seed) self.do_run(args) # see https://github.com/tensorflow/tensorflow/issues/3388 K.clear_session()