/
an_img2pickle.py
148 lines (127 loc) · 4.81 KB
/
an_img2pickle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# -*- coding: utf-8 -*-
"""
Read Images and make pickle for Tensorflow
Created on Tue Mar 8 09:17:02 2016
@author: Aneesh
"""
#%% import and define base variables
import glob
import os
import numpy as np
import skimage
from skimage.io import imread_collection
from six.moves import cPickle as pickle
root_root = "/home/nyx/Desktop/Caffe_test/mult_fold_imgs/ZProj"
img_size = 400
#%% Read Dataset
def ab_load_im_fold(fold_name):
"""Load all images from a folder"""
file_list = glob.glob(os.path.join(fold_name, '*.tif'))
im_coll = imread_collection(file_list)
dataset = np.ndarray(shape = (len(im_coll),img_size,img_size),
dtype = np.float32)
for im_idx,im in enumerate(im_coll):
dataset[im_idx, :, :] = skimage.img_as_float(im)
print('Full dataset tensor:', dataset.shape)
print('Mean:', np.mean(dataset))
print('Standard deviation:', np.std(dataset))
return dataset
def an_pickle(root_root, force=False):
dataset_names = []
fold_list = os.listdir(root_root)
for fold_name in fold_list:
set_filename = fold_name + '.pickle'
write_path = os.path.join(root_root,set_filename)
dataset_names.append(set_filename)
if os.path.exists(set_filename) and not force:
# You may override by setting force=True.
print('%s already present - Skipping pickling.' % set_filename)
else:
print('Pickling %s.' % set_filename)
dataset = ab_load_im_fold(os.path.join(root_root, fold_name))
try:
with open(write_path, 'wb') as f:
pickle.dump(dataset, f, pickle.HIGHEST_PROTOCOL)
except Exception as e:
print('Unable to save data to', set_filename, ':', e)
return dataset_names
#
dataset_names = an_pickle(root_root)
#%%
def make_arrays(nb_rows, img_size):
if nb_rows:
dataset = np.ndarray((nb_rows, img_size, img_size), dtype=np.float32)
labels = np.ndarray(nb_rows, dtype=np.int32)
else:
dataset, labels = None, None
return dataset, labels
#%%
def merge_datasets(pickle_files, train_size, valid_size=0):
num_classes = len(dataset_names)
valid_dataset, valid_labels = make_arrays(valid_size, img_size)
train_dataset, train_labels = make_arrays(train_size, img_size)
vsize_per_class = valid_size // num_classes
tsize_per_class = train_size // num_classes
start_v, start_t = 0, 0
end_v, end_t = vsize_per_class, tsize_per_class
end_l = vsize_per_class+tsize_per_class
for label, im_set in enumerate(dataset_names):
try:
with open(os.path.join(root_root,im_set), "rb") as f:
img_set = pickle.load(f)
np.random.shuffle(img_set)
if valid_dataset is not None:
valid_letter = img_set[:vsize_per_class, :, :]
valid_dataset[start_v:end_v, :, :] = valid_letter
valid_labels[start_v:end_v] = label
start_v += vsize_per_class
end_v += vsize_per_class
print(valid_dataset.shape)
train_letter = img_set[vsize_per_class:end_l, :, :]
train_dataset[start_t:end_t, :, :] = train_letter
train_labels[start_t:end_t] = label
start_t += tsize_per_class
end_t += tsize_per_class
except Exception as e:
print('Unable to process data from', im_set, ':', e)
raise
return valid_dataset, valid_labels, train_dataset, train_labels
#%%
train_size = 700
valid_size = 27
valid_dataset, valid_labels, train_dataset, train_labels = merge_datasets(
dataset_names, train_size, valid_size)
print('Training:', train_dataset.shape, train_labels.shape)
print('Validation:', valid_dataset.shape, valid_labels.shape)
#%%
def randomize(dataset, labels):
permutation = np.random.permutation(labels.shape[0])
shuffled_dataset = dataset[permutation,:,:]
shuffled_labels = labels[permutation]
return shuffled_dataset, shuffled_labels
train_dataset, train_labels = randomize(train_dataset, train_labels)
valid_dataset, valid_labels = randomize(valid_dataset, valid_labels)
#%%
pickle_file = os.path.join(root_root,'all_cells_set.pickle')
try:
f = open(pickle_file, 'wb')
save = {
'train_dataset': train_dataset,
'train_labels': train_labels,
'valid_dataset': valid_dataset,
'valid_labels': valid_labels,
}
pickle.dump(save, f, pickle.HIGHEST_PROTOCOL)
f.close()
except Exception as e:
print('Unable to save data to', pickle_file, ':', e)
raise
#%%
valid_labels
#%% Scratch
d=0
for im_set in dataset_names:
img_set = pickle.load(open(os.path.join(root_root,im_set), "rb"))
d+=len(img_set)
print(len(img_set))
d