def train(create_tensor_dict_fn, create_model_fn, train_config, master, task, num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name, is_chief, train_dir): """Training function for detection models. Args: create_tensor_dict_fn: a function to create a tensor input dictionary. create_model_fn: a function that creates a DetectionModel and generates losses. train_config: a train_pb2.TrainConfig protobuf. master: BNS name of the TensorFlow master to use. task: The task id of this training instance. num_clones: The number of clones to run per machine. worker_replicas: The number of work replicas to train with. clone_on_cpu: True if clones should be forced to run on CPU. ps_tasks: Number of parameter server tasks. worker_job_name: Name of the worker job. is_chief: Whether this replica is the chief replica. train_dir: Directory to write checkpoints and training summaries to. """ detection_model = create_model_fn() data_augmentation_options = [ preprocessor_builder.build(step) for step in train_config.data_augmentation_options ] with tf.Graph().as_default(): # Build a configuration specifying multi-GPU and multi-replicas. deploy_config = model_deploy.DeploymentConfig( num_clones=num_clones, clone_on_cpu=clone_on_cpu, replica_id=task, num_replicas=worker_replicas, num_ps_tasks=ps_tasks, worker_job_name=worker_job_name) # Place the global step on the device storing the variables. with tf.device(deploy_config.variables_device()): global_step = tf.train.create_global_step() with tf.device(deploy_config.inputs_device()): input_queue = create_input_queue( train_config.batch_size // num_clones, create_tensor_dict_fn, train_config.batch_queue_capacity, train_config.num_batch_queue_threads, train_config.prefetch_queue_capacity, data_augmentation_options) # Gather initial summaries. # TODO(rathodv): See if summaries can be added/extracted from global tf # collections so that they don't have to be passed around. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) global_summaries = set([]) model_fn = functools.partial(_create_losses, create_model_fn=create_model_fn, train_config=train_config) clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue]) first_clone_scope = clones[0].scope # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by model_fn. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) with tf.device(deploy_config.optimizer_device()): training_optimizer = optimizer_builder.build( train_config.optimizer, global_summaries) sync_optimizer = None if train_config.sync_replicas: training_optimizer = tf.SyncReplicasOptimizer( training_optimizer, replicas_to_aggregate=train_config.replicas_to_aggregate, total_num_replicas=train_config.worker_replicas) sync_optimizer = training_optimizer # Create ops required to initialize the model from a given checkpoint. init_fn = None if train_config.fine_tune_checkpoint: var_map = detection_model.restore_map( from_detection_checkpoint=train_config. from_detection_checkpoint) available_var_map = ( variables_helper.get_variables_available_in_checkpoint( var_map, train_config.fine_tune_checkpoint)) init_saver = tf.train.Saver(available_var_map) def initializer_fn(sess): init_saver.restore(sess, train_config.fine_tune_checkpoint) init_fn = initializer_fn with tf.device(deploy_config.optimizer_device()): total_loss, grads_and_vars = model_deploy.optimize_clones( clones, training_optimizer, regularization_losses=None) total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.') # Optionally multiply bias gradients by train_config.bias_grad_multiplier. if train_config.bias_grad_multiplier: biases_regex_list = ['.*/biases'] grads_and_vars = variables_helper.multiply_gradients_matching_regex( grads_and_vars, biases_regex_list, multiplier=train_config.bias_grad_multiplier) # Optionally freeze some layers by setting their gradients to be zero. if train_config.freeze_variables: grads_and_vars = variables_helper.freeze_gradients_matching_regex( grads_and_vars, train_config.freeze_variables) # Optionally clip gradients if train_config.gradient_clipping_by_norm > 0: with tf.name_scope('clip_grads'): grads_and_vars = slim.learning.clip_gradient_norms( grads_and_vars, train_config.gradient_clipping_by_norm) # Create gradient updates. grad_updates = training_optimizer.apply_gradients( grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add summaries. for model_var in slim.get_model_variables(): global_summaries.add( tf.summary.histogram(model_var.op.name, model_var)) for loss_tensor in tf.losses.get_losses(): global_summaries.add( tf.summary.scalar(loss_tensor.op.name, loss_tensor)) global_summaries.add( tf.summary.scalar('TotalLoss', tf.losses.get_total_loss())) # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) summaries |= global_summaries # Merge all summaries together. summary_op = tf.summary.merge(list(summaries), name='summary_op') # Soft placement allows placing on CPU ops without GPU implementation. session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) session_config.gpu_options.allow_growth = True # Save checkpoints regularly. keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours saver = tf.train.Saver( keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours) slim.learning.train( train_tensor, logdir=train_dir, master=master, is_chief=is_chief, session_config=session_config, startup_delay_steps=train_config.startup_delay_steps, init_fn=init_fn, summary_op=summary_op, number_of_steps=(train_config.num_steps if train_config.num_steps else None), save_summaries_secs=120, sync_optimizer=sync_optimizer, saver=saver)
def general_train(make_loss, hparams, make_hooks=None): """Trains a general model with a loss. Args: make_loss: Function which creates loss (and possibly registers accuracy summaries and other features). hparams: Hyperparameters (see default_hparams() for details). make_hooks: Optional, function which creates additional hooks for training. Returns: Final loss. Raises: ValueError: If flags are missing or invalid. """ train_dir = mode_dir('train') if not tf.gfile.Exists(train_dir): tf.gfile.MakeDirs(train_dir) if hparams.seed: tf.set_random_seed(hparams.seed) # Configure keras keras.backend.set_learning_phase(1) keras.backend.manual_variable_initialization(True) with tf.device( tf.train.replica_device_setter(FLAGS.ps_tasks, merge_devices=True)): # Set the caching device to prevent hangs during distributed training vs = tf.get_variable_scope() if vs.caching_device is None: vs.set_caching_device(lambda op: op.device) # Grab loss and global step total_loss = make_loss() global_step = slim.get_or_create_global_step() # Set up Polyak averaging if desired if hparams.use_averages: moving_average_variables = tf.trainable_variables() moving_average_variables.extend(slim.losses.get_losses()) moving_average_variables.append(total_loss) variable_averages = tf.train.ExponentialMovingAverage( hparams.moving_average_decay, global_step) # For sync_replicas, averaging happens in the chief queue runner if not hparams.sync_replicas: tf.add_to_collection( tf.GraphKeys.UPDATE_OPS, variable_averages.apply(moving_average_variables)) else: variable_averages = None moving_average_variables = None # Decay learning rate exponentially learning_rate = tf.train.exponential_decay( hparams.learning_rate, global_step, hparams.decay_steps, hparams.learning_rate_decay_factor, staircase=True) tf.contrib.deprecated.scalar_summary('learning rate', learning_rate) # Create optimizer if hparams.optimizer == 'adam': optimizer = tf.train.AdamOptimizer(learning_rate, epsilon=1e-3) elif hparams.optimizer == 'rmsprop': optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate, decay=0.9, momentum=0.9, epsilon=1e-5) else: raise ValueError('Unknown optimizer %s' % hparams.optimizer) is_chief = FLAGS.task == 0 chief_only_hooks = [] hooks = [ tf.train.LoggingTensorHook( { 'global_step': global_step, 'total_loss': total_loss }, every_n_iter=FLAGS.log_every_n_iter), tf.train.NanTensorHook(total_loss), tf.train.StopAtStepHook(hparams.max_steps), ] if make_hooks is not None: hooks.extend(make_hooks()) # If desired, optimize synchronously if hparams.sync_replicas: optimizer = tf.SyncReplicasOptimizer( optimizer=optimizer, replicas_to_aggregate=FLAGS.worker_replicas - hparams.backup_replicas, variable_averages=variable_averages, variables_to_average=moving_average_variables, replica_id=FLAGS.task, total_num_replicas=FLAGS.worker_replicas) sync_replicas_hook = optimizer.make_session_run_hook(is_chief) hooks.append(sync_replicas_hook) # Train train_tensor = slim.learning.create_train_op( total_loss, optimizer, clip_gradient_norm=hparams.gradient_clipping_norm) saver = tf.train.Saver(keep_checkpoint_every_n_hours=2) scaffold = tf.train.Scaffold(saver=saver) if FLAGS.save_summaries_secs > 0: save_summaries_secs = FLAGS.save_summaries_secs save_summaries_steps = None else: save_summaries_steps = FLAGS.save_summaries_steps save_summaries_secs = None with tf.train.MonitoredTrainingSession( master=FLAGS.master, is_chief=is_chief, hooks=hooks, chief_only_hooks=chief_only_hooks, checkpoint_dir=train_dir, scaffold=scaffold, save_checkpoint_secs=FLAGS.save_checkpoint_secs, save_summaries_secs=save_summaries_secs, save_summaries_steps=save_summaries_steps) as mon_sess: while not mon_sess.should_stop(): mon_sess.run(train_tensor)
def resnet_model_fn(features, labels, mode, params): """Returns the model function.""" global_step = tf.train.get_global_step() feature = features['feature'] labels = labels['label'] one_hot_labels = model_utils.get_label(labels, params, bird_num_classes, batch_size=params['batch_size']) def get_logits(): """Return the logits.""" end_points, aux_logits = None, None if FLAGS.model_type == 'resnet': avg_pool = model.resnet_v1_model(feature, labels, mode, params) else: assert False name = 'final_dense_dst' with tf.variable_scope('target_CLS'): logits = tf.layers.dense( inputs=avg_pool, units=bird_num_classes, kernel_initializer=tf.random_normal_initializer( stddev=.01), name=name) if end_points is not None: aux_pool = end_points['AuxLogits_Pool'] aux_logits = tf.layers.dense( inputs=aux_pool, units=bird_num_classes, kernel_initializer=tf.random_normal_initializer( stddev=.001), name='Aux{}'.format(name)) return logits, aux_logits, end_points logits, _, _ = get_logits() logits = tf.cast(logits, tf.float32) if FLAGS.model_type == 'resnet': dst_loss = tf.losses.softmax_cross_entropy( logits=logits, weights=1., onehot_labels=one_hot_labels, label_smoothing=params['label_smoothing']) dst_l2_loss = FLAGS.weight_decay * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) loss = dst_loss + dst_l2_loss train_op = None if mode == tf.estimator.ModeKeys.TRAIN: cur_finetune_step = tf.train.get_global_step() update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): if FLAGS.model_type == 'resnet': finetune_learning_rate = rampcosine() else: finetune_learning_rate = rampcosine() if FLAGS.optimizer == 'momentum': optimizer = tf.train.MomentumOptimizer( learning_rate=finetune_learning_rate, momentum=params['momentum'], use_nesterov=True) elif FLAGS.optimizer == 'RMS': optimizer = tf.train.RMSPropOptimizer( finetune_learning_rate, RMSPROP_DECAY, momentum=RMSPROP_MOMENTUM, epsilon=RMSPROP_EPSILON) elif FLAGS.optimizer == 'adam': optimizer = tf.train.AdamOptimizer(finetune_learning_rate) optimizer = tf.SyncReplicasOptimizer( optimizer, replicas_to_aggregate=FLAGS.sync_replicas, total_num_replicas=run_config.num_worker_replicas) train_op = tf.contrib.training.create_train_op(loss, optimizer) with tf.variable_scope('finetune'): train_op = optimizer.minimize(loss, cur_finetune_step) if FLAGS.moving_average: ema = tf.train.ExponentialMovingAverage( decay=MOVING_AVERAGE_DECAY, num_updates=global_step) variables_to_average = (tf.trainable_variables() + tf.moving_average_variables()) with tf.control_dependencies([train_op]): with tf.name_scope('moving_average'): train_op = ema.apply(variables_to_average) else: train_op = None batch_size = params['batch_size'] # pylint: disable=unused-variable eval_metrics = None if mode == tf.estimator.ModeKeys.EVAL: eval_metrics = model_utils.metric_fn(labels, logits) if mode == tf.estimator.ModeKeys.TRAIN: with tf.control_dependencies([train_op]): tf.summary.scalar('classifier/finetune_loss', loss) tf.summary.scalar('classifier/finetune_lr', finetune_learning_rate) else: train_op = None return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metrics, )
def main(unused_argv): # Create training directory if it doesn't already exist. if not tf.gfile.IsDirectory(FLAGS.train_dir): tf.logging.info("Creating training directory: %s", FLAGS.train_dir) tf.gfile.MakeDirs(FLAGS.train_dir) # Set up the model config. model_config = configuration.model_config( input_file_pattern=FLAGS.input_file_pattern) if FLAGS.model_config_overrides: model_config.parse_json(FLAGS.model_config_overrides) _log_config(model_config, "model_config") # Set up the training config. training_config = configuration.training_config() if FLAGS.training_config_overrides: training_config.parse_json(FLAGS.training_config_overrides) _log_config(training_config, "training_config") tf.logging.info("Building training graph.") g = tf.Graph() with g.as_default(), g.device( tf.train.replica_device_setter(FLAGS.ps_tasks)): # Build the model. model = skip_thoughts_model.SkipThoughtsModel(model_config, mode="train") model.build() _log_variable_device_placement() hooks = [ # Stop training if loss is NaN. tf.train.NanTensorHook(model.total_loss), # Log every training step. tf.train.LoggingTensorHook( { "global_step": model.global_step, "total_loss": model.total_loss }, every_n_iter=1) ] # Set up the learning rate and optimizer. learning_rate = training.create_learning_rate(training_config, model.global_step) optimizer = training.create_optimizer(training_config, learning_rate) # Set up distributed sync or async training. is_chief = (FLAGS.task == 0) if FLAGS.sync_replicas: optimizer = tf.SyncReplicasOptimizer( optimizer, replicas_to_aggregate=FLAGS.replicas_to_aggregate, total_num_replicas=FLAGS.total_num_replicas) hooks.append(optimizer.make_session_run_hook(is_chief)) else: # Startup delay for non-chief asynchronous workers. if not is_chief and training_config.startup_delay_steps: hooks.append( tf.train.GlobalStepWaiterHook( training_config.startup_delay_steps)) train_tensor = training.create_train_op(training_config, optimizer, model) keep_every_n = training_config.keep_checkpoint_every_n_hours saver = tf.train.Saver( max_to_keep=training_config.max_checkpoints_to_keep, keep_checkpoint_every_n_hours=keep_every_n, save_relative_paths=True) scaffold = tf.train.Scaffold(saver=saver) # Possibly set a step limit. if training_config.number_of_steps: hooks.append( tf.train.StopAtStepHook( last_step=training_config.number_of_steps)) # Create the TensorFlow session. with tf.train.MonitoredTrainingSession( master=FLAGS.master, is_chief=is_chief, checkpoint_dir=FLAGS.train_dir, scaffold=scaffold, hooks=hooks, save_checkpoint_secs=training_config.save_model_secs, save_summaries_steps=None, save_summaries_secs=training_config.save_summaries_secs ) as sess: # Run training. while not sess.should_stop(): sess.run(train_tensor)
def train(create_tensor_dict_fn, create_model_fn, train_config, master, task, num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name, is_chief, train_dir): """Training function for detection models. Args: create_tensor_dict_fn: a function to create a tensor input dictionary. create_model_fn: a function that creates a DetectionModel and generates losses. train_config: a train_pb2.TrainConfig protobuf. master: BNS name of the TensorFlow master to use. task: The task id of this training instance. num_clones: The number of clones to run per machine. worker_replicas: The number of work replicas to train with. clone_on_cpu: True if clones should be forced to run on CPU. ps_tasks: Number of parameter server tasks. worker_job_name: Name of the worker job. is_chief: Whether this replica is the chief replica. train_dir: Directory to write checkpoints and training summaries to. """ detection_model = create_model_fn() #Object for create the detection model data_augmentation_options = [ #for ssd it's ssd random crop preprocessor_builder.build( step) #random_horizontal_flip in the faster rcnn config file for step in train_config.data_augmentation_options ] with tf.Graph().as_default( ): #we need a default graph in order to create the model # Build a configuration specifying multi-GPU and multi-replicas. deploy_config = model_deploy.DeploymentConfig( num_clones=num_clones, clone_on_cpu=clone_on_cpu, replica_id=task, num_replicas=worker_replicas, num_ps_tasks=ps_tasks, worker_job_name=worker_job_name) # Place the global step on the device storing the variables. #global step is needed to keep the records with tf.device(deploy_config.variables_device() ): #suitable device for operation +++On CPU I think global_step = slim.create_global_step( ) #created the global step tensor #The following will create an input Que images ,boxes m targets with tf.device(deploy_config.inputs_device() ): #Device to use to build the inputs ++++on CPU ?? input_queue = _create_input_queue( train_config.batch_size // num_clones, #here batch size/number_clones create_tensor_dict_fn, train_config.batch_queue_capacity, train_config.num_batch_queue_threads, train_config.prefetch_queue_capacity, data_augmentation_options) #random_horizontal_flip # Gather initial summaries. summaries = set(tf.get_collection( tf.GraphKeys.SUMMARIES)) #vreate the summeries global_summaries = set([]) #Creating the loss model_fn = functools.partial( _create_losses, #This will create the losses , It need a object of our model as an argivement create_model_fn=create_model_fn) clones = model_deploy.create_clones( deploy_config, model_fn, [input_queue ]) #creating the clones with respect to t he input model fn first_clone_scope = clones[0].scope # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by model_fn. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) with tf.device(deploy_config.optimizer_device()): #This is important training_optimizer = optimizer_builder.build( train_config.optimizer, #optimization global_summaries ) #will select rms_prop , Adam Here derectly we get the optimizer sync_optimizer = None if train_config.sync_replicas: training_optimizer = tf.SyncReplicasOptimizer( #This is more of synchronising the optimizer because there are repicas doing optimizing training_optimizer, replicas_to_aggregate=train_config.replicas_to_aggregate, total_num_replicas=train_config.worker_replicas) sync_optimizer = training_optimizer # Create ops required to initialize the model from a given checkpoint. init_fn = None if train_config.fine_tune_checkpoint: #This is the checkpoint path file init_fn = detection_model.restore_fn( #Re storing the weights from the feature extractors train_config.fine_tune_checkpoint, from_detection_checkpoint=train_config. from_detection_checkpoint ) #This is more of the initializer which is re-stored from check points with tf.device(deploy_config.optimizer_device()): total_loss, grads_and_vars = model_deploy.optimize_clones( #This gives the total loss and also the grad and var pairs (Tuple) clones, training_optimizer, regularization_losses=None) total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.') # Optionally multiply bias gradients by train_config.bias_grad_multiplier. if train_config.bias_grad_multiplier: #We have not initialized a bias gradient multiplier biases_regex_list = ['.*/biases'] grads_and_vars = variables_helper.multiply_gradients_matching_regex( grads_and_vars, biases_regex_list, multiplier=train_config.bias_grad_multiplier) # Optionally freeze some layers by setting their gradients to be zero. if train_config.freeze_variables: #Here we are not freezing any may be it's good to freeze the #This will be usefult to go through the variables print("Priting the grad_and_vars to check the tuples ") print(grad_and_vars) grads_and_vars = variables_helper.freeze_gradients_matching_regex( #input to this also grads and vars which means grads_and_vars, train_config.freeze_variables) #This function will output #We are getiing gradients and of their varaibles exept the froxen list # Optionally clip gradients if train_config.gradient_clipping_by_norm > 0: with tf.name_scope('clip_grads'): grads_and_vars = slim.learning.clip_gradient_norms( grads_and_vars, train_config.gradient_clipping_by_norm) # Create gradient updates. grad_updates = training_optimizer.apply_gradients( grads_and_vars, #updating the gradinets list global_step=global_step) update_ops.append(grad_updates) #Here the new updated variables update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add summaries. for model_var in slim.get_model_variables(): global_summaries.add( tf.summary.histogram(model_var.op.name, model_var)) for loss_tensor in tf.losses.get_losses(): global_summaries.add( tf.summary.scalar(loss_tensor.op.name, loss_tensor)) global_summaries.add( tf.summary.scalar('TotalLoss', tf.losses.get_total_loss())) # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) summaries |= global_summaries # Merge all summaries together. summary_op = tf.summary.merge(list(summaries), name='summary_op') # Soft placement allows placing on CPU ops without GPU implementation. session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) # Save checkpoints regularly. keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours saver = tf.train.Saver( #saving the checkpoints keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours) slim.learning.train( #Training the network using a compact function train_tensor, logdir=train_dir, master=master, is_chief=is_chief, session_config=session_config, startup_delay_steps=train_config.startup_delay_steps, init_fn=init_fn, summary_op=summary_op, number_of_steps=(train_config.num_steps if train_config.num_steps else None), save_summaries_secs=120, sync_optimizer=sync_optimizer, saver=saver)
def train(create_tensor_dict_fn, create_model_fn, train_config, master, task, num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name, is_chief, train_dir, num_examples, total_configs, model_config, is_first_training=True): """Training function for detection models. Args: create_tensor_dict_fn: a function to create a tensor input dictionary. create_model_fn: a function that creates a DetectionModel and generates losses. train_config: a train_pb2.TrainConfig protobuf. master: BNS name of the TensorFlow master to use. task: The task id of this training instance. num_clones: The number of clones to run per machine. worker_replicas: The number of work replicas to train with. clone_on_cpu: True if clones should be forced to run on CPU. ps_tasks: Number of parameter server tasks. worker_job_name: Name of the worker job. is_chief: Whether this replica is the chief replica. train_dir: Directory to write checkpoints and training summaries to. num_examples: The number of examples in dataset for training. total_configs: config list """ detection_model = create_model_fn() data_augmentation_options = [ preprocessor_builder.build(step) for step in train_config.data_augmentation_options ] with tf.Graph().as_default(): # Build a configuration specifying multi-GPU and multi-replicas. deploy_config = model_deploy.DeploymentConfig( num_clones=num_clones, clone_on_cpu=clone_on_cpu, replica_id=task, num_replicas=worker_replicas, num_ps_tasks=ps_tasks, worker_job_name=worker_job_name) # Place the global step on the device storing the variables. with tf.device(deploy_config.variables_device()): if is_first_training: global_step = slim.create_global_step() else: prev_global_step = int( train_config.fine_tune_checkpoint.split('-')[-1]) global_step = variable_scope.get_variable( ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64, initializer=tf.constant(prev_global_step, dtype=dtypes.int64), trainable=False, collections=[ ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP ]) with tf.device(deploy_config.inputs_device()): input_queue = _create_input_queue( train_config.batch_size // num_clones, create_tensor_dict_fn, train_config.batch_queue_capacity, train_config.num_batch_queue_threads, train_config.prefetch_queue_capacity, data_augmentation_options, ignore_options=train_config.ignore_options, mtl_window=model_config.mtl.window, mtl_edgemask=model_config.mtl.edgemask) # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) global_summaries = set([]) kwargs = {} kwargs['mtl'] = model_config.mtl update_schedule = None model_fn = functools.partial( _create_losses, create_model_fn=create_model_fn, show_image_summary=train_config.show_image_summary, update_schedule=update_schedule, **kwargs) clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue]) first_clone_scope = clones[0].scope with tf.device(deploy_config.optimizer_device()): training_optimizer = optimizer_builder.build( train_config.optimizer, global_summaries) sync_optimizer = None if train_config.sync_replicas: # TODO: support syncrhonous update for manual loss update training_optimizer = tf.SyncReplicasOptimizer( training_optimizer, replicas_to_aggregate=train_config.replicas_to_aggregate, total_num_replicas=train_config.worker_replicas) sync_optimizer = training_optimizer # Create ops required to initialize the model from a given checkpoint. init_fn = None if train_config.fine_tune_checkpoint: var_map = detection_model.restore_map( from_detection_checkpoint=train_config. from_detection_checkpoint, restore_box_predictor=train_config.restore_box_predictor, restore_window=train_config.restore_window, restore_edgemask=train_config.restore_edgemask, restore_closeness=train_config.restore_closeness, restore_mtl_refine=train_config.restore_mtl_refine, ) available_var_map = ( variables_helper.get_variables_available_in_checkpoint( var_map, train_config.fine_tune_checkpoint)) init_saver = tf.train.Saver(available_var_map) mtl = model_config.mtl mtl_init_saver_list = [] def _get_mtl_init_saver(scope_name): _var_map = detection_model._feature_extractor.mtl_restore_from_classification_checkpoint_fn( scope_name) if train_config.from_detection_checkpoint: _var_map_new = dict() for name, val in _var_map.iteritems(): _var_map_new[detection_model. second_stage_feature_extractor_scope + '/' + name] = val _var_map = _var_map_new _available_var_map = ( variables_helper.get_variables_available_in_checkpoint( _var_map, train_config.fine_tune_checkpoint)) if _available_var_map: return tf.train.Saver(_available_var_map) else: return None # if mtl.share_second_stage_init and mtl.shared_feature == 'proposal_feature_maps': if mtl.share_second_stage_init and train_config.from_detection_checkpoint == False: if mtl.window: mtl_init_saver_list.append( _get_mtl_init_saver( detection_model.window_box_predictor_scope)) if mtl.closeness: mtl_init_saver_list.append( _get_mtl_init_saver( detection_model.closeness_box_predictor_scope)) if mtl.edgemask: mtl_init_saver_list.append( _get_mtl_init_saver( detection_model.edgemask_predictor_scope)) def initializer_fn(sess): init_saver.restore(sess, train_config.fine_tune_checkpoint) for mtl_init_saver in mtl_init_saver_list: if not mtl_init_saver == None: mtl_init_saver.restore( sess, train_config.fine_tune_checkpoint) init_fn = initializer_fn def _get_trainable_variables(except_scopes=None): trainable_variables = tf.trainable_variables() if except_scopes is None: return trainable_variables for var in tf.trainable_variables(): if any([scope in var.name for scope in except_scopes]): trainable_variables.remove(var) return trainable_variables def _get_update_ops(except_scopes=None): # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by model_fn. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) if except_scopes is None: return update_ops for var in tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope): if any([scope in var.name for scope in except_scopes]): update_ops.remove(var) return update_ops with tf.device(deploy_config.optimizer_device()): def _single_update(): kwargs = {} _training_optimizer = training_optimizer kwargs['var_list'] = None update_ops = _get_update_ops() total_loss, grads_and_vars = model_deploy.optimize_clones( clones, _training_optimizer, regularization_losses=None, **kwargs) # Optionaly multiply gradients by train_config.{grad_multiplier, # divide_grad_by_batch}. if train_config.grad_multiplier or train_config.divide_grad_by_batch: base_multiplier = train_config.grad_multiplier \ if train_config.grad_multiplier else 1.0 batch_divider = float(train_config.batch_size) \ if train_config.divide_grad_by_batch else 1.0 total_multiplier = base_multiplier / batch_divider grads_and_vars = variables_helper.multiply_gradients_by_scalar_multiplier( grads_and_vars, multiplier=total_multiplier) # Optionally multiply bias gradients by train_config.bias_grad_multiplier. if train_config.bias_grad_multiplier: biases_regex_list = ['.*/biases'] grads_and_vars = variables_helper.multiply_gradients_matching_regex( grads_and_vars, biases_regex_list, multiplier=train_config.bias_grad_multiplier) # Optionally freeze some layers by setting their gradients to be zero. if train_config.freeze_variables: grads_and_vars = variables_helper.freeze_gradients_matching_regex( grads_and_vars, train_config.freeze_variables) # Optionally clip gradients if train_config.gradient_clipping_by_norm > 0: with tf.name_scope('clip_grads'): grads_and_vars = slim.learning.clip_gradient_norms( grads_and_vars, train_config.gradient_clipping_by_norm) # Create gradient updates. grad_updates = _training_optimizer.apply_gradients( grads_and_vars, global_step=global_step) # update_ops.append(grad_updates) total_update_ops = update_ops + [grad_updates] update_op = tf.group(*total_update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name=('train_op')) return train_tensor train_tensor = _single_update() # Add summaries. def _get_total_loss_with_collection(collection, add_regularization_losses=True, name="total_loss"): losses = tf.losses.get_losses(loss_collection=collection) if add_regularization_losses: losses += tf.losses.get_regularization_losses() return math_ops.add_n(losses, name=name) for model_var in slim.get_model_variables(): global_summaries.add( tf.summary.histogram(model_var.op.name, model_var)) for loss_tensor in tf.losses.get_losses(): global_summaries.add( tf.summary.scalar(loss_tensor.op.name, loss_tensor)) global_summaries.add( tf.summary.scalar('TotalLoss', tf.losses.get_total_loss())) # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) summaries |= global_summaries # Merge all summaries together. summary_op = tf.summary.merge(list(summaries), name='summary_op') # not contained in global_summaries config_summary_list = select_config_summary_list(total_configs, as_matrix=False) # Soft placement allows placing on CPU ops without GPU implementation. session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) # Save checkpoints regularly. keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours saver = tf.train.Saver( keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours) custom_learning.train( train_tensor, logdir=train_dir, master=master, is_chief=is_chief, global_step=(None if is_first_training else global_step), session_config=session_config, startup_delay_steps=train_config.startup_delay_steps, init_fn=init_fn, summary_op=summary_op, number_of_steps=(train_config.num_steps if train_config.num_steps else None), log_every_n_steps=(train_config.log_every_n_steps if train_config.log_every_n_steps else None), save_summaries_secs=train_config.save_summaries_secs, save_interval_secs=train_config.save_interval_secs, sync_optimizer=sync_optimizer, saver=saver, batch_size=train_config.batch_size, num_examples=num_examples, config_summary_list=config_summary_list)