def load_vgg_imagenet(ckpt_path): """Initialize the network parameters from the VGG-16 pre-trained model provided by TF-SLIM Args: Path to the checkpoint Returns: Function that takes a session and initializes the network """ reader = tf.train.NewCheckpointReader(ckpt_path) var_to_shape_map = reader.get_variable_to_shape_map() vars_corresp = dict() for v in var_to_shape_map: if "conv" in v: vars_corresp[v] = slim.get_model_variables( v.replace("vgg_16", "osvos"))[0] init_fn = slim.assign_from_checkpoint_fn(ckpt_path, vars_corresp) return init_fn
def _get_init_fn(): """Returns a function run by the chief worker to warm-start the training. Note that the init_fn is only run when initializing the model during the very first global step. Returns: An init function run by the supervisor. """ if FLAGS.checkpoint_path is None: return None # Warn the user if a checkpoint exists in the train_dir. Then we'll be # ignoring the checkpoint anyway. if tf.train.latest_checkpoint(FLAGS.train_dir): tf.logging.info( 'Ignoring --checkpoint_path because a checkpoint already exists in %s' % FLAGS.train_dir) return None exclusions = [] if FLAGS.checkpoint_exclude_scopes: exclusions = [scope.strip() for scope in FLAGS.checkpoint_exclude_scopes.split(',')] # TODO(sguada) variables.filter_variables() variables_to_restore = [] for var in slim.get_model_variables(): excluded = False for exclusion in exclusions: if var.op.name.startswith(exclusion): excluded = True break if not excluded: variables_to_restore.append(var) if tf.gfile.IsDirectory(FLAGS.checkpoint_path): checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path) else: checkpoint_path = FLAGS.checkpoint_path tf.logging.info('Fine-tuning from %s' % checkpoint_path) return slim.assign_from_checkpoint_fn( checkpoint_path, variables_to_restore, ignore_missing_vars=FLAGS.ignore_missing_vars)
def resnet_process_data_dir(raw_data_dir): ## Setup ################################################################## video_list = sorted(next(os.walk(raw_data_dir))[1]) ## Script ################################################################# for _, video_name in enumerate(video_list): pickle_out = os.path.join(raw_data_dir, video_name, 'ResNet_preprocess.pk') if not os.path.isfile(pickle_out): print('ResNet preprocessing for ' + video_name) # Image directory info. img_dir = os.path.join(raw_data_dir, video_name, 'src') img_list = sorted(glob.glob(os.path.join(img_dir, '*'))) # Pre-process using ResNet. img_size = resnet_v2.resnet_v2.default_image_size with tf.Graph().as_default(): processed_images = [] for i, img in enumerate(img_list): image = tf.image.decode_jpeg(tf.read_file(img), channels=3) processed_images.append( inception_preprocessing.preprocess_image( image, img_size, img_size, is_training=False)) processed_images = tf.convert_to_tensor(processed_images) with slim.arg_scope(resnet_v2.resnet_arg_scope()): # Return ResNet 2048 vector. logits, _ = resnet_v2.resnet_v2_50(processed_images, num_classes=None, is_training=False) init_fn = slim.assign_from_checkpoint_fn( './methods/annotate_suggest/ResNet/resnet_v2_50.ckpt', slim.get_variables_to_restore()) with tf.Session() as sess: init_fn(sess) np_images, resnet_vectors = sess.run( [processed_images, logits]) resnet_vectors = resnet_vectors[:, 0, 0, :] # Save preprocessed data to pickle file. pickle_data = {'frame_resnet_vectors': resnet_vectors} pickle.dump(pickle_data, open(pickle_out, 'wb'))
def BubbleNets_sort(raw_data_dir, model='BNLF'): # Sorting parameters. n_frames = 5 n_ref = n_frames - 2 n_batch = 5 n_sorts = 1 # Generate video list for frame selection. video_list = sorted(next(os.walk(raw_data_dir))[1]) # Prepare the tf input data. tf.logging.set_verbosity(tf.logging.INFO) input_vector = tf.placeholder(tf.float32, [None, (2048 + 1) * n_frames]) input_label = tf.placeholder(tf.float32, [None, 1]) # Select network model. if model == 'BNLF': ckpt_filename = './methods/annotate_suggest/BubbleNets/BNLF_181030.ckpt-10000000' predict, end_pts = bn_models.BNLF(input_vector, is_training=False, n_frames=n_frames) else: ckpt_filename = './methods/annotate_suggest/BubbleNets/BN0_181029.ckpt-10000000' predict, end_pts = bn_models.BN0(input_vector, is_training=False, n_frames=n_frames) # Initialize network and select frame. init = tf.global_variables_initializer() tic = time.time() with tf.Session() as sess: init_fn = slim.assign_from_checkpoint_fn(ckpt_filename, slim.get_variables_to_restore()) init_fn(sess) # Go through each video in list. for j, vid_name in enumerate(video_list): # Check if sort selection has aleady been made. select_dir = os.path.join(raw_data_dir,vid_name,'frame_selection') if not os.path.isdir(select_dir): os.makedirs(select_dir) text_out = os.path.join(select_dir, '%s.txt' % model) if os.path.isfile(text_out): print('%s already has %s frame selection!' %(vid_name,model)) continue print ('\nRunning BubbleNets %s for video %i %s' %(model,j,vid_name)) # Load ResNet vectors for network input. vector_file = os.path.join(raw_data_dir, vid_name, 'ResNet_preprocess.pk') input_data = bn_input.BN_Input(vector_file, n_ref=n_ref) num_frames = input_data.n_frames rank_bn = range(0,num_frames) # BubbleNets Deep Sort. bubble_step = 1 while bubble_step < num_frames * n_sorts: a = deepcopy(rank_bn[0]) for i in range(1,num_frames): b = deepcopy(rank_bn[i]) batch_vector = input_data.video_batch_n_ref_no_label(a,b,batch=n_batch) frame_select = sess.run(predict, feed_dict={input_vector: batch_vector}) # If frame b is preferred, use frame b for next comparison. if np.mean(frame_select[0]) < 0: rank_bn[i-1] = a rank_bn[i] = b a = deepcopy(b) else: rank_bn[i-1] = b rank_bn[i] = a bubble_step += 1 # Write out frame selection to text file. select_idx = rank_bn[-1] img_file = os.path.basename(sorted(glob.glob(os.path.join( raw_data_dir,vid_name,'src','*')))[select_idx]) statements = [model,'\n',str(select_idx),'\n',img_file,'\n'] bn_utils.print_statements(text_out, statements) sess.close() tf.reset_default_graph() toc = time.time() print('finished selecting all %s frames on list!' % model) print('Runtime is ' + str(toc-tic))