def test_print_section(capsys, drop_end, add_ts): # error with pytest.raises(ValueError): with print_section("name", fill_char="*-", drop_end=drop_end, add_ts=add_ts): pass # clear cache _ = capsys.readouterr() fill_char = "@" with print_section("name", fill_char=fill_char, drop_end=drop_end, add_ts=add_ts): pass captured = capsys.readouterr() assert f"{fill_char}name" in captured.out
def main_replay(replay_file, force_perfect, overwrite, instant_death, target_wpm): """Run main curses loop with a replay. Parameters ---------- force_perfect : bool If True, then one cannot finish typing before all characters are typed without any mistakes. overwrite : bool If True, the replay file will be overwritten in case we are faster than it. replay_file : str or pathlib.Path Typed text in this file from some previous game. instant_death : bool If active, the first mistake will end the game. target_wpm : None or int The desired speed to be shown as guide. """ replay_file = pathlib.Path(replay_file) replay_tt = TypedText.load(replay_file) tt = curses.wrapper( run_loop, replay_tt.text, force_perfect=force_perfect, replay_tt=replay_tt, instant_death=instant_death, target_wpm=target_wpm, ) wpm_replay = replay_tt.compute_wpm() wpm_new = tt.compute_wpm() with print_section(" Statistics ", fill_char="=", add_ts=False): print(f"Old WPM: {wpm_replay:.1f}\nNew WPM: {wpm_new:.1f}") if wpm_new > wpm_replay: print("Congratulations!") if overwrite: print("Updating the checkpoint file") tt.save(replay_file) else: print("You lost!")
def main_basic(text, force_perfect, output_file, instant_death, target_wpm): """Run main curses loop with no previous replay. Parameters ---------- force_perfect : bool If True, then one cannot finish typing before all characters are typed without any mistakes. output_file : str or pathlib.Path or None If ``pathlib.Path`` then we store the typed text in this file. If None, no saving is taking place. instant_death : bool If active, the first mistake will end the game. target_wpm : int or None The desired speed to be displayed as a guide. """ text_stripped = text.rstrip() tt = curses.wrapper( run_loop, text_stripped, force_perfect=force_perfect, replay_tt=None, instant_death=instant_death, target_wpm=target_wpm, ) if output_file is not None: tt.save(pathlib.Path(output_file)) with print_section(" Statistics ", fill_char="=", add_ts=False): print(f"Accuracy: {tt.compute_accuracy():.1%}\n" f"WPM: {tt.compute_wpm():.1f}")
def train( path, model_name, checkpoint_path, extensions, fill_strategy, illegal_chars, gpus, batch_size, dense_size, hidden_size, max_epochs, early_stopping, n_layers, output_path, train_test_split, use_mlflow, vocab_size, window_size, ): # noqa: D400 """Train a language""" params = locals() from mltype.data import file2text, folder2text from mltype.ml import run_train from mltype.utils import print_section with print_section(" Parameters ", drop_end=True): pprint.pprint(params) all_texts = [] with print_section(" Reading file(s) ", drop_end=True): for p in path: path_p = pathlib.Path(str(p)) if not path_p.exists(): raise ValueError( "The provided path does not exist") # pragma: no cover if path_p.is_file(): texts = [file2text(path_p)] elif path_p.is_dir(): valid_extensions = (extensions.split(",") if extensions is not None else None) texts = folder2text(path_p, valid_extensions=valid_extensions) else: ValueError("Unrecognized object") # pragma: no cover all_texts.extend(texts) if not all_texts: raise ValueError("Did not manage to read any text") # pragma: no cover run_train( all_texts, model_name, max_epochs=max_epochs, window_size=window_size, batch_size=batch_size, vocab_size=vocab_size, fill_strategy=fill_strategy, illegal_chars=illegal_chars, train_test_split=train_test_split, hidden_size=hidden_size, dense_size=dense_size, n_layers=n_layers, use_mlflow=use_mlflow, early_stopping=early_stopping, gpus=gpus, checkpoint_path=checkpoint_path, output_path=output_path, ) print("Done")
def run_train( texts, name, max_epochs=10, window_size=50, batch_size=32, vocab_size=None, fill_strategy="skip", illegal_chars="", train_test_split=0.5, hidden_size=32, dense_size=32, n_layers=1, checkpoint_path=None, output_path=None, use_mlflow=True, early_stopping=True, gpus=None, ): """Run the training loop. Note that the parameters are also explained in the cli of `mlt train`. Parameters ---------- texts : list List of str representing all texts we would like to train on. name : str Name of the model. This name is only used when we save the model - it is not hardcoded anywhere in the serialization. max_epochs : int Maximum number of epochs. Note that the number of actual epochs can be lower if we activate the `early_stopping` flag. window_size : int Number of previous characters to consider when predicting the next character. The higher the number the longer the memory we are enforcing. Howerever, at the same time, the training becomes slower. batch_size : int Number of samples in one batch. vocab_size : int Maximum number of characters to be put in the vocabulary. Note that one can explicityly exclude characters via `illegal_chars`. The higher this number the bigger the feature vectors are and the slower the training. fill_strategy : str, {"zeros", "skip"} Determines how to deal with out of vocabulary characters. When "zeros" then we simply encode them as zero vectors. If "skip", we skip a given sample if any of the characters in the window or the predicted character are not in the vocabulary. illegal_chars : str or None If specified, then each character of the str represents a forbidden character that we do not put in the vocabulary. train_test_split : float Float in the range (0, 1) representing the percentage of the training set with respect to the entire dataset. hidden_size : int Hidden size of LSTM cells (equal in all layers). dense_size : int Size of the dense layer that is bridging the hidden state outputted by the LSTM and the final output probabilities over the vocabulary. n_layers : int Number of layers inside of the LSTM. checkpoint_path : None or pathlib.Path or str If specified, it is pointing to a checkpoint file (generated by Pytorch-lightning). This file does not contain the vocabulary. It can be used to continue the training. output_path : None or pathlib.Path or str If specified, it is an alternative output folder when the trained models and logging information will be stored. If not specified the output folder is by default set to `~/.mltype`. use_mlflow : bool If active, than we use mlflow for logging of training and validation loss. Additionally, at the end of each epoch we generate a few sample texts to demonstrate how good/bad the current network is. early_stopping : bool If True, then we monitor the validation loss and if it does not improve for a certain number of epochs then we stop the traning. gpus : int or None If None or 0, no GPUs are used (only CPUs). Otherwise, it represents the number of GPUs to be used (using the data parallelization strategy). """ illegal_chars = illegal_chars or "" cache_dir = get_cache_dir(output_path) languages_path = cache_dir / "languages" / name checkpoints_path = cache_dir / "checkpoints" / name if languages_path.exists(): raise FileExistsError(f"The model {name} already exists") with print_section(" Computing vocabulary ", drop_end=True): vocabulary = sorted([ x[0] for x in Counter("".join(texts)).most_common() if x[0] not in illegal_chars ][:vocab_size]) # works for None vocab_size = len(vocabulary) print(f"# characters: {vocab_size}") print(vocabulary) with print_section(" Creating training set ", drop_end=True): X_list = [] y_list = [] for text in tqdm.tqdm(texts): X_, y_, _ = create_data_language( text, vocabulary, window_size=window_size, verbose=False, fill_strategy=fill_strategy, ) X_list.append(X_) y_list.append(y_) X = np.concatenate(X_list, axis=0) if len(X_list) != 1 else X_list[0] y = np.concatenate(y_list, axis=0) if len(y_list) != 1 else y_list[0] print(f"X.dtype={X.dtype}, y.dtype={y.dtype}") split_ix = int(len(X) * train_test_split) indices = np.random.permutation(len(X)) train_indices = indices[:split_ix] val_indices = indices[split_ix:] print(f"Train: {len(train_indices)}\nValidation: {len(val_indices)}") dataset = LanguageDataset(X, y, vocabulary=vocabulary) dataloader_t = torch.utils.data.DataLoader( dataset, batch_size=batch_size, sampler=torch.utils.data.SubsetRandomSampler(train_indices), ) dataloader_v = torch.utils.data.DataLoader( dataset, batch_size=batch_size, sampler=torch.utils.data.SubsetRandomSampler(val_indices), ) if checkpoint_path is None: network = SingleCharacterLSTM( vocab_size, hidden_size=hidden_size, dense_size=dense_size, n_layers=n_layers, ) else: print(f"Loading a checkpointed network: {checkpoint_path}") network = SingleCharacterLSTM.load_from_checkpoint( str(checkpoint_path)) chp_name_template = str(checkpoints_path / "{epoch}-{val_loss:.3f}") chp_callback = pl.callbacks.ModelCheckpoint( filepath=chp_name_template, save_last=True, # last epoch always there save_top_k=1, verbose=True, monitor="val_loss", mode="min", save_weights_only=False, ) callbacks = [] if use_mlflow: print("Logging with MLflow") logger = pl.loggers.MLFlowLogger("mltype", save_dir=get_cache_dir(output_path) / "logs" / "mlruns") print(f"Run ID: {logger.run_id}") logger.log_hyperparams({ "fill_strategy": fill_strategy, "model_name": name, "train_test_split": train_test_split, "vocab_size": vocab_size, "window_size": window_size, }) else: logger = None if early_stopping: print("Activating early stopping") callbacks.append( pl.callbacks.EarlyStopping(monitor="val_loss", verbose=True)) with print_section(" Training ", drop_end=True): trainer = pl.Trainer( gpus=gpus, max_epochs=max_epochs, logger=logger, callbacks=callbacks, checkpoint_callback=chp_callback, ) trainer.fit(network, dataloader_t, dataloader_v) with print_section(" Saving the model ", drop_end=False): if chp_callback.best_model_path: print(f"Using the checkpoint {chp_callback.best_model_path}") network = SingleCharacterLSTM.load_from_checkpoint( chp_callback.best_model_path) else: print("No checkpoint found, using the current network") print(f"The final model is saved to: {languages_path}") save_model(network, vocabulary, languages_path)