def train(dataset_frame1, dataset_frame2, dataset_frame3): """Trains a model.""" with tf.Graph().as_default(): # Create input and target placeholder. input_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 6)) target_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 3)) # input_resized = tf.image.resize_area(input_placeholder, [128, 128]) # target_resized = tf.image.resize_area(target_placeholder,[128, 128]) # Prepare model. model = Voxel_flow_model() prediction = model.inference(input_placeholder) # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder) reproduction_loss = model.loss(prediction, target_placeholder) # total_loss = reproduction_loss + prior_loss total_loss = reproduction_loss # Perform learning rate scheduling. learning_rate = FLAGS.initial_learning_rate # Create an optimizer that performs gradient descent. opt = tf.train.AdamOptimizer(learning_rate) grads = opt.compute_gradients(total_loss) update_op = opt.apply_gradients(grads) # Create summaries summaries = tf.get_collection(tf.GraphKeys.SUMMARIES) summaries.append(tf.scalar_summary('total_loss', total_loss)) summaries.append( tf.scalar_summary('reproduction_loss', reproduction_loss)) # summaries.append(tf.scalar_summary('prior_loss', prior_loss)) summaries.append(tf.image_summary('Input Image', input_placeholder, 3)) summaries.append(tf.image_summary('Output Image', prediction, 3)) summaries.append( tf.image_summary('Target Image', target_placeholder, 3)) # Create a saver. saver = tf.train.Saver(tf.all_variables()) # Build the summary operation from the last tower summaries. summary_op = tf.merge_all_summaries() # Build an initialization operation to run below. init = tf.initialize_all_variables() sess = tf.Session() sess.run(init) # Summary Writter summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, graph=sess.graph) # Training loop using feed dict method. data_list_frame1 = dataset_frame1.read_data_list_file() random.seed(1) shuffle(data_list_frame1) data_list_frame2 = dataset_frame2.read_data_list_file() random.seed(1) shuffle(data_list_frame2) data_list_frame3 = dataset_frame3.read_data_list_file() random.seed(1) shuffle(data_list_frame3) data_size = len(data_list_frame1) epoch_num = int(data_size / FLAGS.batch_size) # num_workers = 1 # load_fn_frame1 = partial(dataset_frame1.process_func) # p_queue_frame1 = PrefetchQueue(load_fn_frame1, data_list_frame1, FLAGS.batch_size, shuffle=False, num_workers=num_workers) # load_fn_frame2 = partial(dataset_frame2.process_func) # p_queue_frame2 = PrefetchQueue(load_fn_frame2, data_list_frame2, FLAGS.batch_size, shuffle=False, num_workers=num_workers) # load_fn_frame3 = partial(dataset_frame3.process_func) # p_queue_frame3 = PrefetchQueue(load_fn_frame3, data_list_frame3, FLAGS.batch_size, shuffle=False, num_workers=num_workers) for step in xrange(0, FLAGS.max_steps): batch_idx = step % epoch_num batch_data_list_frame1 = data_list_frame1[ int(batch_idx * FLAGS.batch_size):int((batch_idx + 1) * FLAGS.batch_size)] batch_data_list_frame2 = data_list_frame2[ int(batch_idx * FLAGS.batch_size):int((batch_idx + 1) * FLAGS.batch_size)] batch_data_list_frame3 = data_list_frame3[ int(batch_idx * FLAGS.batch_size):int((batch_idx + 1) * FLAGS.batch_size)] # Load batch data. batch_data_frame1 = np.array([ dataset_frame1.process_func(line) for line in batch_data_list_frame1 ]) batch_data_frame2 = np.array([ dataset_frame2.process_func(line) for line in batch_data_list_frame2 ]) batch_data_frame3 = np.array([ dataset_frame3.process_func(line) for line in batch_data_list_frame3 ]) # batch_data_frame1 = p_queue_frame1.get_batch() # batch_data_frame2 = p_queue_frame2.get_batch() # batch_data_frame3 = p_queue_frame3.get_batch() feed_dict = { input_placeholder: np.concatenate((batch_data_frame1, batch_data_frame3), 3), target_placeholder: batch_data_frame2 } # Run single step update. _, loss_value = sess.run([update_op, total_loss], feed_dict=feed_dict) if batch_idx == 0: # Shuffle data at each epoch. random.seed(1) shuffle(data_list_frame1) random.seed(1) shuffle(data_list_frame2) random.seed(1) shuffle(data_list_frame3) print('Epoch Number: %d' % int(step / epoch_num)) # Output Summary if step % 10 == 0: # summary_str = sess.run(summary_op, feed_dict = feed_dict) # summary_writer.add_summary(summary_str, step) print("Loss at step %d: %f" % (step, loss_value)) if step % 500 == 0: # Run a batch of images prediction_np, target_np = sess.run( [prediction, target_placeholder], feed_dict=feed_dict) for i in range(0, prediction_np.shape[0]): file_name = FLAGS.train_image_dir + str(i) + '_out.png' file_name_label = FLAGS.train_image_dir + str( i) + '_gt.png' imwrite(file_name, prediction_np[i, :, :, :]) imwrite(file_name_label, target_np[i, :, :, :]) # Save checkpoint if step % 5000 == 0 or (step + 1) == FLAGS.max_steps: checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step)
def test(dataset_frame1, dataset_frame2, dataset_frame3): """Perform test on a trained model.""" with tf.Graph().as_default(): # Create input and target placeholder. input_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 6)) target_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 3)) # input_resized = tf.image.resize_area(input_placeholder, [128, 128]) # target_resized = tf.image.resize_area(target_placeholder,[128, 128]) # Prepare model. model = Voxel_flow_model(is_train=True) prediction = model.inference(input_placeholder) # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder) reproduction_loss = model.loss(prediction, target_placeholder) # total_loss = reproduction_loss + prior_loss total_loss = reproduction_loss # Create a saver and load. saver = tf.train.Saver(tf.all_variables()) sess = tf.Session() # Restore checkpoint from file. if FLAGS.pretrained_model_checkpoint_path: assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path) ckpt = tf.train.get_checkpoint_state( FLAGS.pretrained_model_checkpoint_path) restorer = tf.train.Saver() restorer.restore(sess, ckpt.model_checkpoint_path) print('%s: Pre-trained model restored from %s' % (datetime.now(), ckpt.model_checkpoint_path)) # Process on test dataset. data_list_frame1 = dataset_frame1.read_data_list_file() data_size = len(data_list_frame1) epoch_num = int(data_size / FLAGS.batch_size) data_list_frame2 = dataset_frame2.read_data_list_file() data_list_frame3 = dataset_frame3.read_data_list_file() i = 0 PSNR = 0 for id_img in range(0, data_size): # Load single data. line_image_frame1 = dataset_frame1.process_func( data_list_frame1[id_img]) line_image_frame2 = dataset_frame2.process_func( data_list_frame2[id_img]) line_image_frame3 = dataset_frame3.process_func( data_list_frame3[id_img]) batch_data_frame1 = [ dataset_frame1.process_func(ll) for ll in data_list_frame1[0:63] ] batch_data_frame2 = [ dataset_frame2.process_func(ll) for ll in data_list_frame2[0:63] ] batch_data_frame3 = [ dataset_frame3.process_func(ll) for ll in data_list_frame3[0:63] ] batch_data_frame1.append(line_image_frame1) batch_data_frame2.append(line_image_frame2) batch_data_frame3.append(line_image_frame3) batch_data_frame1 = np.array(batch_data_frame1) batch_data_frame2 = np.array(batch_data_frame2) batch_data_frame3 = np.array(batch_data_frame3) feed_dict = { input_placeholder: np.concatenate((batch_data_frame1, batch_data_frame3), 3), target_placeholder: batch_data_frame2 } # Run single step update. prediction_np, target_np, loss_value = sess.run( [prediction, target_placeholder, total_loss], feed_dict=feed_dict) print("Loss for image %d: %f" % (i, loss_value)) file_name = FLAGS.test_image_dir + str(i) + '_out.png' file_name_label = FLAGS.test_image_dir + str(i) + '_gt.png' imwrite(file_name, prediction_np[-1, :, :, :]) imwrite(file_name_label, target_np[-1, :, :, :]) i += 1 PSNR += 10 * np.log10( 255.0 * 255.0 / np.sum(np.square(prediction_np - target_np))) print("Overall PSNR: %f db" % (PSNR / len(data_list)))
def test( dataset_frame1, dataset_frame2, dataset_frame3, dataset_frame4, dataset_frame5, dataset_frame6, dataset_frame7, dataset_frame8, dataset_frame9, dataset_frame10, dataset_frame11, dataset_frame12, dataset_frame13, dataset_frame14, dataset_frame15, dataset_frame16, dataset_frame17, dataset_frame18, dataset_frame19, dataset_frame20, dataset_frame21, dataset_frame22, dataset_frame23, dataset_frame24, dataset_frame25, dataset_frame26, dataset_frame27, dataset_frame28, dataset_frame29, dataset_frame30, dataset_frame31, dataset_frame32, dataset_frame33, dataset_frame34, dataset_frame35, dataset_frame36, dataset_frame37, dataset_frame38, dataset_frame39, dataset_frame40, dataset_frame41 ): #, dataset_frame42, dataset_frame43, dataset_frame44, dataset_frame45, dataset_frame46, dataset_frame47, dataset_frame48, dataset_frame49, dataset_frame50, dataset_frame51, dataset_frame52, dataset_frame53, dataset_frame54): """Perform test on a trained model.""" with tf.Graph().as_default(): # Create input and target placeholder. input_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, FLAGS.num_in * 3)) target_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, FLAGS.num_out * 3)) # input_resized = tf.image.resize_area(input_placeholder, [128, 128]) # target_resized = tf.image.resize_area(target_placeholder,[128, 128]) # Prepare model. model = Voxel_flow_model(is_train=True) prediction = model.inference(input_placeholder) # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder) reproduction_loss = model.loss(prediction, target_placeholder) # total_loss = reproduction_loss + prior_loss total_loss = reproduction_loss # Create a saver and load. saver = tf.train.Saver(tf.all_variables()) sess = tf.Session() # Restore checkpoint from file. if FLAGS.pretrained_model_checkpoint_path: assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path) ckpt = tf.train.get_checkpoint_state( FLAGS.pretrained_model_checkpoint_path) restorer = tf.train.Saver() restorer.restore(sess, ckpt.model_checkpoint_path) print('%s: Pre-trained model restored from %s' % (datetime.now(), ckpt.model_checkpoint_path)) # Process on test dataset. # data_list_frame1 = dataset_frame1.read_data_list_file() # data_list_frame2 = dataset_frame2.read_data_list_file() # data_list_frame3 = dataset_frame3.read_data_list_file() # data_list_frame4 = dataset_frame4.read_data_list_file() # data_list_frame5 = dataset_frame5.read_data_list_file() # data_list_frame6 = dataset_frame6.read_data_list_file() # data_list_frame7 = dataset_frame7.read_data_list_file() # data_list_frame8 = dataset_frame8.read_data_list_file() # data_list_frame9 = dataset_frame9.read_data_list_file() # data_list_frame10 = dataset_frame10.read_data_list_file() # data_list_frame11 = dataset_frame11.read_data_list_file() # data_list_frame12 = dataset_frame12.read_data_list_file() # data_list_frame13 = dataset_frame13.read_data_list_file() # data_list_frame14 = dataset_frame14.read_data_list_file() feed_str = "{input_placeholder:np.concatenate((" for i in range(FLAGS.num_in): feed_str += "batch_data_frame{}, ".format(i + 1) feed_str = feed_str[:-2] + '),3)' m = globals() n = locals() for i in range(FLAGS.num_out + FLAGS.num_in): exec( "data_list_frame{}=dataset_frame{}.read_data_list_file()". format(i + 1, i + 1), m, n) exec("data_size = len(data_list_frame1)", m, n) exec("epoch_num = int(data_size / FLAGS.batch_size)", m, n) print("feed_dict = " + feed_str + ", target_placeholder: batch_data_frame41}") j = 0 PSNR = 0 # batch_data_frame1 = [dataset_frame1.process_func(ll) for ll in data_list_frame1[-8:]] # batch_data_frame2 = [dataset_frame2.process_func(ll) for ll in data_list_frame2[-8:]] # batch_data_frame3 = [dataset_frame3.process_func(ll) for ll in data_list_frame3[-8:]] # batch_data_frame4 = [dataset_frame4.process_func(ll) for ll in data_list_frame4[-8:]] # batch_data_frame5 = [dataset_frame5.process_func(ll) for ll in data_list_frame5[-8:]] # batch_data_frame6 = [dataset_frame6.process_func(ll) for ll in data_list_frame6[-8:]] # batch_data_frame7 = [dataset_frame7.process_func(ll) for ll in data_list_frame7[-8:]] # batch_data_frame8 = [dataset_frame8.process_func(ll) for ll in data_list_frame8[-8:]] # batch_data_frame9 = [dataset_frame9.process_func(ll) for ll in data_list_frame9[-8:]] # batch_data_frame10 = [dataset_frame10.process_func(ll) for ll in data_list_frame10[-8:]] # batch_data_frame11 = [dataset_frame11.process_func(ll) for ll in data_list_frame11[-8:]] # batch_data_frame12 = [dataset_frame12.process_func(ll) for ll in data_list_frame12[-8:]] # batch_data_frame13 = [dataset_frame13.process_func(ll) for ll in data_list_frame13[-8:]] # batch_data_frame14 = [dataset_frame14.process_func(ll) for ll in data_list_frame14[-8:]] # predicting using batch size 8 and input starting from 101 for i in range(FLAGS.num_out + FLAGS.num_in): exec( "tr = np.array([dataset_frame{}.process_func(data_list_frame{}[73])])" .format(i + 1, i + 1), m, n) exec("np.expand_dims(tr,axis=0)", m, n) exec("s='data_list_frame{}'".format(i + 1), m, n) for line in n[n['s']][74:81]: n['line'] = line exec( "tr=np.append(tr,np.expand_dims(dataset_frame1.process_func(line),axis=0),axis=0)", m, n) exec('batch_data_frame{} = tr'.format(i + 1), m, n) # batch_data_frame1 = np.array(batch_data_frame1) # batch_data_frame2 = np.array(batch_data_frame2) # batch_data_frame3 = np.array(batch_data_frame3) # batch_data_frame4 = np.array(batch_data_frame4) # batch_data_frame5 = np.array(batch_data_frame5) # batch_data_frame6 = np.array(batch_data_frame6) # batch_data_frame7 = np.array(batch_data_frame7) # batch_data_frame8 = np.array(batch_data_frame8) # batch_data_frame9 = np.array(batch_data_frame9) # batch_data_frame10 = np.array(batch_data_frame10) # batch_data_frame11 = np.array(batch_data_frame11) # batch_data_frame12 = np.array(batch_data_frame12) # batch_data_frame13 = np.array(batch_data_frame13) # batch_data_frame14 = np.array(batch_data_frame14) for i in range(FLAGS.num_out + FLAGS.num_in): exec( "batch_data_frame{} = np.array(batch_data_frame{})".format( i + 1, i + 1), m, n) for id_img in range(0, 10): # Load single data. # line_image_frame1 = dataset_frame1.process_func(data_list_frame1[id_img]) # line_image_frame2 = dataset_frame2.process_func(data_list_frame2[id_img]) # line_image_frame3 = dataset_frame3.process_func(data_list_frame3[id_img]) # batch_data_frame1.append(line_image_frame1) # batch_data_frame2.append(line_image_frame2) # batch_data_frame3.append(line_image_frame3) # feed_dict = {input_placeholder: np.concatenate((batch_data_frame1, batch_data_frame2, batch_data_frame3, batch_data_frame4, # batch_data_frame5, batch_data_frame6, batch_data_frame7, batch_data_frame8, batch_data_frame9, batch_data_frame10), 3), target_placeholder: np.concatenate((batch_data_frame11,batch_data_frame12,batch_data_frame13,batch_data_frame14),3)} # Run single step update. exec( "feed_dict = " + feed_str + ", target_placeholder: batch_data_frame41}", m, n) prediction_np, target_np, loss_value = sess.run( [prediction, target_placeholder, total_loss], feed_dict=n['feed_dict']) # print("Loss for image %d: %f" % (i,loss_value)) # for i in range(0,prediction_np.shape[0]): # for j in range(0,4): # file_name = FLAGS.test_image_dir+str(i)+'_out_{}.png'.format(j) # file_name_label = FLAGS.test_image_dir+str(i)+'_gt_{}.png'.format(j) # imsave(prediction_np[i,:,:,j*3:(j+1)*3], file_name) # imsave(target_np[i,:,:,j*3:(j+1)*3], file_name_label) file_name = FLAGS.test_image_dir + str(j) + '_out.png' # file_name_label = FLAGS.test_image_dir+str(j)+'_gt.png' imsave(prediction_np[-1, :, :, :], file_name) # imsave(target_np[-1,:,:,:], file_name_label) j += 1 print(id_img) PSNR += 10 * np.log10( 255.0 * 255.0 / np.sum(np.square(prediction_np - target_np))) # batch_data_frame1[-1]=batch_data_frame2[-1] # batch_data_frame2[-1]=batch_data_frame3[-1] # batch_data_frame3[-1]=batch_data_frame4[-1] # batch_data_frame4[-1]=batch_data_frame5[-1] # batch_data_frame5[-1]=batch_data_frame6[-1] # batch_data_frame6[-1]=batch_data_frame7[-1] # batch_data_frame7[-1]=batch_data_frame8[-1] # batch_data_frame8[-1]=batch_data_frame9[-1] # batch_data_frame9[-1]=batch_data_frame10[-1] # batch_data_frame10[-1]=batch_data_frame11[-1] # batch_data_frame11[-1]=prediction_np[-1,:,:,:] for i in range(FLAGS.num_out + FLAGS.num_in - 1): exec( "batch_data_frame{}[-1]=batch_data_frame{}[-1]".format( i + 1, i + 2), m, n) n['batch_data_frame41'][-1] = prediction_np[-1, :, :, :] print("Overall PSNR: %f db" % (PSNR / n['data_size']))
def train( dataset_frame1, dataset_frame2, dataset_frame3, dataset_frame4, dataset_frame5, dataset_frame6, dataset_frame7, dataset_frame8 ): # dataset_frame9, dataset_frame10, dataset_frame11, dataset_frame12, dataset_frame13, dataset_frame14, dataset_frame15, dataset_frame16, dataset_frame17, dataset_frame18, dataset_frame19, dataset_frame20,dataset_frame21): # , dataset_frame22, dataset_frame23, dataset_frame24, dataset_frame25, dataset_frame26, dataset_frame27, dataset_frame28, dataset_frame29, dataset_frame30, dataset_frame31, dataset_frame32, dataset_frame33, dataset_frame34, dataset_frame35, dataset_frame36, dataset_frame37, dataset_frame38, dataset_frame39, dataset_frame40, # dataset_frame41, dataset_frame42, dataset_frame43, dataset_frame44, dataset_frame45, dataset_frame46, dataset_frame47, dataset_frame48, dataset_frame49, dataset_frame50, dataset_frame51, dataset_frame52, dataset_frame53, dataset_frame54): """Trains a model.""" with tf.Graph().as_default(): # Create input and target placeholder. input_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, FLAGS.num_in * 3)) target_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, FLAGS.num_out * 3)) # input_resized = tf.image.resize_area(input_placeholder, [128, 128]) # target_resized = tf.image.resize_area(target_placeholder,[128, 128]) # Prepare model. model = Voxel_flow_model() prediction = model.inference(input_placeholder) # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder) reproduction_loss = model.loss(prediction, target_placeholder) # total_loss = reproduction_loss + prior_loss total_loss = reproduction_loss # Perform learning rate scheduling. learning_rate = FLAGS.initial_learning_rate # Create an optimizer that performs gradient descent. opt = tf.train.AdamOptimizer(learning_rate) grads = opt.compute_gradients(total_loss) update_op = opt.apply_gradients(grads) # Create summaries summaries = tf.get_collection(tf.GraphKeys.SUMMARIES) summaries.append(tf.summary.scalar('total_loss', total_loss)) summaries.append( tf.summary.scalar('reproduction_loss', reproduction_loss)) # summaries.append(tf.summary.scalar('prior_loss', prior_loss)) summaries.append(tf.summary.image('Input Image', input_placeholder, 3)) summaries.append(tf.summary.image('Output Image', prediction, 3)) summaries.append( tf.summary.image('Target Image', target_placeholder, 3)) # Create a saver. saver = tf.train.Saver(tf.all_variables()) # Build the summary operation from the last tower summaries. summary_op = tf.summary.merge_all() # Build an initialization operation to run below. init = tf.initialize_all_variables() sess = tf.Session(config=tf.ConfigProto(log_device_placement=True)) sess.run(init) # Summary Writter summary_writer = tf.summary.FileWriter(FLAGS.train_dir, graph=sess.graph) # Training loop using feed dict method. # data_list_frame1 = dataset_frame1.read_data_list_file() # random.seed(1) # shuffle(data_list_frame1) # data_list_frame2 = dataset_frame2.read_data_list_file() # random.seed(1) # shuffle(data_list_frame2) # data_list_frame3 = dataset_frame3.read_data_list_file() # random.seed(1) # shuffle(data_list_frame3) # data_list_frame4 = dataset_frame4.read_data_list_file() # random.seed(1) # shuffle(data_list_frame4) # data_list_frame5 = dataset_frame5.read_data_list_file() # random.seed(1) # shuffle(data_list_frame5) # data_list_frame6 = dataset_frame6.read_data_list_file() # random.seed(1) # shuffle(data_list_frame6) # data_list_frame7 = dataset_frame7.read_data_list_file() # random.seed(1) # shuffle(data_list_frame7) # data_list_frame8 = dataset_frame8.read_data_list_file() # random.seed(1) # shuffle(data_list_frame8) # data_list_frame9 = dataset_frame9.read_data_list_file() # random.seed(1) # shuffle(data_list_frame9) # data_list_frame10 = dataset_frame10.read_data_list_file() # random.seed(1) # shuffle(data_list_frame10) # data_list_frame11 = dataset_frame11.read_data_list_file() # random.seed(1) # shuffle(data_list_frame11) # data_list_frame12 = dataset_frame12.read_data_list_file() # random.seed(1) # shuffle(data_list_frame12) # data_list_frame13 = dataset_frame13.read_data_list_file() # random.seed(1) # shuffle(data_list_frame13) # data_list_frame14 = dataset_frame14.read_data_list_file() # random.seed(1) # shuffle(data_list_frame14) feed_str = "{input_placeholder:np.concatenate((" for i in range(FLAGS.num_in): feed_str += "batch_data_frame{}, ".format(i + 1) feed_str = feed_str[:-2] + '),3)' feed_str += ", target_placeholder:np.concatenate((" for i in range(FLAGS.num_in, FLAGS.num_in + FLAGS.num_out): feed_str += "batch_data_frame{}, ".format(i + 1) feed_str = feed_str[:-2] + '),3)}' m = globals() n = locals() for i in range(FLAGS.num_out + FLAGS.num_in): exec( "data_list_frame{}=dataset_frame{}.read_data_list_file()". format(i + 1, i + 1), m, n) exec("random.seed(1)", m, n) exec("shuffle(data_list_frame{})".format(i + 1), m, n) exec("data_size = len(data_list_frame1)", m, n) exec("epoch_num = int(data_size / FLAGS.batch_size)", m, n) results = np.array([[0, 0]]) # num_workers = 1 # load_fn_frame1 = partial(dataset_frame1.process_func) # p_queue_frame1 = PrefetchQueue(load_fn_frame1, data_list_frame1, FLAGS.batch_size, shuffle=False, num_workers=num_workers) # load_fn_frame2 = partial(dataset_frame2.process_func) # p_queue_frame2 = PrefetchQueue(load_fn_frame2, data_list_frame2, FLAGS.batch_size, shuffle=False, num_workers=num_workers) # load_fn_frame3 = partial(dataset_frame3.process_func) # p_queue_frame3 = PrefetchQueue(load_fn_frame3, data_list_frame3, FLAGS.batch_size, shuffle=False, num_workers=num_workers) for step in range(0, FLAGS.max_steps): n['step'] = step exec("batch_idx = step % epoch_num", m, n) # batch_data_list_frame1 = data_list_frame1[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)] # batch_data_list_frame2 = data_list_frame2[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)] # batch_data_list_frame3 = data_list_frame3[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)] # batch_data_list_frame4 = data_list_frame4[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)] # batch_data_list_frame5 = data_list_frame5[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)] # batch_data_list_frame6 = data_list_frame6[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)] # batch_data_list_frame7 = data_list_frame7[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)] # batch_data_list_frame8 = data_list_frame8[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)] # batch_data_list_frame9 = data_list_frame9[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)] # batch_data_list_frame10 = data_list_frame10[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)] # batch_data_list_frame11 = data_list_frame11[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)] # batch_data_list_frame12 = data_list_frame12[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)] # batch_data_list_frame13 = data_list_frame13[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)] # batch_data_list_frame14 = data_list_frame14[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)] for i in range(FLAGS.num_out + FLAGS.num_in): exec( "batch_data_list_frame{} = data_list_frame{}[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]" .format(i + 1, i + 1), m, n) # Load batch data. # batch_data_frame1 = np.array([dataset_frame1.process_func(line) for line in batch_data_list_frame1]) # batch_data_frame2 = np.array([dataset_frame2.process_func(line) for line in batch_data_list_frame2]) # batch_data_frame3 = np.array([dataset_frame3.process_func(line) for line in batch_data_list_frame3]) # batch_data_frame4 = np.array([dataset_frame4.process_func(line) for line in batch_data_list_frame4]) # batch_data_frame5 = np.array([dataset_frame5.process_func(line) for line in batch_data_list_frame5]) # batch_data_frame6 = np.array([dataset_frame6.process_func(line) for line in batch_data_list_frame6]) # batch_data_frame7 = np.array([dataset_frame7.process_func(line) for line in batch_data_list_frame7]) # batch_data_frame8 = np.array([dataset_frame8.process_func(line) for line in batch_data_list_frame8]) # batch_data_frame9 = np.array([dataset_frame9.process_func(line) for line in batch_data_list_frame9]) # batch_data_frame10 = np.array([dataset_frame10.process_func(line) for line in batch_data_list_frame10]) # batch_data_frame11 = np.array([dataset_frame11.process_func(line) for line in batch_data_list_frame11]) # batch_data_frame12 = np.array([dataset_frame12.process_func(line) for line in batch_data_list_frame12]) # batch_data_frame13 = np.array([dataset_frame13.process_func(line) for line in batch_data_list_frame13]) # batch_data_frame14 = np.array([dataset_frame14.process_func(line) for line in batch_data_list_frame14]) # print(np.array([n['dataset_frame1'].process_func(line) for line in n['batch_data_list_frame1']])) for i in range(FLAGS.num_out + FLAGS.num_in): exec( "tr = np.array([dataset_frame{}.process_func(batch_data_list_frame{}[0])])" .format(i + 1, i + 1), m, n) exec("np.expand_dims(tr,axis=0)") exec("s='batch_data_list_frame{}'".format(i + 1), m, n) for line in n[n['s']][1:]: n['line'] = line exec( "tr=np.append(tr,np.expand_dims(dataset_frame1.process_func(line),axis=0),axis=0)", m, n) exec('batch_data_frame{} = tr'.format(i + 1), m, n) # batch_data_frame1 = p_queue_frame1.get_batch() # batch_data_frame2 = p_queue_frame2.get_batch() # batch_data_frame3 = p_queue_frame3.get_batch() exec("feed_dict = " + feed_str, m, n) # exec("feed_dict = "+feed_str+", target_placeholder: batch_data_frame21}",m,n) # exec("feed_dict = {input_placeholder: np.concatenate((batch_data_frame1, batch_data_frame2, batch_data_frame3, batch_data_frame4, batch_data_frame5, batch_data_frame6, batch_data_frame7, batch_data_frame8, batch_data_frame9, batch_data_frame10, batch_data_frame11, batch_data_frame12, batch_data_frame13, batch_data_frame14, batch_data_frame15, batch_data_frame16, batch_data_frame17, batch_data_frame18, batch_data_frame19, batch_data_frame20), 3), target_placeholder: np.concatenate((batch_data_frame11,batch_data_frame12,batch_data_frame13,batch_data_frame14),3)}",m,n) # exec("feed_dict = {input_placeholder: np.concatenate((batch_data_frame1, batch_data_frame2, batch_data_frame3, batch_data_frame4, batch_data_frame5, batch_data_frame6, batch_data_frame7, batch_data_frame8, batch_data_frame9, batch_data_frame10, batch_data_frame11, batch_data_frame12, batch_data_frame13, batch_data_frame14, batch_data_frame15, batch_data_frame16, batch_data_frame17, batch_data_frame18, batch_data_frame19, batch_data_frame20, batch_data_frame21, batch_data_frame22, batch_data_frame23, batch_data_frame24, batch_data_frame25, batch_data_frame26, batch_data_frame27, batch_data_frame28, batch_data_frame29, batch_data_frame30, batch_data_frame31, batch_data_frame32, batch_data_frame33, batch_data_frame34, batch_data_frame35, batch_data_frame36, batch_data_frame37, batch_data_frame38, batch_data_frame39, batch_data_frame40), 3), target_placeholder: batch_data_frame41}",m,n) # Run single step update. _, loss_value = sess.run([update_op, total_loss], feed_dict=n['feed_dict']) if n['batch_idx'] == 0: # Shuffle data at each epoch. # random.seed(1) # shuffle(data_list_frame1) # random.seed(1) # shuffle(data_list_frame2) # random.seed(1) # shuffle(data_list_frame3) # random.seed(1) # shuffle(data_list_frame4) # random.seed(1) # shuffle(data_list_frame5) # random.seed(1) # shuffle(data_list_frame6) # random.seed(1) # shuffle(data_list_frame7) # random.seed(1) # shuffle(data_list_frame8) # random.seed(1) # shuffle(data_list_frame9) # random.seed(1) # shuffle(data_list_frame10) # random.seed(1) # shuffle(data_list_frame11) # random.seed(1) # shuffle(data_list_frame12) # random.seed(1) # shuffle(data_list_frame13) # random.seed(1) # shuffle(data_list_frame14) for i in range(FLAGS.num_in + FLAGS.num_out): exec("random.seed(1)") exec("shuffle(data_list_frame{})".format(i + 1), m, n) print('Epoch Number: %d' % int(step / n['epoch_num'])) # Output Summary if step % 10 == 0: # summary_str = sess.run(summary_op, feed_dict = feed_dict) # summary_writer.add_summary(summary_str, step) print("Loss at step %d: %f" % (step, loss_value)) results = np.append(results, [[step, loss_value]], axis=0) if step % 100 == 0: # Run a batch of images try: prediction_np, target_np = sess.run( [prediction, target_placeholder], feed_dict=n['feed_dict']) for i in range(0, prediction_np.shape[0]): for j in range(0, FLAGS.num_out): file_name = FLAGS.train_image_dir + str( i) + '_out_{}.png'.format(j) file_name_label = FLAGS.train_image_dir + str( i) + '_gt_{}.png'.format(j) imsave(prediction_np[i, :, :, j * 3:(j + 1) * 3], file_name) imsave(target_np[i, :, :, j * 3:(j + 1) * 3], file_name_label) except ValueError: print(prediction_np[0, :, :, 0:3].shape) print(target_np[0, :, :, 0:3].shape) break # Save checkpoint if step % 200 == 0 or (step + 1) == FLAGS.max_steps: checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step) np.savetxt(FLAGS.train_image_dir + "results.csv", results[1:, :], delimiter=",", header="iter,loss", comments='')
def train(dataset_frame1, dataset_frame2, dataset_frame3): """Trains a model.""" with tf.Graph().as_default(): # Create input. data_list_frame1 = dataset_frame1.read_data_list_file() dataset_frame1 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame1)) dataset_frame1 = dataset_frame1.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image) dataset_frame1 = dataset_frame1.prefetch(100) data_list_frame2 = dataset_frame2.read_data_list_file() dataset_frame2 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame2)) dataset_frame2 = dataset_frame2.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image) dataset_frame2 = dataset_frame2.prefetch(100) data_list_frame3 = dataset_frame3.read_data_list_file() dataset_frame3 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame3)) dataset_frame3 = dataset_frame3.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image) dataset_frame3 = dataset_frame3.prefetch(100) batch_frame1 = dataset_frame1.batch(FLAGS.batch_size).make_initializable_iterator() batch_frame2 = dataset_frame2.batch(FLAGS.batch_size).make_initializable_iterator() batch_frame3 = dataset_frame3.batch(FLAGS.batch_size).make_initializable_iterator() # Create input and target placeholder. input_placeholder = tf.concat([batch_frame1.get_next(), batch_frame3.get_next()], 3) target_placeholder = batch_frame2.get_next() # input_resized = tf.image.resize_area(input_placeholder, [128, 128]) # target_resized = tf.image.resize_area(target_placeholder,[128, 128]) # Prepare model. model = Voxel_flow_model(is_train=True) prediction, flow_motion, flow_mask = model.inference(input_placeholder) # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder) reproduction_loss = model.loss(prediction, flow_motion, flow_mask, target_placeholder, FLAGS.lambda_motion, FLAGS.lambda_mask) # total_loss = reproduction_loss + prior_loss total_loss = reproduction_loss # Perform learning rate scheduling. learning_rate = FLAGS.initial_learning_rate # Create an optimizer that performs gradient descent. opt = tf.train.AdamOptimizer(learning_rate) grads = opt.compute_gradients(total_loss) update_op = opt.apply_gradients(grads) # Create summaries summaries = tf.get_collection(tf.GraphKeys.SUMMARIES) summaries.append(tf.summary.scalar('total_loss', total_loss)) summaries.append(tf.summary.scalar('reproduction_loss', reproduction_loss)) # summaries.append(tf.summary.scalar('prior_loss', prior_loss)) summaries.append(tf.summary.image('Input Image (before)', input_placeholder[:, :, :, 0:3], 3)) summaries.append(tf.summary.image('Input Image (after)', input_placeholder[:, :, :, 3:6], 3)) summaries.append(tf.summary.image('Output Image', prediction, 3)) summaries.append(tf.summary.image('Target Image', target_placeholder, 3)) # summaries.append(tf.summary.image('Flow', flow, 3)) # Create a saver. saver = tf.train.Saver(tf.global_variables()) # Build the summary operation from the last tower summaries. summary_op = tf.summary.merge_all() # Restore checkpoint from file. if FLAGS.pretrained_model_checkpoint_path \ and tf.train.get_checkpoint_state(FLAGS.pretrained_model_checkpoint_path): sess = tf.Session() assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path) ckpt = tf.train.get_checkpoint_state( FLAGS.pretrained_model_checkpoint_path) restorer = tf.train.Saver() restorer.restore(sess, ckpt.model_checkpoint_path) print('%s: Pre-trained model restored from %s' % (datetime.now(), ckpt.model_checkpoint_path)) sess.run([batch_frame1.initializer, batch_frame2.initializer, batch_frame3.initializer]) else: # Build an initialization operation to run below. print('No existing checkpoints.') init = tf.global_variables_initializer() sess = tf.Session() sess.run([init, batch_frame1.initializer, batch_frame2.initializer, batch_frame3.initializer]) # Summary Writter summary_writer = tf.summary.FileWriter( FLAGS.train_dir, graph=sess.graph) data_size = len(data_list_frame1) epoch_num = int(data_size / FLAGS.batch_size) for step in range(0, FLAGS.max_steps): batch_idx = step % epoch_num # Run single step update. _, loss_value = sess.run([update_op, total_loss]) if batch_idx == 0: print('Epoch Number: %d' % int(step / epoch_num)) if step % 10 == 0: print("Loss at step %d: %f" % (step, loss_value)) if step % 100 == 0: # Output Summary summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, step) if step % 500 == 0: # Run a batch of images prediction_np, target_np = sess.run([prediction, target_placeholder]) for i in range(0,prediction_np.shape[0]): file_name = FLAGS.train_image_dir+str(i)+'_out.png' file_name_label = FLAGS.train_image_dir+str(i)+'_gt.png' imwrite(file_name, prediction_np[i,:,:,:]) imwrite(file_name_label, target_np[i,:,:,:]) # Save checkpoint if step % 500 == 0 or (step +1) == FLAGS.max_steps: checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step)