def test_train_can_overfit_one_image(tmpdir, caplog): syms, img_dirs, data_module = prepare_data(tmpdir) # manually select a specific image txt_file = data_module.root / "tr.gt" line = "tr-6 9 2 0 1" assert txt_file.read_text().splitlines()[6] == line txt_file.write_text(line) caplog.set_level("INFO") script.run( syms, img_dirs, txt_file, txt_file, common=CommonArgs(train_path=tmpdir, seed=0x12345, experiment_dirname="", monitor="va_loss"), data=DataArgs(batch_size=1), # after some manual runs, this lr seems to be the # fastest one to reliably learn for this toy example. # RMSProp performed considerably better than Adam|SGD optimizer=OptimizerArgs(learning_rate=0.01, name="RMSProp"), train=TrainArgs( checkpoint_k=0, # disable checkpoints early_stopping_patience=100, # disable early stopping ), trainer=TrainerArgs( weights_summary=None, overfit_batches=1, max_epochs=70, check_val_every_n_epoch=100, # disable validation ), ) assert sum("cer=0.0%" in m and "wer=0.0%" in m for m in caplog.messages)
def test_train_can_resume_training(tmpdir, caplog): syms, img_dirs, data_module = prepare_data(tmpdir) caplog.set_level("INFO") args = [ syms, img_dirs, data_module.root / "tr.gt", data_module.root / "va.gt", ] kwargs = { "common": CommonArgs(train_path=tmpdir), "data": DataArgs(batch_size=3), "optimizer": OptimizerArgs(name="SGD"), "train": TrainArgs(augment_training=True), "trainer": TrainerArgs(progress_bar_refresh_rate=0, weights_summary=None, max_epochs=1), } # run to have a checkpoint script.run(*args, **kwargs) assert "Model has been trained for 1 epochs (11 steps)" in caplog.messages caplog.clear() # train for one more epoch kwargs["train"] = TrainArgs(resume=1) script.run(*args, **kwargs) assert "Model has been trained for 2 epochs (21 steps)" in caplog.messages
def test_raises(tmpdir): with pytest.raises(AssertionError, match="Could not find the model"): script.run("", [], "", "") syms, img_dirs, data_module = prepare_data(tmpdir) with pytest.raises(AssertionError, match='The delimiter "TEST" is not available'): script.run( syms, [], "", "", common=CommonArgs(train_path=tmpdir), train=TrainArgs(delimiters=["<space>", "TEST"]), )
def test_train_early_stops(tmpdir, caplog): syms, img_dirs, data_module = prepare_data(tmpdir) caplog.set_level("INFO") script.run( syms, img_dirs, data_module.root / "tr.gt", data_module.root / "va.gt", common=CommonArgs(train_path=tmpdir), data=DataArgs(batch_size=3), train=TrainArgs(early_stopping_patience=2), trainer=TrainerArgs(progress_bar_refresh_rate=0, weights_summary=None, max_epochs=5), ) assert (sum( m.startswith( "Early stopping triggered after epoch 3 (waited for 2 epochs)") for m in caplog.messages) == 1)
def test_train_with_scheduler(tmpdir, caplog): syms, img_dirs, data_module = prepare_data(tmpdir) caplog.set_level("INFO") script.run( syms, img_dirs, data_module.root / "tr.gt", data_module.root / "va.gt", common=CommonArgs(train_path=tmpdir), data=DataArgs(batch_size=3), optimizer=OptimizerArgs(learning_rate=1), scheduler=SchedulerArgs(active=True, patience=0, monitor="va_wer", factor=0.5), trainer=TrainerArgs(progress_bar_refresh_rate=0, weights_summary=None, max_epochs=5), ) assert "E1: lr-RMSprop 1.000e+00 ⟶ 5.000e-01" in caplog.messages assert "E2: lr-RMSprop 5.000e-01 ⟶ 2.500e-01" in caplog.messages