-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_loading.py
117 lines (86 loc) · 3.09 KB
/
data_loading.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import numpy as np
import pylab as plt
import os
import csv
from torch.utils.data import DataLoader, Dataset
import processing
from blob_extraction import find_blobs
class ImageDataset(Dataset):
'''
subclass of pytorch's Dataset to provide extracted characters to a model.
Attributes:
path (String): path containing the data
raws (list of ndarrays): images stored in a list
truth (list of strings): ground truth strings of images
size (integer): height and width of input size for the model
blobs (list of ndarrays): contains size x size images of extracted letters
blob_truth (list of strings): ground truth for characters in blobs
'''
def __init__(self, path='data', size=28):
'''
initiaizes an ImagaDataset instance
Args:
path (string): path of dataset. must contain jpg images and a file named truth.txt
size (int): size of character images handed to the model
'''
# images
self.path = path
self.raws = self.load_raws()
self.truth = self.load_truth()
# characters
self.size = size
self.blobs = []
self.blob_truth = []
self.extract_blobs()
def load_raws(self):
'''
loads raw images from the data directory and stores it as ndarrays
Returns:
list of ndarrays
'''
raws = []
for file in os.listdir(self.path):
if file.endswith('.jpg'):
file_path = os.path.join(self.path, file)
raws.append(processing.load_img(file_path))
return raws
def load_truth(self):
'''
loads ground truth strings from "truth.txt" file
Returns:
a list of strings
'''
truth_path = os.path.join(self.path, 'truth.txt')
if not os.path.exists(truth_path):
raise IOError('There is no file "truth.txt" in {}'.format(self.path))
truth = []
with open(truth_path) as input:
for line in csv.reader(input):
truth.append(line[0])
return truth
def extract_blobs(self):
'''
function to perform blob extraction on raw images.
Data is only used if blob extraction finds the right number of characters in an image.
'''
for truth, img in zip(self.truth, self.raws):
chars = find_blobs(img)
if len(chars) == len(truth):
for j in range(len(chars)):
char = processing.rescale(chars[j], self.size)
self.blobs.append(char)
self.blob_truth.append(truth[j])
return
def __len__(self):
''' overrides funciton of Dataset class'''
return len(self.blobs)
def __getitem__(self, idx):
''' overrides function of Dataset class'''
return self.blobs[idx], self.blob_truth[idx]
if __name__ == "__main__":
imgs = ImageDataset()
for i in range(10):
char, label = imgs.__getitem__(idx=i)
plt.imshow(char)
plt.title(label)
plt.show()