def train(dataset_frame1, dataset_frame2, dataset_frame3): """Trains a model.""" with tf.Graph().as_default(): # Create input and target placeholder. input_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 6)) target_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 3)) # input_resized = tf.image.resize_area(input_placeholder, [128, 128]) # target_resized = tf.image.resize_area(target_placeholder,[128, 128]) # Prepare model. model = Voxel_flow_model() prediction = model.inference(input_placeholder) # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder) reproduction_loss = model.loss(prediction, target_placeholder) # total_loss = reproduction_loss + prior_loss total_loss = reproduction_loss # Perform learning rate scheduling. learning_rate = FLAGS.initial_learning_rate # Create an optimizer that performs gradient descent. opt = tf.train.AdamOptimizer(learning_rate) grads = opt.compute_gradients(total_loss) update_op = opt.apply_gradients(grads) # Create summaries summaries = tf.get_collection(tf.GraphKeys.SUMMARIES) summaries.append(tf.scalar_summary('total_loss', total_loss)) summaries.append( tf.scalar_summary('reproduction_loss', reproduction_loss)) # summaries.append(tf.scalar_summary('prior_loss', prior_loss)) summaries.append(tf.image_summary('Input Image', input_placeholder, 3)) summaries.append(tf.image_summary('Output Image', prediction, 3)) summaries.append( tf.image_summary('Target Image', target_placeholder, 3)) # Create a saver. saver = tf.train.Saver(tf.all_variables()) # Build the summary operation from the last tower summaries. summary_op = tf.merge_all_summaries() # Build an initialization operation to run below. init = tf.initialize_all_variables() sess = tf.Session() sess.run(init) # Summary Writter summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, graph=sess.graph) # Training loop using feed dict method. data_list_frame1 = dataset_frame1.read_data_list_file() random.seed(1) shuffle(data_list_frame1) data_list_frame2 = dataset_frame2.read_data_list_file() random.seed(1) shuffle(data_list_frame2) data_list_frame3 = dataset_frame3.read_data_list_file() random.seed(1) shuffle(data_list_frame3) data_size = len(data_list_frame1) epoch_num = int(data_size / FLAGS.batch_size) # num_workers = 1 # load_fn_frame1 = partial(dataset_frame1.process_func) # p_queue_frame1 = PrefetchQueue(load_fn_frame1, data_list_frame1, FLAGS.batch_size, shuffle=False, num_workers=num_workers) # load_fn_frame2 = partial(dataset_frame2.process_func) # p_queue_frame2 = PrefetchQueue(load_fn_frame2, data_list_frame2, FLAGS.batch_size, shuffle=False, num_workers=num_workers) # load_fn_frame3 = partial(dataset_frame3.process_func) # p_queue_frame3 = PrefetchQueue(load_fn_frame3, data_list_frame3, FLAGS.batch_size, shuffle=False, num_workers=num_workers) for step in xrange(0, FLAGS.max_steps): batch_idx = step % epoch_num batch_data_list_frame1 = data_list_frame1[ int(batch_idx * FLAGS.batch_size):int((batch_idx + 1) * FLAGS.batch_size)] batch_data_list_frame2 = data_list_frame2[ int(batch_idx * FLAGS.batch_size):int((batch_idx + 1) * FLAGS.batch_size)] batch_data_list_frame3 = data_list_frame3[ int(batch_idx * FLAGS.batch_size):int((batch_idx + 1) * FLAGS.batch_size)] # Load batch data. batch_data_frame1 = np.array([ dataset_frame1.process_func(line) for line in batch_data_list_frame1 ]) batch_data_frame2 = np.array([ dataset_frame2.process_func(line) for line in batch_data_list_frame2 ]) batch_data_frame3 = np.array([ dataset_frame3.process_func(line) for line in batch_data_list_frame3 ]) # batch_data_frame1 = p_queue_frame1.get_batch() # batch_data_frame2 = p_queue_frame2.get_batch() # batch_data_frame3 = p_queue_frame3.get_batch() feed_dict = { input_placeholder: np.concatenate((batch_data_frame1, batch_data_frame3), 3), target_placeholder: batch_data_frame2 } # Run single step update. _, loss_value = sess.run([update_op, total_loss], feed_dict=feed_dict) if batch_idx == 0: # Shuffle data at each epoch. random.seed(1) shuffle(data_list_frame1) random.seed(1) shuffle(data_list_frame2) random.seed(1) shuffle(data_list_frame3) print('Epoch Number: %d' % int(step / epoch_num)) # Output Summary if step % 10 == 0: # summary_str = sess.run(summary_op, feed_dict = feed_dict) # summary_writer.add_summary(summary_str, step) print("Loss at step %d: %f" % (step, loss_value)) if step % 500 == 0: # Run a batch of images prediction_np, target_np = sess.run( [prediction, target_placeholder], feed_dict=feed_dict) for i in range(0, prediction_np.shape[0]): file_name = FLAGS.train_image_dir + str(i) + '_out.png' file_name_label = FLAGS.train_image_dir + str( i) + '_gt.png' imwrite(file_name, prediction_np[i, :, :, :]) imwrite(file_name_label, target_np[i, :, :, :]) # Save checkpoint if step % 5000 == 0 or (step + 1) == FLAGS.max_steps: checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step)
def test(dataset_frame1, dataset_frame2, dataset_frame3): """Perform test on a trained model.""" with tf.Graph().as_default(): # Create input and target placeholder. input_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 6)) target_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 3)) # input_resized = tf.image.resize_area(input_placeholder, [128, 128]) # target_resized = tf.image.resize_area(target_placeholder,[128, 128]) # Prepare model. model = Voxel_flow_model(is_train=True) prediction = model.inference(input_placeholder) # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder) reproduction_loss = model.loss(prediction, target_placeholder) # total_loss = reproduction_loss + prior_loss total_loss = reproduction_loss # Create a saver and load. saver = tf.train.Saver(tf.all_variables()) sess = tf.Session() # Restore checkpoint from file. if FLAGS.pretrained_model_checkpoint_path: assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path) ckpt = tf.train.get_checkpoint_state( FLAGS.pretrained_model_checkpoint_path) restorer = tf.train.Saver() restorer.restore(sess, ckpt.model_checkpoint_path) print('%s: Pre-trained model restored from %s' % (datetime.now(), ckpt.model_checkpoint_path)) # Process on test dataset. data_list_frame1 = dataset_frame1.read_data_list_file() data_size = len(data_list_frame1) epoch_num = int(data_size / FLAGS.batch_size) data_list_frame2 = dataset_frame2.read_data_list_file() data_list_frame3 = dataset_frame3.read_data_list_file() i = 0 PSNR = 0 for id_img in range(0, data_size): # Load single data. line_image_frame1 = dataset_frame1.process_func( data_list_frame1[id_img]) line_image_frame2 = dataset_frame2.process_func( data_list_frame2[id_img]) line_image_frame3 = dataset_frame3.process_func( data_list_frame3[id_img]) batch_data_frame1 = [ dataset_frame1.process_func(ll) for ll in data_list_frame1[0:63] ] batch_data_frame2 = [ dataset_frame2.process_func(ll) for ll in data_list_frame2[0:63] ] batch_data_frame3 = [ dataset_frame3.process_func(ll) for ll in data_list_frame3[0:63] ] batch_data_frame1.append(line_image_frame1) batch_data_frame2.append(line_image_frame2) batch_data_frame3.append(line_image_frame3) batch_data_frame1 = np.array(batch_data_frame1) batch_data_frame2 = np.array(batch_data_frame2) batch_data_frame3 = np.array(batch_data_frame3) feed_dict = { input_placeholder: np.concatenate((batch_data_frame1, batch_data_frame3), 3), target_placeholder: batch_data_frame2 } # Run single step update. prediction_np, target_np, loss_value = sess.run( [prediction, target_placeholder, total_loss], feed_dict=feed_dict) print("Loss for image %d: %f" % (i, loss_value)) file_name = FLAGS.test_image_dir + str(i) + '_out.png' file_name_label = FLAGS.test_image_dir + str(i) + '_gt.png' imwrite(file_name, prediction_np[-1, :, :, :]) imwrite(file_name_label, target_np[-1, :, :, :]) i += 1 PSNR += 10 * np.log10( 255.0 * 255.0 / np.sum(np.square(prediction_np - target_np))) print("Overall PSNR: %f db" % (PSNR / len(data_list)))
def test(dataset_frame1, dataset_frame2, dataset_frame3): def rgb2gray(rgb): return np.dot(rgb[..., :3], [0.299, 0.587, 0.114]) """Perform test on a trained model.""" with tf.Graph().as_default(): # Create input and target placeholder. input_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 6)) target_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 3)) edge_vgg_1 = Vgg16(input_placeholder[:, :, :, :3], reuse=None) edge_vgg_3 = Vgg16(input_placeholder[:, :, :, 3:6], reuse=True) edge_1 = tf.nn.sigmoid(edge_vgg_1.fuse) edge_3 = tf.nn.sigmoid(edge_vgg_3.fuse) edge_1 = tf.reshape(edge_1, [ -1, input_placeholder.get_shape().as_list()[1], input_placeholder.get_shape().as_list()[2], 1 ]) edge_3 = tf.reshape(edge_3, [ -1, input_placeholder.get_shape().as_list()[1], input_placeholder.get_shape().as_list()[2], 1 ]) with tf.variable_scope("Cycle_DVF"): # Prepare model. model = Voxel_flow_model(is_train=False) prediction = model.inference( tf.concat([input_placeholder, edge_1, edge_3], 3)) # Create a saver and load. sess = tf.Session() # Restore checkpoint from file. if FLAGS.pretrained_model_checkpoint_path: restorer = tf.train.Saver() restorer.restore(sess, FLAGS.pretrained_model_checkpoint_path) print('%s: Pre-trained model restored from %s' % (datetime.now(), FLAGS.pretrained_model_checkpoint_path)) # Process on test dataset. data_list_frame1 = dataset_frame1.read_data_list_file() data_size = len(data_list_frame1) data_list_frame2 = dataset_frame2.read_data_list_file() data_list_frame3 = dataset_frame3.read_data_list_file() i = 0 PSNR = 0 SSIM = 0 for id_img in range(0, data_size): UCF_index = data_list_frame1[id_img][:-12] # Load single data. batch_data_frame1 = [ dataset_frame1.process_func( os.path.join('ucf101_interp_ours', ll)[:-5] + '00.png') for ll in data_list_frame1[id_img:id_img + 1] ] batch_data_frame2 = [ dataset_frame2.process_func( os.path.join('ucf101_interp_ours', ll)[:-5] + '01_gt.png') for ll in data_list_frame2[id_img:id_img + 1] ] batch_data_frame3 = [ dataset_frame3.process_func( os.path.join('ucf101_interp_ours', ll)[:-5] + '02.png') for ll in data_list_frame3[id_img:id_img + 1] ] batch_data_mask = [ dataset_frame3.process_func( os.path.join('motion_masks_ucf101_interp', ll)[:-11] + 'motion_mask.png') for ll in data_list_frame3[id_img:id_img + 1] ] batch_data_frame1 = np.array(batch_data_frame1) batch_data_frame2 = np.array(batch_data_frame2) batch_data_frame3 = np.array(batch_data_frame3) batch_data_mask = (np.array(batch_data_mask) + 1.0) / 2.0 feed_dict = { input_placeholder: np.concatenate((batch_data_frame1, batch_data_frame3), 3), target_placeholder: batch_data_frame2 } # Run single step update. prediction_np, target_np, warped_img1, warped_img2 = sess.run( [ prediction, target_placeholder, model.warped_img1, model.warped_img2 ], feed_dict=feed_dict) imwrite( 'ucf101_interp_ours/' + str(UCF_index) + '/frame_01_CyclicGen.png', prediction_np[0][-1, :, :, :]) print(np.sum(batch_data_mask)) if np.sum(batch_data_mask) > 0: img_pred_mask = np.expand_dims(batch_data_mask[0], -1) * ( prediction_np[0][-1] + 1.0) / 2.0 img_target_mask = np.expand_dims( batch_data_mask[0], -1) * (target_np[-1] + 1.0) / 2.0 mse = np.sum((img_pred_mask - img_target_mask)** 2) / (3. * np.sum(batch_data_mask)) psnr_cur = 20.0 * np.log10(1.0) - 10.0 * np.log10(mse) img_pred_gray = rgb2gray((prediction_np[0][-1] + 1.0) / 2.0) img_target_gray = rgb2gray((target_np[-1] + 1.0) / 2.0) ssim_cur = ssim(img_pred_gray, img_target_gray, data_range=1.0) PSNR += psnr_cur SSIM += ssim_cur i += 1 print("Overall PSNR: %f db" % (PSNR / i)) print("Overall SSIM: %f db" % (SSIM / i))
print('Copy variables from % s' % ckpt_path) #--test--# b_list = glob('./Datasets/' + dataset + '/bounding_box_train-Market/*.jpg') a_list = glob('./Datasets/' + dataset + '/bounding_box_train-Duke/*.jpg') b_save_dir = './test_predictions/' + dataset + '_spgan' + '/bounding_box_train_market2duke/' a_save_dir = './test_predictions/' + dataset + '_spgan' + '/bounding_box_train_duke2market/' utils.mkdir([a_save_dir, b_save_dir]) for i in range(len(a_list)): a_real_ipt = im.imresize(im.imread(a_list[i]), [crop_size, crop_size]) a_real_ipt.shape = 1, crop_size, crop_size, 3 a2b_opt = sess.run(a2b, feed_dict={a_real: a_real_ipt}) a_img_opt = a2b_opt img_name = os.path.basename(a_list[i]) img_name = 'market_' + img_name # market_style im.imwrite(im.immerge(a_img_opt, 1, 1), a_save_dir + img_name) print('Save %s' % (a_save_dir + img_name)) for i in range(len(b_list)): b_real_ipt = im.imresize(im.imread(b_list[i]), [crop_size, crop_size]) b_real_ipt.shape = 1, crop_size, crop_size, 3 b2a_opt = sess.run(b2a, feed_dict={b_real: b_real_ipt}) b_img_opt = b2a_opt img_name = os.path.basename(b_list[i]) img_name = 'duke_' + img_name #duke_style im.imwrite(im.immerge(b_img_opt, 1, 1), b_save_dir + img_name) print('Save %s' % (b_save_dir + img_name))
def train(dataset_frame1, dataset_frame2, dataset_frame3): """Trains a model.""" with tf.Graph().as_default(): # Create input. data_list_frame1 = dataset_frame1.read_data_list_file() dataset_frame1 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame1)) dataset_frame1 = dataset_frame1.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image) dataset_frame1 = dataset_frame1.prefetch(100) data_list_frame2 = dataset_frame2.read_data_list_file() dataset_frame2 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame2)) dataset_frame2 = dataset_frame2.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image) dataset_frame2 = dataset_frame2.prefetch(100) data_list_frame3 = dataset_frame3.read_data_list_file() dataset_frame3 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame3)) dataset_frame3 = dataset_frame3.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image) dataset_frame3 = dataset_frame3.prefetch(100) batch_frame1 = dataset_frame1.batch(FLAGS.batch_size).make_initializable_iterator() batch_frame2 = dataset_frame2.batch(FLAGS.batch_size).make_initializable_iterator() batch_frame3 = dataset_frame3.batch(FLAGS.batch_size).make_initializable_iterator() # Create input and target placeholder. input_placeholder = tf.concat([batch_frame1.get_next(), batch_frame3.get_next()], 3) target_placeholder = batch_frame2.get_next() # input_resized = tf.image.resize_area(input_placeholder, [128, 128]) # target_resized = tf.image.resize_area(target_placeholder,[128, 128]) # Prepare model. model = Voxel_flow_model(is_train=True) prediction, flow_motion, flow_mask = model.inference(input_placeholder) # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder) reproduction_loss = model.loss(prediction, flow_motion, flow_mask, target_placeholder, FLAGS.lambda_motion, FLAGS.lambda_mask) # total_loss = reproduction_loss + prior_loss total_loss = reproduction_loss # Perform learning rate scheduling. learning_rate = FLAGS.initial_learning_rate # Create an optimizer that performs gradient descent. opt = tf.train.AdamOptimizer(learning_rate) grads = opt.compute_gradients(total_loss) update_op = opt.apply_gradients(grads) # Create summaries summaries = tf.get_collection(tf.GraphKeys.SUMMARIES) summaries.append(tf.summary.scalar('total_loss', total_loss)) summaries.append(tf.summary.scalar('reproduction_loss', reproduction_loss)) # summaries.append(tf.summary.scalar('prior_loss', prior_loss)) summaries.append(tf.summary.image('Input Image (before)', input_placeholder[:, :, :, 0:3], 3)) summaries.append(tf.summary.image('Input Image (after)', input_placeholder[:, :, :, 3:6], 3)) summaries.append(tf.summary.image('Output Image', prediction, 3)) summaries.append(tf.summary.image('Target Image', target_placeholder, 3)) # summaries.append(tf.summary.image('Flow', flow, 3)) # Create a saver. saver = tf.train.Saver(tf.global_variables()) # Build the summary operation from the last tower summaries. summary_op = tf.summary.merge_all() # Restore checkpoint from file. if FLAGS.pretrained_model_checkpoint_path \ and tf.train.get_checkpoint_state(FLAGS.pretrained_model_checkpoint_path): sess = tf.Session() assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path) ckpt = tf.train.get_checkpoint_state( FLAGS.pretrained_model_checkpoint_path) restorer = tf.train.Saver() restorer.restore(sess, ckpt.model_checkpoint_path) print('%s: Pre-trained model restored from %s' % (datetime.now(), ckpt.model_checkpoint_path)) sess.run([batch_frame1.initializer, batch_frame2.initializer, batch_frame3.initializer]) else: # Build an initialization operation to run below. print('No existing checkpoints.') init = tf.global_variables_initializer() sess = tf.Session() sess.run([init, batch_frame1.initializer, batch_frame2.initializer, batch_frame3.initializer]) # Summary Writter summary_writer = tf.summary.FileWriter( FLAGS.train_dir, graph=sess.graph) data_size = len(data_list_frame1) epoch_num = int(data_size / FLAGS.batch_size) for step in range(0, FLAGS.max_steps): batch_idx = step % epoch_num # Run single step update. _, loss_value = sess.run([update_op, total_loss]) if batch_idx == 0: print('Epoch Number: %d' % int(step / epoch_num)) if step % 10 == 0: print("Loss at step %d: %f" % (step, loss_value)) if step % 100 == 0: # Output Summary summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, step) if step % 500 == 0: # Run a batch of images prediction_np, target_np = sess.run([prediction, target_placeholder]) for i in range(0,prediction_np.shape[0]): file_name = FLAGS.train_image_dir+str(i)+'_out.png' file_name_label = FLAGS.train_image_dir+str(i)+'_gt.png' imwrite(file_name, prediction_np[i,:,:,:]) imwrite(file_name_label, target_np[i,:,:,:]) # Save checkpoint if step % 500 == 0 or (step +1) == FLAGS.max_steps: checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step)
def train(dataset_objects): with tf.Graph().as_default(): # read and shuffle images data_lists = [ dataset_obj.read_data_list_file() for dataset_obj in dataset_objects ] dataset_frames = [ tf.data.Dataset.from_tensor_slices(tf.constant(data_list)) for data_list in data_lists ] dataset_frames = [frame.repeat().shuffle(buffer_size=int(1e5), seed=1)\ .map(_read_image) for frame in dataset_frames] dataset_frames = [frame.prefetch(100) for frame in dataset_frames] # 9 sets of frames in total, 1 for each frame in a 9-frame sequence batch_frames = [frame.batch(FLAGS.batch_size)\ .make_initializable_iterator() for frame in dataset_frames] # grab the first and last frames for input input_placeholder = tf.concat( [batch_frames[0].get_next(), batch_frames[8].get_next()], axis=3) # the middle 7 frames for ground truth target_placeholder = tf.concat( [frame.get_next() for frame in batch_frames[1:8]], axis=3) sess = tf.Session() # the first network computer = SloMo_model(for_interpolation=False) # the second interpolater = SloMo_model(for_interpolation=True) # vgg for perceptual loss vgg_mod = vgg16(sess=sess) # flow computations between the first and last frames flow_01, flow_10 = computer.inference(input_placeholder) image_0, image_1 = input_placeholder[:, :, :, :3], \ input_placeholder[:, :, :, 3:] total_loss = 0 pred_imgs_t = [] # for each intermediate frame for idx, t in enumerate(np.arange(1.0 / 8, .999999, 1.0 / 8)): # intermediate flow approximation at t; paper calls this F hat flow_t0_hat = t * (-(1 - t) * flow_01 + t * flow_10) flow_t1_hat = (1 - t) * ((1 - t) * flow_01 - t * flow_10) # warp to approximate image_0, image_1 in inputs approx_img_0 = computer.warp(flow_10, image_1) approx_img_1 = computer.warp(flow_01, image_0) # interpolate intermediate frame interp_input = tf.concat([ input_placeholder, approx_img_0, approx_img_1, flow_t0_hat, flow_t1_hat ], axis=3) flow_t0, flow_t1, vis_mask_0, vis_mask_1 = \ interpolater.inference(interp_input) # compute intermediate frame via equation (1) z = (1 - t) * tf.abs(vis_mask_0) + t * tf.abs(vis_mask_1) pred_img_t = (1 / z) * ( (1 - t) * tf.abs(vis_mask_0) * computer.warp(-flow_t0, image_0) + t * tf.abs(vis_mask_1) * computer.warp(-flow_t1, image_1)) pred_imgs_t += [pred_img_t] # reconstruction loss @ equation (7) target = target_placeholder[:, :, :, idx * 3:(idx + 1) * 3] loss_recons = l1_loss(pred_img_t, target) total_loss += FLAGS.lambda_reconstruction * loss_recons # perceptual loss @ equation (8) phi_true = vgg_mod.inference(target) phi_pred = vgg_mod.inference(pred_img_t) loss_percept = l2_loss(phi_true, phi_pred) total_loss += FLAGS.lambda_perceptual * loss_percept # Lagrangian penalty to enforce constraint @ equation (5) loss_constraint = l1_loss(tf.abs(vis_mask_0), 1 - tf.abs(vis_mask_1)) total_loss += FLAGS.lambda_penalty * loss_constraint # warping and smoothness losses @ equations (9) and (10) loss_warping = l1_loss(image_0, approx_img_0) \ + l1_loss(image_1, approx_img_1) loss_smooth = l1_regularizer(flow_01) + l1_regularizer(flow_10) # all losses total_loss += FLAGS.lambda_warping * loss_warping \ + FLAGS.lambda_smoothness * loss_smooth learning_rate = FLAGS.initial_learning_rate # backprop operation; collect gradient norms for tensorboard opt = tf.train.AdamOptimizer(learning_rate) update_op = slim.learning.create_train_op(total_loss, opt, summarize_gradients=True) # collect losses for tensorboard summaries = tf.get_collection(tf.GraphKeys.SUMMARIES) summaries.append(tf.summary.scalar('total_loss', total_loss)) summaries.append(tf.summary.scalar('loss_recons', loss_recons)) summaries.append(tf.summary.scalar('loss_percept', loss_percept)) summaries.append(tf.summary.scalar('loss_constraint', loss_constraint)) summaries.append(tf.summary.scalar('loss_warping', loss_warping)) summaries.append(tf.summary.scalar('loss_smooth', loss_smooth)) summary_op = tf.summary.merge_all() summary_writer = tf.summary.FileWriter(FLAGS.train_dir, graph=sess.graph) # save stuffs saver = tf.train.Saver(tf.global_variables()) # restore model if it exists if FLAGS.pretrained_model_checkpoint_path \ and tf.train.get_checkpoint_state( FLAGS.pretrained_model_checkpoint_path): ckpt = tf.train.get_checkpoint_state( FLAGS.pretrained_model_checkpoint_path) restorer = tf.train.Saver() restorer.restore(sess, ckpt.model_checkpoint_path) print('%s: Pre-trained model restored from %s' % (datetime.now(), ckpt.model_checkpoint_path)) sess.run([batch_frame.initializer for batch_frame in batch_frames]) else: print('No existing checkpoints.') init = tf.global_variables_initializer() sess.run([init] + [batch_frame.initializer for batch_frame in batch_frames]) data_size = len(data_lists[0]) epoch_num = int(data_size / FLAGS.batch_size) for step in range(0, FLAGS.max_steps): batch_idx = step % epoch_num loss_value, __ = sess.run([total_loss, update_op]) if batch_idx == 0: print('Epoch Number: %d' % int(step / epoch_num)) if step % 10 == 0: print("Loss at step %d: %f" % (step, loss_value)) if step % 200 == 0: # save summary summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, step) if step % 500 in {0, 1, 2}: # get some intermediate results prediction = tf.concat(pred_imgs_t, axis=3) prediction_np, input_np, target_np, comp_img1, comp_img2 \ = sess.run([prediction, input_placeholder, target_placeholder, approx_img_0, approx_img_1]) input_1, input_2 = input_np[:, :, :, :3], input_np[:, :, :, 3:] for examp_idx in range(prediction_np.shape[0]): file_name_comp1 = FLAGS.train_image_dir + 'comp_img0_out' + '_step' + str(step) \ + '.png' file_name_comp2 = FLAGS.train_image_dir + 'comp_img1_out' + '_step' + str(step) \ + '.png' imwrite(file_name_comp1, comp_img1[0, :, :, :]) imwrite(file_name_comp2, comp_img2[0, :, :, :]) imwrite(file_name_comp1.replace('out', 'gt'), input_1[0, :, :, :]) imwrite(file_name_comp2.replace('out', 'gt'), input_2[0, :, :, :]) for inputs in ['out', 'gt']: file_name_input1 = FLAGS.train_image_dir + inputs + '_step' + str(step) \ + '_frame0' + '.png' file_name_input2 = FLAGS.train_image_dir + inputs + '_step' + str(step) \ + '_frame8' + '.png' imwrite(file_name_input1, input_1[0, :, :, :]) imwrite(file_name_input2, input_2[0, :, :, :]) for frame_idx in range(int(prediction_np.shape[3] / 3)): file_name = FLAGS.train_image_dir + 'out' + '_step' + str(step) \ + '_frame' + str(frame_idx + 1) + '.png' file_name_label = FLAGS.train_image_dir + 'gt' + '_step' + str(step) \ + '_frame' + str(frame_idx + 1) + '.png' imwrite( file_name, prediction_np[examp_idx, :, :, frame_idx * 3:(frame_idx + 1) * 3]) imwrite( file_name_label, target_np[examp_idx, :, :, frame_idx * 3:(frame_idx + 1) * 3]) # save model weights if step % 200 == 0 or (step + 1) == FLAGS.max_steps: checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step)