def train_segmentation_model(create_model_fn, create_input_fn, train_config, master, task, is_chief, startup_delay_steps, train_dir, num_clones, num_worker_replicas, num_ps_tasks, clone_on_cpu, replica_id, num_replicas, max_checkpoints_to_keep, save_interval_secs, image_summaries, log_memory=False, gradient_checkpoints=None, sync_bn_accross_gpu=False): """Create an instance of the FastSegmentationModel""" _, segmentation_model = create_model_fn() deploy_config = model_deploy.DeploymentConfig( num_clones=num_clones, clone_on_cpu=clone_on_cpu, replica_id=task, num_replicas=num_worker_replicas, num_ps_tasks=num_ps_tasks) startup_delay_steps = task * startup_delay_steps per_clone_batch_size = train_config.batch_size // num_clones preprocess_fn = None if train_config.preprocessor_step: preprocess_fn = functools.partial( preprocessor_builder.build, preprocessor_config_list=train_config.preprocessor_step) with tf.Graph().as_default(): # CPU of common ps server with tf.device(deploy_config.variables_device()): global_step = tf.train.get_or_create_global_step() with tf.device(deploy_config.inputs_device()): # CPU of each worker input_queue = create_training_input( create_input_fn, preprocess_fn, per_clone_batch_size, batch_queue_capacity=train_config.batch_queue_capacity, batch_queue_threads=train_config.num_batch_queue_threads, prefetch_queue_capacity=train_config.prefetch_queue_capacity) # Create the global step on the device storing the variables. with tf.device(deploy_config.variables_device()): # Note: it is assumed that any loss created by `model_fn` # is collected at the tf.GraphKeys.LOSSES collection. model_fn = functools.partial(create_training_model_losses, create_model_fn=create_model_fn, train_config=train_config, train_dir=train_dir, gradient_checkpoints=gradient_checkpoints) clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue]) first_clone_scope = deploy_config.clone_scope(0) if sync_bn_accross_gpu: # Attempt to sync BN updates across all GPU's in a tower. # Caution since this is very slow. Might not be needed update_ops = [] for idx in range(num_clones): nth_clone_sope = deploy_config.clone_scope(0) update_ops.extend(tf.get_collection( tf.GraphKeys.UPDATE_OPS, nth_clone_sope)) else: # Gather updates from first GPU only update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) # Init variable to collect summeries summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) # Add summaries for losses. for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar('Losses/%s' % loss.op.name, loss)) with tf.device(deploy_config.optimizer_device()): # CPU of each worker (training_optimizer, optimizer_summary_vars) = optimizer_builder.build( train_config.optimizer) for var in optimizer_summary_vars: summaries.add( tf.summary.scalar(var.op.name, var, family='LearningRate')) # Add summaries for model variables. for model_var in slim.get_model_variables(): summaries.add(tf.summary.histogram(model_var.op.name, model_var)) # Fine tune from classification or segmentation checkpoints trainable_vars = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES) if train_config.fine_tune_checkpoint: if not train_config.fine_tune_checkpoint_type: raise ValueError('Must specify `fine_tune_checkpoint_type`.') tf.logging.info('Initializing %s model from checkpoint %s', train_config.fine_tune_checkpoint_type, train_config.fine_tune_checkpoint) variables_to_restore = segmentation_model.restore_map( fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type) init_fn = slim.assign_from_checkpoint_fn( train_config.fine_tune_checkpoint, variables_to_restore, ignore_missing_vars=True) if train_config.freeze_fine_tune_backbone: tf.logging.info('Freezing %s scope from checkpoint.') non_frozen_vars = [] for var in trainable_vars: if not var.op.name.startswith( segmentation_model.shared_feature_extractor_scope): non_frozen_vars.append(var) tf.logging.info('Training variable: %s', var.op.name) trainable_vars = non_frozen_vars else: tf.logging.info('Not initializing the model from a checkpoint. ' 'Initializing from scratch!') # TODO(@oandrien): we might want to add gradient multiplier here # for the last layer if we have trouble with training # CPU of common ps server with tf.device(deploy_config.optimizer_device()): reg_losses = (None if train_config.add_regularization_loss else []) total_loss, grads_and_vars = model_deploy.optimize_clones( clones, training_optimizer, regularization_losses=reg_losses, var_list=trainable_vars) total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.') summaries.add( tf.summary.scalar('Losses/TotalLoss', total_loss)) grad_updates = training_optimizer.apply_gradients(grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops, name='update_barrier') with tf.control_dependencies([update_op]): train_op = tf.identity(total_loss, name='train_op') # TODO: this ideally should not be hardcoded like this. # should have a way to access the prediction and GT tensor if image_summaries: graph = tf.get_default_graph() pixel_scaling = max(1, 255 // 19) summ_first_clone_scope = (first_clone_scope + '/' if first_clone_scope else '') main_labels = graph.get_tensor_by_name( '%sSegmentationLoss/Reshape:0'% summ_first_clone_scope) main_preds = graph.get_tensor_by_name( '%sSegmentationLoss/Reshape_1:0'% summ_first_clone_scope) main_preds = tf.cast(main_preds * pixel_scaling, tf.uint8) summaries.add( tf.summary.image('VerifyTrainImages/Predictions', main_preds)) main_labels = tf.cast(main_labels * pixel_scaling, tf.uint8) summaries.add( tf.summary.image('VerifyTrainImages/Groundtruths', main_labels)) # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() # or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries)) session_config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=True) # Save checkpoints regularly. saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep) # HACK to see memory usage. # TODO: Clean up, pretty messy. def train_step_mem(sess, train_op, global_step, train_step_kwargs): start_time = time.time() run_metadata = tf.RunMetadata() options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) total_loss, np_global_step = sess.run([train_op, global_step], options=options, run_metadata=run_metadata) time_elapsed = time.time() - start_time if 'should_log' in train_step_kwargs: if sess.run(train_step_kwargs['should_log']): tf.logging.info( 'global step %d: loss = %.4f (%.3f sec/step)', np_global_step, total_loss, time_elapsed) if log_memory: mem_use = mem_util.peak_memory(run_metadata)['/gpu:0']/1e6 tf.logging.info('Memory used: %.2f MB',(mem_use)) if 'should_stop' in train_step_kwargs: should_stop = sess.run(train_step_kwargs['should_stop']) else: should_stop = False return total_loss, should_stop # Main training loop slim.learning.train( train_op, train_step_fn=train_step_mem, logdir=train_dir, master=master, is_chief=is_chief, session_config=session_config, number_of_steps=train_config.num_steps, startup_delay_steps=startup_delay_steps, init_fn=init_fn, summary_op=summary_op, save_summaries_secs=120, save_interval_secs=save_interval_secs, saver=saver)
def train_segmentation_model(create_model_fn, create_input_fn, train_config, model_config, master, task, is_chief, startup_delay_steps, train_dir, num_clones, num_worker_replicas, num_ps_tasks, clone_on_cpu, replica_id, num_replicas, max_checkpoints_to_keep, save_interval_secs, image_summaries, log_memory=False, gradient_checkpoints=None, sync_bn_accross_gpu=False): """Create an instance of the SegmentationModel""" _, segmentation_model = create_model_fn() deploy_config = model_deploy.DeploymentConfig( num_clones=num_clones, clone_on_cpu=clone_on_cpu, replica_id=task, num_replicas=num_worker_replicas, num_ps_tasks=num_ps_tasks) startup_delay_steps = task * startup_delay_steps per_clone_batch_size = train_config.batch_size #// num_clones preprocess_fn = None if train_config.preprocessor_step: preprocess_fn = functools.partial( preprocessor_builder.build, preprocessor_config_list=train_config.preprocessor_step) with tf.Graph().as_default(): # CPU of common ps server with tf.device(deploy_config.variables_device()): global_step = tf.train.get_or_create_global_step() with tf.device(deploy_config.inputs_device()): # CPU of each worker dataset = create_training_input(create_input_fn, preprocess_fn, per_clone_batch_size, num_clones) dataset = dataset.apply(tf.data.experimental.ignore_errors()) data_iterator = dataset.make_one_shot_iterator() # Create the global step on the device storing the variables. with tf.device(deploy_config.variables_device()): # Note: it is assumed that any loss created by `model_fn` # is collected at the tf.GraphKeys.LOSSES collection. model_fn = functools.partial( create_training_model_losses, create_model_fn=create_model_fn, train_config=train_config, train_dir=train_dir, gradient_checkpoints=gradient_checkpoints) clones = model_deploy.create_clones(deploy_config, model_fn, [data_iterator.get_next]) first_clone_scope = deploy_config.clone_scope(0) if sync_bn_accross_gpu: # Attempt to sync BN updates across all GPU's in a tower. # Caution since this is very slow. Might not be needed update_ops = [] for idx in range(num_clones): nth_clone_sope = deploy_config.clone_scope(idx) update_ops.extend( tf.get_collection(tf.GraphKeys.UPDATE_OPS, nth_clone_sope)) else: # Gather updates from first GPU only update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) # Init variable to collect summeries summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) # Add summaries for losses. for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar('Losses/%s' % loss.op.name, loss)) with tf.device(deploy_config.optimizer_device()): # CPU of each worker (training_optimizer, optimizer_summary_vars) = optimizer_builder.build( train_config.optimizer, num_clones) for var in optimizer_summary_vars: summaries.add( tf.summary.scalar(var.op.name, var, family='LearningRate')) # Add summaries for model variables. # for model_var in slim.get_model_variables(): # summaries.add(tf.summary.histogram(model_var.op.name, model_var)) # Fine tune from classification or segmentation checkpoints trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if train_config.fine_tune_checkpoint: if not train_config.fine_tune_checkpoint_type: raise ValueError('Must specify `fine_tune_checkpoint_type`.') tf.logging.info('Initializing %s model from checkpoint %s', train_config.fine_tune_checkpoint_type, train_config.fine_tune_checkpoint) variables_to_restore = segmentation_model.restore_map( train_config.fine_tune_checkpoint, fine_tune_checkpoint_type=train_config. fine_tune_checkpoint_type) writer = tf.summary.FileWriter(train_dir) writer.close() init_fn = slim.assign_from_checkpoint_fn( train_config.fine_tune_checkpoint, variables_to_restore, ignore_missing_vars=True) if train_config.freeze_fine_tune_backbone: tf.logging.info('Freezing %s scope from checkpoint.') non_frozen_vars = [] for var in trainable_vars: if not var.op.name.startswith( segmentation_model.shared_feature_extractor_scope): non_frozen_vars.append(var) tf.logging.info('Training variable: %s', var.op.name) trainable_vars = non_frozen_vars else: init_fn = None tf.logging.info('Not initializing the model from a checkpoint. ' 'Initializing from scratch!') # TODO(@oandrien): we might want to add gradient multiplier here # for the last layer if we have trouble with training # CPU of common ps server with tf.device(deploy_config.optimizer_device()): reg_losses = (None if train_config.add_regularization_loss else []) if model_config.pspnet.train_reduce and reg_losses is None: regularization_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) reg_losses = [ r for r in regularization_losses if "dim_reduce" in r.name ] total_loss, grads_and_vars = model_deploy.optimize_clones( clones, training_optimizer, regularization_losses=reg_losses, var_list=trainable_vars) # total_loss = tf.check_numerics(total_loss, # 'total_loss is inf or nan.') summaries.add(tf.summary.scalar('Losses/TotalLoss', total_loss)) # with tf.variable_scope("grad_clip"): # grads_and_vars = [(tf.clip_by_norm(grad, 1.), var) for grad, var in grads_and_vars] grad_updates = training_optimizer.apply_gradients( grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops, name='update_barrier') with tf.control_dependencies([update_op]): train_op = tf.identity(total_loss, name='train_op') # TODO: this ideally should not be hardcoded like this. # should have a way to access the prediction and GT tensor if image_summaries: graph = tf.get_default_graph() pixel_scaling = max(1, 255 // 19) summ_first_clone_scope = (first_clone_scope + '/' if first_clone_scope else '') input_img = graph.get_tensor_by_name('%sInputs:0' % summ_first_clone_scope) main_labels = graph.get_tensor_by_name( '%sSegmentationLoss/ScaledLabels:0' % summ_first_clone_scope) main_preds = graph.get_tensor_by_name( '%sSegmentationLoss/ScaledPreds:0' % summ_first_clone_scope) summaries.add( tf.summary.image('VerifyTrainImages/Inputs', input_img)) logits = main_preds main_preds = tf.cast( tf.expand_dims(tf.argmax(main_preds, -1), -1) * pixel_scaling, tf.uint8) summaries.add( tf.summary.image('VerifyTrainImages/Predictions', main_preds)) main_labels = tf.cast(main_labels * pixel_scaling, tf.uint8) summaries.add( tf.summary.image('VerifyTrainImages/Groundtruths', main_labels)) # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() # or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries)) session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) session_config.gpu_options.allow_growth = True #load_vars = [v for v in tf.global_variables() if "Dont_Load" not in v.op.name] # Save checkpoints regularly. saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep) # gpu_mems = [] # for i in range(num_clones): # with tf.device("/gpu:"+str(i)): # gpu_mems.append(tf.cast(BytesInUse(), tf.float32)/float(1024*1024)) # HACK to see memory usage. # TODO: Clean up, pretty messy. # import pdb; pdb.set_trace() def train_step_mem(sess, train_op, global_step, train_step_kwargs): start_time = time.time() if log_memory: run_metadata = tf.RunMetadata() options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) else: run_metadata = None options = None total_loss, np_global_step, cur_gvs, dbg = sess.run( [train_op, global_step, grads_and_vars, dist_builder.DEBUG], options=options, run_metadata=run_metadata) time_elapsed = time.time() - start_time # graph = tf.get_default_graph() # main_labels = graph.get_tensor_by_name('SegmentationLoss/ScaledLabels:0') # label_out = sess.run(main_labels) # if len(np.unique(label_out)) != 1: # print(label_out) # import pdb; pdb.set_trace() # print(label_out) if 'should_log' in train_step_kwargs: if sess.run(train_step_kwargs['should_log']): tf.logging.info( 'global step %d: loss = %.4f (%.3f sec/step)', np_global_step, total_loss, time_elapsed) if log_memory: peaks = mem_util.peak_memory(run_metadata) for mem_use in peaks: # mem_use = mem_util.peak_memory(run_metadata)['/gpu:0']/1e6 if "/gpu" in mem_use: tf.logging.info('Memory used (%s): %.2f MB', mem_use, peaks[mem_use] / 1e6) # for m in mem: # tf.logging.info('Memory used: %.2f MB',(m)) if 'should_stop' in train_step_kwargs: should_stop = sess.run(train_step_kwargs['should_stop']) else: should_stop = False return total_loss, should_stop # Main training loop slim.learning.train(train_op, train_step_fn=train_step_mem, logdir=train_dir, master=master, is_chief=is_chief, session_config=session_config, number_of_steps=train_config.num_steps, startup_delay_steps=startup_delay_steps, init_fn=init_fn, summary_op=summary_op, save_summaries_secs=60, save_interval_secs=save_interval_secs, saver=saver)