Exemplo n.º 1
0
    def test_checkpoint(self, tmpdir, capsys):
        tmpdir_ = pathlib.Path(str(tmpdir))
        window_size = 1
        texts = ["abcd", "yxz"]
        name = "test_model"

        run_train(
            texts,
            name,
            max_epochs=1,
            output_path=tmpdir_,
            window_size=window_size,
        )

        chp_message = "Loading a checkpointed network"
        captured = capsys.readouterr()

        assert chp_message not in captured.out

        checkpoints_dir = tmpdir_ / "checkpoints" / name

        run_train(
            texts,
            name + "_cont",
            max_epochs=0,
            output_path=tmpdir_,
            checkpoint_path=checkpoints_dir / "last.ckpt",
            window_size=window_size,
        )

        captured = capsys.readouterr()
        assert chp_message in captured.out
Exemplo n.º 2
0
    def test_error(self, tmpdir):
        tmpdir_ = pathlib.Path(str(tmpdir))
        model_path = tmpdir_ / "languages" / "a"

        model_path.parent.mkdir(parents=True)
        model_path.touch()

        with pytest.raises(FileExistsError):
            run_train(["some text"], "a", output_path=tmpdir_)
Exemplo n.º 3
0
    def test_zero_epochs(self, tmpdir, capsys):
        tmpdir_ = pathlib.Path(str(tmpdir))
        window_size = 1
        texts = ["abcd", "yxz"]
        name = "test_model"

        run_train(
            texts,
            name,
            max_epochs=0,
            output_path=tmpdir_,
            window_size=window_size,
        )

        captured = capsys.readouterr()
        assert "No checkpoint found" in captured.out
        checkpoints_dir = tmpdir_ / "checkpoints" / name
        assert not checkpoints_dir.exists()
Exemplo n.º 4
0
    def test_overall(
        self,
        monkeypatch,
        capsys,
        tmpdir,
        illegal_chars,
        use_mlflow,
        early_stopping,
    ):
        tmpdir_ = pathlib.Path(str(tmpdir))
        window_size = 1
        texts = ["abcd", "yxz"]
        name = "test_model"

        run_train(
            texts,
            name,
            early_stopping=early_stopping,
            illegal_chars=illegal_chars,
            max_epochs=2,
            output_path=tmpdir_,
            use_mlflow=use_mlflow,
            window_size=window_size,
        )

        captured = capsys.readouterr()
        assert "Using the checkpoint " in captured.out

        checkpoints_dir = tmpdir_ / "checkpoints" / name
        assert checkpoints_dir.exists()
        checkpoints = set([x.name for x in checkpoints_dir.iterdir()])
        assert len(checkpoints) == 2  # best and last
        assert "last.ckpt" in checkpoints

        assert (tmpdir_ / "languages" / name).exists()
        assert (not use_mlflow) ^ (tmpdir_ / "logs" / "mlruns").exists()
Exemplo n.º 5
0
def train(
    path,
    model_name,
    checkpoint_path,
    extensions,
    fill_strategy,
    illegal_chars,
    gpus,
    batch_size,
    dense_size,
    hidden_size,
    max_epochs,
    early_stopping,
    n_layers,
    output_path,
    train_test_split,
    use_mlflow,
    vocab_size,
    window_size,
):  # noqa: D400
    """Train a language"""
    params = locals()
    from mltype.data import file2text, folder2text
    from mltype.ml import run_train
    from mltype.utils import print_section

    with print_section(" Parameters ", drop_end=True):
        pprint.pprint(params)

    all_texts = []
    with print_section(" Reading file(s) ", drop_end=True):
        for p in path:
            path_p = pathlib.Path(str(p))

            if not path_p.exists():
                raise ValueError(
                    "The provided path does not exist")  # pragma: no cover

            if path_p.is_file():
                texts = [file2text(path_p)]
            elif path_p.is_dir():
                valid_extensions = (extensions.split(",")
                                    if extensions is not None else None)
                texts = folder2text(path_p, valid_extensions=valid_extensions)
            else:
                ValueError("Unrecognized object")  # pragma: no cover

            all_texts.extend(texts)

    if not all_texts:
        raise ValueError("Did not manage to read any text")  # pragma: no cover

    run_train(
        all_texts,
        model_name,
        max_epochs=max_epochs,
        window_size=window_size,
        batch_size=batch_size,
        vocab_size=vocab_size,
        fill_strategy=fill_strategy,
        illegal_chars=illegal_chars,
        train_test_split=train_test_split,
        hidden_size=hidden_size,
        dense_size=dense_size,
        n_layers=n_layers,
        use_mlflow=use_mlflow,
        early_stopping=early_stopping,
        gpus=gpus,
        checkpoint_path=checkpoint_path,
        output_path=output_path,
    )
    print("Done")