def setup(train_path, fixed_input_height=0): seed = 31102020 seed_everything(seed) n = 10**4 data_module = DummyMNISTLines(tr_n=n, va_n=int(0.1 * n), samples_per_space=5) print("Generating data...") data_module.prepare_data() syms = str(train_path / "syms") syms_table = SymbolsTable() for k, v in data_module.syms.items(): syms_table.add(v, k) syms_table.save(syms) model( syms, adaptive_pooling="avgpool-3", fixed_input_height=fixed_input_height, save_model=True, common=CommonArgs(train_path=train_path), crnn=CreateCRNNArgs( cnn_num_features=[16, 32, 48, 64], # data is random so minimal RNN layer # because there are no long term dependencies rnn_units=32, rnn_layers=1, rnn_dropout=0, ), ) return seed, data_module, syms
def test_netout_on_dummy_mnist_lines_data(tmpdir, nprocs): seed_everything(0x12345) # prepare data data_module = DummyMNISTLines(tr_n=0, va_n=5, batch_size=3, samples_per_space=3) data_module.prepare_data() # prepare model file final_size, classes = 3, 12 # 12 == 10 digits + space + ctc model_args = [(final_size,) * 2, classes] ModelSaver(tmpdir).save(DummyModel, *model_args) # prepare ckpt file ckpt = tmpdir / "model.ckpt" torch.save(DummyModel(*model_args).state_dict(), str(ckpt)) # prepare img list img_list = tmpdir / "img_list" img_list.write_text( "\n".join(f"va-{i}" for i in range(data_module.n["va"])), "utf-8" ) args = [ img_list, f"--img_dirs=[{str(data_module.root / 'va')}]", f"--common.train_path={tmpdir}", f"--common.checkpoint={ckpt}", f"--common.experiment_dirname={tmpdir}", f"--data.batch_size={data_module.batch_size}", "--netout.output_transform=softmax", "--netout.digits=3", "--netout.lattice=lattice", "--netout.matrix=matrix", ] if nprocs > 1: args.append("--trainer.accelerator=ddp_cpu") args.append(f"--trainer.num_processes={nprocs}") stdout, stderr = call_script(script.__file__, args) print(f"Script stdout:\n{stdout}") print(f"Script stderr:\n{stderr}") assert "Using checkpoint" in stderr lattice = tmpdir / "lattice" assert lattice.exists() lines = [l.strip() for l in lattice.read_text("utf-8").split("\n") if l] n = data_module.n["va"] assert len(lines) == n * (final_size * classes + 1) + n # this is harder to test so do some basic checks matrix = tmpdir / "matrix" assert matrix.exists() assert len(matrix.read_binary()) > 0
def prepare_data(dir, image_sequencer="avgpool-8"): seed_everything(0x12345) data_module = DummyMNISTLines(samples_per_space=5) data_module.prepare_data() prepare_model(dir, image_sequencer) # prepare syms file syms = dir / "syms" syms_table = SymbolsTable() for k, v in data_module.syms.items(): syms_table.add(v, k) syms_table.save(syms) # prepare img dirs img_dirs = [str(data_module.root / p) for p in ("tr", "va")] return syms, img_dirs, data_module
def test_concatenate(): h, w = 10, 15 dataset = [ (np.full((h, w), 11), "a"), (np.full((h, w), 12), "b"), (np.full((h, w), 13), "foo"), (np.full((h, w), 14), "bar"), ] indices = [0, 3, 1, "sp", 2] img, txt, mask = DummyMNISTLines.concatenate(dataset, h, w, indices, space_sym="test") # check image values assert np.all(img[:, :w] == 11) assert np.all(img[:, w:w * 2] == 14) assert np.all(img[:, w * 2:w * 3] == 12) assert np.all(img[:, w * 3:w * 4] == 0) assert np.all(img[:, w * 4:] == 13) # check label text assert txt == "a bar b test foo" # check mask values assert np.all(mask[:, :w] == 0) assert np.all(mask[:, w:w * 2] == 0) assert np.all(mask[:, w * 2:w * 3] == 0) assert np.all(mask[:, w * 3:w * 4] == 1) assert np.all(mask[:, w * 4:] == 0)
def test_get_indices_with_spaces(): choices1 = [5, 10, 15, 20, 25, 30, 35, 40] choices2 = [1, 5, 6, 7] randint = patch("numpy.random.randint", return_value=len(choices1)) choice = patch("numpy.random.choice", side_effect=[choices1, choices2]) with randint, choice: out = DummyMNISTLines.get_indices(10, 0, samples_per_space=3) assert out == [5, "sp", 10, 15, 20, 25, "sp", 30, "sp", 35, "sp", 40]
def test_decode_on_dummy_mnist_lines_data(tmpdir, nprocs): # prepare data seed_everything(0x12345) data_module = DummyMNISTLines(tr_n=0, va_n=5, batch_size=3, samples_per_space=3) data_module.prepare_data() # prepare model file model_args = [(3, 3), 12] ModelSaver(tmpdir).save(DummyModel, *model_args) # prepare ckpt file ckpt = tmpdir / "model.ckpt" torch.save(DummyModel(*model_args).state_dict(), str(ckpt)) # prepare syms file syms = tmpdir / "syms" syms_table = SymbolsTable() for k, v in data_module.syms.items(): syms_table.add(v, k) syms_table.save(syms) # prepare img list img_list = tmpdir / "img_list" img_list.write_text( "\n".join(f"va-{i}" for i in range(data_module.n["va"])), "utf-8") args = [ syms, img_list, f"--img_dirs={[str(data_module.root / 'va')]}", f"--common.checkpoint={ckpt}", f"--common.train_path={tmpdir}", f"--data.batch_size={data_module.batch_size}", ] if nprocs > 1: args.append("--trainer.accelerator=ddp_cpu") args.append(f"--trainer.num_processes={nprocs}") stdout, stderr = call_script(script.__file__, args) print(f"Script stdout:\n{stdout}") print(f"Script stderr:\n{stderr}") img_ids = [l.split(" ", maxsplit=1)[0] for l in stdout.strip().split("\n")] assert sorted(img_ids) == [f"va-{i}" for i in range(data_module.n["va"])] assert "Using checkpoint" in stderr
def test_netout_callback(tmpdir, num_processes): data_module = DummyMNISTLines(batch_size=2, va_n=12) trainer = DummyTrainer( default_root_dir=tmpdir, limit_test_batches=3, callbacks=[Netout(writers=[__TestWriter()])], accelerator="ddp_cpu" if num_processes > 1 else None, num_processes=num_processes, ) trainer.test(DummyEvaluator(), datamodule=data_module)
def test_segmentation_callback(tmpdir, num_processes, kwargs, img_id, segm): data_module = DummyMNISTLines(batch_size=2, va_n=12) trainer = DummyTrainer( default_root_dir=tmpdir, limit_test_batches=3, callbacks=[ __TestSegmentation(img_id, segm, data_module.syms, **kwargs) ], accelerator="ddp_cpu" if num_processes > 1 else None, num_processes=num_processes, ) trainer.test(DummyEvaluator(), datamodule=data_module)
def test_decode(tmpdir, num_processes, kwargs, img_id, hyp): module = DummyEvaluator() data_module = DummyMNISTLines(batch_size=2, va_n=12, samples_per_space=10) decode_callback = __TestDecode(img_id, hyp, syms=data_module.syms, **kwargs) trainer = DummyTrainer( default_root_dir=tmpdir, limit_test_batches=3, callbacks=[decode_callback], accelerator="ddp_cpu" if num_processes > 1 else None, num_processes=num_processes, ) trainer.test(module, datamodule=data_module)
def test_prepare_data(tmpdir): data_module = DummyMNISTLines(max_length=5, tr_n=5, va_n=3) indices = [5, "sp", "sp", 10, 25, "sp", 30] expected_labels = { "tr": "2 <space> <space> 3 2 <space> 3", "va": "1 <space> <space> 0 0 <space> 3", } data_module.get_indices = MagicMock(return_value=indices) data_module.prepare_data() for partition in ("tr", "va"): # check generated images image_ids = [ f"{partition}-{i}" for i in range(data_module.n[partition]) ] assert set(os.listdir(data_module.root / partition)) == { img + ".jpg" for img in image_ids } # check generated ground-truth file gt_file = data_module.root / f"{partition}.gt" assert gt_file.exists() lines = [l.strip() for l in gt_file.read_text().split("\n") if l] for i, l in enumerate(lines): img_id, labels = l.split(maxsplit=1) assert img_id == image_ids[i] assert labels == expected_labels[partition] # check generated indices file indices_file = data_module.root / f"{partition}.indices" assert indices_file.exists() lines = [l.strip() for l in indices_file.read_text().split("\n") if l] for i, l in enumerate(lines): img_id, indices_str = l.split(maxsplit=1) assert img_id == image_ids[i] assert indices_str == repr([str(x) for x in indices])
def test_get_indices_without_spaces(): expected = [5, 10, 15, 20, 25, 30, 35, 40] randint = patch("numpy.random.randint", return_value=len(expected)) choice = patch("numpy.random.choice", return_value=expected) with randint, choice: assert DummyMNISTLines.get_indices(10, 0) == expected