예제 #1
0
def setup(train_path, fixed_input_height=0):
    seed = 31102020
    seed_everything(seed)

    n = 10**4
    data_module = DummyMNISTLines(tr_n=n,
                                  va_n=int(0.1 * n),
                                  samples_per_space=5)
    print("Generating data...")
    data_module.prepare_data()

    syms = str(train_path / "syms")
    syms_table = SymbolsTable()
    for k, v in data_module.syms.items():
        syms_table.add(v, k)
    syms_table.save(syms)

    model(
        syms,
        adaptive_pooling="avgpool-3",
        fixed_input_height=fixed_input_height,
        save_model=True,
        common=CommonArgs(train_path=train_path),
        crnn=CreateCRNNArgs(
            cnn_num_features=[16, 32, 48, 64],
            # data is random so minimal RNN layer
            # because there are no long term dependencies
            rnn_units=32,
            rnn_layers=1,
            rnn_dropout=0,
        ),
    )

    return seed, data_module, syms
예제 #2
0
def test_netout_on_dummy_mnist_lines_data(tmpdir, nprocs):
    seed_everything(0x12345)
    # prepare data
    data_module = DummyMNISTLines(tr_n=0, va_n=5, batch_size=3, samples_per_space=3)
    data_module.prepare_data()
    # prepare model file
    final_size, classes = 3, 12  # 12 == 10 digits + space + ctc
    model_args = [(final_size,) * 2, classes]
    ModelSaver(tmpdir).save(DummyModel, *model_args)
    # prepare ckpt file
    ckpt = tmpdir / "model.ckpt"
    torch.save(DummyModel(*model_args).state_dict(), str(ckpt))
    # prepare img list
    img_list = tmpdir / "img_list"
    img_list.write_text(
        "\n".join(f"va-{i}" for i in range(data_module.n["va"])), "utf-8"
    )

    args = [
        img_list,
        f"--img_dirs=[{str(data_module.root / 'va')}]",
        f"--common.train_path={tmpdir}",
        f"--common.checkpoint={ckpt}",
        f"--common.experiment_dirname={tmpdir}",
        f"--data.batch_size={data_module.batch_size}",
        "--netout.output_transform=softmax",
        "--netout.digits=3",
        "--netout.lattice=lattice",
        "--netout.matrix=matrix",
    ]
    if nprocs > 1:
        args.append("--trainer.accelerator=ddp_cpu")
        args.append(f"--trainer.num_processes={nprocs}")

    stdout, stderr = call_script(script.__file__, args)
    print(f"Script stdout:\n{stdout}")
    print(f"Script stderr:\n{stderr}")

    assert "Using checkpoint" in stderr

    lattice = tmpdir / "lattice"
    assert lattice.exists()
    lines = [l.strip() for l in lattice.read_text("utf-8").split("\n") if l]
    n = data_module.n["va"]
    assert len(lines) == n * (final_size * classes + 1) + n

    # this is harder to test so do some basic checks
    matrix = tmpdir / "matrix"
    assert matrix.exists()
    assert len(matrix.read_binary()) > 0
예제 #3
0
def prepare_data(dir, image_sequencer="avgpool-8"):
    seed_everything(0x12345)
    data_module = DummyMNISTLines(samples_per_space=5)
    data_module.prepare_data()
    prepare_model(dir, image_sequencer)
    # prepare syms file
    syms = dir / "syms"
    syms_table = SymbolsTable()
    for k, v in data_module.syms.items():
        syms_table.add(v, k)
    syms_table.save(syms)
    # prepare img dirs
    img_dirs = [str(data_module.root / p) for p in ("tr", "va")]
    return syms, img_dirs, data_module
def test_concatenate():
    h, w = 10, 15
    dataset = [
        (np.full((h, w), 11), "a"),
        (np.full((h, w), 12), "b"),
        (np.full((h, w), 13), "foo"),
        (np.full((h, w), 14), "bar"),
    ]
    indices = [0, 3, 1, "sp", 2]
    img, txt, mask = DummyMNISTLines.concatenate(dataset,
                                                 h,
                                                 w,
                                                 indices,
                                                 space_sym="test")

    # check image values
    assert np.all(img[:, :w] == 11)
    assert np.all(img[:, w:w * 2] == 14)
    assert np.all(img[:, w * 2:w * 3] == 12)
    assert np.all(img[:, w * 3:w * 4] == 0)
    assert np.all(img[:, w * 4:] == 13)

    # check label text
    assert txt == "a bar b test foo"

    # check mask values
    assert np.all(mask[:, :w] == 0)
    assert np.all(mask[:, w:w * 2] == 0)
    assert np.all(mask[:, w * 2:w * 3] == 0)
    assert np.all(mask[:, w * 3:w * 4] == 1)
    assert np.all(mask[:, w * 4:] == 0)
def test_get_indices_with_spaces():
    choices1 = [5, 10, 15, 20, 25, 30, 35, 40]
    choices2 = [1, 5, 6, 7]
    randint = patch("numpy.random.randint", return_value=len(choices1))
    choice = patch("numpy.random.choice", side_effect=[choices1, choices2])
    with randint, choice:
        out = DummyMNISTLines.get_indices(10, 0, samples_per_space=3)
    assert out == [5, "sp", 10, 15, 20, 25, "sp", 30, "sp", 35, "sp", 40]
예제 #6
0
def test_decode_on_dummy_mnist_lines_data(tmpdir, nprocs):
    # prepare data
    seed_everything(0x12345)
    data_module = DummyMNISTLines(tr_n=0,
                                  va_n=5,
                                  batch_size=3,
                                  samples_per_space=3)
    data_module.prepare_data()
    # prepare model file
    model_args = [(3, 3), 12]
    ModelSaver(tmpdir).save(DummyModel, *model_args)
    # prepare ckpt file
    ckpt = tmpdir / "model.ckpt"
    torch.save(DummyModel(*model_args).state_dict(), str(ckpt))
    # prepare syms file
    syms = tmpdir / "syms"
    syms_table = SymbolsTable()
    for k, v in data_module.syms.items():
        syms_table.add(v, k)
    syms_table.save(syms)
    # prepare img list
    img_list = tmpdir / "img_list"
    img_list.write_text(
        "\n".join(f"va-{i}" for i in range(data_module.n["va"])), "utf-8")

    args = [
        syms,
        img_list,
        f"--img_dirs={[str(data_module.root / 'va')]}",
        f"--common.checkpoint={ckpt}",
        f"--common.train_path={tmpdir}",
        f"--data.batch_size={data_module.batch_size}",
    ]
    if nprocs > 1:
        args.append("--trainer.accelerator=ddp_cpu")
        args.append(f"--trainer.num_processes={nprocs}")

    stdout, stderr = call_script(script.__file__, args)
    print(f"Script stdout:\n{stdout}")
    print(f"Script stderr:\n{stderr}")

    img_ids = [l.split(" ", maxsplit=1)[0] for l in stdout.strip().split("\n")]
    assert sorted(img_ids) == [f"va-{i}" for i in range(data_module.n["va"])]
    assert "Using checkpoint" in stderr
예제 #7
0
def test_netout_callback(tmpdir, num_processes):
    data_module = DummyMNISTLines(batch_size=2, va_n=12)
    trainer = DummyTrainer(
        default_root_dir=tmpdir,
        limit_test_batches=3,
        callbacks=[Netout(writers=[__TestWriter()])],
        accelerator="ddp_cpu" if num_processes > 1 else None,
        num_processes=num_processes,
    )
    trainer.test(DummyEvaluator(), datamodule=data_module)
예제 #8
0
def test_segmentation_callback(tmpdir, num_processes, kwargs, img_id, segm):
    data_module = DummyMNISTLines(batch_size=2, va_n=12)
    trainer = DummyTrainer(
        default_root_dir=tmpdir,
        limit_test_batches=3,
        callbacks=[
            __TestSegmentation(img_id, segm, data_module.syms, **kwargs)
        ],
        accelerator="ddp_cpu" if num_processes > 1 else None,
        num_processes=num_processes,
    )
    trainer.test(DummyEvaluator(), datamodule=data_module)
예제 #9
0
def test_decode(tmpdir, num_processes, kwargs, img_id, hyp):
    module = DummyEvaluator()
    data_module = DummyMNISTLines(batch_size=2, va_n=12, samples_per_space=10)
    decode_callback = __TestDecode(img_id, hyp, syms=data_module.syms, **kwargs)
    trainer = DummyTrainer(
        default_root_dir=tmpdir,
        limit_test_batches=3,
        callbacks=[decode_callback],
        accelerator="ddp_cpu" if num_processes > 1 else None,
        num_processes=num_processes,
    )
    trainer.test(module, datamodule=data_module)
예제 #10
0
def test_prepare_data(tmpdir):
    data_module = DummyMNISTLines(max_length=5, tr_n=5, va_n=3)
    indices = [5, "sp", "sp", 10, 25, "sp", 30]
    expected_labels = {
        "tr": "2 <space> <space> 3 2 <space> 3",
        "va": "1 <space> <space> 0 0 <space> 3",
    }
    data_module.get_indices = MagicMock(return_value=indices)
    data_module.prepare_data()

    for partition in ("tr", "va"):
        # check generated images
        image_ids = [
            f"{partition}-{i}" for i in range(data_module.n[partition])
        ]
        assert set(os.listdir(data_module.root / partition)) == {
            img + ".jpg"
            for img in image_ids
        }

        # check generated ground-truth file
        gt_file = data_module.root / f"{partition}.gt"
        assert gt_file.exists()
        lines = [l.strip() for l in gt_file.read_text().split("\n") if l]
        for i, l in enumerate(lines):
            img_id, labels = l.split(maxsplit=1)
            assert img_id == image_ids[i]
            assert labels == expected_labels[partition]

        # check generated indices file
        indices_file = data_module.root / f"{partition}.indices"
        assert indices_file.exists()
        lines = [l.strip() for l in indices_file.read_text().split("\n") if l]
        for i, l in enumerate(lines):
            img_id, indices_str = l.split(maxsplit=1)
            assert img_id == image_ids[i]
            assert indices_str == repr([str(x) for x in indices])
예제 #11
0
def test_get_indices_without_spaces():
    expected = [5, 10, 15, 20, 25, 30, 35, 40]
    randint = patch("numpy.random.randint", return_value=len(expected))
    choice = patch("numpy.random.choice", return_value=expected)
    with randint, choice:
        assert DummyMNISTLines.get_indices(10, 0) == expected