Ejemplo n.º 1
0
 def do_run(self, args: argparse.Namespace):
     if not args.model:
         raise ValueError("model required")
     if args.preload_embedding:
         self.preload_and_validate_embedding(args.preload_embedding,
                                             use_word_embeddings=True)
     embedding_name = self.preload_and_validate_embedding(
         args.embedding,
         use_word_embeddings=args.use_word_embeddings
         and not args.resume_train_model_path)
     LOGGER.info('get_tf_info: %s', get_tf_info())
     train(embeddings_name=embedding_name, **self.get_train_args(args))
Ejemplo n.º 2
0
 def test_should_be_able_to_train_without_features(
         self, default_args: dict, default_model_directory: str):
     train(
         use_features=False,
         **default_args
     )
     model_config = load_model_config(default_model_directory)
     assert not model_config.use_features
     tag_input(
         model_name=default_args['model_name'],
         model_path=default_model_directory,
         input_paths=default_args['input_paths'],
         download_manager=default_args['download_manager'],
         embedding_registry_path=default_args['embedding_registry_path']
     )
Ejemplo n.º 3
0
 def test_should_be_able_to_train_without_word_embeddings(
         self, default_args: dict, default_model_directory: str):
     train(
         **cast(TrainArgsDict, {
             **default_args,
             'embeddings_name': None,
             'config_props': {
                 **default_args.get('config_props', {}),
                 'use_word_embeddings': False
             }
         })
     )
     model_config = load_model_config(default_model_directory)
     assert model_config.embeddings_name is None
     tag_input(
         model_name=default_args['model_name'],
         model_path=default_model_directory,
         input_paths=default_args['input_paths'],
         download_manager=default_args['download_manager'],
         embedding_registry_path=default_args['embedding_registry_path']
     )
Ejemplo n.º 4
0
 def test_should_be_able_to_train_with_features_and_features_embeddings(
         self, default_args: dict, default_model_directory: str):
     train(
         use_features=True,
         **cast(TrainArgsDict, {
             **default_args,
             'features_indices': FEATURE_INDICES_1,
             'features_embedding_size': FEATURES_EMBEDDING_SIZE_1
         })
     )
     model_config = load_model_config(default_model_directory)
     assert model_config.use_features
     assert model_config.features_indices == FEATURE_INDICES_1
     assert model_config.features_embedding_size == FEATURES_EMBEDDING_SIZE_1
     tag_input(
         model_name=default_args['model_name'],
         model_path=default_model_directory,
         input_paths=default_args['input_paths'],
         download_manager=default_args['download_manager'],
         embedding_registry_path=default_args['embedding_registry_path']
     )
Ejemplo n.º 5
0
 def test_should_be_able_to_copy_weights_from_previous_model(
     self,
     tmp_path: Path,
     default_args: dict,
     copy_preprocessor: bool
 ):
     source_model_output_path = tmp_path / 'source_model'
     train(
         **cast(TrainArgsDict, {
             **default_args,
             'output_path': str(source_model_output_path),
             'embeddings_name': None,
             'config_props': {
                 **default_args.get('config_props', {}),
                 'use_word_embeddings': False
             }
         })
     )
     train(
         **cast(DefaultArgsDict, {
             **default_args,
             'transfer_learning_config': TransferLearningConfig(
                 source_model_path=(
                     str(source_model_output_path / default_args['model_name'])
                 ),
                 copy_layers={
                     'char_embeddings': 'char_embeddings',
                     'char_lstm': 'char_lstm'
                 },
                 copy_preprocessor=copy_preprocessor,
                 copy_preprocessor_fields=['vocab_char'],
                 freeze_layers=['char_embeddings']
             ),
             'embeddings_name': None,
             'config_props': {
                 **default_args.get('config_props', {}),
                 'use_word_embeddings': False
             }
         })
     )
Ejemplo n.º 6
0
 def test_should_be_able_to_train_with_additional_token_feature_indices(
         self, default_args: dict, default_model_directory: str):
     train(
         **cast(TrainArgsDict, {
             **default_args,
             'config_props': {
                 **default_args.get('config_props', {}),
                 'max_char_length': 60,
                 'additional_token_feature_indices': [0]
             }
         })
     )
     model_config = load_model_config(default_model_directory)
     assert model_config.max_char_length == 60
     assert model_config.additional_token_feature_indices == [0]
     tag_input(
         model_name=default_args['model_name'],
         model_path=default_model_directory,
         input_paths=default_args['input_paths'],
         download_manager=default_args['download_manager'],
         embedding_registry_path=default_args['embedding_registry_path']
     )
Ejemplo n.º 7
0
 def test_should_be_able_to_train_CustomBidLSTM_CRF_FEATURES(
         self, default_args: DefaultArgsDict, default_model_directory: str):
     train_args = cast(TrainArgsDict, {
         **default_args,
         'architecture': 'CustomBidLSTM_CRF_FEATURES',
         'features_embedding_size': 4,
         'config_props': {
             'features_lstm_units': 4
         }
     })
     train(
         **train_args
     )
     model_config = load_model_config(default_model_directory)
     assert model_config.model_type == 'CustomBidLSTM_CRF_FEATURES'
     assert model_config.features_embedding_size == 4
     assert model_config.features_lstm_units == 4
     tag_input(
         model_name=default_args['model_name'],
         model_path=default_model_directory,
         input_paths=default_args['input_paths'],
         download_manager=default_args['download_manager'],
         embedding_registry_path=default_args['embedding_registry_path']
     )