def test(): # lmdb_path = "/share/zhui/reg_dataset/NIPS2014" lmdb_path = "/share/zhui/reg_dataset/IIIT5K_3000" train_dataset = LmdbDataset(root=lmdb_path, voc_type='ALLCASES_SYMBOLS', max_len=50) batch_size = 1 train_dataloader = data.DataLoader( train_dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=AlignCollate(imgH=64, imgW=256, keep_ratio=False)) for i, (images, labels, label_lens) in enumerate(train_dataloader): # visualization of input image # toPILImage = transforms.ToPILImage() images = images.permute(0, 2, 3, 1) images = to_numpy(images) images = images * 0.5 + 0.5 images = images * 255 for id, (image, label, label_len) in enumerate(zip(images, labels, label_lens)): image = Image.fromarray(np.uint8(image)) # image = toPILImage(image) image.show() print(image.size) print( labels2strs(label, train_dataset.id2char, train_dataset.char2id)) print(label_len.item()) input()
def test(): train_dataset = LmdbDataset( root=global_args.test_data_dir, voc_type=global_args.voc_type, max_len=global_args.max_len, num_samples=global_args.num_test, with_name=True) print(train_dataset.nSamples, 'samples') train_dataloader = data.DataLoader( train_dataset, batch_size=1, shuffle=False, num_workers=global_args.workers, collate_fn=AlignCollateWithNames(imgH=64, imgW=256, keep_ratio=False)) if global_args.image_path: out_html = open(os.path.join(global_args.image_path, 'index.html'), 'w') out_html.write('''<html> <body> <table> <tr><th>No</th><th>Image</th><th>Labels</th><th>Length</th><th>Name</th></tr> ''') else: out_html = None i = 1 max_len = 0 for images, labels, label_lens, image_names in train_dataloader: # visualization of input image # toPILImage = transforms.ToPILImage() images = images.permute(0,2,3,1) images = to_numpy(images) images = images * 0.5 + 0.5 images = images * 255 for image, label, label_len, image_name in zip(images, labels, label_lens, image_names): image = Image.fromarray(np.uint8(image)) label_str = labels2strs(label, train_dataset.id2char, train_dataset.char2id) if image_name is not None: image_name = image_name.decode('utf-8') else: image_name = '' l_len = label_len.item() if max_len < l_len: max_len = l_len if global_args.image_path: image_filename = f'image-{i:09d}.jpg' image.save(os.path.join(global_args.image_path, image_filename)) out_html.write( f'<tr><td>{i}</td>' f'<td><img src="{image_filename}" width="{image.width}" height="{image.height}" /></td>' f'<td>{label_str}</td><td>{l_len}</td><td>{image_name}</td></tr>\n') else: image.show() print(image.size) print(label_str, l_len) if image_name: print(image_name) input() i += 1 if out_html: out_html.write(f'</table>\n<p>The maximal label length is {max_len}.</p>\n</body>\n</html>\n') out_html.close()