def _Train(): params = dict() params['batch_size'] = FLAGS.batch_size params['seq_len'] = FLAGS.sequence_length params['image_size'] = FLAGS.image_size params['is_training'] = True params['norm_scale'] = FLAGS.norm_scale params['scale'] = FLAGS.scale params['learning_rate'] = FLAGS.learning_rate params['l2_loss'] = FLAGS.l2_loss params['reconstr_loss'] = FLAGS.reconstr_loss params['kl_loss'] = FLAGS.kl_loss train_dir = os.path.join(FLAGS.log_root, 'train') images = reader.ReadInput(FLAGS.data_filepattern, shuffle=True, params=params) images *= params['scale'] # Increase the value makes training much faster. image_diff_list = reader.SequenceToImageAndDiff(images) model = cross_conv_model.CrossConvModel(image_diff_list, params) model.Build() tf.contrib.tfprof.model_analyzer.print_model_analysis( tf.get_default_graph()) summary_writer = tf.summary.FileWriter(train_dir) sv = tf.train.Supervisor(logdir=FLAGS.log_root, summary_op=None, is_chief=True, save_model_secs=60, global_step=model.global_step) sess = sv.prepare_or_wait_for_session( FLAGS.master, config=tf.ConfigProto(allow_soft_placement=True)) total_loss = 0.0 step = 0 sample_z_mean = np.zeros(model.z_mean.get_shape().as_list()) sample_z_stddev_log = np.zeros(model.z_stddev_log.get_shape().as_list()) sample_step = 0 while True: _, loss_val, total_steps, summaries, z_mean, z_stddev_log = sess.run([ model.train_op, model.loss, model.global_step, model.summary_op, model.z_mean, model.z_stddev_log ]) sample_z_mean += z_mean sample_z_stddev_log += z_stddev_log total_loss += loss_val step += 1 sample_step += 1 if step % 100 == 0: summary_writer.add_summary(summaries, total_steps) sys.stderr.write('step: %d, loss: %f\n' % (total_steps, total_loss / step)) total_loss = 0.0 step = 0 # Sampled z is used for eval. # It seems 10k is better than 1k. Maybe try 100k next? if sample_step % 10000 == 0: with tf.gfile.Open(os.path.join(FLAGS.log_root, 'z_mean.npy'), 'w') as f: np.save(f, sample_z_mean / sample_step) with tf.gfile.Open( os.path.join(FLAGS.log_root, 'z_stddev_log.npy'), 'w') as f: np.save(f, sample_z_stddev_log / sample_step) sample_z_mean = np.zeros(model.z_mean.get_shape().as_list()) sample_z_stddev_log = np.zeros( model.z_stddev_log.get_shape().as_list()) sample_step = 0
def _Eval(): params = dict() params['batch_size'] = FLAGS.batch_size params['seq_len'] = FLAGS.sequence_length params['image_size'] = FLAGS.image_size params['is_training'] = False params['norm_scale'] = FLAGS.norm_scale params['scale'] = FLAGS.scale params['l2_loss'] = FLAGS.l2_loss params['reconstr_loss'] = FLAGS.reconstr_loss params['kl_loss'] = FLAGS.kl_loss eval_dir = os.path.join(FLAGS.log_root, 'eval') images = reader.ReadInput(FLAGS.data_filepattern, shuffle=False, params=params) images *= params['scale'] # Increase the value makes training much faster. image_diff_list = reader.SequenceToImageAndDiff(images) model = cross_conv_model.CrossConvModel(image_diff_list, params) model.Build() summary_writer = tf.summary.FileWriter(eval_dir) saver = tf.train.Saver() sess = tf.Session('', config=tf.ConfigProto(allow_soft_placement=True)) tf.train.start_queue_runners(sess) while True: time.sleep(60) try: ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root) except tf.errors.OutOfRangeError as e: sys.stderr.write('Cannot restore checkpoint: %s\n' % e) continue if not (ckpt_state and ckpt_state.model_checkpoint_path): sys.stderr.write('No model to eval yet at %s\n' % FLAGS.log_root) continue sys.stderr.write('Loading checkpoint %s\n' % ckpt_state.model_checkpoint_path) saver.restore(sess, ckpt_state.model_checkpoint_path) # Use the empirical distribution of z from training set. if not tf.gfile.Exists(os.path.join(FLAGS.log_root, 'z_mean.npy')): sys.stderr.write('No z at %s\n' % FLAGS.log_root) continue with tf.gfile.Open(os.path.join(FLAGS.log_root, 'z_mean.npy')) as f: sample_z_mean = np.load(io.BytesIO(f.read())) with tf.gfile.Open(os.path.join(FLAGS.log_root, 'z_stddev_log.npy')) as f: sample_z_stddev_log = np.load(io.BytesIO(f.read())) total_loss = 0.0 for _ in xrange(FLAGS.eval_batch_count): loss_val, total_steps, summaries = sess.run( [model.loss, model.global_step, model.summary_op], feed_dict={ model.z_mean: sample_z_mean, model.z_stddev_log: sample_z_stddev_log }) total_loss += loss_val summary_writer.add_summary(summaries, total_steps) sys.stderr.write('steps: %d, loss: %f\n' % (total_steps, total_loss / FLAGS.eval_batch_count))