Exemplo n.º 1
0
def build_model(config):
    vocab = Vocab(config['vocab'])
    device = config['device']

    model = VietOCR(len(vocab), config['backbone'], config['cnn'],
                    config['transformer'], config['seq_modeling'])

    model = model.to(device)

    return model, vocab
Exemplo n.º 2
0
def test_loader():
    chars = 'aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ '

    vocab = Vocab(chars)
    s_gen = DataGen('./vietocr/tests/', 'sample.txt', vocab, 'cpu')

    iterator = s_gen.gen(30)
    for batch in iterator:
        assert batch['img'].shape[1]==3, 'image must have 3 channels'
        assert batch['img'].shape[2]==32, 'the height must be 32'
        print(batch['img'].shape, batch['tgt_input'].shape, batch['tgt_output'].shape, batch['tgt_padding_mask'].shape)
Exemplo n.º 3
0
def build_model(config):
    vocab = Vocab(config['vocab'])
    device = config['device']
    
    model = VietOCR(len(vocab), 
            ss=config['cnn']['pooling_stride_size'], ks=config['cnn']['pooling_kernel_size'], 
            **config['transformer'])
    
    model = model.to(device)

    return model, vocab
Exemplo n.º 4
0
def test_loader():
    with open("table_ocr/dict.txt", "r") as f:
        t = []
        for i in f.readlines():
            t.append(i.strip('\n'))
        character = set(t)
        character.update('\u2028')
    
    vocab = Vocab(chars=character)
    s_gen = DataGen('./vietocr/tests/', 'sample.txt', vocab, 'cuda:0', 32, 512)

    iterator = s_gen.gen(30)
    for batch in iterator:
        assert batch['img'].shape[1]==3, 'image must have 3 channels'
        assert batch['img'].shape[2]==32, 'the height must be 32'
        print(batch['img'].shape, batch['tgt_input'].shape, batch['tgt_output'].shape, batch['tgt_padding_mask'].shape)