Esempio n. 1
0
def test_print_section(capsys, drop_end, add_ts):
    # error
    with pytest.raises(ValueError):
        with print_section("name",
                           fill_char="*-",
                           drop_end=drop_end,
                           add_ts=add_ts):
            pass

    # clear cache
    _ = capsys.readouterr()

    fill_char = "@"
    with print_section("name",
                       fill_char=fill_char,
                       drop_end=drop_end,
                       add_ts=add_ts):
        pass

    captured = capsys.readouterr()
    assert f"{fill_char}name" in captured.out
Esempio n. 2
0
def main_replay(replay_file, force_perfect, overwrite, instant_death,
                target_wpm):
    """Run main curses loop with a replay.

    Parameters
    ----------
    force_perfect : bool
        If True, then one cannot finish typing before all characters
        are typed without any mistakes.


    overwrite : bool
        If True, the replay file will be overwritten in case
        we are faster than it.

    replay_file : str or pathlib.Path
        Typed text in this file from some previous game.

    instant_death : bool
        If active, the first mistake will end the game.

    target_wpm : None or int
        The desired speed to be shown as guide.
    """
    replay_file = pathlib.Path(replay_file)
    replay_tt = TypedText.load(replay_file)

    tt = curses.wrapper(
        run_loop,
        replay_tt.text,
        force_perfect=force_perfect,
        replay_tt=replay_tt,
        instant_death=instant_death,
        target_wpm=target_wpm,
    )

    wpm_replay = replay_tt.compute_wpm()
    wpm_new = tt.compute_wpm()

    with print_section(" Statistics ", fill_char="=", add_ts=False):
        print(f"Old WPM: {wpm_replay:.1f}\nNew WPM: {wpm_new:.1f}")

        if wpm_new > wpm_replay:
            print("Congratulations!")
            if overwrite:
                print("Updating the checkpoint file")
                tt.save(replay_file)
        else:
            print("You lost!")
Esempio n. 3
0
def main_basic(text, force_perfect, output_file, instant_death, target_wpm):
    """Run main curses loop with no previous replay.

    Parameters
    ----------
    force_perfect : bool
        If True, then one cannot finish typing before all characters are
        typed without any mistakes.

    output_file : str or pathlib.Path or None
        If ``pathlib.Path`` then we store the typed text in this file.
        If None, no saving is taking place.

    instant_death : bool
        If active, the first mistake will end the game.

    target_wpm : int or None
        The desired speed to be displayed as a guide.
    """
    text_stripped = text.rstrip()

    tt = curses.wrapper(
        run_loop,
        text_stripped,
        force_perfect=force_perfect,
        replay_tt=None,
        instant_death=instant_death,
        target_wpm=target_wpm,
    )

    if output_file is not None:
        tt.save(pathlib.Path(output_file))

    with print_section(" Statistics ", fill_char="=", add_ts=False):
        print(f"Accuracy: {tt.compute_accuracy():.1%}\n"
              f"WPM: {tt.compute_wpm():.1f}")
Esempio n. 4
0
def train(
    path,
    model_name,
    checkpoint_path,
    extensions,
    fill_strategy,
    illegal_chars,
    gpus,
    batch_size,
    dense_size,
    hidden_size,
    max_epochs,
    early_stopping,
    n_layers,
    output_path,
    train_test_split,
    use_mlflow,
    vocab_size,
    window_size,
):  # noqa: D400
    """Train a language"""
    params = locals()
    from mltype.data import file2text, folder2text
    from mltype.ml import run_train
    from mltype.utils import print_section

    with print_section(" Parameters ", drop_end=True):
        pprint.pprint(params)

    all_texts = []
    with print_section(" Reading file(s) ", drop_end=True):
        for p in path:
            path_p = pathlib.Path(str(p))

            if not path_p.exists():
                raise ValueError(
                    "The provided path does not exist")  # pragma: no cover

            if path_p.is_file():
                texts = [file2text(path_p)]
            elif path_p.is_dir():
                valid_extensions = (extensions.split(",")
                                    if extensions is not None else None)
                texts = folder2text(path_p, valid_extensions=valid_extensions)
            else:
                ValueError("Unrecognized object")  # pragma: no cover

            all_texts.extend(texts)

    if not all_texts:
        raise ValueError("Did not manage to read any text")  # pragma: no cover

    run_train(
        all_texts,
        model_name,
        max_epochs=max_epochs,
        window_size=window_size,
        batch_size=batch_size,
        vocab_size=vocab_size,
        fill_strategy=fill_strategy,
        illegal_chars=illegal_chars,
        train_test_split=train_test_split,
        hidden_size=hidden_size,
        dense_size=dense_size,
        n_layers=n_layers,
        use_mlflow=use_mlflow,
        early_stopping=early_stopping,
        gpus=gpus,
        checkpoint_path=checkpoint_path,
        output_path=output_path,
    )
    print("Done")
Esempio n. 5
0
def run_train(
    texts,
    name,
    max_epochs=10,
    window_size=50,
    batch_size=32,
    vocab_size=None,
    fill_strategy="skip",
    illegal_chars="",
    train_test_split=0.5,
    hidden_size=32,
    dense_size=32,
    n_layers=1,
    checkpoint_path=None,
    output_path=None,
    use_mlflow=True,
    early_stopping=True,
    gpus=None,
):
    """Run the training loop.

    Note that the parameters are also explained in the cli of `mlt train`.

    Parameters
    ----------
    texts : list
        List of str representing all texts we would like to train on.

    name : str
        Name of the model. This name is only used when we save the model -
        it is not hardcoded anywhere in the serialization.

    max_epochs : int
        Maximum number of epochs. Note that the number of actual epochs
        can be lower if we activate the `early_stopping` flag.

    window_size : int
        Number of previous characters to consider when predicting the next
        character. The higher the number the longer the memory we are
        enforcing. Howerever, at the same time, the training becomes slower.

    batch_size : int
        Number of samples in one batch.

    vocab_size : int
        Maximum number of characters to be put in the vocabulary. Note that
        one can explicityly exclude characters via `illegal_chars`. The higher
        this number the bigger the feature vectors are and the slower the
        training.

    fill_strategy : str, {"zeros", "skip"}
        Determines how to deal with out of vocabulary characters. When
        "zeros" then we simply encode them as zero vectors. If "skip", we
        skip a given sample if any of the characters in the window or the
        predicted character are not in the vocabulary.

    illegal_chars : str or None
        If specified, then each character of the str represents a forbidden
        character that we do not put in the vocabulary.

    train_test_split : float
        Float in the range (0, 1) representing the percentage of the training
        set with respect to the entire dataset.

    hidden_size : int
        Hidden size of LSTM cells (equal in all layers).

    dense_size : int
        Size of the dense layer that is bridging the hidden state outputted
        by the LSTM and the final output probabilities over the vocabulary.

    n_layers : int
        Number of layers inside of the LSTM.

    checkpoint_path : None or pathlib.Path or str
        If specified, it is pointing to a checkpoint file (generated
        by Pytorch-lightning). This file does not contain the vocabulary.
        It can be used to continue the training.

    output_path : None or pathlib.Path or str
        If specified, it is an alternative output folder when the trained
        models and logging information will be stored. If not specified
        the output folder is by default set to `~/.mltype`.

    use_mlflow : bool
        If active, than we use mlflow for logging of training and validation
        loss. Additionally, at the end of each epoch we generate a few
        sample texts to demonstrate how good/bad the current network is.

    early_stopping : bool
        If True, then we monitor the validation loss and if it does not
        improve for a certain number of epochs then we stop the traning.

    gpus : int or None
        If None or 0, no GPUs are used (only CPUs). Otherwise, it represents
        the number of GPUs to be used (using the data parallelization
        strategy).
    """
    illegal_chars = illegal_chars or ""

    cache_dir = get_cache_dir(output_path)
    languages_path = cache_dir / "languages" / name
    checkpoints_path = cache_dir / "checkpoints" / name

    if languages_path.exists():
        raise FileExistsError(f"The model {name} already exists")

    with print_section(" Computing vocabulary ", drop_end=True):
        vocabulary = sorted([
            x[0] for x in Counter("".join(texts)).most_common()
            if x[0] not in illegal_chars
        ][:vocab_size])  # works for None
        vocab_size = len(vocabulary)
        print(f"# characters: {vocab_size}")
        print(vocabulary)

    with print_section(" Creating training set ", drop_end=True):
        X_list = []
        y_list = []
        for text in tqdm.tqdm(texts):
            X_, y_, _ = create_data_language(
                text,
                vocabulary,
                window_size=window_size,
                verbose=False,
                fill_strategy=fill_strategy,
            )
            X_list.append(X_)
            y_list.append(y_)
        X = np.concatenate(X_list, axis=0) if len(X_list) != 1 else X_list[0]
        y = np.concatenate(y_list, axis=0) if len(y_list) != 1 else y_list[0]

        print(f"X.dtype={X.dtype}, y.dtype={y.dtype}")

        split_ix = int(len(X) * train_test_split)
        indices = np.random.permutation(len(X))
        train_indices = indices[:split_ix]
        val_indices = indices[split_ix:]
        print(f"Train: {len(train_indices)}\nValidation: {len(val_indices)}")

    dataset = LanguageDataset(X, y, vocabulary=vocabulary)

    dataloader_t = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=torch.utils.data.SubsetRandomSampler(train_indices),
    )

    dataloader_v = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=torch.utils.data.SubsetRandomSampler(val_indices),
    )

    if checkpoint_path is None:
        network = SingleCharacterLSTM(
            vocab_size,
            hidden_size=hidden_size,
            dense_size=dense_size,
            n_layers=n_layers,
        )
    else:
        print(f"Loading a checkpointed network: {checkpoint_path}")
        network = SingleCharacterLSTM.load_from_checkpoint(
            str(checkpoint_path))

    chp_name_template = str(checkpoints_path / "{epoch}-{val_loss:.3f}")
    chp_callback = pl.callbacks.ModelCheckpoint(
        filepath=chp_name_template,
        save_last=True,  # last epoch always there
        save_top_k=1,
        verbose=True,
        monitor="val_loss",
        mode="min",
        save_weights_only=False,
    )
    callbacks = []

    if use_mlflow:
        print("Logging with MLflow")
        logger = pl.loggers.MLFlowLogger("mltype",
                                         save_dir=get_cache_dir(output_path) /
                                         "logs" / "mlruns")
        print(f"Run ID: {logger.run_id}")

        logger.log_hyperparams({
            "fill_strategy": fill_strategy,
            "model_name": name,
            "train_test_split": train_test_split,
            "vocab_size": vocab_size,
            "window_size": window_size,
        })
    else:
        logger = None

    if early_stopping:
        print("Activating early stopping")
        callbacks.append(
            pl.callbacks.EarlyStopping(monitor="val_loss", verbose=True))

    with print_section(" Training ", drop_end=True):
        trainer = pl.Trainer(
            gpus=gpus,
            max_epochs=max_epochs,
            logger=logger,
            callbacks=callbacks,
            checkpoint_callback=chp_callback,
        )
        trainer.fit(network, dataloader_t, dataloader_v)

    with print_section(" Saving the model ", drop_end=False):
        if chp_callback.best_model_path:
            print(f"Using the checkpoint {chp_callback.best_model_path}")
            network = SingleCharacterLSTM.load_from_checkpoint(
                chp_callback.best_model_path)
        else:
            print("No checkpoint found, using the current network")

        print(f"The final model is saved to: {languages_path}")
        save_model(network, vocabulary, languages_path)