def train(create_tensor_dict_fn, create_model_fn, train_config, master, task, num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name, is_chief, train_dir): """Training function for detection models. Args: create_tensor_dict_fn: a function to create a tensor input dictionary. create_model_fn: a function that creates a DetectionModel and generates losses. train_config: a train_pb2.TrainConfig protobuf. master: BNS name of the TensorFlow master to use. task: The task id of this training instance. num_clones: The number of clones to run per machine. worker_replicas: The number of work replicas to train with. clone_on_cpu: True if clones should be forced to run on CPU. ps_tasks: Number of parameter server tasks. worker_job_name: Name of the worker job. is_chief: Whether this replica is the chief replica. train_dir: Directory to write checkpoints and training summaries to. """ detection_model = create_model_fn() data_augmentation_options = [ preprocessor_builder.build(step) for step in train_config.data_augmentation_options ] with tf.Graph().as_default(): # Build a configuration specifying multi-GPU and multi-replicas. deploy_config = model_deploy.DeploymentConfig( num_clones=num_clones, clone_on_cpu=clone_on_cpu, replica_id=task, num_replicas=worker_replicas, num_ps_tasks=ps_tasks, worker_job_name=worker_job_name) # Place the global step on the device storing the variables. with tf.device(deploy_config.variables_device()): global_step = slim.create_global_step() with tf.device(deploy_config.inputs_device()): input_queue = create_input_queue( train_config.batch_size // num_clones, create_tensor_dict_fn, train_config.batch_queue_capacity, train_config.num_batch_queue_threads, train_config.prefetch_queue_capacity, data_augmentation_options) # Gather initial summaries. # TODO(rathodv): See if summaries can be added/extracted from global tf # collections so that they don't have to be passed around. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) global_summaries = set([]) model_fn = functools.partial(_create_losses, create_model_fn=create_model_fn, train_config=train_config) clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue]) first_clone_scope = clones[0].scope # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by model_fn. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) with tf.device(deploy_config.optimizer_device()): training_optimizer = optimizer_builder.build( train_config.optimizer, global_summaries) sync_optimizer = None if train_config.sync_replicas: training_optimizer = tf.SyncReplicasOptimizer( training_optimizer, replicas_to_aggregate=train_config.replicas_to_aggregate, total_num_replicas=train_config.worker_replicas) sync_optimizer = training_optimizer # Create ops required to initialize the model from a given checkpoint. init_fn = None if train_config.fine_tune_checkpoint: var_map = detection_model.restore_map( from_detection_checkpoint=train_config. from_detection_checkpoint) available_var_map = ( variables_helper.get_variables_available_in_checkpoint( var_map, train_config.fine_tune_checkpoint)) init_saver = tf.train.Saver(available_var_map) def initializer_fn(sess): init_saver.restore(sess, train_config.fine_tune_checkpoint) init_fn = initializer_fn with tf.device(deploy_config.optimizer_device()): total_loss, grads_and_vars = model_deploy.optimize_clones( clones, training_optimizer, regularization_losses=None) total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.') # Optionally multiply bias gradients by train_config.bias_grad_multiplier. if train_config.bias_grad_multiplier: biases_regex_list = ['.*/biases'] grads_and_vars = variables_helper.multiply_gradients_matching_regex( grads_and_vars, biases_regex_list, multiplier=train_config.bias_grad_multiplier) # Optionally freeze some layers by setting their gradients to be zero. if train_config.freeze_variables: grads_and_vars = variables_helper.freeze_gradients_matching_regex( grads_and_vars, train_config.freeze_variables) # Optionally clip gradients if train_config.gradient_clipping_by_norm > 0: with tf.name_scope('clip_grads'): grads_and_vars = slim.learning.clip_gradient_norms( grads_and_vars, train_config.gradient_clipping_by_norm) # Create gradient updates. grad_updates = training_optimizer.apply_gradients( grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add summaries. for model_var in slim.get_model_variables(): global_summaries.add( tf.summary.histogram(model_var.op.name, model_var)) for loss_tensor in tf.losses.get_losses(): global_summaries.add( tf.summary.scalar(loss_tensor.op.name, loss_tensor)) global_summaries.add( tf.summary.scalar('TotalLoss', tf.losses.get_total_loss())) # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) summaries |= global_summaries # Merge all summaries together. summary_op = tf.summary.merge(list(summaries), name='summary_op') # Soft placement allows placing on CPU ops without GPU implementation. session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) # Save checkpoints regularly. keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours saver = tf.train.Saver( keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours) slim.learning.train( train_tensor, logdir=train_dir, master=master, is_chief=is_chief, session_config=session_config, startup_delay_steps=train_config.startup_delay_steps, init_fn=init_fn, summary_op=summary_op, number_of_steps=(train_config.num_steps if train_config.num_steps else None), save_summaries_secs=120, sync_optimizer=sync_optimizer, saver=saver)
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) # Set up deployment (i.e., multi-GPUs and/or multi-replicas). config = model_deploy.DeploymentConfig( num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.num_replicas, num_ps_tasks=FLAGS.num_ps_tasks) # Split the batch across GPUs. assert FLAGS.train_batch_size % config.num_clones == 0, ( 'Training batch size not divisble by number of clones (GPUs).') clone_batch_size = int(FLAGS.train_batch_size / config.num_clones) # Get dataset-dependent information. dataset = segmentation_dataset.get_dataset( FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir) tf.gfile.MakeDirs(FLAGS.train_logdir) tf.logging.info('Training on %s set', FLAGS.train_split) with tf.Graph().as_default() as graph: with tf.device(config.inputs_device()): samples = input_generator.get( dataset, FLAGS.train_crop_size, clone_batch_size, min_resize_value=FLAGS.min_resize_value, max_resize_value=FLAGS.max_resize_value, resize_factor=FLAGS.resize_factor, min_scale_factor=FLAGS.min_scale_factor, max_scale_factor=FLAGS.max_scale_factor, scale_factor_step_size=FLAGS.scale_factor_step_size, dataset_split=FLAGS.train_split, is_training=True, model_variant=FLAGS.model_variant) inputs_queue = prefetch_queue.prefetch_queue( samples, capacity=128 * config.num_clones) # Create the global step on the device storing the variables. with tf.device(config.variables_device()): global_step = tf.train.get_or_create_global_step() # Define the model and create clones. model_fn = _build_deeplab model_args = (inputs_queue, { common.OUTPUT_TYPE: dataset.num_classes }, dataset.ignore_label) clones = model_deploy.create_clones(config, model_fn, args=model_args) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by model_fn. first_clone_scope = config.clone_scope(0) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) # Add summaries for model variables. for model_var in slim.get_model_variables(): summaries.add(tf.summary.histogram(model_var.op.name, model_var)) # Add summaries for images, labels, semantic predictions if FLAGS.save_summaries_images: summary_image = graph.get_tensor_by_name( ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/')) summaries.add(tf.summary.image('samples/%s' % common.IMAGE, summary_image)) summary_label = tf.cast(graph.get_tensor_by_name( ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/')), tf.uint8) summaries.add(tf.summary.image('samples/%s' % common.LABEL, summary_label)) predictions = tf.cast(tf.expand_dims(tf.argmax(graph.get_tensor_by_name( ('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/')), 3), -1), tf.uint8) summaries.add(tf.summary.image('samples/%s' % common.OUTPUT_TYPE, predictions)) # Add summaries for losses. for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) # Build the optimizer based on the device specification. with tf.device(config.optimizer_device()): learning_rate = train_utils.get_model_learning_rate( FLAGS.learning_policy, FLAGS.base_learning_rate, FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor, FLAGS.training_number_of_steps, FLAGS.learning_power, FLAGS.slow_start_step, FLAGS.slow_start_learning_rate) optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum) summaries.add(tf.summary.scalar('learning_rate', learning_rate)) startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps for variable in slim.get_model_variables(): summaries.add(tf.summary.histogram(variable.op.name, variable)) with tf.device(config.variables_device()): total_loss, grads_and_vars = model_deploy.optimize_clones( clones, optimizer) total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.') summaries.add(tf.summary.scalar('total_loss', total_loss)) # Modify the gradients for biases and last layer variables. last_layers = model.get_extra_layer_scopes(FLAGS.last_layers_contain_logits_only) grad_mult = train_utils.get_model_gradient_multipliers( last_layers, FLAGS.last_layer_gradient_multiplier) if grad_mult: grads_and_vars = slim.learning.multiply_gradients( grads_and_vars, grad_mult) # Create gradient update op. grad_updates = optimizer.apply_gradients( grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries)) # Soft placement allows placing on CPU ops without GPU implementation. session_config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=False) # Start the training. slim.learning.train( train_tensor, logdir=FLAGS.train_logdir, log_every_n_steps=FLAGS.log_steps, master=FLAGS.master, number_of_steps=FLAGS.training_number_of_steps, is_chief=(FLAGS.task == 0), session_config=session_config, startup_delay_steps=startup_delay_steps, init_fn=train_utils.get_model_init_fn( FLAGS.train_logdir, FLAGS.tf_initial_checkpoint, FLAGS.initialize_last_layer, last_layers, ignore_missing_vars=True), summary_op=summary_op, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)
def train_model(candidate, N, F): print("train model") print(FLAGS.dataset_name) if not FLAGS.dataset_dir: raise ValueError('You must supply the dataset directory with --dataset_dir') tf.logging.set_verbosity(tf.logging.INFO) with tf.Graph().as_default(): ####################### # Config model_deploy # ####################### deploy_config = model_deploy.DeploymentConfig( num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.worker_replicas, num_ps_tasks=FLAGS.num_ps_tasks) # Create global_step with tf.device(deploy_config.variables_device()): global_step = slim.create_global_step() ###################### # Select the dataset # ###################### dataset = dataset_factory.get_dataset( FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir) ###################### # Select the network # ###################### network_fn = nets_factory.get_network_fn( FLAGS.model_name, candidate, N, F, num_classes=(dataset.num_classes - FLAGS.labels_offset), weight_decay=FLAGS.weight_decay, is_training=True) ##################################### # Select the preprocessing function # ##################################### preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name image_preprocessing_fn = preprocessing_factory.get_preprocessing( preprocessing_name, is_training=True) ############################################################## # Create a dataset provider that loads data from the dataset # ############################################################## with tf.device(deploy_config.inputs_device()): provider = slim.dataset_data_provider.DatasetDataProvider( dataset, num_readers=FLAGS.num_readers, common_queue_capacity=20 * FLAGS.batch_size, common_queue_min=10 * FLAGS.batch_size) [image, label] = provider.get(['image', 'label']) label -= FLAGS.labels_offset train_image_size = FLAGS.train_image_size or network_fn.default_image_size image = image_preprocessing_fn(image, train_image_size, train_image_size) images, labels = tf.train.batch( [image, label], batch_size=FLAGS.batch_size, num_threads=FLAGS.num_preprocessing_threads, capacity=5 * FLAGS.batch_size) labels = slim.one_hot_encoding( labels, dataset.num_classes - FLAGS.labels_offset) batch_queue = slim.prefetch_queue.prefetch_queue( [images, labels], capacity=2 * deploy_config.num_clones) #################### # Define the model # #################### def clone_fn(batch_queue): """Allows data parallelism by creating multiple clones of network_fn.""" images, labels = batch_queue.dequeue() logits, end_points = network_fn(images) ############################# # Specify the loss function # ############################# if 'AuxLogits' in end_points: slim.losses.softmax_cross_entropy( end_points['AuxLogits'], labels, label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss') slim.losses.softmax_cross_entropy( logits, labels, label_smoothing=FLAGS.label_smoothing, weights=1.0) return end_points # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue]) first_clone_scope = deploy_config.clone_scope(0) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by network_fn. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) # Add summaries for end_points. end_points = clones[0].outputs for end_point in end_points: x = end_points[end_point] summaries.add(tf.summary.histogram('activations/' + end_point, x)) summaries.add(tf.summary.scalar('sparsity/' + end_point, tf.nn.zero_fraction(x))) # Add summaries for losses. for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) # Add summaries for variables. for variable in slim.get_model_variables(): summaries.add(tf.summary.histogram(variable.op.name, variable)) ################################# # Configure the moving averages # ################################# if FLAGS.moving_average_decay: moving_average_variables = slim.get_model_variables() variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay, global_step) else: moving_average_variables, variable_averages = None, None ######################################### # Configure the optimization procedure. # ######################################### with tf.device(deploy_config.optimizer_device()): learning_rate = _configure_learning_rate(dataset.num_samples, global_step) optimizer = _configure_optimizer(learning_rate) summaries.add(tf.summary.scalar('learning_rate', learning_rate)) if FLAGS.sync_replicas: # If sync_replicas is enabled, the averaging will be done in the chief # queue runner. optimizer = tf.train.SyncReplicasOptimizer( opt=optimizer, replicas_to_aggregate=FLAGS.replicas_to_aggregate, total_num_replicas=FLAGS.worker_replicas, variable_averages=variable_averages, variables_to_average=moving_average_variables) elif FLAGS.moving_average_decay: # Update ops executed locally by trainer. update_ops.append(variable_averages.apply(moving_average_variables)) # Variables to train. variables_to_train = _get_variables_to_train() # and returns a train_tensor and summary_op total_loss, clones_gradients = model_deploy.optimize_clones( clones, optimizer, var_list=variables_to_train) # Add total_loss to summary. summaries.add(tf.summary.scalar('total_loss', total_loss)) # Create gradient updates. grad_updates = optimizer.apply_gradients(clones_gradients, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries), name='summary_op') ########################### # Kicks off the training. # ########################### slim.learning.train( train_tensor, logdir=FLAGS.train_dir, master=FLAGS.master, is_chief=(FLAGS.task == 0), init_fn=_get_init_fn(), summary_op=summary_op, number_of_steps=FLAGS.max_number_of_steps, log_every_n_steps=FLAGS.log_every_n_steps, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs, sync_optimizer=optimizer if FLAGS.sync_replicas else None)
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) # Set up deployment (i.e., multi-GPUs and/or multi-replicas). config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.num_replicas, num_ps_tasks=FLAGS.num_ps_tasks) # Split the batch across GPUs. assert FLAGS.train_batch_size % config.num_clones == 0, ( 'Training batch size not divisble by number of clones (GPUs).') clone_batch_size = FLAGS.train_batch_size // config.num_clones tf.gfile.MakeDirs(FLAGS.train_logdir) tf.logging.info('Training on %s set', FLAGS.train_split) with tf.Graph().as_default() as graph: with tf.device(config.inputs_device()): dataset = data_generator.Dataset( dataset_name=FLAGS.dataset, split_name=FLAGS.train_split, dataset_dir=FLAGS.dataset_dir, batch_size=clone_batch_size, crop_size=[int(sz) for sz in FLAGS.train_crop_size], min_resize_value=FLAGS.min_resize_value, max_resize_value=FLAGS.max_resize_value, resize_factor=FLAGS.resize_factor, min_scale_factor=FLAGS.min_scale_factor, max_scale_factor=FLAGS.max_scale_factor, scale_factor_step_size=FLAGS.scale_factor_step_size, model_variant=FLAGS.model_variant, num_readers=4, is_training=True, should_shuffle=False, should_repeat=True) # Create the global step on the device storing the variables. with tf.device(config.variables_device()): global_step = tf.train.get_or_create_global_step() # Define the model and create clones. model_args = (dataset.get_one_shot_iterator(), { common.OUTPUT_TYPE: dataset.num_of_classes, common.INSTANCE: 1, common.OFFSET: 2 }, dataset.ignore_label) clones = model_deploy.create_clones(config, _build_deeplab, args=model_args) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by model_fn. first_clone_scope = config.clone_scope(0) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) # Add summaries for model variables. for model_var in tf.model_variables(): summaries.add(tf.summary.histogram(model_var.op.name, model_var)) # Add summaries for images, labels, semantic predictions if FLAGS.save_summaries_images: summary_image = graph.get_tensor_by_name( ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/')) summaries.add( tf.summary.image('samples/%s' % common.IMAGE, summary_image)) ############ SEG LABEL ############### first_clone_label = graph.get_tensor_by_name( ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/')) # Scale up summary image pixel values for better visualization. pixel_scaling = max(1, 255 // dataset.num_of_classes) summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8) summaries.add( tf.summary.image('samples/%s' % common.LABEL, summary_label)) first_clone_output = graph.get_tensor_by_name( ('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/')) predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1) summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8) summaries.add( tf.summary.image('samples/%s' % common.OUTPUT_TYPE, summary_predictions)) ########### INSTANCE CENTER ########### first_clone_label = graph.get_tensor_by_name( ('%s/%s:0' % (first_clone_scope, common.LABEL_INSTANCE)).strip('/')) # Scale up summary image pixel values for better visualization. summary_label = tf.cast(first_clone_label, tf.uint8) summaries.add( tf.summary.image('samples/%s' % common.LABEL_INSTANCE, summary_label)) first_clone_label = graph.get_tensor_by_name( ('%s/%s:0' % (first_clone_scope, common.INSTANCE)).strip('/')) # Scale up summary image pixel values for better visualization. label_max_x = tf.reduce_max(first_clone_label) first_clone_label = tf.multiply( tf.divide(first_clone_label, label_max_x), 255.0) #upscaling = tf.constant(255.0, dtype=tf.float32) #summary_label = tf.cast(first_clone_label * upscaling, tf.uint8) summary_label = tf.cast(first_clone_label, tf.uint8) summaries.add( tf.summary.image('samples/%s' % common.INSTANCE, summary_label)) ########## INSTANCE OFFSET ########### first_clone_label = graph.get_tensor_by_name( ('%s/%s:0' % (first_clone_scope, common.LABEL_OFFSET)).strip('/')) # Scale up (between 0-255) summary image pixel values for better visualization. x_offset, y_offset, extra = tf.split(first_clone_label, 3, 3) label_max_x = tf.reduce_max(x_offset) x_offset = tf.multiply(tf.divide(x_offset, label_max_x), 255.0) label_max_y = tf.reduce_max(y_offset) y_offset = tf.multiply(tf.divide(y_offset, label_max_y), 255.0) channel_padding = tf.zeros_like(x_offset) first_clone_label = tf.concat( [x_offset, y_offset, channel_padding], 3) summary_label = tf.image.convert_image_dtype(first_clone_label, dtype=tf.uint8, saturate=True) summaries.add( tf.summary.image('samples/%s' % common.LABEL_OFFSET, summary_label)) ################ PREDICTIOS ####################### first_clone_label = graph.get_tensor_by_name( ('%s/%s:0' % (first_clone_scope, common.OFFSET)).strip('/')) # Scale up (between 0-255) summary image pixel values for better visualization. #first_clone_label = tf.squeeze(first_clone_label, 0) x_offset, y_offset = tf.split(first_clone_label, 2, 3) label_max_x = tf.reduce_max(x_offset) x_offset = tf.multiply(tf.divide(x_offset, label_max_x), 255.0) label_max_y = tf.reduce_max(y_offset) y_offset = tf.multiply(tf.divide(y_offset, label_max_y), 255.0) channel_padding = tf.zeros_like(x_offset) first_clone_label = tf.concat( [x_offset, y_offset, channel_padding], 3) #first_clone_label = tf.expand_dims(first_clone_label, 0) summary_label = tf.cast(first_clone_label, tf.uint8) summaries.add( tf.summary.image('samples/%s' % common.OFFSET, summary_label)) # Add summaries for losses. for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) # Build the optimizer based on the device specification. with tf.device(config.optimizer_device()): learning_rate = train_utils.get_model_learning_rate( FLAGS.learning_policy, FLAGS.base_learning_rate, FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor, FLAGS.training_number_of_steps, FLAGS.learning_power, FLAGS.slow_start_step, FLAGS.slow_start_learning_rate, decay_steps=FLAGS.decay_steps, end_learning_rate=FLAGS.end_learning_rate) summaries.add(tf.summary.scalar('learning_rate', learning_rate)) if FLAGS.optimizer == 'momentum': optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum) elif FLAGS.optimizer == 'adam': optimizer = tf.train.AdamOptimizer( learning_rate=FLAGS.adam_learning_rate, epsilon=FLAGS.adam_epsilon) else: raise ValueError('Unknown optimizer') if FLAGS.quantize_delay_step >= 0: if FLAGS.num_clones > 1: raise ValueError( 'Quantization doesn\'t support multi-clone yet.') contrib_quantize.create_training_graph( quant_delay=FLAGS.quantize_delay_step) startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps with tf.device(config.variables_device()): total_loss, grads_and_vars = model_deploy.optimize_clones( clones, optimizer) total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.') summaries.add(tf.summary.scalar('total_loss', total_loss)) # Modify the gradients for biases and last layer variables. last_layers = model.get_extra_layer_scopes( FLAGS.last_layers_contain_logits_only) grad_mult = train_utils.get_model_gradient_multipliers( last_layers, FLAGS.last_layer_gradient_multiplier) if grad_mult: grads_and_vars = slim.learning.multiply_gradients( grads_and_vars, grad_mult) # Create gradient update op. grad_updates = optimizer.apply_gradients(grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries)) # Soft placement allows placing on CPU ops without GPU implementation. session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) # Start the training. profile_dir = FLAGS.profile_logdir if profile_dir is not None: tf.gfile.MakeDirs(profile_dir) with contrib_tfprof.ProfileContext(enabled=profile_dir is not None, profile_dir=profile_dir): init_fn = None if FLAGS.tf_initial_checkpoint: init_fn = train_utils.get_model_init_fn( FLAGS.train_logdir, FLAGS.tf_initial_checkpoint, FLAGS.initialize_last_layer, last_layers, ignore_missing_vars=True) slim.learning.train(train_tensor, logdir=FLAGS.train_logdir, log_every_n_steps=FLAGS.log_steps, master=FLAGS.master, number_of_steps=FLAGS.training_number_of_steps, is_chief=(FLAGS.task == 0), session_config=session_config, startup_delay_steps=startup_delay_steps, init_fn=init_fn, summary_op=summary_op, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)
def run_transfer_learning(root_model_dir, bot_model_dir, protobuf_dir, model_name='inception_v4', dataset_split_name='train', dataset_name='bot', checkpoint_exclude_scopes=None, trainable_scopes=None, max_train_time_sec=None, max_number_of_steps=None, log_every_n_steps=None, save_summaries_secs=None, optimization_params=None): """ Starts the transfer learning of a model in a tensorflow session :param root_model_dir: Directory containing the root models pretrained checkpoint files :param bot_model_dir: Directory where the transfer learned model's checkpoint files are written to :param protobuf_dir: Directory for the dataset factory to load the bot's training data from :param model_name: name of the network model for the net factory to provide the correct network and preprocesing fn :param dataset_split_name: 'train' or 'validation' :param dataset_name: triggers the dataset factory to load a bot dataset :param checkpoint_exclude_scopes: Layers to exclude when restoring the models variables :param trainable_scopes: Layers to train from the restored model :param max_train_time_sec: time boundary to stop training after in seconds :param max_number_of_steps: maximum number of steps to run :param log_every_n_steps: write a log after every nth optimization step :param save_summaries_secs: save summaries to disc every n seconds :param optimization_params: parameters for the optimization :return: """ if not optimization_params: optimization_params = OPTIMIZATION_PARAMS if not max_number_of_steps: max_number_of_steps = _MAX_NUMBER_OF_STEPS if not checkpoint_exclude_scopes: checkpoint_exclude_scopes = _CHECKPOINT_EXCLUDE_SCOPES if not trainable_scopes: trainable_scopes = _TRAINABLE_SCOPES if not max_train_time_sec: max_train_time_sec = _MAX_TRAIN_TIME_SECONDS if not log_every_n_steps: log_every_n_steps = _LOG_EVERY_N_STEPS if not save_summaries_secs: save_summaries_secs = _SAVE_SUMMARRIES_SECS tf.logging.set_verbosity(tf.logging.INFO) with tf.Graph().as_default(): ####################### # Config model_deploy # ####################### deploy_config = model_deploy.DeploymentConfig( num_clones=_NUM_CLONES, clone_on_cpu=_CLONE_ON_CPU, replica_id=_TASK, num_replicas=_WORKER_REPLICAS, num_ps_tasks=_NUM_PS_TASKS) # Create global_step with tf.device(deploy_config.variables_device()): global_step = slim.create_global_step() ###################### # Select the dataset # ###################### dataset = dataset_factory.get_dataset( dataset_name, dataset_split_name, protobuf_dir) ###################### # Select the network # ###################### network_fn = nets_factory.get_network_fn( model_name, num_classes=(dataset.num_classes - _LABELS_OFFSET), weight_decay=OPTIMIZATION_PARAMS['weight_decay'], is_training=True, dropout_keep_prob=OPTIMIZATION_PARAMS['dropout_keep_prob']) ##################################### # Select the preprocessing function # ##################################### image_preprocessing_fn = preprocessing_factory.get_preprocessing( model_name, is_training=True) ############################################################## # Create a dataset provider that loads data from the dataset # ############################################################## with tf.device(deploy_config.inputs_device()): provider = slim.dataset_data_provider.DatasetDataProvider( dataset, num_readers=_NUM_READERS, common_queue_capacity=20 * _BATCH_SIZE, common_queue_min=10 * _BATCH_SIZE) [image, label] = provider.get(['image', 'label']) label -= _LABELS_OFFSET train_image_size = network_fn.default_image_size image = image_preprocessing_fn(image, train_image_size, train_image_size) images, labels = tf.train.batch( [image, label], batch_size=_BATCH_SIZE, num_threads=_NUM_PREPROCESSING_THREADS, capacity=5 * _BATCH_SIZE) labels = slim.one_hot_encoding( labels, dataset.num_classes - _LABELS_OFFSET) batch_queue = slim.prefetch_queue.prefetch_queue( [images, labels], capacity=2 * deploy_config.num_clones) #################### # Define the model # #################### def clone_fn(batch_queue): """Allows data parallelism by creating multiple clones of network_fn.""" images, labels = batch_queue.dequeue() logits, end_points = network_fn(images) ############################# # Specify the loss function # ############################# if 'AuxLogits' in end_points: tf.losses.softmax_cross_entropy( logits=end_points['AuxLogits'], onehot_labels=labels, label_smoothing=_LABEL_SMOOTHING, weights=0.4, scope='aux_loss') tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=labels, label_smoothing=_LABEL_SMOOTHING, weights=1.0) return end_points # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue]) first_clone_scope = deploy_config.clone_scope(0) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by network_fn. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) # Add summaries for end_points. end_points = clones[0].outputs for end_point in end_points: x = end_points[end_point] summaries.add(tf.summary.histogram('activations/' + end_point, x)) summaries.add(tf.summary.scalar('sparsity/' + end_point, tf.nn.zero_fraction(x))) # Add summaries for losses. for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) # Add summaries for variables. for variable in slim.get_model_variables(): summaries.add(tf.summary.histogram(variable.op.name, variable)) ################################# # Configure the moving averages # ################################# if OPTIMIZATION_PARAMS['moving_average_decay']: moving_average_variables = slim.get_model_variables() variable_averages = tf.train.ExponentialMovingAverage( OPTIMIZATION_PARAMS['moving_average_decay'], global_step) else: moving_average_variables, variable_averages = None, None ######################################### # Configure the optimization procedure. # ######################################### with tf.device(deploy_config.optimizer_device()): learning_rate = _configure_learning_rate(dataset.num_samples, global_step) optimizer = _configure_optimizer(learning_rate) summaries.add(tf.summary.scalar('learning_rate', learning_rate)) if _SYNC_REPLICAS: # If sync_replicas is enabled, the averaging will be done in the chief # queue runner. optimizer = tf.train.SyncReplicasOptimizer( opt=optimizer, replicas_to_aggregate=_REPLICAS_TO_AGGREGATE, variable_averages=variable_averages, variables_to_average=moving_average_variables, replica_id=tf.constant(_TASK, tf.int32, shape=()), total_num_replicas=_WORKER_REPLICAS) elif OPTIMIZATION_PARAMS['moving_average_decay']: # Update ops executed locally by trainer. update_ops.append(variable_averages.apply(moving_average_variables)) # Variables to train. variables_to_train = _get_variables_to_train(trainable_scopes) # and returns a train_tensor and summary_op total_loss, clones_gradients = model_deploy.optimize_clones( clones, optimizer, var_list=variables_to_train) # Add total_loss to summary. summaries.add(tf.summary.scalar('total_loss', total_loss)) # Create gradient updates. grad_updates = optimizer.apply_gradients(clones_gradients, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) train_tensor = control_flow_ops.with_dependencies([update_op], total_loss, name='train_op') # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries), name='summary_op') ########################### # Kicks off the training. # ########################### slim.learning.train( train_tensor, logdir=bot_model_dir, train_step_fn=train_step, # Manually added a custom train step to stop after max_time train_step_kwargs=_train_step_kwargs(logdir=bot_model_dir, max_train_time_seconds=max_train_time_sec), master=_MASTER, is_chief=(_TASK == 0), init_fn=_get_init_fn(root_model_dir, bot_model_dir, checkpoint_exclude_scopes), summary_op=summary_op, # number_of_steps=max_number_of_steps, log_every_n_steps=log_every_n_steps, save_summaries_secs=save_summaries_secs, save_interval_secs=_SAVE_INTERNAL_SECS, sync_optimizer=optimizer if _SYNC_REPLICAS else None)