def setup(train_path, fixed_input_height=0): seed = 31102020 seed_everything(seed) n = 10**4 data_module = DummyMNISTLines(tr_n=n, va_n=int(0.1 * n), samples_per_space=5) print("Generating data...") data_module.prepare_data() syms = str(train_path / "syms") syms_table = SymbolsTable() for k, v in data_module.syms.items(): syms_table.add(v, k) syms_table.save(syms) model( syms, adaptive_pooling="avgpool-3", fixed_input_height=fixed_input_height, save_model=True, common=CommonArgs(train_path=train_path), crnn=CreateCRNNArgs( cnn_num_features=[16, 32, 48, 64], # data is random so minimal RNN layer # because there are no long term dependencies rnn_units=32, rnn_layers=1, rnn_dropout=0, ), ) return seed, data_module, syms
def test_call_with_symbols_table(caplog): st = SymbolsTable() for k, v in {"a": 0, "b": 1, "<space>": 2, "<": 3}.items(): st.add(k, v) t = ToTensor(st) x = "a < b <space> a ö" y = t(x) assert y == [0, 3, 1, 2, 0, None] assert caplog.messages.count( 'Could not find "ö" in the symbols table') == 1
def prepare_data(dir, image_sequencer="avgpool-8"): seed_everything(0x12345) data_module = DummyMNISTLines(samples_per_space=5) data_module.prepare_data() prepare_model(dir, image_sequencer) # prepare syms file syms = dir / "syms" syms_table = SymbolsTable() for k, v in data_module.syms.items(): syms_table.add(v, k) syms_table.save(syms) # prepare img dirs img_dirs = [str(data_module.root / p) for p in ("tr", "va")] return syms, img_dirs, data_module
def test_decode_on_dummy_mnist_lines_data(tmpdir, nprocs): # prepare data seed_everything(0x12345) data_module = DummyMNISTLines(tr_n=0, va_n=5, batch_size=3, samples_per_space=3) data_module.prepare_data() # prepare model file model_args = [(3, 3), 12] ModelSaver(tmpdir).save(DummyModel, *model_args) # prepare ckpt file ckpt = tmpdir / "model.ckpt" torch.save(DummyModel(*model_args).state_dict(), str(ckpt)) # prepare syms file syms = tmpdir / "syms" syms_table = SymbolsTable() for k, v in data_module.syms.items(): syms_table.add(v, k) syms_table.save(syms) # prepare img list img_list = tmpdir / "img_list" img_list.write_text( "\n".join(f"va-{i}" for i in range(data_module.n["va"])), "utf-8") args = [ syms, img_list, f"--img_dirs={[str(data_module.root / 'va')]}", f"--common.checkpoint={ckpt}", f"--common.train_path={tmpdir}", f"--data.batch_size={data_module.batch_size}", ] if nprocs > 1: args.append("--trainer.accelerator=ddp_cpu") args.append(f"--trainer.num_processes={nprocs}") stdout, stderr = call_script(script.__file__, args) print(f"Script stdout:\n{stdout}") print(f"Script stderr:\n{stderr}") img_ids = [l.split(" ", maxsplit=1)[0] for l in stdout.strip().split("\n")] assert sorted(img_ids) == [f"va-{i}" for i in range(data_module.n["va"])] assert "Using checkpoint" in stderr