Exemple #1
0
def setup(train_path, fixed_input_height=0):
    seed = 31102020
    seed_everything(seed)

    n = 10**4
    data_module = DummyMNISTLines(tr_n=n,
                                  va_n=int(0.1 * n),
                                  samples_per_space=5)
    print("Generating data...")
    data_module.prepare_data()

    syms = str(train_path / "syms")
    syms_table = SymbolsTable()
    for k, v in data_module.syms.items():
        syms_table.add(v, k)
    syms_table.save(syms)

    model(
        syms,
        adaptive_pooling="avgpool-3",
        fixed_input_height=fixed_input_height,
        save_model=True,
        common=CommonArgs(train_path=train_path),
        crnn=CreateCRNNArgs(
            cnn_num_features=[16, 32, 48, 64],
            # data is random so minimal RNN layer
            # because there are no long term dependencies
            rnn_units=32,
            rnn_layers=1,
            rnn_dropout=0,
        ),
    )

    return seed, data_module, syms
Exemple #2
0
def test_call_with_symbols_table(caplog):
    st = SymbolsTable()
    for k, v in {"a": 0, "b": 1, "<space>": 2, "<": 3}.items():
        st.add(k, v)
    t = ToTensor(st)
    x = "a < b <space> a ö"
    y = t(x)
    assert y == [0, 3, 1, 2, 0, None]
    assert caplog.messages.count(
        'Could not find "ö" in the symbols table') == 1
Exemple #3
0
def prepare_data(dir, image_sequencer="avgpool-8"):
    seed_everything(0x12345)
    data_module = DummyMNISTLines(samples_per_space=5)
    data_module.prepare_data()
    prepare_model(dir, image_sequencer)
    # prepare syms file
    syms = dir / "syms"
    syms_table = SymbolsTable()
    for k, v in data_module.syms.items():
        syms_table.add(v, k)
    syms_table.save(syms)
    # prepare img dirs
    img_dirs = [str(data_module.root / p) for p in ("tr", "va")]
    return syms, img_dirs, data_module
Exemple #4
0
def test_decode_on_dummy_mnist_lines_data(tmpdir, nprocs):
    # prepare data
    seed_everything(0x12345)
    data_module = DummyMNISTLines(tr_n=0,
                                  va_n=5,
                                  batch_size=3,
                                  samples_per_space=3)
    data_module.prepare_data()
    # prepare model file
    model_args = [(3, 3), 12]
    ModelSaver(tmpdir).save(DummyModel, *model_args)
    # prepare ckpt file
    ckpt = tmpdir / "model.ckpt"
    torch.save(DummyModel(*model_args).state_dict(), str(ckpt))
    # prepare syms file
    syms = tmpdir / "syms"
    syms_table = SymbolsTable()
    for k, v in data_module.syms.items():
        syms_table.add(v, k)
    syms_table.save(syms)
    # prepare img list
    img_list = tmpdir / "img_list"
    img_list.write_text(
        "\n".join(f"va-{i}" for i in range(data_module.n["va"])), "utf-8")

    args = [
        syms,
        img_list,
        f"--img_dirs={[str(data_module.root / 'va')]}",
        f"--common.checkpoint={ckpt}",
        f"--common.train_path={tmpdir}",
        f"--data.batch_size={data_module.batch_size}",
    ]
    if nprocs > 1:
        args.append("--trainer.accelerator=ddp_cpu")
        args.append(f"--trainer.num_processes={nprocs}")

    stdout, stderr = call_script(script.__file__, args)
    print(f"Script stdout:\n{stdout}")
    print(f"Script stderr:\n{stderr}")

    img_ids = [l.split(" ", maxsplit=1)[0] for l in stdout.strip().split("\n")]
    assert sorted(img_ids) == [f"va-{i}" for i in range(data_module.n["va"])]
    assert "Using checkpoint" in stderr