def test_should_skip_if_lmdb_cache_is_disabled( self, embedding_manager: EmbeddingManager): embedding_manager.add_embedding_config(EMBEDDING_1) embedding_manager.disable_embedding_lmdb_cache() assert embedding_manager.ensure_available( EMBEDDING_NAME_1) == EMBEDDING_NAME_1 assert not embedding_manager.has_lmdb_cache(EMBEDDING_NAME_1) embedding_manager.ensure_lmdb_cache_if_enabled(EMBEDDING_NAME_1) assert not embedding_manager.has_lmdb_cache(EMBEDDING_NAME_1)
class BaseSubCommand(SubCommand): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.download_manager = None self.embedding_manager = None self.app_config = None @abstractmethod def do_run(self, args: argparse.Namespace): pass def preload_and_validate_embedding( self, embedding_name: str, use_word_embeddings: bool = True) -> Optional[str]: if not use_word_embeddings: return None embedding_name = self.embedding_manager.ensure_available(embedding_name) LOGGER.info('embedding_name: %s', embedding_name) self.embedding_manager.validate_embedding(embedding_name) return embedding_name def run(self, args: argparse.Namespace): self.download_manager = DownloadManager() self.embedding_manager = EmbeddingManager( download_manager=self.download_manager ) self.app_config = AppConfig( download_manager=self.download_manager, embedding_manager=self.embedding_manager ) if args.no_use_lmdb: self.embedding_manager.disable_embedding_lmdb_cache() if args.preload_embedding: self.preload_and_validate_embedding( args.preload_embedding, use_word_embeddings=True ) self.do_run(args) # see https://github.com/tensorflow/tensorflow/issues/3388 K.clear_session()
def test_should_disable_lmdb_cache( self, embedding_manager: EmbeddingManager): embedding_manager.disable_embedding_lmdb_cache() assert embedding_manager.get_embedding_lmdb_path() is None
class GrobidTrainerSubCommand(SubCommand): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.download_manager = None self.embedding_manager = None @abstractmethod def do_run(self, args: argparse.Namespace): pass def preload_and_validate_embedding( self, embedding_name: str, use_word_embeddings: bool = True) -> Optional[str]: if not use_word_embeddings: return None embedding_name = self.embedding_manager.ensure_available( embedding_name) LOGGER.info('embedding_name: %s', embedding_name) self.embedding_manager.validate_embedding(embedding_name) return embedding_name def get_common_args(self, args: argparse.Namespace) -> dict: return dict(model_name=args.model, input_paths=args.input, limit=args.limit, shuffle_input=args.shuffle_input, random_seed=args.random_seed, batch_size=args.batch_size, max_sequence_length=args.max_sequence_length, multiprocessing=args.multiprocessing, embedding_manager=self.embedding_manager, download_manager=self.download_manager) def get_train_args(self, args: argparse.Namespace) -> dict: return dict( architecture=args.architecture, use_ELMo=args.use_ELMo, output_path=args.output, log_dir=args.checkpoint, char_emb_size=args.char_embedding_size, char_lstm_units=args.char_lstm_units, word_lstm_units=args.word_lstm_units, dropout=args.dropout, recurrent_dropout=args.recurrent_dropout, max_epoch=args.max_epoch, use_features=args.use_features, features_indices=args.features_indices, features_embedding_size=args.features_embedding_size, patience=args.early_stopping_patience, config_props=dict( max_char_length=args.max_char_length, char_input_mask_zero=args.char_input_mask_zero, char_input_dropout=args.char_input_dropout, char_lstm_dropout=args.char_lstm_dropout, additional_token_feature_indices=args. additional_token_feature_indices, text_feature_indices=args.text_feature_indices, unroll_text_feature_index=args.unroll_text_feature_index, concatenated_embeddings_token_count=args. concatenated_embeddings_token_count, use_word_embeddings=args.use_word_embeddings, use_features_indices_input=args.use_features_indices_input, continuous_features_indices=args.continuous_features_indices, features_lstm_units=args.features_lstm_units, stateful=args.stateful), training_props=dict( initial_epoch=args.initial_epoch, input_window_stride=args.input_window_stride, checkpoint_epoch_interval=args.checkpoint_epoch_interval), resume_train_model_path=args.resume_train_model_path, auto_resume=args.auto_resume, transfer_learning_config= get_transfer_learning_config_for_parsed_args(args), train_notification_manager=get_train_notification_manager(args), **self.get_common_args(args)) def run(self, args: argparse.Namespace): if args.save_input_to_and_exit: save_input_to(args.input, args.save_input_to_and_exit) return self.download_manager = DownloadManager() self.embedding_manager = EmbeddingManager( download_manager=self.download_manager) if args.no_use_lmdb: self.embedding_manager.disable_embedding_lmdb_cache() set_random_seeds(args.random_seed) self.do_run(args) # see https://github.com/tensorflow/tensorflow/issues/3388 K.clear_session()