def test_net(): # Set logging test_logger = logging.getLogger('main.testing') test_logger.info('---Begin testing: ---') # Load network net = SKETCHNET() # Testing data data_records = [item for item in os.listdir(hyper_params['dbTest']) if item.endswith('.tfrecords')] test_records = [os.path.join(hyper_params['dbTest'], item) for item in data_records if item.find('test') != -1] test_loss, test_d_loss, test_n_loss, test_ds_loss, test_r_loss, test_real_dloss, \ test_real_nloss, test_omega_loss, test_gt_normal, test_f_normal, test_gt_depth, test_f_depth, test_gt_ds, \ test_gt_lines, test_reg_mask, test_f_cfmap, test_gt_a, test_gt_b, test_f_a, test_f_b, test_inputList \ = test_procedure(net, test_records) # Saver tf_saver = tf.train.Saver() config = tf.ConfigProto()
def train_net(): # Set logging train_logger = logging.getLogger('main.training') train_logger.info('---Begin training: ---') # Load network net = SKETCHNET() # Train train_data_records = [ item for item in os.listdir(hyper_params['dbTrain']) if item.endswith('.tfrecords') ] train_records = [ os.path.join(hyper_params['dbTrain'], item) for item in train_data_records if item.find('train') != -1 ] train_summary, train_step, train_loss, train_inputList = train_procedure( net, train_records) # Validation val_data_records = [ item for item in os.listdir(hyper_params['dbEval']) if item.endswith('.tfrecords') ] val_records = [ os.path.join(hyper_params['dbEval'], item) for item in val_data_records if item.find('eval') != -1 ] num_eval_samples = sum( 1 for _ in tf.python_io.tf_record_iterator(val_records[0])) num_eval_itr = num_eval_samples // hyper_params['batchSize'] num_eval_itr += 1 val_proto, val_total_loss, val_data_loss, val_smooth_loss, val_inputList = validation_procedure( net, val_records) valid_loss = tf.placeholder(tf.float32, name='val_loss') valid_loss_proto = tf.summary.scalar('Validating_TotalLoss', valid_loss) valid_data_loss = tf.placeholder(tf.float32, name='val_data_loss') valid_data_loss_proto = tf.summary.scalar('Validating_DataL1Loss', valid_data_loss) valid_smooth_loss = tf.placeholder(tf.float32, name='val_smooth_loss') valid_smooth_loss_proto = tf.summary.scalar('Validating_SmoothL1Loss', valid_smooth_loss) valid_loss_merge = tf.summary.merge( [valid_loss_proto, valid_data_loss_proto, valid_smooth_loss_proto]) # Saver tf_saver = tf.train.Saver(max_to_keep=100) config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True # config.log_device_placement = True with tf.Session(config=config) as sess: # TF summary train_writer = tf.summary.FileWriter(output_folder + '/train', sess.graph) # initialize init_op = tf.global_variables_initializer() sess.run(init_op) # Start input enqueue threads coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) train_logger.info('pre-load data to fill data buffer...') for titr in range(hyper_params['maxIter']): # Validation if titr % hyper_params['exeValStep'] == 0: idx = randint(0, num_eval_itr - 1) avg_loss = 0.0 avg_data_loss = 0.0 avg_smooth_loss = 0.0 for eitr in range(num_eval_itr): # get real input val_real_input = sess.run(val_inputList) if eitr == idx: val_merge, cur_v_loss, cur_v_data_loss, cur_v_smooth_loss = sess.run( [ val_proto, val_total_loss, val_data_loss, val_smooth_loss ], feed_dict={ 'npr_input:0': val_real_input[0], 'ds_input:0': val_real_input[1], 'fLMask_input:0': val_real_input[2], 'fLInvMask_input:0': val_real_input[3], 'gtField_input:0': val_real_input[4], 'clIMask_input:0': val_real_input[5], 'shapeMask_input:0': val_real_input[6], '2dMask_input:0': val_real_input[7], 'sLMask_input:0': val_real_input[8], 'curvMag_input:0': val_real_input[9] }) train_writer.add_summary(val_merge, titr) else: cur_v_loss, cur_v_data_loss, cur_v_smooth_loss = sess.run( [val_total_loss, val_data_loss, val_smooth_loss], feed_dict={ 'npr_input:0': val_real_input[0], 'ds_input:0': val_real_input[1], 'fLMask_input:0': val_real_input[2], 'fLInvMask_input:0': val_real_input[3], 'gtField_input:0': val_real_input[4], 'clIMask_input:0': val_real_input[5], 'shapeMask_input:0': val_real_input[6], '2dMask_input:0': val_real_input[7], 'sLMask_input:0': val_real_input[8], 'curvMag_input:0': val_real_input[9] }) avg_loss += cur_v_loss avg_data_loss += cur_v_data_loss avg_smooth_loss += cur_v_smooth_loss avg_loss /= num_eval_itr avg_data_loss /= num_eval_itr avg_smooth_loss /= num_eval_itr valid_summary = sess.run(valid_loss_merge, feed_dict={ 'val_loss:0': avg_loss, 'val_data_loss:0': avg_data_loss, 'val_smooth_loss:0': avg_smooth_loss }) train_writer.add_summary(valid_summary, titr) train_logger.info('Validation loss at step {} is: {}'.format( titr, avg_loss)) # Save model if titr % hyper_params['saveModelStep'] == 0: tf_saver.save( sess, hyper_params['outDir'] + '/savedModel/my_model{:d}.ckpt'.format(titr)) train_logger.info('Save model at step: {:d}'.format(titr)) # Training # get real input train_real_input = sess.run(train_inputList) t_summary, _, t_loss = sess.run( [train_summary, train_step, train_loss], feed_dict={ 'npr_input:0': train_real_input[0], 'ds_input:0': train_real_input[1], 'fLMask_input:0': train_real_input[2], 'fLInvMask_input:0': train_real_input[3], 'gtField_input:0': train_real_input[4], 'clIMask_input:0': train_real_input[5], 'shapeMask_input:0': train_real_input[6], '2dMask_input:0': train_real_input[7], 'sLMask_input:0': train_real_input[8], 'curvMag_input:0': train_real_input[9] }) # display if titr % hyper_params['dispLossStep'] == 0: train_writer.add_summary(t_summary, titr) train_logger.info('Training loss at step {} is: {}'.format( titr, t_loss)) # Finish training coord.request_stop() coord.join(threads) # Release resource train_writer.close()
def test_net(): # Set logging test_logger = logging.getLogger('main.testing') test_logger.info('---Begin testing: ---') # Load network net = SKETCHNET() # Testing data data_records = [item for item in os.listdir(hyper_params['dbTest']) if item.endswith('.tfrecords')] test_records = [os.path.join(hyper_params['dbTest'], item) for item in data_records if item.find('test') != -1] test_loss, test_d_loss, test_n_loss, test_real_dloss, test_real_nloss, test_gt_normal, test_f_normal, \ test_gt_depth, test_f_depth, test_gt_ds, test_gt_lines, test_inputList = test_procedure(net, test_records) # Saver tf_saver = tf.train.Saver() config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True # config.log_device_placement = True with tf.Session(config=config) as sess: # initialize init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init_op) # Restore model ckpt = tf.train.latest_checkpoint(hyper_params['cktDir']) if ckpt: tf_saver.restore(sess, ckpt) test_logger.info('restore from the checkpoint {}'.format(ckpt)) # write graph: tf.train.write_graph(sess.graph_def, hyper_params['outDir'], hyper_params['graphName'], as_text=True) test_logger.info('save graph tp pbtxt, done') # Start input enqueue threads coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: titr = 0 avg_loss = 0.0 while not coord.should_stop(): # get real input test_real_input = sess.run(test_inputList) t_loss, t_d_loss, t_n_loss, t_real_dloss, t_real_nloss, \ t_gt_normal, t_f_normal, t_gt_depth, t_f_depth, t_gt_ds, t_gt_lines = sess.run( [test_loss, test_d_loss, test_n_loss, test_real_dloss, test_real_nloss, test_gt_normal, test_f_normal, test_gt_depth, test_f_depth, test_gt_ds, test_gt_lines], feed_dict={'npr_input:0': test_real_input[0], 'ds_input:0': test_real_input[1], 'fLMask_input:0': test_real_input[2], 'gtN_input:0': test_real_input[3], 'gtD_input:0': test_real_input[4], 'clIMask_input:0': test_real_input[5], 'shapeMask_input:0': test_real_input[6], 'dsMask_input:0': test_real_input[7], '2dMask_input:0': test_real_input[8], 'sLMask_input:0': test_real_input[9], 'curvMag_input:0': test_real_input[10] }) # Record loss avg_loss += t_loss test_logger.info( 'Test case {}, loss: {}, {}, {}, {}, {}, 0.0, 0.0, 0.0, 0.0'.format(titr, t_loss, t_real_dloss, t_real_nloss, t_d_loss, t_n_loss)) # Write img out if titr < 200: fn1 = os.path.join(out_img_dir, 'gt_depth_' + str(titr) + '.exr') fn2 = os.path.join(out_img_dir, 'fwd_depth_' + str(titr) + '.exr') fn3 = os.path.join(out_img_dir, 'gt_normal_' + str(titr) + '.exr') fn4 = os.path.join(out_img_dir, 'fwd_normal_' + str(titr) + '.exr') out_gt_d = t_gt_depth[0, :, :, :] out_gt_d.astype(np.float32) out_gt_d = np.flip(out_gt_d, 0) cv2.imwrite(fn1, out_gt_d) out_f_d = t_f_depth[0, :, :, :] out_f_d.astype(np.float32) out_f_d = np.flip(out_f_d, 0) cv2.imwrite(fn2, out_f_d) out_gt_normal = t_gt_normal[0, :, :, :] out_gt_normal = out_gt_normal[:, :, [2, 1, 0]] out_gt_normal.astype(np.float32) out_gt_normal = np.flip(out_gt_normal, 0) cv2.imwrite(fn3, out_gt_normal) out_f_normal = t_f_normal[0, :, :, :] out_f_normal = out_f_normal[:, :, [2, 1, 0]] out_f_normal.astype(np.float32) out_f_normal = np.flip(out_f_normal, 0) cv2.imwrite(fn4, out_f_normal) titr += 1 if titr % 100 == 0: print('Iteration: {}'.format(titr)) avg_loss /= titr test_logger.info('Finish test model, average loss is: {}'.format(avg_loss)) except tf.errors.OutOfRangeError: print('Test Done.') finally: coord.request_stop() # Finish testing coord.join(threads)
def test_net(): # Set logging test_logger = logging.getLogger('main.testing') test_logger.info('---Begin testing: ---') # Load network net = SKETCHNET() # Testing data data_records = [ item for item in os.listdir(hyper_params['dbTest']) if item.endswith('.tfrecords') ] test_records = [ os.path.join(hyper_params['dbTest'], item) for item in data_records if item.find('test') != -1 ] test_loss, test_d_loss, test_n_loss, test_ds_loss, test_r_loss, test_real_dloss, \ test_real_nloss, test_omega_loss, test_gt_normal, test_f_normal, test_gt_depth, test_f_depth, test_gt_ds, \ test_gt_lines, test_reg_mask, test_f_cfmap = test_procedure(net, test_records) # Saver tf_saver = tf.train.Saver() with tf.Session() as sess: # initialize init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init_op) # Restore model ckpt = tf.train.latest_checkpoint(hyper_params['cktDir']) if ckpt: tf_saver.restore(sess, ckpt) test_logger.info('restore from the checkpoint {}'.format(ckpt)) # writeGraph: tf.train.write_graph(sess.graph_def, hyper_params['outDir'], hyper_params['graphName'], as_text=True) test_logger.info('save graph tp pbtxt, done') # Start input enqueue threads coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: titr = 0 avg_loss = 0.0 while not coord.should_stop(): t_loss, t_d_loss, t_n_loss, t_ds_loss, t_r_loss, t_real_dloss, t_real_nloss, \ t_omega_loss, t_gt_normal, t_f_normal, t_gt_depth, t_f_depth, t_gt_ds, t_gt_lines, t_reg_mask, \ t_f_cfmap = sess.run( [test_loss, test_d_loss, test_n_loss, test_ds_loss, test_r_loss, test_real_dloss, test_real_nloss, test_omega_loss, test_gt_normal, test_f_normal, test_gt_depth, test_f_depth, test_gt_ds, test_gt_lines, test_reg_mask, test_f_cfmap]) # Record loss avg_loss += t_loss test_logger.info( 'Test case {}, loss: {}, {}, {}, {}, {}, {}, {}, 0.0, {}'. format(titr, t_loss, t_real_dloss, t_real_nloss, t_d_loss, t_n_loss, t_ds_loss, t_r_loss, t_omega_loss)) # Write img out # if titr < 200: fn1 = os.path.join(out_img_dir, 'gt_depth_' + str(titr) + '.exr') fn2 = os.path.join(out_img_dir, 'fwd_depth_' + str(titr) + '.exr') fn3 = os.path.join(out_img_dir, 'gt_normal_' + str(titr) + '.exr') fn4 = os.path.join(out_img_dir, 'fwd_normal_' + str(titr) + '.exr') fn5 = os.path.join(out_img_dir, 'fwd_conf_map_' + str(titr) + '.exr') out_gt_d = t_gt_depth[0, :, :, :] out_gt_d.astype(np.float32) out_gt_d = np.flip(out_gt_d, 0) cv2.imwrite(fn1, out_gt_d) out_f_d = t_f_depth[0, :, :, :] out_f_d.astype(np.float32) out_f_d = np.flip(out_f_d, 0) cv2.imwrite(fn2, out_f_d) out_gt_normal = t_gt_normal[0, :, :, :] out_gt_normal = out_gt_normal[:, :, [2, 1, 0]] out_gt_normal.astype(np.float32) out_gt_normal = np.flip(out_gt_normal, 0) cv2.imwrite(fn3, out_gt_normal) out_f_normal = t_f_normal[0, :, :, :] out_f_normal = out_f_normal[:, :, [2, 1, 0]] out_f_normal.astype(np.float32) out_f_normal = np.flip(out_f_normal, 0) cv2.imwrite(fn4, out_f_normal) out_f_cfmap = t_f_cfmap[0, :, :, :] out_f_cfmap.astype(np.float32) out_f_cfmap = np.flip(out_f_cfmap, 0) cv2.imwrite(fn5, out_f_cfmap) titr += 1 if titr % 100 == 0: print('Iteration: {}'.format(titr)) avg_loss /= titr test_logger.info( 'Finish test model, average loss is: {}'.format(avg_loss)) except tf.errors.OutOfRangeError: print('Test Done.') finally: coord.request_stop() # Finish testing coord.join(threads)
def train_net(): # Set logging train_logger = logging.getLogger('main.training') train_logger.info('---Begin training: ---') # Load network net = SKETCHNET() # regularization weight reg_weight_value = tf.placeholder(tf.float32, name='reg_weight') # Train train_data_records = [ item for item in os.listdir(hyper_params['dbTrain']) if item.endswith('.tfrecords') ] train_records = [ os.path.join(hyper_params['dbTrain'], item) for item in train_data_records if item.find('train') != -1 ] train_summary, train_step, train_loss = train_procedure( net, train_records, reg_weight_value) # Validation val_data_records = [ item for item in os.listdir(hyper_params['dbEval']) if item.endswith('.tfrecords') ] val_records = [ os.path.join(hyper_params['dbEval'], item) for item in val_data_records if item.find('eval') != -1 ] num_eval_samples = sum( 1 for _ in tf.python_io.tf_record_iterator(val_records[0])) num_eval_itr = num_eval_samples // hyper_params['batchSize'] num_eval_itr += 1 val_proto, val_loss, val_d_loss, val_n_loss, val_ds_loss, val_reg_loss, val_real_dloss, \ val_real_nloss, val_omega_loss, feild_vars = validation_procedure(net, val_records, reg_weight_value) valid_loss = tf.placeholder(tf.float32, name='val_loss') valid_loss_proto = tf.summary.scalar('Validating_TotalLoss', valid_loss) valid_d_loss = tf.placeholder(tf.float32, name='val_d_loss') valid_d_loss_proto = tf.summary.scalar('Validating_DepthL2Loss', valid_d_loss) valid_n_loss = tf.placeholder(tf.float32, name='val_n_loss') valid_n_loss_proto = tf.summary.scalar('Validating_NormalL2Loss', valid_n_loss) valid_ds_loss = tf.placeholder(tf.float32, name='val_ds_loss') valid_ds_loss_proto = tf.summary.scalar('Validating_DepthSampleL2Loss', valid_ds_loss) valid_reg_loss = tf.placeholder(tf.float32, name='val_reg_loss') valid_reg_loss_proto = tf.summary.scalar('Validating_RegL2Loss', valid_reg_loss) valid_real_dloss = tf.placeholder(tf.float32, name='val_real_dloss') valid_real_dloss_proto = tf.summary.scalar('Validating_RealDLoss', valid_real_dloss) valid_real_nloss = tf.placeholder(tf.float32, name='val_real_nloss') valid_real_nloss_proto = tf.summary.scalar('Validating_RealNLoss', valid_real_nloss) valid_omega_loss = tf.placeholder(tf.float32, name='val_omega_loss') valid_omega_loss_proto = tf.summary.scalar('Validating_OmegaLoss', valid_omega_loss) valid_value_merge = tf.summary.merge([ valid_loss_proto, valid_d_loss_proto, valid_n_loss_proto, valid_ds_loss_proto, valid_reg_loss_proto, valid_real_dloss_proto, valid_real_nloss_proto, valid_omega_loss_proto ]) # Saver tf_saver = tf.train.Saver(max_to_keep=100) tf_field_saver = tf.train.Saver(var_list=feild_vars) config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: # TF summary train_writer = tf.summary.FileWriter(output_folder + '/train', sess.graph) # initialize init_op = tf.global_variables_initializer() sess.run(init_op) # load direction field network checkpoint f_ckpt = tf.train.latest_checkpoint(hyper_params['fCktDir']) if f_ckpt: tf_field_saver.restore(sess, f_ckpt) train_logger.info('restore from the checkpoint {}'.format(f_ckpt)) # Start input enqueue threads train_logger.info('pre-load data into data buffer...') coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) # reg_weight init cur_weight = hyper_params['regWeight'] for titr in range(hyper_params['maxIter']): # update regularization weight: # # 4 gpus # if titr % 5000 == 0 and titr > 0: # cur_weight *= 3.985 # cur_weight = min(cur_weight, 2.0) # 2 gpus if titr % 10000 == 0 and titr > 0: cur_weight *= 3.275 cur_weight = min(cur_weight, 2.0) # # 1 gup # if titr % 10000 == 0 and titr > 0: # cur_weight *= 1.738 # cur_weight = min(cur_weight, 2.0) # Validation if titr % hyper_params['exeValStep'] == 0: idx = randint(0, num_eval_itr - 1) avg_loss = 0.0 avg_d_loss = 0.0 avg_n_loss = 0.0 avg_ds_loss = 0.0 avg_reg_loss = 0.0 avg_real_dloss = 0.0 avg_real_nloss = 0.0 avg_omega_loss = 0.0 for eitr in range(num_eval_itr): if eitr == idx: val_merge, cur_v_loss, cur_vd_loss, cur_vn_loss, cur_vds_loss, cur_vreg_loss, \ cur_real_dloss, cur_omega_loss, cur_real_nloss = sess.run( [val_proto, val_loss, val_d_loss, val_n_loss, val_ds_loss, val_reg_loss, val_real_dloss, val_omega_loss, val_real_nloss], feed_dict={'reg_weight:0': cur_weight}) train_writer.add_summary(val_merge, titr) else: cur_v_loss, cur_vd_loss, cur_vn_loss, cur_vds_loss, cur_vreg_loss, \ cur_real_dloss, cur_omega_loss, cur_real_nloss = \ sess.run([val_loss, val_d_loss, val_n_loss, val_ds_loss, val_reg_loss, val_real_dloss, val_omega_loss, val_real_nloss], feed_dict={'reg_weight:0': cur_weight}) avg_loss += cur_v_loss avg_d_loss += cur_vd_loss avg_n_loss += cur_vn_loss avg_ds_loss += cur_vds_loss avg_reg_loss += cur_vreg_loss avg_real_dloss += cur_real_dloss avg_real_nloss += cur_real_nloss avg_omega_loss += cur_omega_loss avg_loss /= num_eval_itr avg_d_loss /= num_eval_itr avg_n_loss /= num_eval_itr avg_ds_loss /= num_eval_itr avg_reg_loss /= num_eval_itr avg_real_dloss /= num_eval_itr avg_real_nloss /= num_eval_itr avg_omega_loss /= num_eval_itr valid_summary = sess.run(valid_value_merge, feed_dict={ 'val_loss:0': avg_loss, 'val_d_loss:0': avg_d_loss, 'val_n_loss:0': avg_n_loss, 'val_ds_loss:0': avg_ds_loss, 'val_reg_loss:0': avg_reg_loss, 'val_real_dloss:0': avg_real_dloss, 'val_real_nloss:0': avg_real_nloss, 'val_omega_loss:0': avg_omega_loss }) train_writer.add_summary(valid_summary, titr) train_logger.info('Validation loss at step {} is: {}'.format( titr, avg_loss)) # Save model if titr % hyper_params['saveModelStep'] == 0: tf_saver.save( sess, hyper_params['outDir'] + '/savedModel/my_model{:d}.ckpt'.format(titr)) train_logger.info( 'Save model at step: {:d}, reg weight is: {}'.format( titr, cur_weight)) # Training t_summary, _, t_loss = sess.run( [train_summary, train_step, train_loss], feed_dict={'reg_weight:0': cur_weight}) # Display if titr % hyper_params['dispLossStep'] == 0: train_writer.add_summary(t_summary, titr) train_logger.info('Training loss at step {} is: {}'.format( titr, t_loss)) # Finish training coord.request_stop() coord.join(threads) # Release resource train_writer.close()