def build_graph(self): with tf.device(self.device): self.model = SRN(self.config) self.g_loss = self.model.build_train() self.global_step = tf.train.get_or_create_global_step() self.g_train_op = self.model.train(self.global_step) self.train_summary, self.loss_summary = self.model.get_summaries()
def main(argv=None): # arguments parsing import argparse argp = argparse.ArgumentParser() # testing parameters argp.add_argument('dataset') argp.add_argument('--num-epochs', type=int, default=1) argp.add_argument('--random-seed', type=int) argp.add_argument('--device', default='/gpu:0') argp.add_argument('--postfix', default='') argp.add_argument('--train-dir', default='./train{postfix}.tmp') argp.add_argument('--test-dir', default='./test{postfix}.tmp') argp.add_argument('--log-file', default='test.log') argp.add_argument('--batch-size', type=int, default=1) # data parameters argp.add_argument('--dtype', type=int, default=2) argp.add_argument('--data-format', default='NCHW') argp.add_argument('--patch-height', type=int, default=512) argp.add_argument('--patch-width', type=int, default=512) argp.add_argument('--in-channels', type=int, default=3) argp.add_argument('--out-channels', type=int, default=3) # pre-processing parameters Data.add_arguments(argp) # model parameters SRN.add_arguments(argp) argp.add_argument('--scaling', type=int, default=1) # parse args = argp.parse_args(argv) args.train_dir = args.train_dir.format(postfix=args.postfix) args.test_dir = args.test_dir.format(postfix=args.postfix) args.dtype = [tf.int8, tf.float16, tf.float32, tf.float64][args.dtype] args.pre_down = True # run testing test = Test(args) test()
def build_graph(self): with tf.device(self.device): inputs = tf.placeholder(tf.float32, name='inputs') labels = tf.placeholder(tf.float32, name='labels') self.model = SRN(self.config) outputs = self.model.build_model(inputs) self.losses = list(test_losses(labels, outputs)) # post-processing for output with tf.device('/cpu:0'): # convert to NHWC format if self.config.data_format == 'NCHW': inputs = tf.transpose(inputs, [0, 2, 3, 1]) labels = tf.transpose(labels, [0, 2, 3, 1]) outputs = tf.transpose(outputs, [0, 2, 3, 1]) # PNG output self.pngs = (BatchPNG(inputs, self.batch_size) + BatchPNG(labels, self.batch_size) + BatchPNG(outputs, self.batch_size))
def main(argv=None): # arguments parsing import argparse argp = argparse.ArgumentParser() # training parameters argp.add_argument('dataset') argp.add_argument('--num-epochs', type=int, default=24) argp.add_argument('--max-steps', type=int) argp.add_argument('--random-seed', type=int) argp.add_argument('--device', default='/gpu:0') argp.add_argument('--postfix', default='') argp.add_argument('--pretrain-dir', default='') argp.add_argument('--train-dir', default='./train{postfix}.tmp') argp.add_argument('--restore', action='store_true') argp.add_argument('--save-steps', type=int, default=5000) argp.add_argument('--ckpt-period', type=int, default=600) argp.add_argument('--log-frequency', type=int, default=100) argp.add_argument('--log-file', default='train.log') argp.add_argument('--batch-size', type=int, default=32) argp.add_argument('--val-size', type=int, default=256) # data parameters argp.add_argument('--dtype', type=int, default=2) argp.add_argument('--data-format', default='NCHW') argp.add_argument('--patch-height', type=int, default=128) argp.add_argument('--patch-width', type=int, default=128) argp.add_argument('--in-channels', type=int, default=3) argp.add_argument('--out-channels', type=int, default=3) # pre-processing parameters Data.add_arguments(argp) # model parameters SRN.add_arguments(argp) argp.add_argument('--scaling', type=int, default=1) # parse args = argp.parse_args(argv) args.train_dir = args.train_dir.format(postfix=args.postfix) args.dtype = [tf.int8, tf.float16, tf.float32, tf.float64][args.dtype] # run training train = Train(args) train()
def main(argv=None): # arguments parsing import argparse argp = argparse.ArgumentParser() # testing parameters argp.add_argument('--postfix', default='') argp.add_argument('--train-dir', default='./train{postfix}.tmp') argp.add_argument('--model-dir', default='./model{postfix}.tmp') # data parameters argp.add_argument('--dtype', type=int, default=2) argp.add_argument('--data-format', default='NCHW') argp.add_argument('--in-channels', type=int, default=3) argp.add_argument('--out-channels', type=int, default=3) # model parameters SRN.add_arguments(argp) argp.add_argument('--scaling', type=int, default=1) # parse args = argp.parse_args(argv) args.train_dir = args.train_dir.format(postfix=args.postfix) args.model_dir = args.model_dir.format(postfix=args.postfix) args.dtype = [tf.int8, tf.float16, tf.float32, tf.float64][args.dtype] # save model graph = Graph(args) graph()
class Graph: def __init__(self, config): self.postfix = None self.train_dir = None self.model_dir = None # copy all the properties from config object self.config = config self.__dict__.update(config.__dict__) def initialize(self): # arXiv 1509.09308 # a new class of fast algorithms for convolutional neural networks using Winograd's minimal filtering algorithms os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' # create testing directory if not os.path.exists(self.train_dir): raise FileNotFoundError('Could not find folder {}'.format( self.train_dir)) if os.path.exists(self.model_dir): eprint('Confirm removing {}\n[Y/n]'.format(self.model_dir)) if input() != 'Y': import sys sys.exit() import shutil shutil.rmtree(self.model_dir, ignore_errors=True) eprint('Removed: ' + self.model_dir) os.makedirs(self.model_dir) def build_graph(self): self.model = SRN(self.config) self.model.build_model() def build_saver(self): # a Saver object to restore the variables with mappings self.saver_r = tf.train.Saver(self.model.rvars) # a Saver object to save the variables without mappings self.saver_s = tf.train.Saver(self.model.svars) def run(self, sess): # save the GraphDef tf.train.write_graph(tf.get_default_graph(), self.model_dir, 'model.graphdef', as_text=True) # restore variables from checkpoint self.saver_r.restore(sess, tf.train.latest_checkpoint(self.train_dir)) # save the model parameters self.saver_s.export_meta_graph(os.path.join(self.model_dir, 'model.meta'), as_text=False, clear_devices=True, clear_extraneous_savers=True) self.saver_s.save(sess, os.path.join(self.model_dir, 'model'), write_meta_graph=False, write_state=False) def __call__(self): self.initialize() with tf.Graph().as_default(): self.build_graph() self.build_saver() with create_session() as sess: self.run(sess)
def build_graph(self): self.model = SRN(self.config) self.model.build_model()
class Train: def __init__(self, config): self.random_seed = None self.device = None self.postfix = None self.pretrain_dir = None self.train_dir = None self.restore = None self.save_steps = None self.ckpt_period = None self.log_frequency = None self.log_file = None self.batch_size = None self.val_size = None # copy all the properties from config object self.config = config self.__dict__.update(config.__dict__) def initialize(self): # arXiv 1509.09308 # a new class of fast algorithms for convolutional neural networks using Winograd's minimal filtering algorithms os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' # create training directory if not self.restore: if os.path.exists(self.train_dir): eprint('Confirm removing {}\n[Y/n]'.format(self.train_dir)) if input() != 'Y': import sys sys.exit() import shutil shutil.rmtree(self.train_dir, ignore_errors=True) eprint('Removed: ' + self.train_dir) os.makedirs(self.train_dir) # set deterministic random seed if self.random_seed is not None: reset_random(self.random_seed) def get_dataset(self): self.data = Data(self.config) self.epoch_steps = self.data.epoch_steps self.max_steps = self.data.max_steps # pre-computing validation set self.val_inputs = [] self.val_labels = [] for _inputs, _labels in self.data.gen_val(): self.val_inputs.append(_inputs) self.val_labels.append(_labels) def build_graph(self): with tf.device(self.device): self.model = SRN(self.config) self.g_loss = self.model.build_train() self.global_step = tf.train.get_or_create_global_step() self.g_train_op = self.model.train(self.global_step) self.train_summary, self.loss_summary = self.model.get_summaries() def build_saver(self): # a Saver object to restore the variables with mappings # only for restoring from pre-trained model if self.pretrain_dir and not self.restore: self.saver_pt = tf.train.Saver(self.model.rvars) # a Saver object to save recent checkpoints self.saver_ckpt = tf.train.Saver(max_to_keep=5, save_relative_paths=True) # a Saver object to save the variables without mappings # used for saving checkpoints throughout the entire training progress self.saver = tf.train.Saver(self.model.svars, max_to_keep=1 << 16, save_relative_paths=True) # save the graph self.saver.export_meta_graph(os.path.join(self.train_dir, 'model.meta'), as_text=False, clear_devices=True, clear_extraneous_savers=True) def train_session(self): self.train_writer = tf.summary.FileWriter(self.train_dir + '/train', tf.get_default_graph(), max_queue=20, flush_secs=120) self.val_writer = tf.summary.FileWriter(self.train_dir + '/val') return create_session() def run_sess(self, sess, global_step, data_gen, options=None, run_metadata=None): from datetime import datetime import time epoch = global_step // self.epoch_steps last_step = global_step + 1 >= self.max_steps logging = last_step or (self.log_frequency > 0 and global_step % self.log_frequency == 0) # training - train op inputs, labels = next(data_gen) feed_dict = { self.model.g_training: True, 'Input:0': inputs, 'Label:0': labels } if logging: fetch = (self.train_summary, self.g_train_op, self.model.g_losses_acc) summary, _, _ = sess.run(fetch, feed_dict, options, run_metadata) self.train_writer.add_summary(summary, global_step) else: fetch = (self.g_train_op, self.model.g_losses_acc) sess.run(fetch, feed_dict, options, run_metadata) # training - log summary if logging: # loss summary fetch = [self.loss_summary] + self.model.g_log_losses summary, train_loss = sess.run(fetch) self.train_writer.add_summary(summary, global_step) # logging time_current = time.time() duration = time_current - self.log_last self.log_last = time_current sec_batch = duration / self.log_frequency if self.log_frequency > 0 else 0 samples_sec = self.batch_size / sec_batch train_log = ('{}: epoch {}, step {}, train loss: {:.5}' ' ({:.1f} samples/sec, {:.3f} sec/batch)'.format( datetime.now(), epoch, global_step, train_loss, samples_sec, sec_batch)) eprint(train_log) # validation if logging: for inputs, labels in zip(self.val_inputs, self.val_labels): feed_dict = {'Input:0': inputs, 'Label:0': labels} fetch = [self.model.g_losses_acc] sess.run(fetch, feed_dict) # loss summary fetch = [self.loss_summary] + self.model.g_log_losses summary, val_loss = sess.run(fetch) self.val_writer.add_summary(summary, global_step) # logging val_log = ('{}: epoch {}, step {}, val loss: {:.5}'.format( datetime.now(), epoch, global_step, val_loss)) eprint(val_log) # log result for the last step if self.log_file and last_step: last_log = ( 'epoch {}, step {}, train loss: {:.5}, val loss: {:.5}'.format( epoch, global_step, train_loss, val_loss)) with open(self.log_file, 'a', encoding='utf-8') as fd: fd.write('Training No.{}\n'.format(self.postfix)) fd.write(self.train_dir + '\n') fd.write('{}\n'.format(datetime.now())) fd.write(last_log + '\n\n') def run(self, sess): import time # restore from checkpoint if self.restore and os.path.exists( os.path.join(self.train_dir, 'checkpoint')): lastest_ckpt = tf.train.latest_checkpoint(self.train_dir, 'checkpoint') self.saver_ckpt.restore(sess, lastest_ckpt) # restore pre-trained model elif self.pretrain_dir: self.saver_pt.restore(sess, os.path.join(self.pretrain_dir, 'model')) # otherwise, initialize from start else: initializers = (tf.initializers.global_variables(), tf.initializers.local_variables()) sess.run(initializers) # profiler profile_offset = 1000 + self.log_frequency // 2 profile_step = 10000 builder = tf.profiler.ProfileOptionBuilder profiler = tf.profiler.Profiler(sess.graph) # initialization self.log_last = time.time() ckpt_last = time.time() # dataset generator global_step = tf.train.global_step(sess, self.global_step) data_gen = self.data.gen_main(global_step) # run training session while True: # global step global_step = tf.train.global_step(sess, self.global_step) if global_step >= self.max_steps: eprint('Training finished at step={}'.format(global_step)) break # run session if global_step % profile_step == profile_offset: # profiling every few steps options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_meta = tf.RunMetadata() self.run_sess(sess, global_step, data_gen, options, run_meta) profiler.add_step(global_step, run_meta) # profile the parameters if global_step == profile_offset: ofile = os.path.join(self.train_dir, 'parameters.log') profiler.profile_name_scope( builder(builder.trainable_variables_parameter()). with_file_output(ofile).build()) # profile the timing of model operations ofile = os.path.join( self.train_dir, 'time_and_memory_{:0>7}.log'.format(global_step)) profiler.profile_operations( builder(builder.time_and_memory()).with_file_output( ofile).build()) # generate a timeline timeline = os.path.join(self.train_dir, 'timeline') profiler.profile_graph( builder(builder.time_and_memory()).with_step( global_step).with_timeline_output(timeline).build()) else: self.run_sess(sess, global_step, data_gen) # save checkpoints periodically or when training finished if self.ckpt_period > 0: time_current = time.time() if time_current - ckpt_last >= self.ckpt_period or global_step + 1 >= self.max_steps: ckpt_last = time_current self.saver_ckpt.save( sess, os.path.join(self.train_dir, 'model.ckpt'), global_step, 'checkpoint') # save model every few steps if self.save_steps > 0 and global_step % self.save_steps == 0: self.saver.save(sess, os.path.join( self.train_dir, 'model_{:0>7}'.format(global_step)), write_meta_graph=False, write_state=False) # auto detect problems and generate advice ALL_ADVICE = { 'ExpensiveOperationChecker': {}, 'AcceleratorUtilizationChecker': {}, 'JobChecker': {}, 'OperationChecker': {} } profiler.advise(ALL_ADVICE) def __call__(self): self.initialize() self.get_dataset() with tf.Graph().as_default(): self.build_graph() self.build_saver() with self.train_session() as sess: self.run(sess)
class Test: def __init__(self, config): self.random_seed = None self.device = None self.postfix = None self.train_dir = None self.test_dir = None self.log_file = None self.batch_size = None # copy all the properties from config object self.config = config self.__dict__.update(config.__dict__) def initialize(self): # arXiv 1509.09308 # a new class of fast algorithms for convolutional neural networks using Winograd's minimal filtering algorithms os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' # create testing directory if not os.path.exists(self.train_dir): raise FileNotFoundError('Could not find folder {}'.format( self.train_dir)) if os.path.exists(self.test_dir): eprint('Confirm removing {}\n[Y/n]'.format(self.test_dir)) if input() != 'Y': import sys sys.exit() import shutil shutil.rmtree(self.test_dir, ignore_errors=True) eprint('Removed: ' + self.test_dir) os.makedirs(self.test_dir) # set deterministic random seed if self.random_seed is not None: reset_random(self.random_seed) def get_dataset(self): self.data = Data(self.config) self.epoch_steps = self.data.epoch_steps self.max_steps = self.data.max_steps # pre-computing testing set self.test_inputs = [] self.test_labels = [] data_gen = self.data.gen_main() for _inputs, _labels in data_gen: self.test_inputs.append(_inputs) self.test_labels.append(_labels) def build_graph(self): with tf.device(self.device): inputs = tf.placeholder(tf.float32, name='inputs') labels = tf.placeholder(tf.float32, name='labels') self.model = SRN(self.config) outputs = self.model.build_model(inputs) self.losses = list(test_losses(labels, outputs)) # post-processing for output with tf.device('/cpu:0'): # convert to NHWC format if self.config.data_format == 'NCHW': inputs = tf.transpose(inputs, [0, 2, 3, 1]) labels = tf.transpose(labels, [0, 2, 3, 1]) outputs = tf.transpose(outputs, [0, 2, 3, 1]) # PNG output self.pngs = (BatchPNG(inputs, self.batch_size) + BatchPNG(labels, self.batch_size) + BatchPNG(outputs, self.batch_size)) def build_saver(self): # a Saver object to restore the variables with mappings self.saver = tf.train.Saver(self.model.rvars) def run_last(self, sess): # latest checkpoint ckpt = tf.train.latest_checkpoint(self.train_dir) self.saver.restore(sess, ckpt) # to be fetched fetch = self.losses + self.pngs losses_sum = [0 for _ in range(len(self.losses))] # run session for step in range(self.epoch_steps): feed_dict = { 'inputs:0': self.test_inputs[step], 'labels:0': self.test_labels[step] } ret = sess.run(fetch, feed_dict) ret_losses = ret[0:len(self.losses)] ret_pngs = ret[len(self.losses):] # sum of losses for i in range(len(self.losses)): losses_sum[i] += ret_losses[i] # save images _start = step * self.batch_size _stop = _start + self.batch_size _range = range(_start, _stop) ofiles = (['{:0>5}.0.inputs.png'.format(i) for i in _range] + ['{:0>5}.1.labels.png'.format(i) for i in _range] + [ '{:0>5}.2.outputs{}.png'.format(i, self.postfix) for i in _range ]) ofiles = [os.path.join(self.test_dir, f) for f in ofiles] for i in range(len(ret_pngs)): with open(ofiles[i], 'wb') as fd: fd.write(ret_pngs[i]) # summary if self.log_file: from datetime import datetime losses_mean = [l / self.epoch_steps for l in losses_sum] psnr = 10 * np.log10(1 / losses_mean[0]) if losses_mean[0] > 0 else 100 test_log = 'PSNR (RGB):{}, MAD (RGB): {}'\ .format(psnr, *losses_mean[1:]) with open(self.log_file, 'a', encoding='utf-8') as fd: fd.write('Testing No.{}\n'.format(self.postfix)) fd.write(self.test_dir + '\n') fd.write('{}\n'.format(datetime.now())) fd.write(test_log + '\n\n') def run_steps(self, sess): import re prefix = 'model_' # get checkpoints of every few steps ckpts = listdir_files(self.train_dir, recursive=False, filter_ext=['.index']) ckpts = [os.path.splitext(f)[0] for f in ckpts if prefix in f] ckpts.sort() stats = [] # test all the checkpoints for ckpt in ckpts: self.saver.restore(sess, ckpt) # to be fetched fetch = self.losses losses_sum = [0 for _ in range(len(self.losses))] # run session for step in range(self.epoch_steps): feed_dict = { 'inputs:0': self.test_inputs[step], 'labels:0': self.test_labels[step] } ret = sess.run(fetch, feed_dict) ret_losses = ret # sum of losses for i in range(len(self.losses)): losses_sum[i] += ret_losses[i] # summary losses_mean = [l / self.epoch_steps for l in losses_sum] # stats ckpt_num = re.findall(prefix + r'(\d+)', ckpt)[0] stats.append(np.array([float(ckpt_num)] + losses_mean)) # save stats import matplotlib.pyplot as plt stats = np.stack(stats) np.save(os.path.join(self.test_dir, 'stats.npy'), stats) # save plot fig, ax = plt.subplots() ax.set_title('Test Error with Training Progress') ax.set_xlabel('training steps') ax.set_ylabel('MAD (RGB)') ax.set_xscale('linear') ax.set_yscale('log') stats = stats[1:] ax.plot(stats[:, 0], stats[:, 2]) plt.tight_layout() plt.savefig(os.path.join(self.test_dir, 'stats.png')) plt.close() def __call__(self): self.initialize() self.get_dataset() with tf.Graph().as_default(): self.build_graph() self.build_saver() with create_session() as sess: self.run_last(sess) self.run_steps(sess)