def test_checkpoint(self, tmpdir, capsys): tmpdir_ = pathlib.Path(str(tmpdir)) window_size = 1 texts = ["abcd", "yxz"] name = "test_model" run_train( texts, name, max_epochs=1, output_path=tmpdir_, window_size=window_size, ) chp_message = "Loading a checkpointed network" captured = capsys.readouterr() assert chp_message not in captured.out checkpoints_dir = tmpdir_ / "checkpoints" / name run_train( texts, name + "_cont", max_epochs=0, output_path=tmpdir_, checkpoint_path=checkpoints_dir / "last.ckpt", window_size=window_size, ) captured = capsys.readouterr() assert chp_message in captured.out
def test_error(self, tmpdir): tmpdir_ = pathlib.Path(str(tmpdir)) model_path = tmpdir_ / "languages" / "a" model_path.parent.mkdir(parents=True) model_path.touch() with pytest.raises(FileExistsError): run_train(["some text"], "a", output_path=tmpdir_)
def test_zero_epochs(self, tmpdir, capsys): tmpdir_ = pathlib.Path(str(tmpdir)) window_size = 1 texts = ["abcd", "yxz"] name = "test_model" run_train( texts, name, max_epochs=0, output_path=tmpdir_, window_size=window_size, ) captured = capsys.readouterr() assert "No checkpoint found" in captured.out checkpoints_dir = tmpdir_ / "checkpoints" / name assert not checkpoints_dir.exists()
def test_overall( self, monkeypatch, capsys, tmpdir, illegal_chars, use_mlflow, early_stopping, ): tmpdir_ = pathlib.Path(str(tmpdir)) window_size = 1 texts = ["abcd", "yxz"] name = "test_model" run_train( texts, name, early_stopping=early_stopping, illegal_chars=illegal_chars, max_epochs=2, output_path=tmpdir_, use_mlflow=use_mlflow, window_size=window_size, ) captured = capsys.readouterr() assert "Using the checkpoint " in captured.out checkpoints_dir = tmpdir_ / "checkpoints" / name assert checkpoints_dir.exists() checkpoints = set([x.name for x in checkpoints_dir.iterdir()]) assert len(checkpoints) == 2 # best and last assert "last.ckpt" in checkpoints assert (tmpdir_ / "languages" / name).exists() assert (not use_mlflow) ^ (tmpdir_ / "logs" / "mlruns").exists()
def train( path, model_name, checkpoint_path, extensions, fill_strategy, illegal_chars, gpus, batch_size, dense_size, hidden_size, max_epochs, early_stopping, n_layers, output_path, train_test_split, use_mlflow, vocab_size, window_size, ): # noqa: D400 """Train a language""" params = locals() from mltype.data import file2text, folder2text from mltype.ml import run_train from mltype.utils import print_section with print_section(" Parameters ", drop_end=True): pprint.pprint(params) all_texts = [] with print_section(" Reading file(s) ", drop_end=True): for p in path: path_p = pathlib.Path(str(p)) if not path_p.exists(): raise ValueError( "The provided path does not exist") # pragma: no cover if path_p.is_file(): texts = [file2text(path_p)] elif path_p.is_dir(): valid_extensions = (extensions.split(",") if extensions is not None else None) texts = folder2text(path_p, valid_extensions=valid_extensions) else: ValueError("Unrecognized object") # pragma: no cover all_texts.extend(texts) if not all_texts: raise ValueError("Did not manage to read any text") # pragma: no cover run_train( all_texts, model_name, max_epochs=max_epochs, window_size=window_size, batch_size=batch_size, vocab_size=vocab_size, fill_strategy=fill_strategy, illegal_chars=illegal_chars, train_test_split=train_test_split, hidden_size=hidden_size, dense_size=dense_size, n_layers=n_layers, use_mlflow=use_mlflow, early_stopping=early_stopping, gpus=gpus, checkpoint_path=checkpoint_path, output_path=output_path, ) print("Done")