/
save.py
132 lines (124 loc) · 5.07 KB
/
save.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import tensorflow as tf
import numpy as np
import os
from utils import eprint, listdir_files, reset_random, create_session, BatchPNG
from data import get_data, data_arguments
class Save:
def __init__(self, config):
self.dataset = None
self.save_dir = None
self.training = None
self.testing = None
self.num_epochs = None
self.max_steps = None
self.random_seed = None
self.batch_size = None
self.dtype = None
# copy all the properties from config object
self.config = config
self.__dict__.update(config.__dict__)
def initialize(self):
# create save directory
if os.path.exists(self.save_dir):
eprint('Confirm removing {}\n[Y/n]'.format(self.save_dir))
if input() != 'Y':
import sys
sys.exit()
import shutil
shutil.rmtree(self.save_dir)
eprint('Removed: ' + self.save_dir)
os.makedirs(self.save_dir)
# set deterministic random seed
if self.random_seed is not None:
reset_random(self.random_seed)
def get_dataset(self):
files = listdir_files(self.dataset, filter_ext=['.jpeg', '.jpg', '.png'])
# random shuffle
import random
random.shuffle(files)
# size of dataset
self.epoch_steps = len(files) // self.batch_size
self.epoch_size = self.epoch_steps * self.batch_size
if not self.training:
self.num_epochs = 1
self.max_steps = self.epoch_steps
if self.max_steps is None:
self.max_steps = self.epoch_steps * self.num_epochs
else:
self.num_epochs = (self.max_steps + self.epoch_steps - 1) // self.epoch_steps
self.config.num_epochs = self.num_epochs
self.files = files[:self.epoch_size]
eprint('data set: {}\nepoch steps: {}\nnum epochs: {}\nmax steps: {}\n'
.format(len(self.files), self.epoch_steps, self.num_epochs, self.max_steps))
def build_graph(self):
with tf.device('/cpu:0'):
self.inputs, self.labels = get_data(self.config, self.files,
is_training=self.training, is_testing=self.testing)
def run(self, sess):
epoch_len = len(str(self.num_epochs - 1))
step_len = len(str(self.epoch_steps - 1))
for epoch in range(self.num_epochs):
save_dir = os.path.join(self.save_dir, '{:0>{epoch_len}}'
.format(epoch, epoch_len=epoch_len))
if not os.path.exists(save_dir):
os.makedirs(save_dir)
num_steps = min(self.epoch_steps, self.max_steps - self.epoch_steps * epoch)
for step in range(num_steps):
ret_inputs, ret_labels = sess.run((self.inputs, self.labels))
if self.batch_size == 1:
ret_inputs = ret_inputs[0]
ret_labels = ret_labels[0]
# convert dtype
if self.dtype == np.uint8:
ret_inputs *= 255
ret_labels *= 255
elif self.dtype == np.uint16:
ret_inputs *= 65535
ret_labels *= 65535
if self.dtype != np.float32:
ret_inputs = ret_inputs.astype(self.dtype)
ret_labels = ret_labels.astype(self.dtype)
# save to compressed npz file
ofile = os.path.join(save_dir, '{:0>{step_len}}'
.format(step, step_len=step_len))
np.savez_compressed(ofile, inputs=ret_inputs, labels=ret_labels)
def __call__(self):
self.initialize()
self.get_dataset()
with tf.Graph().as_default():
self.build_graph()
with create_session() as sess:
self.run(sess)
def main(argv=None):
# arguments parsing
import argparse
argp = argparse.ArgumentParser()
# testing parameters
argp.add_argument('dataset')
argp.add_argument('save_dir')
argp.add_argument('--training', action='store_true')
argp.add_argument('--testing', action='store_true')
argp.add_argument('--num-epochs', type=int, default=1)
argp.add_argument('--max-steps', type=int)
argp.add_argument('--random-seed', type=int)
argp.add_argument('--batch-size', type=int, default=1)
# data parameters
argp.add_argument('--dtype', type=int, default=3)
argp.add_argument('--data-format', default='NCHW')
argp.add_argument('--patch-height', type=int, default=128)
argp.add_argument('--patch-width', type=int, default=128)
argp.add_argument('--in-channels', type=int, default=3)
argp.add_argument('--out-channels', type=int, default=3)
# pre-processing parameters
data_arguments(argp)
# model parameters
argp.add_argument('--scaling', type=int, default=1)
# parse
args = argp.parse_args(argv)
args.dtype = [np.uint8, np.uint16, np.float16, np.float32, np.float64][args.dtype]
# save dataset
save = Save(args)
save()
if __name__ == '__main__':
import sys
main(sys.argv[1:])