def test_train_can_resume_training(tmpdir, caplog): syms, img_dirs, data_module = prepare_data(tmpdir) caplog.set_level("INFO") args = [ syms, img_dirs, data_module.root / "tr.gt", data_module.root / "va.gt", ] kwargs = { "common": CommonArgs(train_path=tmpdir), "data": DataArgs(batch_size=3), "optimizer": OptimizerArgs(name="SGD"), "train": TrainArgs(augment_training=True), "trainer": TrainerArgs(progress_bar_refresh_rate=0, weights_summary=None, max_epochs=1), } # run to have a checkpoint script.run(*args, **kwargs) assert "Model has been trained for 1 epochs (11 steps)" in caplog.messages caplog.clear() # train for one more epoch kwargs["train"] = TrainArgs(resume=1) script.run(*args, **kwargs) assert "Model has been trained for 2 epochs (21 steps)" in caplog.messages
def test_train_can_overfit_one_image(tmpdir, caplog): syms, img_dirs, data_module = prepare_data(tmpdir) # manually select a specific image txt_file = data_module.root / "tr.gt" line = "tr-6 9 2 0 1" assert txt_file.read_text().splitlines()[6] == line txt_file.write_text(line) caplog.set_level("INFO") script.run( syms, img_dirs, txt_file, txt_file, common=CommonArgs(train_path=tmpdir, seed=0x12345, experiment_dirname="", monitor="va_loss"), data=DataArgs(batch_size=1), # after some manual runs, this lr seems to be the # fastest one to reliably learn for this toy example. # RMSProp performed considerably better than Adam|SGD optimizer=OptimizerArgs(learning_rate=0.01, name="RMSProp"), train=TrainArgs( checkpoint_k=0, # disable checkpoints early_stopping_patience=100, # disable early stopping ), trainer=TrainerArgs( weights_summary=None, overfit_batches=1, max_epochs=70, check_val_every_n_epoch=100, # disable validation ), ) assert sum("cer=0.0%" in m and "wer=0.0%" in m for m in caplog.messages)
def test_raises(tmpdir): with pytest.raises(AssertionError, match="Could not find the model"): script.run("", [], "", "") syms, img_dirs, data_module = prepare_data(tmpdir) with pytest.raises(AssertionError, match='The delimiter "TEST" is not available'): script.run( syms, [], "", "", common=CommonArgs(train_path=tmpdir), train=TrainArgs(delimiters=["<space>", "TEST"]), )
def get_args(argv: Optional[List[str]] = None) -> Dict[str, Any]: parser = jsonargparse.ArgumentParser(parse_as_dict=True) parser.add_argument("--config", action=jsonargparse.ActionConfigFile, help="Configuration file") parser.add_argument( "syms", type=str, help=("Mapping from strings to integers. " "The CTC symbol must be mapped to integer 0"), ) parser.add_argument( "img_dirs", type=List[str], default=[], help="Directories containing segmented line images", ) parser.add_argument( "tr_txt_table", type=str, help="Character transcription of each training image", ) parser.add_argument( "va_txt_table", type=str, help="Character transcription of each validation image", ) parser.add_class_arguments(CommonArgs, "common") parser.add_class_arguments(DataArgs, "data") parser.add_class_arguments(TrainArgs, "train") parser.add_function_arguments(log.config, "logging") parser.add_class_arguments(OptimizerArgs, "optimizer") parser.add_class_arguments(SchedulerArgs, "scheduler") parser.add_class_arguments(TrainerArgs, "trainer") args = parser.parse_args(argv, with_meta=False) args["common"] = CommonArgs(**args["common"]) args["train"] = TrainArgs(**args["train"]) args["data"] = DataArgs(**args["data"]) args["optimizer"] = OptimizerArgs(**args["optimizer"]) args["scheduler"] = SchedulerArgs(**args["scheduler"]) args["trainer"] = TrainerArgs(**args["trainer"]) return args
def test_train_early_stops(tmpdir, caplog): syms, img_dirs, data_module = prepare_data(tmpdir) caplog.set_level("INFO") script.run( syms, img_dirs, data_module.root / "tr.gt", data_module.root / "va.gt", common=CommonArgs(train_path=tmpdir), data=DataArgs(batch_size=3), train=TrainArgs(early_stopping_patience=2), trainer=TrainerArgs(progress_bar_refresh_rate=0, weights_summary=None, max_epochs=5), ) assert (sum( m.startswith( "Early stopping triggered after epoch 3 (waited for 2 epochs)") for m in caplog.messages) == 1)
def run( syms: str, img_dirs: List[str], tr_txt_table: str, va_txt_table: str, common: CommonArgs = CommonArgs(), train: TrainArgs = TrainArgs(), optimizer: OptimizerArgs = OptimizerArgs(), scheduler: SchedulerArgs = SchedulerArgs(), data: DataArgs = DataArgs(), trainer: TrainerArgs = TrainerArgs(), ): pl.seed_everything(common.seed) loader = ModelLoader(common.train_path, filename=common.model_filename, device="cpu") # maybe load a checkpoint checkpoint = None if train.resume: checkpoint = loader.prepare_checkpoint(common.checkpoint, common.experiment_dirpath, common.monitor) trainer.max_epochs = torch.load(checkpoint)["epoch"] + train.resume log.info(f'Using checkpoint "{checkpoint}"') log.info(f"Max epochs set to {trainer.max_epochs}") # load the non-pytorch_lightning model model = loader.load() assert ( model is not None ), "Could not find the model. Have you run pylaia-htr-create-model?" # prepare the symbols syms = SymbolsTable(syms) for d in train.delimiters: assert d in syms, f'The delimiter "{d}" is not available in the symbols file' # prepare the engine engine_module = HTREngineModule( model, [syms[d] for d in train.delimiters], optimizer=optimizer, scheduler=scheduler, batch_input_fn=Compose([ItemFeeder("img"), ImageFeeder()]), batch_target_fn=ItemFeeder("txt"), batch_id_fn=ItemFeeder("id"), # Used to print image ids on exception ) # prepare the data data_module = DataModule( syms=syms, img_dirs=img_dirs, tr_txt_table=tr_txt_table, va_txt_table=va_txt_table, batch_size=data.batch_size, color_mode=data.color_mode, shuffle_tr=not bool(trainer.limit_train_batches), augment_tr=train.augment_training, stage="fit", ) # prepare the training callbacks # TODO: save on lowest_va_wer and every k epochs https://github.com/PyTorchLightning/pytorch-lightning/issues/2908 checkpoint_callback = pl.callbacks.ModelCheckpoint( dirpath=common.experiment_dirpath, filename="{epoch}-lowest_" + common.monitor, monitor=common.monitor, verbose=True, save_top_k=train.checkpoint_k, mode="min", save_last=True, ) checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last" early_stopping_callback = pl.callbacks.EarlyStopping( monitor=common.monitor, patience=train.early_stopping_patience, verbose=True, mode="min", strict=False, # training_step may return None ) callbacks = [ ProgressBar(refresh_rate=trainer.progress_bar_refresh_rate), checkpoint_callback, early_stopping_callback, checkpoint_callback, ] if train.gpu_stats: callbacks.append(ProgressBarGPUStats()) if scheduler.active: callbacks.append(LearningRate(logging_interval="epoch")) # prepare the trainer trainer = pl.Trainer( default_root_dir=common.train_path, resume_from_checkpoint=checkpoint, callbacks=callbacks, logger=EpochCSVLogger(common.experiment_dirpath), checkpoint_callback=True, **vars(trainer), ) # train! trainer.fit(engine_module, datamodule=data_module) # training is over if early_stopping_callback.stopped_epoch: log.info( "Early stopping triggered after epoch" f" {early_stopping_callback.stopped_epoch + 1} (waited for" f" {early_stopping_callback.wait_count} epochs). The best score was" f" {early_stopping_callback.best_score}") log.info(f"Model has been trained for {trainer.current_epoch + 1} epochs" f" ({trainer.global_step + 1} steps)") log.info( f"Best {checkpoint_callback.monitor}={checkpoint_callback.best_model_score} " f"obtained with model={checkpoint_callback.best_model_path}")
# 1e-3 was the best learning rate with batch size 128 k = batch_size / 128 learning_rate = 1e-3 * sqrt(k) train( syms, [str(data_module.root / p) for p in ("tr", "va")], *[str(data_module.root / f"{p}.gt") for p in ("tr", "va")], common=CommonArgs( train_path=train_path, seed=seed, experiment_dirname="", ), data=DataArgs(batch_size=batch_size), optimizer=OptimizerArgs(learning_rate=learning_rate), train=TrainArgs( # disable checkpointing checkpoint_k=0, # disable early stopping early_stopping_patience=epochs, gpu_stats=True, ), trainer=TrainerArgs( max_epochs=epochs, weights_summary=None, gpus=1, # training is still not deterministic on GPU deterministic=True, ), )