def export(model_params, checkpoint_file, config=None): # Input data batch_size = 1 im_size = model_params.im_size guide_image = tf.placeholder(tf.float32, [batch_size, 224, 224, 3]) gb_image = tf.placeholder(tf.float32, [batch_size, im_size[1], im_size[0], 1]) input_image = tf.placeholder(tf.float32, [batch_size, im_size[1], im_size[0], 3]) # Create model model_func = get_model_func(model_params.base_model) # split the model into visual modulator and other parts, visual modulator only need to run once if model_params.use_visual_modulator: if model_params.base_model =='lite': v_m_params = visual_modulator_lite(guide_image, model_params, is_training=False) else: v_m_params = visual_modulator(guide_image, model_params, is_training=False) else: v_m_params = None net, end_points = model_func([guide_image, gb_image, input_image], model_params, visual_modulator_params = v_m_params, is_training=False) probabilities = tf.nn.sigmoid(net, name = 'prob') global_step = tf.Variable(0, name='global_step', trainable=False) rewrite_options = rewriter_config_pb2.RewriterConfig() rewrite_options.optimizers.append('pruning') rewrite_options.optimizers.append('constfold') rewrite_options.optimizers.append('layout') graph_options = tf.GraphOptions( rewrite_options=rewrite_options, infer_shapes=True) config = tf.ConfigProto( graph_options=graph_options, allow_soft_placement=True, ) output_names = ['prob'] for i, v_m_param in enumerate(v_m_params): visual_mod_name = 'visual_mod_params_%d' % (i+1) tf.identity(v_m_param, name = visual_mod_name) output_names.append(visual_mod_name) # Create a saver to load the network saver = tf.train.Saver([v for v in tf.global_variables()]) #if '-up' not in v.name and '-cr' not in v.name]) save_name = checkpoint_file + '.graph.pb' with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) saver.restore(sess, checkpoint_file) if not model_params.base_model == 'lite': sess.run(interp_surgery(tf.global_variables())) output_graph_def = graph_util.convert_variables_to_constants( sess, sess.graph_def, output_names) with open(save_name, 'wb') as writer: writer.write(output_graph_def.SerializeToString()) model_params.output_names = output_names with open(save_name+'.json', 'w') as writer: json.dump(vars(model_params), writer) print('Model saved in', save_name)
def test(dataset, model_params, checkpoint_file, result_path, batch_size=1, config=None): """Test one sequence Args: dataset: Reference to a Dataset object instance model_params: Model parameters checkpoint_path: Path of the checkpoint to use for the evaluation result_path: Path to save the output images config: Reference to a Configuration object used in the creation of a Session Returns: """ if config is None: config = tf.ConfigProto() config.gpu_options.allow_growth = True # config.log_device_placement = True config.allow_soft_placement = True tf.logging.set_verbosity(tf.logging.INFO) assert batch_size==1, "only allow batch size equal to 1 for testing" # Input data guide_image = tf.placeholder(tf.float32, [batch_size, None, None, 3]) gb_image = tf.placeholder(tf.float32, [batch_size, None, None, 1]) input_image = tf.placeholder(tf.float32, [batch_size, None, None, 3]) # Create model model_func = get_model_func(model_params.base_model) # split the model into visual modulator and other parts, visual modulator only need to run once if model_params.use_visual_modulator: if model_params.base_model =='lite': v_m_params = visual_modulator_lite(guide_image, model_params, is_training=False) else: v_m_params = visual_modulator(guide_image, model_params, is_training=False) else: v_m_params = None net, end_points = model_func([guide_image, gb_image, input_image], model_params, visual_modulator_params = v_m_params, is_training=False) probabilities = tf.nn.sigmoid(net) global_step = tf.Variable(0, name='global_step', trainable=False) # Create a saver to load the network saver = tf.train.Saver([v for v in tf.global_variables()]) #if '-up' not in v.name and '-cr' not in v.name]) with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) saver.restore(sess, checkpoint_file) if not model_params.base_model == 'lite': sess.run(interp_surgery(tf.global_variables())) if not os.path.exists(result_path): os.makedirs(result_path) print('start testing process') time_start = time.time() for frame in range(dataset.get_test_size()): guide_images, gb_images, images, save_names = dataset.next_batch(batch_size, 'test') # create folder for results if len(save_names[0].split('/')) > 1: save_path = os.path.join(result_path, *(save_names[0].split('/')[:-1])) if not os.path.exists(save_path): os.makedirs(save_path) if images is None or gb_images is None: # first frame of a squence if model_params.use_visual_modulator: curr_v_m_params = sess.run(v_m_params, feed_dict={guide_image: guide_images}) # create a black dummy image for result of the first frame, to be compatible with DAVIS eval toolkit scipy.misc.imsave(os.path.join(result_path, save_names[0]), np.zeros(guide_images.shape[1:3])) else: feed_dict = { gb_image:gb_images, input_image:images} if model_params.use_visual_modulator: if model_params.base_model=='lite': for v_m_param, curr_param in zip(v_m_params, curr_v_m_params): feed_dict[v_m_param] = curr_param else: feed_dict[v_m_params] = curr_v_m_params res_all = sess.run([probabilities], feed_dict=feed_dict) res = res_all[0] if model_params.crf_postprocessing: res_np = np.zeros(res.shape[:-1]) res_np[0] = dataset.crf_processing(dataset.images[0], (res[0,:,:,0] > 0.5).astype(np.int32)) else: res_np = res.astype(np.float32)[:, :, :, 0] > 0.5 print('Saving ' + os.path.join(result_path, save_names[0])) scipy.misc.imsave(os.path.join(result_path, save_names[0]), res_np[0].astype(np.float32)) curr_score_name = save_names[0][:-4] if model_params.save_score: print('Saving ' + os.path.join(result_path, curr_score_name) + '.npy') np.save(os.path.join(result_path, curr_score_name), res.astype(np.float32)[0,:,:,0]) time_finish = time.time() time_elapsed = time_finish - time_start print('Total time elasped: %.3f seconds' % time_elapsed) print('Each frame takes %.3f seconds' % (time_elapsed / dataset.get_test_size()))
def train_finetune(dataset, model_params, learning_rate, logs_path, max_training_iters, save_step, display_step, global_step, iter_mean_grad=1, batch_size=1, resume_training=False, config=None, use_image_summary=True, ckpt_name="osmn"): """Train OSMN Args: dataset: Reference to a Dataset object instance model_params: Model parameters initial_ckpt: Path to the checkpoint to initialize the whole network or visual modulator, depend on seg_ckpt seg_ckpt: If seg_ckpt is not None, initial_ckpt is used to initialize the visual modulator, and seg_ckpt is used to initialize segmentation network learning_rate: Value for the learning rate. It can be a number or an instance to a learning rate object. logs_path: Path to store the checkpoints max_training_iters: Number of training iterations save_step: A checkpoint will be created every save_steps display_step: Information of the training will be displayed every display_steps global_step: Reference to a Variable that keeps track of the training steps iter_mean_grad: Number of gradient computations that are average before updating the weights batch_size: Size of the training batch resume_training: Boolean to try to restore from a previous checkpoint (True) or not (False) config: Reference to a Configuration object used in the creation of a Session use_image_summary: Boolean to use image summary during training in tensorboard ckpt_name: checkpoint name for saving Returns: """ model_name = os.path.join(logs_path, ckpt_name+".ckpt") if config is None: config = tf.ConfigProto() config.gpu_options.allow_growth = True # config.log_device_placement = True config.allow_soft_placement = True tf.logging.set_verbosity(tf.logging.INFO) # Prepare the input data guide_image = tf.placeholder(tf.float32, [batch_size, 224, 224, 3]) input_image = tf.placeholder(tf.float32, [batch_size, None, None, 3]) gb_image = tf.placeholder(tf.float32, [batch_size, None, None, 1]) input_label = tf.placeholder(tf.float32, [batch_size, None, None, 1]) model_func = get_model_func(model_params.base_model) net, end_points = model_func([guide_image, gb_image, input_image], model_params, is_training=True) # Define loss with tf.name_scope('losses'): main_loss = class_balanced_cross_entropy_loss(net, input_label) tf.summary.scalar('main_loss', main_loss) total_loss = main_loss + tf.add_n(tf.losses.get_regularization_losses()) tf.summary.scalar('total_loss', total_loss) # Define optimization method with tf.name_scope('optimization'): tf.summary.scalar('learning_rate', learning_rate) optimizer = tf.train.AdamOptimizer(learning_rate) grads_and_vars = optimizer.compute_gradients(total_loss) with tf.name_scope('grad_accumulator'): grad_accumulator = {} for ind in range(0, len(grads_and_vars)): if grads_and_vars[ind][0] is not None: grad_accumulator[ind] = tf.ConditionalAccumulator(grads_and_vars[ind][0].dtype) with tf.name_scope('apply_gradient'): grad_accumulator_ops = [] for var_ind, grad_acc in grad_accumulator.items(): var_name = str(grads_and_vars[var_ind][1].name).split(':')[0] var_grad = grads_and_vars[var_ind][0] grad_accumulator_ops.append(grad_acc.apply_grad(var_grad, local_step=global_step)) with tf.name_scope('take_gradients'): mean_grads_and_vars = [] for var_ind, grad_acc in grad_accumulator.items(): mean_grads_and_vars.append( (grad_acc.take_grad(iter_mean_grad), grads_and_vars[var_ind][1])) apply_gradient_op = optimizer.apply_gradients(mean_grads_and_vars, global_step=global_step) # Log training info merged_summary_op = tf.summary.merge_all() # Log results on training images if use_image_summary: probabilities = tf.nn.sigmoid(net) input_image_orig = input_image / model_params.scale_value + model_params.mean_value guide_image_orig = guide_image / model_params.scale_value + model_params.mean_value img_summary = binary_seg_summary(input_image_orig, probabilities, gb_image, input_label) vg_summary = visual_guide_summary(guide_image_orig) # Initialize variables init = tf.global_variables_initializer() with tf.Session(config=config) as sess: print('Init variable') sess.run(init) tvars = tf.trainable_variables() # op to write logs to Tensorboard summary_writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph()) # Create saver to manage checkpoints saver = tf.train.Saver(max_to_keep=40) last_ckpt_path = tf.train.latest_checkpoint(logs_path) if last_ckpt_path is not None and resume_training: # Load last checkpoint print('Initializing from previous checkpoint...') saver.restore(sess, last_ckpt_path) step = global_step.eval() + 1 elif model_params.whole_model_path == '': print('Initializing from pre-trained imagenet model...') if model_params.use_visual_modulator: load_model(model_params.vis_mod_model_path, 'osmn/modulator')(sess) if model_params.seg_model_path != '': load_model(model_params.seg_model_path, 'osmn/seg')(sess) step = 1 else: print('Initializing from pre-trained model...') load_model(model_params.whole_model_path, 'osmn')(sess) step = 1 #if model_params.base_model != 'lite': sess.run(interp_surgery(tf.global_variables())) print('Weights initialized') print('Start training') while step < max_training_iters + 1: # Average the gradient for _ in range(0, iter_mean_grad): batch_g_image, batch_gb_image, batch_image, batch_label = dataset.next_batch(batch_size, 'train') run_res = sess.run([total_loss, merged_summary_op] + grad_accumulator_ops, feed_dict={guide_image: batch_g_image, gb_image: batch_gb_image, input_image: batch_image, input_label: batch_label}) batch_loss = run_res[0] summary = run_res[1] # Apply the gradients sess.run(apply_gradient_op) # Momentum updates here its statistics # Save summary reports summary_writer.add_summary(summary, step) # Display training status if step % display_step == 0: if use_image_summary: #test_g_image, test_gb_image, test_image, _ = dataset.next_batch(batch_size, 'test') curr_img_summary = sess.run([img_summary, vg_summary], feed_dict={guide_image:batch_g_image, gb_image:batch_gb_image, input_image: batch_image, input_label: batch_label}) for s in curr_img_summary: summary_writer.add_summary(s, step) print("{} Iter {}: Training Loss = {:.4f}".format(datetime.now(), step, batch_loss),file=sys.stderr) # Save a checkpoint if step % save_step == 0: save_path = saver.save(sess, model_name, global_step=global_step) print("Model saved in file: %s" % save_path) step += 1 if (step - 1) % save_step != 0: save_path = saver.save(sess, model_name, global_step=global_step) print("Model saved in file: %s" % save_path) print('Finished training.')