def train(create_tensor_dict_fn, create_model_fn, train_config, master, task, num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name, is_chief, train_dir, graph_hook_fn=None): """Training function for detection models. Args: create_tensor_dict_fn: a function to create a tensor input dictionary. create_model_fn: a function that creates a DetectionModel and generates losses. train_config: a train_pb2.TrainConfig protobuf. master: BNS name of the TensorFlow master to use. task: The task id of this training instance. num_clones: The number of clones to run per machine. worker_replicas: The number of work replicas to train with. clone_on_cpu: True if clones should be forced to run on CPU. ps_tasks: Number of parameter server tasks. worker_job_name: Name of the worker job. is_chief: Whether this replica is the chief replica. train_dir: Directory to write checkpoints and training summaries to. graph_hook_fn: Optional function that is called after the training graph is completely built. This is helpful to perform additional changes to the training graph such as optimizing batchnorm. The function should modify the default graph. """ detection_model = create_model_fn() with tf.Graph().as_default(): # Build a configuration specifying multi-GPU and multi-replicas. deploy_config = model_deploy.DeploymentConfig( num_clones=num_clones, clone_on_cpu=clone_on_cpu, replica_id=task, num_replicas=worker_replicas, num_ps_tasks=ps_tasks, worker_job_name=worker_job_name) # Place the global step on the device storing the variables. with tf.device(deploy_config.variables_device()): global_step = slim.create_global_step() with tf.device(deploy_config.inputs_device()): input_queue = create_input_queue(create_tensor_dict_fn) # Gather initial summaries. # TODO(rathodv): See if summaries can be added/extracted from global tf # collections so that they don't have to be passed around. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) global_summaries = set([]) model_fn = functools.partial(_create_losses, create_model_fn=create_model_fn, train_config=train_config) clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue]) first_clone_scope = clones[0].scope # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by model_fn. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) with tf.device(deploy_config.optimizer_device()): training_optimizer, optimizer_summary_vars = optimizer_builder.build( train_config.optimizer) for var in optimizer_summary_vars: tf.summary.scalar(var.op.name, var) sync_optimizer = None if train_config.sync_replicas: training_optimizer = tf.train.SyncReplicasOptimizer( training_optimizer, replicas_to_aggregate=train_config.replicas_to_aggregate, total_num_replicas=train_config.worker_replicas) sync_optimizer = training_optimizer # Create ops required to initialize the model from a given checkpoint. init_fn = None if train_config.fine_tune_checkpoint: restore_checkpoints = [ path.strip() for path in train_config.fine_tune_checkpoint.split(',') ] restorers = get_restore_checkpoint_ops(restore_checkpoints, detection_model, train_config) def initializer_fn(sess): for i, restorer in enumerate(restorers): restorer.restore(sess, restore_checkpoints[i]) init_fn = initializer_fn with tf.device(deploy_config.optimizer_device()): regularization_losses = ( None if train_config.add_regularization_loss else []) total_loss, grads_and_vars = model_deploy.optimize_clones( clones, training_optimizer, regularization_losses=regularization_losses) total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.') # Optionally multiply bias gradients by train_config.bias_grad_multiplier. if train_config.bias_grad_multiplier: biases_regex_list = ['.*/biases'] grads_and_vars = variables_helper.multiply_gradients_matching_regex( grads_and_vars, biases_regex_list, multiplier=train_config.bias_grad_multiplier) # Optionally clip gradients if train_config.gradient_clipping_by_norm > 0: with tf.name_scope('clip_grads'): grads_and_vars = slim.learning.clip_gradient_norms( grads_and_vars, train_config.gradient_clipping_by_norm) moving_average_variables = slim.get_model_variables() variable_averages = tf.train.ExponentialMovingAverage( 0.9999, global_step) update_ops.append( variable_averages.apply(moving_average_variables)) # Create gradient updates. grad_updates = training_optimizer.apply_gradients( grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops, name='update_barrier') with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') if graph_hook_fn: with tf.device(deploy_config.variables_device()): graph_hook_fn() # Add summaries. for model_var in slim.get_model_variables(): global_summaries.add( tf.summary.histogram(model_var.op.name, model_var)) for loss_tensor in tf.losses.get_losses(): global_summaries.add( tf.summary.scalar(loss_tensor.op.name, loss_tensor)) global_summaries.add( tf.summary.scalar('TotalLoss', tf.losses.get_total_loss())) # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, 'critic_loss')) summaries |= global_summaries # Merge all summaries together. summary_op = tf.summary.merge(list(summaries), name='summary_op') # Soft placement allows placing on CPU ops without GPU implementation. session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) # Save checkpoints regularly. keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours saver = tf.train.Saver( keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours) slim.learning.train( train_tensor, logdir=train_dir, master=master, is_chief=is_chief, session_config=session_config, startup_delay_steps=train_config.startup_delay_steps, init_fn=init_fn, summary_op=summary_op, number_of_steps=(train_config.num_steps if train_config.num_steps else None), save_summaries_secs=120, sync_optimizer=sync_optimizer, saver=saver)
def main(_): if not FLAGS.dataset_dir: raise ValueError( 'You must supply the dataset directory with --dataset_dir') tf.logging.set_verbosity(tf.logging.INFO) with tf.Graph().as_default(): ####################### # Config model_deploy # ####################### deploy_config = model_deploy.DeploymentConfig( num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.worker_replicas, num_ps_tasks=FLAGS.num_ps_tasks) # Create global_step with tf.device(deploy_config.variables_device()): global_step = slim.create_global_step() ###################### # Select the dataset # ###################### dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir) ###################### # Select the network # ###################### network_fn = nets_factory.get_network_fn( FLAGS.model_name, num_classes=(dataset.num_classes - FLAGS.labels_offset), weight_decay=FLAGS.weight_decay, is_training=True) ##################################### # Select the preprocessing function # ##################################### preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name image_preprocessing_fn = preprocessing_factory.get_preprocessing( preprocessing_name, is_training=True) ############################################################## # Create a dataset provider that loads data from the dataset # ############################################################## with tf.device(deploy_config.inputs_device()): provider = slim.dataset_data_provider.DatasetDataProvider( dataset, num_readers=FLAGS.num_readers, common_queue_capacity=20 * FLAGS.batch_size, common_queue_min=10 * FLAGS.batch_size) [image, label] = provider.get(['image', 'label']) label -= FLAGS.labels_offset train_image_size = FLAGS.train_image_size or network_fn.default_image_size image = image_preprocessing_fn(image, train_image_size, train_image_size) images, labels = tf.train.batch( [image, label], batch_size=FLAGS.batch_size, num_threads=FLAGS.num_preprocessing_threads, capacity=5 * FLAGS.batch_size) labels = slim.one_hot_encoding( labels, dataset.num_classes - FLAGS.labels_offset) batch_queue = slim.prefetch_queue.prefetch_queue( [images, labels], capacity=2 * deploy_config.num_clones) #################### # Define the model # #################### def clone_fn(batch_queue): """Allows data parallelism by creating multiple clones of network_fn.""" images, labels = batch_queue.dequeue() logits, end_points = network_fn(images) ############################# # Specify the loss function # ############################# if 'AuxLogits' in end_points: slim.losses.softmax_cross_entropy( end_points['AuxLogits'], labels, label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss') slim.losses.softmax_cross_entropy( logits, labels, label_smoothing=FLAGS.label_smoothing, weights=1.0) return end_points # Gather initial summaries. summaries = set(tf.get_collection(tf.compat.v1.GraphKeys.SUMMARIES)) clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue]) first_clone_scope = deploy_config.clone_scope(0) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by network_fn. update_ops = tf.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS, first_clone_scope) # Add summaries for end_points. end_points = clones[0].outputs for end_point in end_points: x = end_points[end_point] summaries.add(tf.summary.histogram('activations/' + end_point, x)) summaries.add( tf.summary.scalar('sparsity/' + end_point, tf.nn.zero_fraction(x))) # Add summaries for losses. for loss in tf.get_collection(tf.compat.v1.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) # Add summaries for variables. for variable in slim.get_model_variables(): summaries.add(tf.summary.histogram(variable.op.name, variable)) ################################# # Configure the moving averages # ################################# if FLAGS.moving_average_decay: moving_average_variables = slim.get_model_variables() variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay, global_step) else: moving_average_variables, variable_averages = None, None #if FLAGS.quantize_delay >= 0: # tf.contrib.quantize.create_training_graph( # quant_delay=FLAGS.quantize_delay) ######################################### # Configure the optimization procedure. # ######################################### with tf.device(deploy_config.optimizer_device()): learning_rate = _configure_learning_rate(dataset.num_samples, global_step) optimizer = _configure_optimizer(learning_rate) summaries.add(tf.summary.scalar('learning_rate', learning_rate)) if FLAGS.sync_replicas: # If sync_replicas is enabled, the averaging will be done in the chief # queue runner. optimizer = tf.train.SyncReplicasOptimizer( opt=optimizer, replicas_to_aggregate=FLAGS.replicas_to_aggregate, total_num_replicas=FLAGS.worker_replicas, variable_averages=variable_averages, variables_to_average=moving_average_variables) elif FLAGS.moving_average_decay: # Update ops executed locally by trainer. update_ops.append( variable_averages.apply(moving_average_variables)) # Variables to train. variables_to_train = _get_variables_to_train() # and returns a train_tensor and summary_op total_loss, clones_gradients = model_deploy.optimize_clones( clones, optimizer, var_list=variables_to_train) # Add total_loss to summary. summaries.add(tf.summary.scalar('total_loss', total_loss)) # Create gradient updates. grad_updates = optimizer.apply_gradients(clones_gradients, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.compat.v1.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries), name='summary_op') ########################### # Kicks off the training. # ########################### slim.learning.train( train_tensor, logdir=FLAGS.train_dir, master=FLAGS.master, is_chief=(FLAGS.task == 0), init_fn=_get_init_fn(), summary_op=summary_op, number_of_steps=FLAGS.max_number_of_steps, log_every_n_steps=FLAGS.log_every_n_steps, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs, sync_optimizer=optimizer if FLAGS.sync_replicas else None)
def main(_): #tf.disable_v2_behavior() ### tf.compat.v1.disable_eager_execution() tf.compat.v1.enable_resource_variables() # Enable habana bf16 conversion pass if FLAGS.dtype == 'bf16': os.environ['TF_BF16_CONVERSION'] = flags.FLAGS.bf16_config_path FLAGS.precision = 'bf16' else: os.environ['TF_BF16_CONVERSION'] = "0" if FLAGS.use_horovod: hvd_init() if not FLAGS.dataset_dir: raise ValueError( 'You must supply the dataset directory with --dataset_dir') tf.logging.set_verbosity(tf.logging.INFO) with tf.Graph().as_default(): ####################### # Config model_deploy # ####################### deploy_config = model_deploy.DeploymentConfig( num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.worker_replicas, num_ps_tasks=FLAGS.num_ps_tasks) # Create global_step with tf.device(deploy_config.variables_device()): global_step = slim.create_global_step() ###################### # Select the dataset # ###################### dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir) ###################### # Select the network # ###################### network_fn = nets_factory.get_network_fn( FLAGS.model_name, num_classes=(dataset.num_classes - FLAGS.labels_offset), weight_decay=FLAGS.weight_decay, is_training=True) ##################################### # Select the preprocessing function # ##################################### preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name image_preprocessing_fn = preprocessing_factory.get_preprocessing( preprocessing_name, is_training=True, use_grayscale=FLAGS.use_grayscale) ############################################################## # Create a dataset provider that loads data from the dataset # ############################################################## with tf.device(deploy_config.inputs_device()): provider = slim.dataset_data_provider.DatasetDataProvider( dataset, num_readers=FLAGS.num_readers, common_queue_capacity=20 * FLAGS.batch_size, common_queue_min=10 * FLAGS.batch_size) [image, label] = provider.get(['image', 'label']) label -= FLAGS.labels_offset train_image_size = FLAGS.train_image_size or network_fn.default_image_size image = image_preprocessing_fn(image, train_image_size, train_image_size) images, labels = tf.train.batch( [image, label], batch_size=FLAGS.batch_size, num_threads=FLAGS.num_preprocessing_threads, capacity=5 * FLAGS.batch_size) labels = slim.one_hot_encoding( labels, dataset.num_classes - FLAGS.labels_offset) batch_queue = slim.prefetch_queue.prefetch_queue( [images, labels], capacity=2 * deploy_config.num_clones) #################### # Define the model # #################### def clone_fn(batch_queue): """Allows data parallelism by creating multiple clones of network_fn.""" images, labels = batch_queue.dequeue() logits, end_points = network_fn(images) ############################# # Specify the loss function # ############################# if 'AuxLogits' in end_points: slim.losses.softmax_cross_entropy( end_points['AuxLogits'], labels, label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss') slim.losses.softmax_cross_entropy( logits, labels, label_smoothing=FLAGS.label_smoothing, weights=1.0) return end_points # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue]) first_clone_scope = deploy_config.clone_scope(0) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by network_fn. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) # Add summaries for end_points. end_points = clones[0].outputs for end_point in end_points: x = end_points[end_point] summaries.add(tf.summary.histogram('activations/' + end_point, x)) summaries.add( tf.summary.scalar('sparsity/' + end_point, tf.nn.zero_fraction(x))) # Add summaries for variables. for variable in slim.get_model_variables(): summaries.add(tf.summary.histogram(variable.op.name, variable)) ################################# # Configure the moving averages # ################################# if FLAGS.moving_average_decay: moving_average_variables = slim.get_model_variables() variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay, global_step) else: moving_average_variables, variable_averages = None, None #if FLAGS.quantize_delay >= 0: # quantize.create_training_graph(quant_delay=FLAGS.quantize_delay) #for debugging!! ######################################### # Configure the optimization procedure. # ######################################### with tf.device(deploy_config.optimizer_device()): learning_rate = _configure_learning_rate(dataset.num_samples, global_step) optimizer = _configure_optimizer(learning_rate) summaries.add(tf.summary.scalar('learning_rate', learning_rate)) if FLAGS.sync_replicas: # If sync_replicas is enabled, the averaging will be done in the chief # queue runner. optimizer = tf.train.SyncReplicasOptimizer( opt=optimizer, replicas_to_aggregate=FLAGS.replicas_to_aggregate, total_num_replicas=FLAGS.worker_replicas, variable_averages=variable_averages, variables_to_average=moving_average_variables) elif FLAGS.moving_average_decay: # Update ops executed locally by trainer. update_ops.append( variable_averages.apply(moving_average_variables)) # Variables to train. variables_to_train = _get_variables_to_train() # and returns a train_tensor and summary_op total_loss, clones_gradients = model_deploy.optimize_clones( clones, optimizer, var_list=variables_to_train) # Create gradient updates. grad_updates = optimizer.apply_gradients(clones_gradients, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries), name='summary_op') if horovod_enabled(): hvd.broadcast_global_variables(0) ########################### # Kicks off the training. # ########################### with dump_callback(): with logger.benchmark_context(FLAGS): eps1 = ExamplesPerSecondKerasHook(FLAGS.log_every_n_steps, output_dir=FLAGS.train_dir, batch_size=FLAGS.batch_size) write_hparams_v1( eps1.writer, { 'batch_size': FLAGS.batch_size, **{x: getattr(FLAGS, x) for x in FLAGS} }) train_step_kwargs = {} if FLAGS.max_number_of_steps: should_stop_op = math_ops.greater_equal( global_step, FLAGS.max_number_of_steps) else: should_stop_op = constant_op.constant(False) train_step_kwargs['should_stop'] = should_stop_op if FLAGS.log_every_n_steps > 0: train_step_kwargs['should_log'] = math_ops.equal( math_ops.mod(global_step, FLAGS.log_every_n_steps), 0) eps1.on_train_begin() train_step_kwargs['EPS'] = eps1 slim.learning.train( train_tensor, logdir=FLAGS.train_dir, train_step_fn=train_step1, train_step_kwargs=train_step_kwargs, master=FLAGS.master, is_chief=(FLAGS.task == 0), init_fn=_get_init_fn(), summary_op=summary_op, summary_writer=None, number_of_steps=FLAGS.max_number_of_steps, log_every_n_steps=FLAGS.log_every_n_steps, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs, sync_optimizer=optimizer if FLAGS.sync_replicas else None)
def train(create_tensor_dict_fn, create_model_fn, train_config, master, task, num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name, is_chief, train_dir, graph_hook_fn=None): """Training function for detection models. Args: create_tensor_dict_fn: a function to create a tensor input dictionary. create_model_fn: a function that creates a DetectionModel and generates losses. train_config: a train_pb2.TrainConfig protobuf. master: BNS name of the TensorFlow master to use. task: The task id of this training instance. num_clones: The number of clones to run per machine. worker_replicas: The number of work replicas to train with. clone_on_cpu: True if clones should be forced to run on CPU. ps_tasks: Number of parameter server tasks. worker_job_name: Name of the worker job. is_chief: Whether this replica is the chief replica. train_dir: Directory to write checkpoints and training summaries to. graph_hook_fn: Optional function that is called after the inference graph is built (before optimization). This is helpful to perform additional changes to the training graph such as adding FakeQuant ops. The function should modify the default graph. Raises: ValueError: If both num_clones > 1 and train_config.sync_replicas is true. """ detection_model = create_model_fn() data_augmentation_options = [ preprocessor_builder.build(step) for step in train_config.data_augmentation_options] with tf.Graph().as_default(): # Build a configuration specifying multi-GPU and multi-replicas. deploy_config = model_deploy.DeploymentConfig( num_clones=num_clones, clone_on_cpu=clone_on_cpu, replica_id=task, num_replicas=worker_replicas, num_ps_tasks=ps_tasks, worker_job_name=worker_job_name) # Place the global step on the device storing the variables. with tf.device(deploy_config.variables_device()): global_step = slim.create_global_step() if num_clones != 1 and train_config.sync_replicas: raise ValueError('In Synchronous SGD mode num_clones must ', 'be 1. Found num_clones: {}'.format(num_clones)) batch_size = train_config.batch_size // num_clones if train_config.sync_replicas: batch_size //= train_config.replicas_to_aggregate with tf.device(deploy_config.inputs_device()): input_queue = create_input_queue( batch_size, create_tensor_dict_fn, train_config.batch_queue_capacity, train_config.num_batch_queue_threads, train_config.prefetch_queue_capacity, data_augmentation_options) # Gather initial summaries. # TODO(rathodv): See if summaries can be added/extracted from global tf # collections so that they don't have to be passed around. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) global_summaries = set([]) model_fn = functools.partial(_create_losses, create_model_fn=create_model_fn, train_config=train_config) clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue]) first_clone_scope = clones[0].scope if graph_hook_fn: with tf.device(deploy_config.variables_device()): graph_hook_fn() # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by model_fn. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) with tf.device(deploy_config.optimizer_device()): training_optimizer, optimizer_summary_vars = optimizer_builder.build( train_config.optimizer) for var in optimizer_summary_vars: tf.summary.scalar(var.op.name, var, family='LearningRate') sync_optimizer = None if train_config.sync_replicas: training_optimizer = tf.train.SyncReplicasOptimizer( training_optimizer, replicas_to_aggregate=train_config.replicas_to_aggregate, total_num_replicas=worker_replicas) sync_optimizer = training_optimizer with tf.device(deploy_config.optimizer_device()): regularization_losses = (None if train_config.add_regularization_loss else []) total_loss, grads_and_vars = model_deploy.optimize_clones( clones, training_optimizer, regularization_losses=regularization_losses) total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.') # Optionally multiply bias gradients by train_config.bias_grad_multiplier. if train_config.bias_grad_multiplier: biases_regex_list = ['.*/biases'] grads_and_vars = variables_helper.multiply_gradients_matching_regex( grads_and_vars, biases_regex_list, multiplier=train_config.bias_grad_multiplier) # Optionally freeze some layers by setting their gradients to be zero. if train_config.freeze_variables: grads_and_vars = variables_helper.freeze_gradients_matching_regex( grads_and_vars, train_config.freeze_variables) # Optionally clip gradients if train_config.gradient_clipping_by_norm > 0: with tf.name_scope('clip_grads'): grads_and_vars = slim.learning.clip_gradient_norms( grads_and_vars, train_config.gradient_clipping_by_norm) # Create gradient updates. grad_updates = training_optimizer.apply_gradients(grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops, name='update_barrier') with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add summaries. for model_var in slim.get_model_variables(): global_summaries.add(tf.summary.histogram('ModelVars/' + model_var.op.name, model_var)) for loss_tensor in tf.losses.get_losses(): global_summaries.add(tf.summary.scalar('Losses/' + loss_tensor.op.name, loss_tensor)) global_summaries.add( tf.summary.scalar('Losses/TotalLoss', tf.losses.get_total_loss())) # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) summaries |= global_summaries # Merge all summaries together. summary_op = tf.summary.merge(list(summaries), name='summary_op') # Soft placement allows placing on CPU ops without GPU implementation. session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) # Save checkpoints regularly. keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours saver = tf.train.Saver( keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours) # Create ops required to initialize the model from a given checkpoint. init_fn = None if train_config.fine_tune_checkpoint: if not train_config.fine_tune_checkpoint_type: # train_config.from_detection_checkpoint field is deprecated. For # backward compatibility, fine_tune_checkpoint_type is set based on # from_detection_checkpoint. if train_config.from_detection_checkpoint: train_config.fine_tune_checkpoint_type = 'detection' else: train_config.fine_tune_checkpoint_type = 'classification' var_map = detection_model.restore_map( fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type, load_all_detection_checkpoint_vars=( train_config.load_all_detection_checkpoint_vars)) available_var_map = (variables_helper. get_variables_available_in_checkpoint( var_map, train_config.fine_tune_checkpoint, include_global_step=False)) init_saver = tf.train.Saver(available_var_map) def initializer_fn(sess): init_saver.restore(sess, train_config.fine_tune_checkpoint) init_fn = initializer_fn slim.learning.train( train_tensor, logdir=train_dir, master=master, is_chief=is_chief, session_config=session_config, startup_delay_steps=train_config.startup_delay_steps, init_fn=init_fn, summary_op=summary_op, number_of_steps=( train_config.num_steps if train_config.num_steps else None), save_summaries_secs=120, sync_optimizer=sync_optimizer, saver=saver)
def main(model_root, datasets_dir, model_name): # tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) # 训练相关参数设置 with tf.Graph().as_default(): deploy_config = model_deploy.DeploymentConfig( num_clones=num_clones, clone_on_cpu=False, replica_id=task, num_replicas=worker_replicas, num_ps_tasks=num_ps_tasks) global_step = slim.create_global_step() train_dir = os.path.join(model_root, model_name) dataset = convert_data.get_datasets('train', dataset_dir=datasets_dir) network_fn = net_select.get_network_fn(model_name, num_classes=dataset.num_classes, weight_decay=weight_decay, is_training=True) image_preprocessing_fn = preprocessing_select.get_preprocessing( model_name, is_training=True) print("the data_sources:", dataset.data_sources) with tf.device(deploy_config.inputs_device()): provider = slim.dataset_data_provider.DatasetDataProvider( dataset, num_readers=num_readers, common_queue_capacity=20 * batch_size, common_queue_min=10 * batch_size) [image, label] = provider.get(['image', 'label']) train_image_size = network_fn.default_image_size image = image_preprocessing_fn(image, train_image_size, train_image_size) images, labels = tf.compat.v1.train.batch( [image, label], batch_size=batch_size, num_threads=num_preprocessing_threads, capacity=5 * batch_size) labels = slim.one_hot_encoding(labels, dataset.num_classes) batch_queue = slim.prefetch_queue.prefetch_queue( [images, labels], capacity=2 * deploy_config.num_clones) def calculate_pooling_center_loss(features, label, alfa, nrof_classes, weights, name): features = tf.reshape(features, [features.shape[0], -1]) label = tf.argmax(label, 1) nrof_features = features.get_shape()[1] centers = tf.compat.v1.get_variable( name, [nrof_classes, nrof_features], dtype=tf.float32, initializer=tf.constant_initializer(0), trainable=False) label = tf.reshape(label, [-1]) centers_batch = tf.gather(centers, label) centers_batch = tf.nn.l2_normalize(centers_batch, axis=-1) diff = (1 - alfa) * (centers_batch - features) centers = tf.compat.v1.scatter_sub(centers, label, diff) with tf.control_dependencies([centers]): distance = tf.square(features - centers_batch) distance = tf.reduce_sum(distance, axis=-1) center_loss = tf.reduce_mean(distance) center_loss = tf.identity(center_loss * weights, name=name + '_loss') return center_loss def attention_crop(attention_maps): ''' 利用attention map 做数据增强,这里是论文中的Crop Mask :param attention_maps: Feature maps降维得到的 :return: ''' batch_size, height, width, num_parts = attention_maps.shape bboxes = [] for i in range(batch_size): attention_map = attention_maps[i] part_weights = attention_map.mean(axis=0).mean(axis=0) part_weights = np.sqrt(part_weights) part_weights = part_weights / np.sum(part_weights) selected_index = np.random.choice(np.arange(0, num_parts), 1, p=part_weights)[0] mask = attention_map[:, :, selected_index] threshold = random.uniform(0.4, 0.6) itemindex = np.where(mask >= mask.max() * threshold) ymin = itemindex[0].min() / height - 0.1 ymax = itemindex[0].max() / height + 0.1 xmin = itemindex[1].min() / width - 0.1 xmax = itemindex[1].max() / width + 0.1 bbox = np.asarray([ymin, xmin, ymax, xmax], dtype=np.float32) bboxes.append(bbox) bboxes = np.asarray(bboxes, np.float32) return bboxes def attention_drop(attention_maps): ''' 这里是attention drop部分,目的是为了让模型可以注意到物体的其他部位(因不同attention map可能聚焦了同一部位) :param attention_maps: :return: ''' batch_size, height, width, num_parts = attention_maps.shape masks = [] for i in range(batch_size): attention_map = attention_maps[i] part_weights = attention_map.mean(axis=0).mean(axis=0) part_weights = np.sqrt(part_weights) if (np.sum(part_weights) != 0): part_weights = part_weights / np.sum(part_weights) selected_index = np.random.choice(np.arange(0, num_parts), 1, p=part_weights)[0] mask = attention_map[:, :, selected_index:selected_index + 1] # soft mask threshold = random.uniform(0.2, 0.5) mask = (mask < threshold * mask.max()).astype(np.float32) masks.append(mask) masks = np.asarray(masks, dtype=np.float32) return masks def clone_fn(batch_queue): """Allows data parallelism by creating multiple clones of network_fn.""" images, labels = batch_queue.dequeue() logits_1, end_points_1 = network_fn(images) attention_maps = end_points_1['attention_maps'] attention_maps = tf.image.resize( attention_maps, [train_image_size, train_image_size], method=tf.image.ResizeMethod.BILINEAR) # attention crop bboxes = tf.compat.v1.py_func(attention_crop, [attention_maps], [tf.float32]) bboxes = tf.reshape(bboxes, [batch_size, 4]) box_ind = tf.range(batch_size, dtype=tf.int32) images_crop = tf.image.crop_and_resize( images, bboxes, box_ind, crop_size=[train_image_size, train_image_size]) # attention drop masks = tf.compat.v1.py_func(attention_drop, [attention_maps], [tf.float32]) masks = tf.reshape( masks, [batch_size, train_image_size, train_image_size, 1]) images_drop = images * masks logits_2, end_points_2 = network_fn(images_crop, reuse=True) logits_3, end_points_3 = network_fn(images_drop, reuse=True) slim.losses.softmax_cross_entropy(logits_1, labels, weights=1 / 3.0, scope='cross_entropy_1') slim.losses.softmax_cross_entropy(logits_2, labels, weights=1 / 3.0, scope='cross_entropy_2') slim.losses.softmax_cross_entropy(logits_3, labels, weights=1 / 3.0, scope='cross_entropy_3') embeddings = end_points_1['embeddings'] center_loss = calculate_pooling_center_loss( features=embeddings, label=labels, alfa=0.95, nrof_classes=dataset.num_classes, weights=1.0, name='center_loss') slim.losses.add_loss(center_loss) return end_points_1 # Gather initial summaries. summaries = set( tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES)) clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue]) first_clone_scope = deploy_config.clone_scope(0) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by network_fn. update_ops = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.UPDATE_OPS, first_clone_scope) # Add summaries for end_points. end_points = clones[0].outputs for end_point in end_points: x = end_points[end_point] summaries.add(tf.summary.histogram('activations/' + end_point, x)) summaries.add( tf.summary.scalar('sparsity/' + end_point, tf.nn.zero_fraction(x))) # Add summaries for losses. for loss in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) # Add summaries for variables. for variable in slim.get_model_variables(): summaries.add(tf.summary.histogram(variable.op.name, variable)) ################################# # Configure the moving averages # ################################# if moving_average_decay: moving_average_variables = slim.get_model_variables() variable_averages = tf.train.ExponentialMovingAverage( moving_average_decay, global_step) else: moving_average_variables, variable_averages = None, None ######################################### # Configure the optimization procedure. # ######################################### with tf.device(deploy_config.optimizer_device()): learning_rate = configure_learning_rate(dataset.num_samples, global_step) optimizer = configure_optimizer(learning_rate) summaries.add(tf.summary.scalar('learning_rate', learning_rate)) if moving_average_decay: # Update ops executed locally by trainer. update_ops.append( variable_averages.apply(moving_average_variables)) # Variables to train. variables_to_train = get_variables_to_train(trainable_scopes) # and returns a train_tensor and summary_op total_loss, clones_gradients = model_deploy.optimize_clones( clones, optimizer, var_list=variables_to_train) # Add total_loss to summary. summaries.add(tf.summary.scalar('total_loss', total_loss)) # Create gradient updates. grad_updates = optimizer.apply_gradients(clones_gradients, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.compat.v1.summary.merge_all() config = tf.compat.v1.ConfigProto(allow_soft_placement=True, log_device_placement=False) config.gpu_options.allow_growth = True config.gpu_options.visible_device_list = "0" save_model_path = os.path.join(checkpoint_path, model_name, "%s.ckpt" % model_name) print(save_model_path) # saver = tf.compat.v1.train.import_meta_graph('%s.meta'%save_model_path, clear_devices=True) tf.compat.v1.disable_eager_execution() # train the model slim.learning.train( train_op=train_tensor, logdir=train_dir, is_chief=(task == 0), init_fn=_get_init_fn(save_model_path, train_dir=train_dir), summary_op=summary_op, number_of_steps=max_number_of_steps, log_every_n_steps=log_every_n_steps, save_summaries_secs=save_summaries_secs, save_interval_secs=save_interval_secs, # sync_optimizer=None, session_config=config)
def main(_): if not FLAGS.dataset_dir: raise ValueError( 'You must supply the dataset directory with --dataset_dir') tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) with tf.Graph().as_default(): ####################### # Config model_deploy # ####################### deploy_config = model_deploy.DeploymentConfig( num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.worker_replicas, num_ps_tasks=FLAGS.num_ps_tasks) # Create global_step with tf.device(deploy_config.variables_device()): global_step = slim.create_global_step() ###################### # Select the dataset # ###################### dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir) ###################### # Select the network # ###################### network_fn = nets_factory.get_network_fn( FLAGS.model_name, num_classes=(dataset.num_classes - FLAGS.labels_offset), weight_decay=FLAGS.weight_decay, is_training=True) ##################################### # Select the preprocessing function # ##################################### preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name image_preprocessing_fn = preprocessing_factory.get_preprocessing( preprocessing_name, is_training=True) ############################################################## # Create a dataset provider that loads data from the dataset # ############################################################## with tf.device(deploy_config.inputs_device()): provider = slim.dataset_data_provider.DatasetDataProvider( dataset, num_readers=FLAGS.num_readers, common_queue_capacity=20 * FLAGS.batch_size, common_queue_min=10 * FLAGS.batch_size) [image, label] = provider.get(['image', 'label']) label -= FLAGS.labels_offset train_image_size = FLAGS.train_image_size or network_fn.default_image_size image = image_preprocessing_fn(image, train_image_size, train_image_size) images, labels = tf.compat.v1.train.batch( [image, label], batch_size=FLAGS.batch_size, num_threads=FLAGS.num_preprocessing_threads, capacity=5 * FLAGS.batch_size) labels = slim.one_hot_encoding( labels, dataset.num_classes - FLAGS.labels_offset) batch_queue = slim.prefetch_queue.prefetch_queue( [images, labels], capacity=2 * deploy_config.num_clones) #################### # Define the model # #################### def clone_fn(batch_queue): """Allows data parallelism by creating multiple clones of network_fn.""" images, labels = batch_queue.dequeue() logits, end_points = network_fn(images) ############################# # Specify the loss function # ############################# if 'AuxLogits' in end_points: slim.losses.softmax_cross_entropy( end_points['AuxLogits'], labels, label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss') slim.losses.softmax_cross_entropy( logits, labels, label_smoothing=FLAGS.label_smoothing, weights=1.0) accuracy = slim.metrics.accuracy( tf.cast(tf.argmax(input=logits, axis=1), dtype=tf.int32), tf.cast(tf.argmax(input=labels, axis=1), dtype=tf.int32)) tf.compat.v1.add_to_collection('accuracy', accuracy) end_points['train_accuracy'] = accuracy return end_points # Get accuracies for the batch # Gather initial summaries. summaries = set( tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES)) clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue]) first_clone_scope = deploy_config.clone_scope(0) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by network_fn. update_ops = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.UPDATE_OPS, first_clone_scope) # Add summaries for end_points. end_points = clones[0].outputs for end_point in end_points: if 'accuracy' in end_point: continue x = end_points[end_point] summaries.add( tf.compat.v1.summary.histogram('activations/' + end_point, x)) summaries.add( tf.compat.v1.summary.scalar('sparsity/' + end_point, tf.nn.zero_fraction(x))) train_acc = end_points['train_accuracy'] summaries.add( tf.compat.v1.summary.scalar('train_accuracy', end_points['train_accuracy'])) # Add summaries for losses. for loss in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.LOSSES, first_clone_scope): summaries.add( tf.compat.v1.summary.scalar('losses/%s' % loss.op.name, loss)) # @philkuz # Add accuracy summaries # TODO add if statemetn for n iterations # images_val, labels_val= tf.train.batch( # [image, label], # batch_size=FLAGS.batch_size, # num_threads=FLAGS.num_preprocessing_threads, # capacity=5 * FLAGS.batch_size) # # labels_val = slim.one_hot_encoding( # # labels_val, dataset.num_classes - FLAGS.labels_offset) # batch_queue_val = slim.prefetch_queue.prefetch_queue( # [images_val, labels_val], capacity=2 * deploy_config.num_clones) # logits, end_points = network_fn(images, reuse=True) # # predictions = tf.nn.softmax(logits) # predictions = tf.to_in32(tf.argmax(logits,1)) # logits_val, end_points_val = network_fn(images_val, reuse=True) # predictions_val = tf.to_in32(tf.argmax(logits_val,1)) # labels_val = tf.squeeze(labels_val) # labels = tf.squeeze(labels) # names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({ # 'train/accuracy': slim.metrics.streaming_accuracy(predictions, labels), # 'val/accuracy': slim.metrics.streaming_accuracy(predictions_val, labels_val), # }) # for metric_name, metric_value in names_to_values.items(): # op = tf.summary.scalar(metric_name, metric_value) # # op = tf.Print(op, [metric_value], metric_name) # summaries.add(op) # Add summaries for variables. # TODO something to remove some of these from tensorboard scalars for variable in slim.get_model_variables(): summaries.add( tf.compat.v1.summary.histogram(variable.op.name, variable)) ################################# # Configure the moving averages # ################################# if FLAGS.moving_average_decay: moving_average_variables = slim.get_model_variables() variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay, global_step) else: moving_average_variables, variable_averages = None, None ######################################### # Configure the optimization procedure. # ######################################### with tf.device(deploy_config.optimizer_device()): learning_rate = _configure_learning_rate(dataset.num_samples, global_step) optimizer = _configure_optimizer(learning_rate) summaries.add( tf.compat.v1.summary.scalar('learning_rate', learning_rate)) if FLAGS.sync_replicas: # If sync_replicas is enabled, the averaging will be done in the chief # queue runner. optimizer = tf.compat.v1.train.SyncReplicasOptimizer( opt=optimizer, replicas_to_aggregate=FLAGS.replicas_to_aggregate, total_num_replicas=FLAGS.worker_replicas, variable_averages=variable_averages, variables_to_average=moving_average_variables) elif FLAGS.moving_average_decay: # Update ops executed locally by trainer. update_ops.append( variable_averages.apply(moving_average_variables)) # Variables to train. variables_to_train = _get_variables_to_train() # and returns a train_tensor and summary_op total_loss, clones_gradients = model_deploy.optimize_clones( clones, optimizer, var_list=variables_to_train) # Add total_loss to summary. summaries.add(tf.compat.v1.summary.scalar('total_loss', total_loss)) # Create gradient updates. grad_updates = optimizer.apply_gradients(clones_gradients, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.compat.v1.summary.merge(list(summaries), name='summary_op') # @philkuz # set the max_number_of_steps parameter if num_epochs is available print('FLAGS.num_epochs', FLAGS.num_epochs) if FLAGS.num_epochs is not None and FLAGS.max_number_of_steps is None: FLAGS.max_number_of_steps = int( FLAGS.num_epochs * dataset.num_samples / FLAGS.batch_size) # FLAGS.max_number_of_steps = int(math.round(FLAGS.num_epochs / dataset.num_samples)) # setup the logdir # @philkuz the train_dir setup if FLAGS.experiment_name is not None: experiment_dir = 'bs={},lr={},epochs={}/{}'.format( FLAGS.batch_size, FLAGS.learning_rate, FLAGS.num_epochs, FLAGS.experiment_name) print(experiment_dir) FLAGS.train_dir = os.path.join(FLAGS.train_dir, experiment_dir) print(FLAGS.train_dir) # @philkuz overriding train_step def train_step(sess, train_op, global_step, train_step_kwargs): """Function that takes a gradient step and specifies whether to stop. Args: sess: The current session. train_op: An `Operation` that evaluates the gradients and returns the total loss. global_step: A `Tensor` representing the global training step. train_step_kwargs: A dictionary of keyword arguments. Returns: The total loss and a boolean indicating whether or not to stop training. Raises: ValueError: if 'should_trace' is in `train_step_kwargs` but `logdir` is not. """ start_time = time.time() trace_run_options = None run_metadata = None should_acc = True # TODO make this not hardcoded @philkuz if 'should_trace' in train_step_kwargs: if 'logdir' not in train_step_kwargs: raise ValueError( 'logdir must be present in train_step_kwargs when ' 'should_trace is present') if sess.run(train_step_kwargs['should_trace']): trace_run_options = config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE) run_metadata = config_pb2.RunMetadata() if not should_acc: total_loss, np_global_step = sess.run( [train_op, global_step], options=trace_run_options, run_metadata=run_metadata) else: total_loss, acc, np_global_step = sess.run( [train_op, train_acc, global_step], options=trace_run_options, run_metadata=run_metadata) time_elapsed = time.time() - start_time if run_metadata is not None: tl = timeline.Timeline(run_metadata.step_stats) trace = tl.generate_chrome_trace_format() trace_filename = os.path.join( train_step_kwargs['logdir'], 'tf_trace-%d.json' % np_global_step) tf.compat.v1.logging.info('Writing trace to %s', trace_filename) file_io.write_string_to_file(trace_filename, trace) if 'summary_writer' in train_step_kwargs: train_step_kwargs['summary_writer'].add_run_metadata( run_metadata, 'run_metadata-%d' % np_global_step) if 'should_log' in train_step_kwargs: if sess.run(train_step_kwargs['should_log']): if not should_acc: tf.compat.v1.logging.info( 'global step %d: loss = %.4f (%.3f sec/step)', np_global_step, total_loss, time_elapsed) else: tf.compat.v1.logging.info( 'global step %d: loss = %.4f train_acc = %.4f (%.3f sec/step)', np_global_step, total_loss, acc, time_elapsed) if 'should_stop' in train_step_kwargs: should_stop = sess.run(train_step_kwargs['should_stop']) else: should_stop = False return total_loss, should_stop ########################### # Kicks off the training. # ########################### slim.learning.train( train_tensor, logdir=FLAGS.train_dir, train_step_fn=train_step, master=FLAGS.master, is_chief=(FLAGS.task == 0), init_fn=_get_init_fn(), summary_op=summary_op, number_of_steps=FLAGS.max_number_of_steps, log_every_n_steps=FLAGS.log_every_n_steps, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs, sync_optimizer=optimizer if FLAGS.sync_replicas else None)