Example #1
0
def test_call_with_symbols_table(caplog):
    st = SymbolsTable()
    for k, v in {"a": 0, "b": 1, "<space>": 2, "<": 3}.items():
        st.add(k, v)
    t = ToTensor(st)
    x = "a < b <space> a ö"
    y = t(x)
    assert y == [0, 3, 1, 2, 0, None]
    assert caplog.messages.count(
        'Could not find "ö" in the symbols table') == 1
Example #2
0
def setup(train_path, fixed_input_height=0):
    seed = 31102020
    seed_everything(seed)

    n = 10**4
    data_module = DummyMNISTLines(tr_n=n,
                                  va_n=int(0.1 * n),
                                  samples_per_space=5)
    print("Generating data...")
    data_module.prepare_data()

    syms = str(train_path / "syms")
    syms_table = SymbolsTable()
    for k, v in data_module.syms.items():
        syms_table.add(v, k)
    syms_table.save(syms)

    model(
        syms,
        adaptive_pooling="avgpool-3",
        fixed_input_height=fixed_input_height,
        save_model=True,
        common=CommonArgs(train_path=train_path),
        crnn=CreateCRNNArgs(
            cnn_num_features=[16, 32, 48, 64],
            # data is random so minimal RNN layer
            # because there are no long term dependencies
            rnn_units=32,
            rnn_layers=1,
            rnn_dropout=0,
        ),
    )

    return seed, data_module, syms
Example #3
0
def prepare_data(dir, image_sequencer="avgpool-8"):
    seed_everything(0x12345)
    data_module = DummyMNISTLines(samples_per_space=5)
    data_module.prepare_data()
    prepare_model(dir, image_sequencer)
    # prepare syms file
    syms = dir / "syms"
    syms_table = SymbolsTable()
    for k, v in data_module.syms.items():
        syms_table.add(v, k)
    syms_table.save(syms)
    # prepare img dirs
    img_dirs = [str(data_module.root / p) for p in ("tr", "va")]
    return syms, img_dirs, data_module
Example #4
0
def run(
    syms: str,
    fixed_input_height: Optional[NonNegativeInt] = 0,
    adaptive_pooling: str = "avgpool-16",
    common: CommonArgs = CommonArgs(),
    crnn: CreateCRNNArgs = CreateCRNNArgs(),
    save_model: bool = False,
) -> LaiaCRNN:
    seed_everything(common.seed)

    crnn.num_output_labels = len(SymbolsTable(syms))
    if crnn is not None:
        if fixed_input_height:
            conv_output_size = LaiaCRNN.get_conv_output_size(
                size=(fixed_input_height, fixed_input_height),
                cnn_kernel_size=crnn.cnn_kernel_size,
                cnn_stride=crnn.cnn_stride,
                cnn_dilation=crnn.cnn_dilation,
                cnn_poolsize=crnn.cnn_poolsize,
            )
            fixed_size_after_conv = conv_output_size[1 if crnn.
                                                     vertical_text else 0]
            assert (fixed_size_after_conv >
                    0), "The image size is too small for the CNN architecture"
            crnn.image_sequencer = f"none-{fixed_size_after_conv}"
        else:
            crnn.image_sequencer = adaptive_pooling
        crnn.rnn_type = getattr(nn, crnn.rnn_type)
        crnn.cnn_activation = [getattr(nn, act) for act in crnn.cnn_activation]

    model = LaiaCRNN(**vars(crnn))
    log.info(
        "Model has {} parameters",
        sum(param.numel() for param in model.parameters()),
    )
    if save_model:
        ModelSaver(common.train_path,
                   common.model_filename).save(LaiaCRNN, **vars(crnn))
    return model
Example #5
0
def test_decode_on_dummy_mnist_lines_data(tmpdir, nprocs):
    # prepare data
    seed_everything(0x12345)
    data_module = DummyMNISTLines(tr_n=0,
                                  va_n=5,
                                  batch_size=3,
                                  samples_per_space=3)
    data_module.prepare_data()
    # prepare model file
    model_args = [(3, 3), 12]
    ModelSaver(tmpdir).save(DummyModel, *model_args)
    # prepare ckpt file
    ckpt = tmpdir / "model.ckpt"
    torch.save(DummyModel(*model_args).state_dict(), str(ckpt))
    # prepare syms file
    syms = tmpdir / "syms"
    syms_table = SymbolsTable()
    for k, v in data_module.syms.items():
        syms_table.add(v, k)
    syms_table.save(syms)
    # prepare img list
    img_list = tmpdir / "img_list"
    img_list.write_text(
        "\n".join(f"va-{i}" for i in range(data_module.n["va"])), "utf-8")

    args = [
        syms,
        img_list,
        f"--img_dirs={[str(data_module.root / 'va')]}",
        f"--common.checkpoint={ckpt}",
        f"--common.train_path={tmpdir}",
        f"--data.batch_size={data_module.batch_size}",
    ]
    if nprocs > 1:
        args.append("--trainer.accelerator=ddp_cpu")
        args.append(f"--trainer.num_processes={nprocs}")

    stdout, stderr = call_script(script.__file__, args)
    print(f"Script stdout:\n{stdout}")
    print(f"Script stderr:\n{stderr}")

    img_ids = [l.split(" ", maxsplit=1)[0] for l in stdout.strip().split("\n")]
    assert sorted(img_ids) == [f"va-{i}" for i in range(data_module.n["va"])]
    assert "Using checkpoint" in stderr
        help="Type of the KWS index to process",
    )
    parser.add_argument(
        "index_file", type=argparse.FileType("r"), help="File containing the KWS index"
    )
    parser.add_argument(
        "pgrams_file",
        type=argparse.FileType("r"),
        help="File containing the parallelograms",
    )
    parser.add_argument(
        "img_dir", type=str, help="Directory containing the processed images"
    )
    args = parser.parse_args()
    # Load symbols table
    syms = SymbolsTable(args.symbols_table) if args.symbols_table else None
    # Load resize info file
    resize_info = (
        parse_resize_info_file(args.resize_info_file) if args.resize_info_file else None
    )
    # Load parallelograms
    pgrams = parse_paralellograms_file(args.pgrams_file)

    for n, sample in enumerate(args.index_file):
        m = re.match(r"^([^ ]+) +(.+)$", sample)
        if not m:
            raise ValueError("Wrong index entry at line{}".format(n))

        sample_id = m.group(1)
        pm = re.match(args.page_id_regex, sample_id)
        page_id = pm.group(1)
Example #7
0
                 help='Number of units the recurrent layers')
    add_argument('--rnn_layers',
                 default=2,
                 type=NumberInClosedRange(int, vmin=1),
                 help='Number of recurrent layers')
    add_argument('--rnn_type',
                 choices=['LSTM', 'GRU'],
                 default='LSTM',
                 help='Type of the recurrent layers')
    add_argument('--rnn_dropout',
                 default=0.5,
                 type=NumberInClosedRange(float, vmin=0, vmax=1),
                 help='Dropout before and after the recurrent layers')
    args = args()

    num_output_symbols = len(SymbolsTable(args.syms))

    ModelSaver(args.train_path, args.filename) \
        .save(GatedCRNN,
              in_channels=args.num_input_channels,
              num_outputs=num_output_symbols,
              cnn_num_features=args.cnn_num_features,
              cnn_kernel_sizes=args.cnn_kernel_size,
              cnn_strides=args.cnn_stride,
              cnn_add_gating=args.cnn_add_gating,
              cnn_poolsize=args.cnn_poolsize,
              cnn_activation=[getattr(nn, act) for act in args.cnn_activations],
              sequencer=args.sequencer,
              columnwise=args.columnwise,
              rnn_hidden_size=args.rnn_hidden_size,
              rnn_num_layers=args.rnn_layers,
Example #8
0
def run(
        syms: str,
        img_dirs: List[str],
        tr_txt_table: str,
        va_txt_table: str,
        common: CommonArgs = CommonArgs(),
        train: TrainArgs = TrainArgs(),
        optimizer: OptimizerArgs = OptimizerArgs(),
        scheduler: SchedulerArgs = SchedulerArgs(),
        data: DataArgs = DataArgs(),
        trainer: TrainerArgs = TrainerArgs(),
):
    pl.seed_everything(common.seed)

    loader = ModelLoader(common.train_path,
                         filename=common.model_filename,
                         device="cpu")
    # maybe load a checkpoint
    checkpoint = None
    if train.resume:
        checkpoint = loader.prepare_checkpoint(common.checkpoint,
                                               common.experiment_dirpath,
                                               common.monitor)
        trainer.max_epochs = torch.load(checkpoint)["epoch"] + train.resume
        log.info(f'Using checkpoint "{checkpoint}"')
        log.info(f"Max epochs set to {trainer.max_epochs}")

    # load the non-pytorch_lightning model
    model = loader.load()
    assert (
        model is not None
    ), "Could not find the model. Have you run pylaia-htr-create-model?"

    # prepare the symbols
    syms = SymbolsTable(syms)
    for d in train.delimiters:
        assert d in syms, f'The delimiter "{d}" is not available in the symbols file'

    # prepare the engine
    engine_module = HTREngineModule(
        model,
        [syms[d] for d in train.delimiters],
        optimizer=optimizer,
        scheduler=scheduler,
        batch_input_fn=Compose([ItemFeeder("img"),
                                ImageFeeder()]),
        batch_target_fn=ItemFeeder("txt"),
        batch_id_fn=ItemFeeder("id"),  # Used to print image ids on exception
    )

    # prepare the data
    data_module = DataModule(
        syms=syms,
        img_dirs=img_dirs,
        tr_txt_table=tr_txt_table,
        va_txt_table=va_txt_table,
        batch_size=data.batch_size,
        color_mode=data.color_mode,
        shuffle_tr=not bool(trainer.limit_train_batches),
        augment_tr=train.augment_training,
        stage="fit",
    )

    # prepare the training callbacks
    # TODO: save on lowest_va_wer and every k epochs https://github.com/PyTorchLightning/pytorch-lightning/issues/2908
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath=common.experiment_dirpath,
        filename="{epoch}-lowest_" + common.monitor,
        monitor=common.monitor,
        verbose=True,
        save_top_k=train.checkpoint_k,
        mode="min",
        save_last=True,
    )
    checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"
    early_stopping_callback = pl.callbacks.EarlyStopping(
        monitor=common.monitor,
        patience=train.early_stopping_patience,
        verbose=True,
        mode="min",
        strict=False,  # training_step may return None
    )
    callbacks = [
        ProgressBar(refresh_rate=trainer.progress_bar_refresh_rate),
        checkpoint_callback,
        early_stopping_callback,
        checkpoint_callback,
    ]
    if train.gpu_stats:
        callbacks.append(ProgressBarGPUStats())
    if scheduler.active:
        callbacks.append(LearningRate(logging_interval="epoch"))

    # prepare the trainer
    trainer = pl.Trainer(
        default_root_dir=common.train_path,
        resume_from_checkpoint=checkpoint,
        callbacks=callbacks,
        logger=EpochCSVLogger(common.experiment_dirpath),
        checkpoint_callback=True,
        **vars(trainer),
    )

    # train!
    trainer.fit(engine_module, datamodule=data_module)

    # training is over
    if early_stopping_callback.stopped_epoch:
        log.info(
            "Early stopping triggered after epoch"
            f" {early_stopping_callback.stopped_epoch + 1} (waited for"
            f" {early_stopping_callback.wait_count} epochs). The best score was"
            f" {early_stopping_callback.best_score}")
    log.info(f"Model has been trained for {trainer.current_epoch + 1} epochs"
             f" ({trainer.global_step + 1} steps)")
    log.info(
        f"Best {checkpoint_callback.monitor}={checkpoint_callback.best_model_score} "
        f"obtained with model={checkpoint_callback.best_model_path}")
        type=str,
        help="Score function",
    )
    add_argument(
        "--source",
        type=str,
        default="experiment",
        choices=["experiment", "model"],
        help="Type of class which generated the checkpoint",
    )
    add_argument("--save_dict_filename", type=str)

    # Loading of models and datasets
    args = args()

    syms = SymbolsTable(args.syms)
    device = torch.device("cuda:{}".format(args.gpu -
                                           1) if args.gpu else "cpu")

    model = ModelLoader(args.train_path,
                        filename=args.model_filename,
                        device=device).load()
    if model is None:
        log.error("Could not find the model")
        exit(1)
    state = CheckpointLoader(device=device).load_by(
        os.path.join(args.train_path, args.checkpoint))
    model.load_state_dict(state if args.source ==
                          "model" else Experiment.get_model_state_dict(state))
    model = model.to(device)
    model.eval()
Example #10
0
def run(
        syms: str,
        img_list: str,
        img_dirs: Optional[List[str]] = None,
        common: CommonArgs = CommonArgs(),
        data: DataArgs = DataArgs(),
        decode: DecodeArgs = DecodeArgs(),
        trainer: TrainerArgs = TrainerArgs(),
):
    loader = ModelLoader(common.train_path,
                         filename=common.model_filename,
                         device="cpu")
    checkpoint = loader.prepare_checkpoint(
        common.checkpoint,
        common.experiment_dirpath,
        common.monitor,
    )
    model = loader.load_by(checkpoint)
    assert (
        model is not None
    ), "Could not find the model. Have you run pylaia-htr-create-model?"

    # prepare the evaluator
    evaluator_module = EvaluatorModule(
        model,
        batch_input_fn=Compose([ItemFeeder("img"),
                                ImageFeeder()]),
        batch_id_fn=ItemFeeder("id"),
    )

    # prepare the symbols
    syms = SymbolsTable(syms)

    # prepare the data
    data_module = DataModule(
        syms=syms,
        img_dirs=img_dirs,
        te_img_list=img_list,
        batch_size=data.batch_size,
        color_mode=data.color_mode,
        stage="test",
    )

    # prepare the testing callbacks
    callbacks = [
        ProgressBar(refresh_rate=trainer.progress_bar_refresh_rate),
        Segmentation(
            syms,
            segmentation=decode.segmentation,
            input_space=decode.input_space,
            separator=decode.separator,
            include_img_ids=decode.include_img_ids,
        ) if bool(decode.segmentation) else Decode(
            syms=syms,
            use_symbols=decode.use_symbols,
            input_space=decode.input_space,
            output_space=decode.output_space,
            convert_spaces=decode.convert_spaces,
            join_string=decode.join_string,
            separator=decode.separator,
            include_img_ids=decode.include_img_ids,
        ),
    ]

    # prepare the trainer
    trainer = pl.Trainer(
        default_root_dir=common.train_path,
        callbacks=callbacks,
        logger=False,
        **vars(trainer),
    )

    # decode!
    trainer.test(evaluator_module, datamodule=data_module, verbose=False)