Example #1
0
 def test_should_skip_if_lmdb_cache_is_disabled(
         self, embedding_manager: EmbeddingManager):
     embedding_manager.add_embedding_config(EMBEDDING_1)
     embedding_manager.disable_embedding_lmdb_cache()
     assert embedding_manager.ensure_available(
         EMBEDDING_NAME_1) == EMBEDDING_NAME_1
     assert not embedding_manager.has_lmdb_cache(EMBEDDING_NAME_1)
     embedding_manager.ensure_lmdb_cache_if_enabled(EMBEDDING_NAME_1)
     assert not embedding_manager.has_lmdb_cache(EMBEDDING_NAME_1)
class BaseSubCommand(SubCommand):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.download_manager = None
        self.embedding_manager = None
        self.app_config = None

    @abstractmethod
    def do_run(self, args: argparse.Namespace):
        pass

    def preload_and_validate_embedding(
            self,
            embedding_name: str,
            use_word_embeddings: bool = True) -> Optional[str]:
        if not use_word_embeddings:
            return None
        embedding_name = self.embedding_manager.ensure_available(embedding_name)
        LOGGER.info('embedding_name: %s', embedding_name)
        self.embedding_manager.validate_embedding(embedding_name)
        return embedding_name

    def run(self, args: argparse.Namespace):
        self.download_manager = DownloadManager()
        self.embedding_manager = EmbeddingManager(
            download_manager=self.download_manager
        )
        self.app_config = AppConfig(
            download_manager=self.download_manager,
            embedding_manager=self.embedding_manager
        )
        if args.no_use_lmdb:
            self.embedding_manager.disable_embedding_lmdb_cache()

        if args.preload_embedding:
            self.preload_and_validate_embedding(
                args.preload_embedding,
                use_word_embeddings=True
            )

        self.do_run(args)

        # see https://github.com/tensorflow/tensorflow/issues/3388
        K.clear_session()
Example #3
0
 def test_should_disable_lmdb_cache(
         self, embedding_manager: EmbeddingManager):
     embedding_manager.disable_embedding_lmdb_cache()
     assert embedding_manager.get_embedding_lmdb_path() is None
Example #4
0
class GrobidTrainerSubCommand(SubCommand):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.download_manager = None
        self.embedding_manager = None

    @abstractmethod
    def do_run(self, args: argparse.Namespace):
        pass

    def preload_and_validate_embedding(
            self,
            embedding_name: str,
            use_word_embeddings: bool = True) -> Optional[str]:
        if not use_word_embeddings:
            return None
        embedding_name = self.embedding_manager.ensure_available(
            embedding_name)
        LOGGER.info('embedding_name: %s', embedding_name)
        self.embedding_manager.validate_embedding(embedding_name)
        return embedding_name

    def get_common_args(self, args: argparse.Namespace) -> dict:
        return dict(model_name=args.model,
                    input_paths=args.input,
                    limit=args.limit,
                    shuffle_input=args.shuffle_input,
                    random_seed=args.random_seed,
                    batch_size=args.batch_size,
                    max_sequence_length=args.max_sequence_length,
                    multiprocessing=args.multiprocessing,
                    embedding_manager=self.embedding_manager,
                    download_manager=self.download_manager)

    def get_train_args(self, args: argparse.Namespace) -> dict:
        return dict(
            architecture=args.architecture,
            use_ELMo=args.use_ELMo,
            output_path=args.output,
            log_dir=args.checkpoint,
            char_emb_size=args.char_embedding_size,
            char_lstm_units=args.char_lstm_units,
            word_lstm_units=args.word_lstm_units,
            dropout=args.dropout,
            recurrent_dropout=args.recurrent_dropout,
            max_epoch=args.max_epoch,
            use_features=args.use_features,
            features_indices=args.features_indices,
            features_embedding_size=args.features_embedding_size,
            patience=args.early_stopping_patience,
            config_props=dict(
                max_char_length=args.max_char_length,
                char_input_mask_zero=args.char_input_mask_zero,
                char_input_dropout=args.char_input_dropout,
                char_lstm_dropout=args.char_lstm_dropout,
                additional_token_feature_indices=args.
                additional_token_feature_indices,
                text_feature_indices=args.text_feature_indices,
                unroll_text_feature_index=args.unroll_text_feature_index,
                concatenated_embeddings_token_count=args.
                concatenated_embeddings_token_count,
                use_word_embeddings=args.use_word_embeddings,
                use_features_indices_input=args.use_features_indices_input,
                continuous_features_indices=args.continuous_features_indices,
                features_lstm_units=args.features_lstm_units,
                stateful=args.stateful),
            training_props=dict(
                initial_epoch=args.initial_epoch,
                input_window_stride=args.input_window_stride,
                checkpoint_epoch_interval=args.checkpoint_epoch_interval),
            resume_train_model_path=args.resume_train_model_path,
            auto_resume=args.auto_resume,
            transfer_learning_config=
            get_transfer_learning_config_for_parsed_args(args),
            train_notification_manager=get_train_notification_manager(args),
            **self.get_common_args(args))

    def run(self, args: argparse.Namespace):
        if args.save_input_to_and_exit:
            save_input_to(args.input, args.save_input_to_and_exit)
            return

        self.download_manager = DownloadManager()
        self.embedding_manager = EmbeddingManager(
            download_manager=self.download_manager)
        if args.no_use_lmdb:
            self.embedding_manager.disable_embedding_lmdb_cache()

        set_random_seeds(args.random_seed)

        self.do_run(args)

        # see https://github.com/tensorflow/tensorflow/issues/3388
        K.clear_session()