def eval_h5(conf, ckpt): """ Train model for a number of steps. Args: conf: configuration dictionary ckpt: restore from ckpt """ cropw = conf['cropw'] mb_size = conf['mb_size'] path_tmp = conf['path_tmp'] n_epochs = conf['n_epochs'] cw = conf['cw'] grad_norm_thresh = conf['grad_norm_thresh'] # Prepare data tr_stream, te_stream = tools.prepare_data(conf) n_tr = tr_stream.dataset.num_examples n_te = te_stream.dataset.num_examples with tf.Graph().as_default(), tf.device('/cpu:0' if FLAGS.dev_assign else None): # Placeholders Xs = [tf.placeholder(tf.float32, [None, cw, cw, 1], name='X_%02d' % i) \ for i in range(FLAGS.num_gpus)] Ys = [tf.placeholder(tf.float32, [None, cw - 2*cropw, cw - 2*cropw, 1], name='Y_%02d' % i) \ for i in range(FLAGS.num_gpus)] # Calculate the gradients for each model tower tower_grads = [] y_splits = [] for i in range(FLAGS.num_gpus): with tf.device(('/gpu:%d' % i) if FLAGS.dev_assign else None): with tf.name_scope('%s_%02d' % (FLAGS.tower_name, i)) as scope: # Calculate the loss for one tower. This function constructs # the entire model but shares the variables across all towers. y_split = model.inference(Xs[i], conf) y_splits.append(y_split) total_loss = model.loss(y_split, Ys[i], conf, scope) # Reuse variables for the next tower. tf.get_variable_scope().reuse_variables() y = tf.concat(0, y_splits, name='y') # Tensorflow boilerplate sess, saver, summ_writer, summ_op = tools.tf_boilerplate(None, conf, ckpt) # Evaluation psnr_tr = eval_epoch(Xs, Ys, y, sess, tr_stream, cropw) psnr_te = eval_epoch(Xs, Ys, y, sess, te_stream, cropw) print('approx psnr_tr=%.3f' % psnr_tr) print('approx psnr_te=%.3f' % psnr_te) tr_stream.close() te_stream.close()
def train(conf, ckpt=None): """ Train model for a number of steps. Args: conf: configuration dictionary ckpt: restore from ckpt """ cropw = conf['cropw'] mb_size = conf['mb_size'] path_tmp = conf['path_tmp'] n_epochs = conf['n_epochs'] cw = conf['cw'] grad_norm_thresh = conf['grad_norm_thresh'] tools.reset_tmp(path_tmp) # Prepare data tr_stream, te_stream = tools.prepare_data(conf) n_tr = tr_stream.dataset.num_examples n_te = te_stream.dataset.num_examples with tf.Graph().as_default(), tf.device('/cpu:0' if FLAGS.dev_assign else None): # Exponential decay learning rate global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), dtype=tf.int32, trainable=False) lr = tools.exp_decay_lr(global_step, n_tr, conf) # Create an optimizer that performs gradient descent opt = tf.train.AdamOptimizer(lr) # Placeholders Xs = [tf.placeholder(tf.float32, [None, cw, cw, 1], name='X_%02d' % i) \ for i in range(FLAGS.num_gpus)] Ys = [tf.placeholder(tf.float32, [None, cw - 2*cropw, cw - 2*cropw, 1], name='Y_%02d' % i) \ for i in range(FLAGS.num_gpus)] # Calculate the gradients for each model tower tower_grads = [] y_splits = [] for i in range(FLAGS.num_gpus): with tf.device(('/gpu:%d' % i) if FLAGS.dev_assign else None): with tf.name_scope('%s_%02d' % (FLAGS.tower_name, i)) as scope: # Calculate the loss for one tower. This function constructs # the entire model but shares the variables across all towers. y_split = model.inference(Xs[i], conf) y_splits.append(y_split) total_loss = model.loss(y_split, Ys[i], conf, scope) # Calculate the gradients for the batch of data on this tower. gvs = opt.compute_gradients(total_loss) # Optionally clip gradients. if grad_norm_thresh > 0: gvs = tools.clip_by_norm(gvs, grad_norm_thresh) # Keep track of the gradients across all towers. tower_grads.append(gvs) # Reuse variables for the next tower. tf.get_variable_scope().reuse_variables() # Retain the summaries from the final tower. summs = tf.get_collection(tf.GraphKeys.SUMMARIES, scope) y = tf.concat(0, y_splits, name='y') # We must calculate the mean of each gradient. Note that this is the # synchronization point across all towers. gvs = tools.average_gradients(tower_grads) # Apply the gradients to adjust the shared variables. apply_grad_op = opt.apply_gradients(gvs, global_step=global_step) # Add a summary to track the learning rate. summs.append(tf.scalar_summary('learning_rate', lr)) # Add histograms for gradients. for g, v in gvs: if g: v_name = re.sub('%s_[0-9]*/' % FLAGS.tower_name, '', v.op.name) summs.append(tf.histogram_summary(v_name + '/gradients', g)) # Tensorflow boilerplate sess, saver, summ_writer, summ_op = tools.tf_boilerplate(summs, conf, ckpt) # Baseline error bpsnr_tr = tools.baseline_psnr(tr_stream) bpsnr_te = tools.baseline_psnr(te_stream) print('approx baseline psnr_tr=%.3f' % bpsnr_tr) print('approx baseline psnr_te=%.3f' % bpsnr_te) # Train format_str = ('%s| %04d PSNR=%.3f (Tr: %.1fex/s; %.1fs/batch)' '(Te: %.1fex/s; %.1fs/batch)') step = 0 for epoch in range(n_epochs): print('--- Epoch %d ---' % epoch) # Training for X_c, y_c in tr_stream.get_epoch_iterator(): y_c = y_c[:, cropw:-cropw, cropw:-cropw] chunk_size = X_c.shape[0] gpu_chunk = chunk_size // FLAGS.num_gpus dict_input1 = [(Xs[i], X_c[i*gpu_chunk : \ ((i + 1)*gpu_chunk) \ if (i != FLAGS.num_gpus - 1) \ else chunk_size]) \ for i in range(FLAGS.num_gpus)] dict_input2 = [(Ys[i], y_c[i*gpu_chunk : \ ((i + 1)*gpu_chunk) \ if (i != FLAGS.num_gpus - 1) \ else chunk_size]) \ for i in range(FLAGS.num_gpus)] feed = dict(dict_input1 + dict_input2) start_time = time.time() sess.run(apply_grad_op, feed_dict=feed) duration_tr = time.time() - start_time if step % 10 == 0: feed2 = dict(dict_input1) start_time = time.time() y_eval = sess.run(y, feed_dict=feed2) duration_eval = time.time() - start_time psnr = tools.eval_psnr(y_c, y_eval) ex_per_step_tr = mb_size * FLAGS.num_gpus / duration_tr ex_per_step_eval = mb_size * FLAGS.num_gpus / duration_eval print(format_str % (datetime.now().time(), step, psnr, ex_per_step_tr, float(duration_tr / FLAGS.num_gpus), ex_per_step_eval, float(duration_eval / FLAGS.num_gpus))) if step % 25 == 0: summ_str = sess.run(summ_op, feed_dict=feed) summ_writer.add_summary(summ_str, step) if step % 100 == 0: saver.save(sess, os.path.join(path_tmp, 'ckpt'), global_step=step) step += 1 # Evaluation #psnr_tr = eval_epoch(Xs, Ys, y, sess, tr_stream, cropw) #psnr_te = eval_epoch(Xs, Ys, y, sess, te_stream, cropw) #print('approx psnr_tr=%.3f' % psnr_tr) #print('approx psnr_te=%.3f' % psnr_te) saver.save(sess, os.path.join(path_tmp, 'ckpt'), global_step=step) saver.save(sess, os.path.join(path_tmp, 'ckpt'), global_step=step) tr_stream.close() te_stream.close()