Exemple #1
0
def test_train_can_resume_training(tmpdir, caplog):
    syms, img_dirs, data_module = prepare_data(tmpdir)
    caplog.set_level("INFO")
    args = [
        syms,
        img_dirs,
        data_module.root / "tr.gt",
        data_module.root / "va.gt",
    ]
    kwargs = {
        "common":
        CommonArgs(train_path=tmpdir),
        "data":
        DataArgs(batch_size=3),
        "optimizer":
        OptimizerArgs(name="SGD"),
        "train":
        TrainArgs(augment_training=True),
        "trainer":
        TrainerArgs(progress_bar_refresh_rate=0,
                    weights_summary=None,
                    max_epochs=1),
    }
    # run to have a checkpoint
    script.run(*args, **kwargs)
    assert "Model has been trained for 1 epochs (11 steps)" in caplog.messages
    caplog.clear()

    # train for one more epoch
    kwargs["train"] = TrainArgs(resume=1)
    script.run(*args, **kwargs)
    assert "Model has been trained for 2 epochs (21 steps)" in caplog.messages
Exemple #2
0
def test_train_can_overfit_one_image(tmpdir, caplog):
    syms, img_dirs, data_module = prepare_data(tmpdir)
    # manually select a specific image
    txt_file = data_module.root / "tr.gt"
    line = "tr-6 9 2 0 1"
    assert txt_file.read_text().splitlines()[6] == line
    txt_file.write_text(line)

    caplog.set_level("INFO")
    script.run(
        syms,
        img_dirs,
        txt_file,
        txt_file,
        common=CommonArgs(train_path=tmpdir,
                          seed=0x12345,
                          experiment_dirname="",
                          monitor="va_loss"),
        data=DataArgs(batch_size=1),
        # after some manual runs, this lr seems to be the
        # fastest one to reliably learn for this toy example.
        # RMSProp performed considerably better than Adam|SGD
        optimizer=OptimizerArgs(learning_rate=0.01, name="RMSProp"),
        train=TrainArgs(
            checkpoint_k=0,  # disable checkpoints
            early_stopping_patience=100,  # disable early stopping
        ),
        trainer=TrainerArgs(
            weights_summary=None,
            overfit_batches=1,
            max_epochs=70,
            check_val_every_n_epoch=100,  # disable validation
        ),
    )
    assert sum("cer=0.0%" in m and "wer=0.0%" in m for m in caplog.messages)
Exemple #3
0
def test_raises(tmpdir):
    with pytest.raises(AssertionError, match="Could not find the model"):
        script.run("", [], "", "")

    syms, img_dirs, data_module = prepare_data(tmpdir)
    with pytest.raises(AssertionError,
                       match='The delimiter "TEST" is not available'):
        script.run(
            syms,
            [],
            "",
            "",
            common=CommonArgs(train_path=tmpdir),
            train=TrainArgs(delimiters=["<space>", "TEST"]),
        )
Exemple #4
0
def get_args(argv: Optional[List[str]] = None) -> Dict[str, Any]:
    parser = jsonargparse.ArgumentParser(parse_as_dict=True)
    parser.add_argument("--config",
                        action=jsonargparse.ActionConfigFile,
                        help="Configuration file")
    parser.add_argument(
        "syms",
        type=str,
        help=("Mapping from strings to integers. "
              "The CTC symbol must be mapped to integer 0"),
    )
    parser.add_argument(
        "img_dirs",
        type=List[str],
        default=[],
        help="Directories containing segmented line images",
    )
    parser.add_argument(
        "tr_txt_table",
        type=str,
        help="Character transcription of each training image",
    )
    parser.add_argument(
        "va_txt_table",
        type=str,
        help="Character transcription of each validation image",
    )
    parser.add_class_arguments(CommonArgs, "common")
    parser.add_class_arguments(DataArgs, "data")
    parser.add_class_arguments(TrainArgs, "train")
    parser.add_function_arguments(log.config, "logging")
    parser.add_class_arguments(OptimizerArgs, "optimizer")
    parser.add_class_arguments(SchedulerArgs, "scheduler")
    parser.add_class_arguments(TrainerArgs, "trainer")

    args = parser.parse_args(argv, with_meta=False)

    args["common"] = CommonArgs(**args["common"])
    args["train"] = TrainArgs(**args["train"])
    args["data"] = DataArgs(**args["data"])
    args["optimizer"] = OptimizerArgs(**args["optimizer"])
    args["scheduler"] = SchedulerArgs(**args["scheduler"])
    args["trainer"] = TrainerArgs(**args["trainer"])

    return args
Exemple #5
0
def test_train_early_stops(tmpdir, caplog):
    syms, img_dirs, data_module = prepare_data(tmpdir)
    caplog.set_level("INFO")
    script.run(
        syms,
        img_dirs,
        data_module.root / "tr.gt",
        data_module.root / "va.gt",
        common=CommonArgs(train_path=tmpdir),
        data=DataArgs(batch_size=3),
        train=TrainArgs(early_stopping_patience=2),
        trainer=TrainerArgs(progress_bar_refresh_rate=0,
                            weights_summary=None,
                            max_epochs=5),
    )
    assert (sum(
        m.startswith(
            "Early stopping triggered after epoch 3 (waited for 2 epochs)")
        for m in caplog.messages) == 1)
Exemple #6
0
def run(
        syms: str,
        img_dirs: List[str],
        tr_txt_table: str,
        va_txt_table: str,
        common: CommonArgs = CommonArgs(),
        train: TrainArgs = TrainArgs(),
        optimizer: OptimizerArgs = OptimizerArgs(),
        scheduler: SchedulerArgs = SchedulerArgs(),
        data: DataArgs = DataArgs(),
        trainer: TrainerArgs = TrainerArgs(),
):
    pl.seed_everything(common.seed)

    loader = ModelLoader(common.train_path,
                         filename=common.model_filename,
                         device="cpu")
    # maybe load a checkpoint
    checkpoint = None
    if train.resume:
        checkpoint = loader.prepare_checkpoint(common.checkpoint,
                                               common.experiment_dirpath,
                                               common.monitor)
        trainer.max_epochs = torch.load(checkpoint)["epoch"] + train.resume
        log.info(f'Using checkpoint "{checkpoint}"')
        log.info(f"Max epochs set to {trainer.max_epochs}")

    # load the non-pytorch_lightning model
    model = loader.load()
    assert (
        model is not None
    ), "Could not find the model. Have you run pylaia-htr-create-model?"

    # prepare the symbols
    syms = SymbolsTable(syms)
    for d in train.delimiters:
        assert d in syms, f'The delimiter "{d}" is not available in the symbols file'

    # prepare the engine
    engine_module = HTREngineModule(
        model,
        [syms[d] for d in train.delimiters],
        optimizer=optimizer,
        scheduler=scheduler,
        batch_input_fn=Compose([ItemFeeder("img"),
                                ImageFeeder()]),
        batch_target_fn=ItemFeeder("txt"),
        batch_id_fn=ItemFeeder("id"),  # Used to print image ids on exception
    )

    # prepare the data
    data_module = DataModule(
        syms=syms,
        img_dirs=img_dirs,
        tr_txt_table=tr_txt_table,
        va_txt_table=va_txt_table,
        batch_size=data.batch_size,
        color_mode=data.color_mode,
        shuffle_tr=not bool(trainer.limit_train_batches),
        augment_tr=train.augment_training,
        stage="fit",
    )

    # prepare the training callbacks
    # TODO: save on lowest_va_wer and every k epochs https://github.com/PyTorchLightning/pytorch-lightning/issues/2908
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath=common.experiment_dirpath,
        filename="{epoch}-lowest_" + common.monitor,
        monitor=common.monitor,
        verbose=True,
        save_top_k=train.checkpoint_k,
        mode="min",
        save_last=True,
    )
    checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"
    early_stopping_callback = pl.callbacks.EarlyStopping(
        monitor=common.monitor,
        patience=train.early_stopping_patience,
        verbose=True,
        mode="min",
        strict=False,  # training_step may return None
    )
    callbacks = [
        ProgressBar(refresh_rate=trainer.progress_bar_refresh_rate),
        checkpoint_callback,
        early_stopping_callback,
        checkpoint_callback,
    ]
    if train.gpu_stats:
        callbacks.append(ProgressBarGPUStats())
    if scheduler.active:
        callbacks.append(LearningRate(logging_interval="epoch"))

    # prepare the trainer
    trainer = pl.Trainer(
        default_root_dir=common.train_path,
        resume_from_checkpoint=checkpoint,
        callbacks=callbacks,
        logger=EpochCSVLogger(common.experiment_dirpath),
        checkpoint_callback=True,
        **vars(trainer),
    )

    # train!
    trainer.fit(engine_module, datamodule=data_module)

    # training is over
    if early_stopping_callback.stopped_epoch:
        log.info(
            "Early stopping triggered after epoch"
            f" {early_stopping_callback.stopped_epoch + 1} (waited for"
            f" {early_stopping_callback.wait_count} epochs). The best score was"
            f" {early_stopping_callback.best_score}")
    log.info(f"Model has been trained for {trainer.current_epoch + 1} epochs"
             f" ({trainer.global_step + 1} steps)")
    log.info(
        f"Best {checkpoint_callback.monitor}={checkpoint_callback.best_model_score} "
        f"obtained with model={checkpoint_callback.best_model_path}")
Exemple #7
0
# 1e-3 was the best learning rate with batch size 128
k = batch_size / 128
learning_rate = 1e-3 * sqrt(k)

train(
    syms,
    [str(data_module.root / p) for p in ("tr", "va")],
    *[str(data_module.root / f"{p}.gt") for p in ("tr", "va")],
    common=CommonArgs(
        train_path=train_path,
        seed=seed,
        experiment_dirname="",
    ),
    data=DataArgs(batch_size=batch_size),
    optimizer=OptimizerArgs(learning_rate=learning_rate),
    train=TrainArgs(
        # disable checkpointing
        checkpoint_k=0,
        # disable early stopping
        early_stopping_patience=epochs,
        gpu_stats=True,
    ),
    trainer=TrainerArgs(
        max_epochs=epochs,
        weights_summary=None,
        gpus=1,
        # training is still not deterministic on GPU
        deterministic=True,
    ),
)