def train_func(config, train, dev=None, test=None, load=False, verbose=True, first=1, each=1, eval_file=None, no_shell=False, **kwargs): """ Main CLI Interface (training) :param config: Path to retrieve configuration file :type config: str :param train: Path to directory containing dev files :type train: str :param dev: Path to directory containing test files :type dev: str :param test: Path to directory containing train files :type test: str :param embed: Path to directory containing files for embeddings :type embed: str :param load: Whether to load an existing model to train on top of it (default: False) :type load: bool :param nb_epochs: Number of epoch :type nb_epochs: int :param verbose: (Overwrite the next few) Print only the first and last if False :param first: Evaluate first N epochs :param each: Evaluate each Nth epoch :param eval_file: Store evaluation into a file :param no_shell: Do not print to shell :param kwargs: Other arguments :type kwargs: dict :return: """ tagger = Tagger.setup_from_disk(config, train, dev, test, verbose=True, load=load, **kwargs) nb_epochs = tagger.nb_epochs # Set up the logger logger_params = dict( shell=not no_shell, file=eval_file, first=first, nb_epochs=nb_epochs, each=each ) if verbose is False: # Print each total number of epoch + 1 to not print any logger_params = dict(shell=True, file=logger_params["file"], first=1, nb_epochs=nb_epochs, each=nb_epochs+1) tagger.logger = Logger(**logger_params) for i in range(nb_epochs): tagger.epoch(autosave=True, eval_test=tagger.include_test) tagger.save() print('::: ended :::')
def test_load_after_save(self): """ Ensure param are correctly saved """ tagger = Tagger.setup_from_disk( config_path="./tests/test_configs/config_chrestien.txt", train_data=TRAIN, dev_data=DEV, test_data=TEST ) tagger.include_pos = False tagger.curr_nb_epochs = 10 tagger.save_params() self.assertEqual(tagger.pretrainer.nb_workers, 1, "Pretrainer Workers should be correctly loaded") del tagger tagger = Tagger(config_path="./fake_model/config.txt") self.assertEqual(tagger.nb_encoding_layers, 2, "nb_encoding_layers should be correctly loaded") self.assertEqual(tagger.nb_epochs, 3, "nb_epochs should be correctly loaded") self.assertEqual(tagger.nb_dense_dims, 1000, "nb_dense_dims should be correctly loaded") self.assertEqual(tagger.batch_size, 100, "batch_size should be correctly loaded") self.assertEqual(tagger.nb_left_tokens, 2, "nb_left_tokens should be correctly loaded") self.assertEqual(tagger.nb_right_tokens, 1, "nb_right_tokens should be correctly loaded") self.assertEqual(tagger.nb_context_tokens, 3, "nb_context_tokens should be correctly computed") self.assertEqual(tagger.nb_embedding_dims, 100, "nb_embedding_dims should be correctly loaded") self.assertEqual(tagger.model_dir, "fake_model", "model_dir should be correctly loaded") self.assertEqual(tagger.postcorrect, False, "postcorrect should be correctly loaded") self.assertEqual(tagger.nb_filters, 100, "nb_filters should be correctly loaded") self.assertEqual(tagger.filter_length, 3, "filter_length should be correctly loaded") self.assertEqual(tagger.focus_repr, "convolutions", "focus_repr should be correctly loaded") self.assertEqual(tagger.dropout_level, 0.15, "dropout_level should be correctly loaded") self.assertEqual(tagger.include_token, True, "include_token should be correctly loaded") self.assertEqual(tagger.include_context, True, "include_context should be correctly loaded") self.assertEqual(tagger.include_lemma, "label", "include_lemma should be correctly loaded") self.assertEqual(tagger.include_pos, False, "include_pos should be correctly loaded") self.assertEqual(tagger.include_morph, False, "include_morph should be correctly loaded") self.assertEqual(tagger.include_dev, True, "include_dev should be correctly loaded") self.assertEqual(tagger.include_test, True, "include_test should be correctly loaded") self.assertEqual(tagger.min_token_freq_emb, 5, "min_token_freq_emb should be correctly loaded") self.assertEqual(tagger.halve_lr_at, 75, "halve_lr_at should be correctly loaded") self.assertEqual(tagger.max_token_len, 20, "max_token_len should be correctly loaded") self.assertEqual(tagger.min_lem_cnt, 1, "min_lem_cnt should be correctly loaded") self.assertEqual(tagger.curr_nb_epochs, 10, "Current number of epochs should be correctly loaded") self.assertEqual(tagger.model, "PyTorch", "PyTorch implementation is loaded") tagger = Tagger(config_path="./fake_model/config.txt", load=True) self.assertIsInstance(tagger.model, MODELS["PyTorch"], "PyTorch implementation is loaded")