コード例 #1
0
def test_train_can_overfit_one_image(tmpdir, caplog):
    syms, img_dirs, data_module = prepare_data(tmpdir)
    # manually select a specific image
    txt_file = data_module.root / "tr.gt"
    line = "tr-6 9 2 0 1"
    assert txt_file.read_text().splitlines()[6] == line
    txt_file.write_text(line)

    caplog.set_level("INFO")
    script.run(
        syms,
        img_dirs,
        txt_file,
        txt_file,
        common=CommonArgs(train_path=tmpdir,
                          seed=0x12345,
                          experiment_dirname="",
                          monitor="va_loss"),
        data=DataArgs(batch_size=1),
        # after some manual runs, this lr seems to be the
        # fastest one to reliably learn for this toy example.
        # RMSProp performed considerably better than Adam|SGD
        optimizer=OptimizerArgs(learning_rate=0.01, name="RMSProp"),
        train=TrainArgs(
            checkpoint_k=0,  # disable checkpoints
            early_stopping_patience=100,  # disable early stopping
        ),
        trainer=TrainerArgs(
            weights_summary=None,
            overfit_batches=1,
            max_epochs=70,
            check_val_every_n_epoch=100,  # disable validation
        ),
    )
    assert sum("cer=0.0%" in m and "wer=0.0%" in m for m in caplog.messages)
コード例 #2
0
def test_train_can_resume_training(tmpdir, caplog):
    syms, img_dirs, data_module = prepare_data(tmpdir)
    caplog.set_level("INFO")
    args = [
        syms,
        img_dirs,
        data_module.root / "tr.gt",
        data_module.root / "va.gt",
    ]
    kwargs = {
        "common":
        CommonArgs(train_path=tmpdir),
        "data":
        DataArgs(batch_size=3),
        "optimizer":
        OptimizerArgs(name="SGD"),
        "train":
        TrainArgs(augment_training=True),
        "trainer":
        TrainerArgs(progress_bar_refresh_rate=0,
                    weights_summary=None,
                    max_epochs=1),
    }
    # run to have a checkpoint
    script.run(*args, **kwargs)
    assert "Model has been trained for 1 epochs (11 steps)" in caplog.messages
    caplog.clear()

    # train for one more epoch
    kwargs["train"] = TrainArgs(resume=1)
    script.run(*args, **kwargs)
    assert "Model has been trained for 2 epochs (21 steps)" in caplog.messages
コード例 #3
0
def test_raises(tmpdir):
    with pytest.raises(AssertionError, match="Could not find the model"):
        script.run("", [], "", "")

    syms, img_dirs, data_module = prepare_data(tmpdir)
    with pytest.raises(AssertionError,
                       match='The delimiter "TEST" is not available'):
        script.run(
            syms,
            [],
            "",
            "",
            common=CommonArgs(train_path=tmpdir),
            train=TrainArgs(delimiters=["<space>", "TEST"]),
        )
コード例 #4
0
def test_train_early_stops(tmpdir, caplog):
    syms, img_dirs, data_module = prepare_data(tmpdir)
    caplog.set_level("INFO")
    script.run(
        syms,
        img_dirs,
        data_module.root / "tr.gt",
        data_module.root / "va.gt",
        common=CommonArgs(train_path=tmpdir),
        data=DataArgs(batch_size=3),
        train=TrainArgs(early_stopping_patience=2),
        trainer=TrainerArgs(progress_bar_refresh_rate=0,
                            weights_summary=None,
                            max_epochs=5),
    )
    assert (sum(
        m.startswith(
            "Early stopping triggered after epoch 3 (waited for 2 epochs)")
        for m in caplog.messages) == 1)
コード例 #5
0
def test_train_with_scheduler(tmpdir, caplog):
    syms, img_dirs, data_module = prepare_data(tmpdir)
    caplog.set_level("INFO")
    script.run(
        syms,
        img_dirs,
        data_module.root / "tr.gt",
        data_module.root / "va.gt",
        common=CommonArgs(train_path=tmpdir),
        data=DataArgs(batch_size=3),
        optimizer=OptimizerArgs(learning_rate=1),
        scheduler=SchedulerArgs(active=True,
                                patience=0,
                                monitor="va_wer",
                                factor=0.5),
        trainer=TrainerArgs(progress_bar_refresh_rate=0,
                            weights_summary=None,
                            max_epochs=5),
    )
    assert "E1: lr-RMSprop 1.000e+00 ⟶ 5.000e-01" in caplog.messages
    assert "E2: lr-RMSprop 5.000e-01 ⟶ 2.500e-01" in caplog.messages