def train_loop( pipeline_config_path, model_dir, val_checkpoint_dir, config_override=None, train_steps=None, use_tpu=False, save_final_config=False, checkpoint_every_n=1000, checkpoint_max_to_keep=7, record_summaries=True, performance_summary_exporter=None, **kwargs): """Trains a model using eager + functions. This method: 1. Processes the pipeline configs 2. (Optionally) saves the as-run config 3. Builds the model & optimizer 4. Gets the training input data 5. Loads a fine-tuning detection or classification checkpoint if requested 6. Loops over the train data, executing distributed training steps inside tf.functions. 7. Checkpoints the model every `checkpoint_every_n` training steps. 8. Logs the training metrics as TensorBoard summaries. Args: pipeline_config_path: A path to a pipeline config file. model_dir: The directory to save checkpoints and summaries to. val_checkpoint_dir: The directory to save validation checkpoint. config_override: A pipeline_pb2.TrainEvalPipelineConfig text proto to override the config from `pipeline_config_path`. train_steps: Number of training steps. If None, the number of training steps is set from the `TrainConfig` proto. use_tpu: Boolean, whether training and evaluation should run on TPU. save_final_config: Whether to save final config (obtained after applying overrides) to `model_dir`. checkpoint_every_n: Checkpoint every n training steps. checkpoint_max_to_keep: int, the number of most recent checkpoints to keep in the model directory. record_summaries: Boolean, whether or not to record summaries. performance_summary_exporter: function for exporting performance metrics. **kwargs: Additional keyword arguments for configuration override. """ print('START train looop function ========================') ## Parse the configs get_configs_from_pipeline_file = MODEL_BUILD_UTIL_MAP[ 'get_configs_from_pipeline_file'] merge_external_params_with_configs = MODEL_BUILD_UTIL_MAP[ 'merge_external_params_with_configs'] create_pipeline_proto_from_configs = MODEL_BUILD_UTIL_MAP[ 'create_pipeline_proto_from_configs'] steps_per_sec_list = [] configs = get_configs_from_pipeline_file( pipeline_config_path, config_override=config_override) kwargs.update({ 'train_steps': train_steps, 'use_bfloat16': configs['train_config'].use_bfloat16 and use_tpu }) configs = merge_external_params_with_configs( configs, None, kwargs_dict=kwargs) model_config = configs['model'] train_config = configs['train_config'] train_input_config = configs['train_input_config'] unpad_groundtruth_tensors = train_config.unpad_groundtruth_tensors add_regularization_loss = train_config.add_regularization_loss clip_gradients_value = None if train_config.gradient_clipping_by_norm > 0: clip_gradients_value = train_config.gradient_clipping_by_norm # update train_steps from config but only when non-zero value is provided if train_steps is None and train_config.num_steps != 0: train_steps = train_config.num_steps if kwargs['use_bfloat16']: tf.compat.v2.keras.mixed_precision.experimental.set_policy('mixed_bfloat16') if train_config.load_all_detection_checkpoint_vars: raise ValueError('train_pb2.load_all_detection_checkpoint_vars ' 'unsupported in TF2') config_util.update_fine_tune_checkpoint_type(train_config) fine_tune_checkpoint_type = train_config.fine_tune_checkpoint_type fine_tune_checkpoint_version = train_config.fine_tune_checkpoint_version # Write the as-run pipeline config to disk. if save_final_config: tf.logging.info('Saving pipeline config file to directory {}'.format( model_dir)) pipeline_config_final = create_pipeline_proto_from_configs(configs) config_util.save_pipeline_config(pipeline_config_final, model_dir) # Build the model, optimizer, and training input strategy = tf.compat.v2.distribute.get_strategy() with strategy.scope(): detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base']( model_config=model_config, is_training=True) def train_dataset_fn(input_context): """Callable to create train input.""" # Create the inputs. train_input = inputs.train_input( train_config=train_config, train_input_config=train_input_config, model_config=model_config, model=detection_model, input_context=input_context) train_input = train_input.repeat() return train_input train_input = strategy.experimental_distribute_datasets_from_function( train_dataset_fn) global_step = tf.Variable( 0, trainable=False, dtype=tf.compat.v2.dtypes.int64, name='global_step', aggregation=tf.compat.v2.VariableAggregation.ONLY_FIRST_REPLICA) optimizer, (learning_rate,) = optimizer_builder.build( train_config.optimizer, global_step=global_step) # We run the detection_model on dummy inputs in order to ensure that the # model and all its variables have been properly constructed. Specifically, # this is currently necessary prior to (potentially) creating shadow copies # of the model variables for the EMA optimizer. if train_config.optimizer.use_moving_average: _ensure_model_is_built(detection_model, train_input, unpad_groundtruth_tensors) optimizer.shadow_copy(detection_model) if callable(learning_rate): learning_rate_fn = learning_rate else: learning_rate_fn = lambda: learning_rate ## Train the model # Get the appropriate filepath (temporary or not) based on whether the worker # is the chief. summary_writer_filepath = get_filepath(strategy, os.path.join(model_dir, 'train')) if record_summaries: summary_writer = tf.compat.v2.summary.create_file_writer( summary_writer_filepath) else: summary_writer = tf2.summary.create_noop_writer() if use_tpu: num_steps_per_iteration = 100 else: # TODO(b/135933080) Explore setting to 100 when GPU performance issues # are fixed. num_steps_per_iteration = 1 with summary_writer.as_default(): with strategy.scope(): with tf.compat.v2.summary.record_if( lambda: global_step % num_steps_per_iteration == 0): # Load a fine-tuning checkpoint. if train_config.fine_tune_checkpoint: load_fine_tune_checkpoint( detection_model, train_config.fine_tune_checkpoint, fine_tune_checkpoint_type, fine_tune_checkpoint_version, train_config.run_fine_tune_checkpoint_dummy_computation, train_input, unpad_groundtruth_tensors) ckpt = tf.compat.v2.train.Checkpoint( step=global_step, model=detection_model, optimizer=optimizer) val_ckpt = tf.compat.v2.train.Checkpoint( step=global_step, model=detection_model, optimizer=optimizer) manager_dir = get_filepath(strategy, model_dir) val_manager_dir = get_filepath(strategy, val_checkpoint_dir) # if not strategy.extended.should_checkpoint: # checkpoint_max_to_keep = 1 checkpoint_max_to_keep = 1 manager = tf.compat.v2.train.CheckpointManager( ckpt, manager_dir, max_to_keep=checkpoint_max_to_keep) val_manager = tf.compat.v2.train.CheckpointManager( val_ckpt, val_manager_dir, max_to_keep=checkpoint_max_to_keep) model_checkpoint_callback = tfc.ModelCheckpoint(val_manager) early_stopping_callback = tfc.EarlyStopping(min_delta=0.0001, patience=5, mode='min') train_logger_callback = tfc.TrainLogger(model_dir, 'logs.txt') cancellation_point = tfc.CancellationPoint() # We use the following instead of manager.latest_checkpoint because # manager_dir does not point to the model directory when we are running # in a worker. latest_checkpoint = tf.train.latest_checkpoint(model_dir) ckpt.restore(latest_checkpoint) val_ckpt.restore(latest_checkpoint) def train_step_fn(features, labels): """Single train step.""" loss = eager_train_step( detection_model, features, labels, unpad_groundtruth_tensors, optimizer, learning_rate=learning_rate_fn(), add_regularization_loss=add_regularization_loss, clip_gradients_value=clip_gradients_value, global_step=global_step, num_replicas=strategy.num_replicas_in_sync) global_step.assign_add(1) return loss def _sample_and_train(strategy, train_step_fn, data_iterator): features, labels = data_iterator.next() if hasattr(tf.distribute.Strategy, 'run'): per_replica_losses = strategy.run( train_step_fn, args=(features, labels)) else: per_replica_losses = strategy.experimental_run_v2( train_step_fn, args=(features, labels)) # TODO(anjalisridhar): explore if it is safe to remove the ## num_replicas scaling of the loss and switch this to a ReduceOp.Mean return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) @tf.function def _dist_train_step(data_iterator): """A distributed train step.""" if num_steps_per_iteration > 1: for _ in tf.range(num_steps_per_iteration - 1): # Following suggestion on yaqs/5402607292645376 with tf.name_scope(''): _sample_and_train(strategy, train_step_fn, data_iterator) return _sample_and_train(strategy, train_step_fn, data_iterator) train_input_iter = iter(train_input) if int(global_step.value()) == 0: manager.save() checkpointed_step = int(global_step.value()) logged_step = global_step.value() # num_epochs = (train_steps - global_step.value()) // num_steps_per_iteration last_step_time = time.time() for epoch, _ in enumerate(range(global_step.value(), train_steps, num_steps_per_iteration)): loss = _dist_train_step(train_input_iter) time_taken = time.time() - last_step_time last_step_time = time.time() steps_per_sec = num_steps_per_iteration * 1.0 / time_taken tf.compat.v2.summary.scalar( 'steps_per_sec', steps_per_sec, step=global_step) steps_per_sec_list.append(steps_per_sec) if global_step.value() - logged_step >= 100: tf.logging.info( 'Step {} per-step time {:.3f}s loss={:.3f}'.format( global_step.value(), time_taken / num_steps_per_iteration, loss)) manager.save() checkpointed_step = int(global_step.value()) log_metrics = eval_continuously(pipeline_config_path, model_dir=model_dir, checkpoint_dir=model_dir, timeout=20) log_metrics['train_total_loss'] = loss model_checkpoint_callback.step(epoch, log_metrics['Loss/total_loss']) stop_training = early_stopping_callback.step(epoch, log_metrics['Loss/total_loss']) train_logger_callback.log(log_metrics) if stop_training or cancellation_point.check(): break print(log_metrics) logged_step = global_step.value() # Remove the checkpoint directories of the non-chief workers that # MultiWorkerMirroredStrategy forces us to save during sync distributed # training. clean_temporary_directories(strategy, manager_dir) clean_temporary_directories(strategy, summary_writer_filepath) # TODO(pkanwar): add accuracy metrics. if performance_summary_exporter is not None: metrics = { 'steps_per_sec': np.mean(steps_per_sec_list), 'steps_per_sec_p50': np.median(steps_per_sec_list), 'steps_per_sec_max': max(steps_per_sec_list), 'last_batch_loss': float(loss) } mixed_precision = 'bf16' if kwargs['use_bfloat16'] else 'fp32' performance_summary_exporter(metrics, mixed_precision)
def train_loop(pipeline_config_path, model_dir, config_override=None, train_steps=None, use_tpu=False, save_final_config=False, checkpoint_every_n=1000, checkpoint_max_to_keep=7, record_summaries=True, performance_summary_exporter=None, num_steps_per_iteration=NUM_STEPS_PER_ITERATION, **kwargs): # """Trains a model using eager + functions. config_override = None configs = config_util.get_configs_from_pipeline_file( pipeline_config_path, config_override=config_override) kwargs.update({ 'train_steps': train_steps, 'use_bfloat16': configs['train_config'].use_bfloat16 and use_tpu }) configs = config_util.merge_external_params_with_configs( configs, None, kwargs_dict=kwargs) model_config = configs['model'] train_config = configs['train_config'] train_input_config = configs['train_input_config'] unpad_groundtruth_tensors = train_config.unpad_groundtruth_tensors # False add_regularization_loss = train_config.add_regularization_loss # True clip_gradients_value = None if train_config.gradient_clipping_by_norm > 0: # Not run clip_gradients_value = train_config.gradient_clipping_by_norm # update train_steps from config but only when non-zero value is provided train_steps = num_train_steps if train_steps is None and train_config.num_steps != 0: train_steps = train_config.num_steps tf.compat.v2.keras.mixed_precision.set_global_policy('mixed_bfloat16') if train_config.load_all_detection_checkpoint_vars: raise ValueError('train_pb2.load_all_detection_checkpoint_vars ' 'unsupported in TF2') config_util.update_fine_tune_checkpoint_type(train_config) fine_tune_checkpoint_type = train_config.fine_tune_checkpoint_type # 'detection' fine_tune_checkpoint_version = train_config.fine_tune_checkpoint_version # Build the model, optimizer, and training input strategy = tf.compat.v2.distribute.get_strategy() from object_detection import inputs from object_detection.builders import optimizer_builder from object_detection.utils import variables_helper with strategy.scope(): detection_model = model_builder.build(model_config=model_config, is_training=True) def train_dataset_fn(input_context): """Callable to create train input.""" # Create the inputs. train_input = inputs.train_input( train_config=train_config, train_input_config=train_input_config, model_config=model_config, model=detection_model, input_context=input_context) train_input = train_input.repeat() return train_input train_input = strategy.experimental_distribute_datasets_from_function( train_dataset_fn) global_step = tf.Variable( 0, trainable=False, dtype=tf.compat.v2.dtypes.int64, name='global_step', aggregation=tf.compat.v2.VariableAggregation.ONLY_FIRST_REPLICA) optimizer, (learning_rate, ) = optimizer_builder.build( train_config.optimizer, global_step=global_step) if callable(learning_rate): learning_rate_fn = learning_rate else: def learning_rate_fn(): return learning_rate # Train the model # Get the appropriate filepath (temporary or not) based on whether the worker # is the chief. summary_writer_filepath = get_filepath(strategy, os.path.join(model_dir, 'train')) if record_summaries: summary_writer = tf.compat.v2.summary.create_file_writer( summary_writer_filepath) else: #summary_writer = tf2.summary.create_noop_writer() summary_writer = tf.summary.create_noop_writer() with summary_writer.as_default(): with strategy.scope(): with tf.compat.v2.summary.record_if( lambda: global_step % num_steps_per_iteration == 0): # Load a fine-tuning checkpoint. if train_config.fine_tune_checkpoint: variables_helper.ensure_checkpoint_supported( train_config.fine_tune_checkpoint, fine_tune_checkpoint_type, model_dir) load_fine_tune_checkpoint( detection_model, train_config.fine_tune_checkpoint, fine_tune_checkpoint_type, fine_tune_checkpoint_version, train_config. run_fine_tune_checkpoint_dummy_computation, train_input, unpad_groundtruth_tensors) ckpt = tf.compat.v2.train.Checkpoint(step=global_step, model=detection_model, optimizer=optimizer) manager_dir = get_filepath(strategy, model_dir) if not strategy.extended.should_checkpoint: checkpoint_max_to_keep = 1 manager = tf.compat.v2.train.CheckpointManager( ckpt, manager_dir, max_to_keep=checkpoint_max_to_keep) # We use the following instead of manager.latest_checkpoint because # manager_dir does not point to the model directory when we are running # in a worker. latest_checkpoint = tf.train.latest_checkpoint(model_dir) ckpt.restore(latest_checkpoint) def train_step_fn(features, labels): """Single train step.""" loss = eager_train_step( detection_model, features, labels, unpad_groundtruth_tensors, optimizer, learning_rate=learning_rate_fn(), add_regularization_loss=add_regularization_loss, clip_gradients_value=clip_gradients_value, global_step=global_step, num_replicas=strategy.num_replicas_in_sync) def _sample_and_train(strategy, train_step_fn, data_iterator): features, labels = data_iterator.next() if hasattr(tf.distribute.Strategy, 'run'): per_replica_losses = strategy.run(train_step_fn, args=(features, labels)) else: per_replica_losses = strategy.experimental_run_v2( train_step_fn, args=(features, labels)) # TODO(anjalisridhar): explore if it is safe to remove the # num_replicas scaling of the loss and switch this to a ReduceOp.Mean return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) @tf.function def _dist_train_step(data_iterator): """A distributed train step.""" if num_steps_per_iteration > 1: for _ in tf.range(num_steps_per_iteration - 1): # Following suggestion on yaqs/5402607292645376 with tf.name_scope(''): _sample_and_train(strategy, train_step_fn, data_iterator) return _sample_and_train(strategy, train_step_fn, data_iterator) train_input_iter = iter(train_input) if int(global_step.value()) == 0: manager.save() checkpointed_step = int(global_step.value()) logged_step = global_step.value() last_step_time = time.time() for _ in range(global_step.value(), train_steps, num_steps_per_iteration): loss = _dist_train_step(train_input_iter) time_taken = time.time() - last_step_time last_step_time = time.time() steps_per_sec = num_steps_per_iteration * 1.0 / time_taken tf.compat.v2.summary.scalar('steps_per_sec', steps_per_sec, step=global_step) steps_per_sec_list.append(steps_per_sec) if global_step.value() - logged_step >= 100: tf.logging.info( 'Step {} per-step time {:.3f}s loss={:.3f}'.format( global_step.value(), time_taken / num_steps_per_iteration, loss)) logged_step = global_step.value() if ((int(global_step.value()) - checkpointed_step) >= checkpoint_every_n): manager.save() checkpointed_step = int(global_step.value())
def train_loop(config_path: str, model_dir: str, config_override: Optional[ pipeline_pb2.TrainEvalPipelineConfig] = None, train_steps: Optional[int] = None, use_tpu: bool = False, save_final_config: bool = False, log_every_n: int = 100, ckpt_every_n: int = 1000, ckpt_max_to_keep: int = 7, record_summaries: bool = True, **kwargs) -> None: """Trains a model using eager + functions. This method: 1. Processes the pipeline configs 2. (Optionally) saves the as-run config 3. Builds the model & optimizer 4. Gets the training input data 5. Loads a fine-tuning detection or classification checkpoint if requested 6. Loops over the train data, executing distributed training steps inside tf.functions. 7. Checkpoints the model every `ckpt_every_n` training steps. 8. Logs the training metrics as TensorBoard summaries. Args: config_path: A path to a pipeline config file. model_dir: The directory to save checkpoints and summaries to. config_override: A pipeline_pb2.TrainEvalPipelineConfig text proto to override the config from `config_path`. train_steps: Number of training steps. If None, training steps from `TrainConfig` proto will be adopted. use_tpu: Boolean, whether training and evaluation should run on TPU. save_final_config: Whether to save final config (obtained after applying overrides) to `model_dir`. log_every_n: Log total loss every n training steps. ckpt_every_n: Checkpoint every n training steps. ckpt_max_to_keep: int, the number of most recent checkpoints to keep in the model directory. record_summaries: Boolean, whether or not to record summaries. **kwargs: Additional keyword arguments for configuration override. """ # parse config configs = config_util.get_configs_from_pipeline_file( config_path, config_override=config_override) kwargs.update({ 'train_steps': train_steps, 'use_bfloat16': configs['train_config'].use_bfloat16 and use_tpu, }) configs = config_util.merge_external_params_with_configs( configs, None, kwargs_dict=kwargs) model_config = configs['model'] train_config = configs['train_config'] train_input_config = configs['train_input_config'] unpad_gt_tensors = train_config.unpad_groundtruth_tensors add_regularization_loss = train_config.add_regularization_loss clip_gradient_norm = None if train_config.gradient_clipping_by_norm > 0: clip_gradient_norm = train_config.gradient_clipping_by_norm if kwargs['use_bfloat16']: tf.keras.mixed_precision.experimental.set_policy('mixed_bfloat16') if train_config.load_all_detection_checkpoint_vars: raise ValueError( 'train_pb2.load_all_detection_checkpoint_vars unsupported in TF2') # base checkpoint to fine-tune from config_util.update_fine_tune_checkpoint_type(train_config) base_ckpt = train_config.fine_tune_checkpoint base_ckpt_type = train_config.fine_tune_checkpoint_type base_ckpt_ver = train_config.fine_tune_checkpoint_version # write the as-run pipeline config to disk if save_final_config: pipeline_config_final = config_util.create_pipeline_proto_from_configs( configs) config_util.save_pipeline_config(pipeline_config_final, model_dir) # build model, input, optimizer strategy = tf.distribute.get_strategy() with strategy.scope(): # build model model = model_builder.build(model_config=model_config, is_training=True) # build input def train_dataset_fn( input_context: tf.distribute.InputContext) -> tf.data.Dataset: """Callable to create train input.""" train_input = inputs.train_input( train_config=train_config, train_input_config=train_input_config, model_config=model_config, model=model, input_context=input_context, ) train_input = train_input.repeat() return train_input train_input = strategy.experimental_distribute_datasets_from_function( train_dataset_fn) # build optimizer global_step = tf.Variable( 0, trainable=False, dtype=tf.int64, name='global_step', aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, ) optimizer, (learning_rate, ) = optimizer_builder.build( train_config.optimizer, global_step=global_step) if callable(learning_rate): learning_rate_fn = learning_rate else: learning_rate_fn = lambda: learning_rate # prepare for training # get appropriate filepath (temporary or not) based on whether the worker is the chief summary_log_path = get_filepath(strategy, os.path.join(model_dir, 'train')) if record_summaries: summary_writer = tf.summary.create_file_writer(summary_log_path) else: summary_writer = tf.summary.create_noop_writer() if use_tpu: num_steps_per_iteration = 100 else: num_steps_per_iteration = 1 with summary_writer.as_default(): with strategy.scope(): with tf.summary.record_if( lambda: global_step % num_steps_per_iteration == 0): # prepare checkpoint manager # (do not use manager.latest_checkpoint as manager_dir is not model_dir while running in worker) ckpt = tf.train.Checkpoint(model=model, step=global_step, optimizer=optimizer) ckpt_max_to_keep = ckpt_max_to_keep if strategy.extended.should_checkpoint else 1 manager_dir = get_filepath(strategy, model_dir) manager = tf.train.CheckpointManager( ckpt, manager_dir, max_to_keep=ckpt_max_to_keep) latest_ckpt = tf.train.latest_checkpoint(model_dir) if latest_ckpt: # load latest checkpoint being trained ckpt.restore(latest_ckpt).expect_partial() elif base_ckpt: # load a pre-trained checkpoint load_base_ckpt(model, base_ckpt, base_ckpt_type, base_ckpt_ver, train_input, unpad_gt_tensors) # get trainable variables train_vars = get_train_vars(model, train_config) # define training step def train_step_fn(features: Dict, labels: Dict): """Single train step.""" loss = eager_train_step( model, train_vars, features, labels, unpad_gt_tensors, optimizer, learning_rate=learning_rate_fn(), add_regularization_loss=add_regularization_loss, clip_gradient_norm=clip_gradient_norm, global_step=global_step, num_replicas=strategy.num_replicas_in_sync, ) global_step.assign_add(1) return loss def _sample_and_train(strategy, train_step_fn, data_iterator): features, labels = data_iterator.next() per_replica_losses = strategy.run(train_step_fn, args=(features, labels)) return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) @tf.function def _dist_train_step(data_iterator): """A distributed train step.""" if num_steps_per_iteration > 1: for _ in tf.range(num_steps_per_iteration - 1): with tf.name_scope(''): _sample_and_train(strategy, train_step_fn, data_iterator) return _sample_and_train(strategy, train_step_fn, data_iterator) train_input_iter = iter(train_input) # save initialized version of checkpoint if int(global_step.value()) == 0: manager.save() ckpt_step = int(global_step.value()) logged_step = global_step.value() # proceed with training last_step_time = time.time() for _ in range(global_step.value(), train_config.num_steps, num_steps_per_iteration): # execute a step (forward pass + backward pass) loss = _dist_train_step(train_input_iter) # log time curr_step = global_step.value() time_taken = time.time() - last_step_time last_step_time = time.time() tf.summary.scalar( 'steps_per_sec', num_steps_per_iteration * 1.0 / time_taken, step=global_step, ) # log loss if curr_step - logged_step >= log_every_n: step_time = time_taken / num_steps_per_iteration step_msg = 'Step {} per-step time {:.3f}s loss={:.3f}'.format( curr_step, step_time, loss) v1.logging.info(step_msg) logged_step = curr_step # save checkpoint regularly if (curr_step - ckpt_step) >= ckpt_every_n: manager.save() ckpt_step = curr_step # remove checkpoint directories of non-chief workers that MultiWorkerMirroredStrategy forces us to save during sync # distributed training. clean_temporary_directories(strategy, manager_dir) clean_temporary_directories(strategy, summary_log_path)
def train_loop(pipeline_config_path, model_dir, config_override=None, train_steps=None, use_tpu=False, save_final_config=False, checkpoint_every_n=1000, checkpoint_max_to_keep=7, **kwargs): """Trains a model using eager + functions. This method: 1. Processes the pipeline configs 2. (Optionally) saves the as-run config 3. Builds the model & optimizer 4. Gets the training input data 5. Loads a fine-tuning detection or classification checkpoint if requested 6. Loops over the train data, executing distributed training steps inside tf.functions. 7. Checkpoints the model every `checkpoint_every_n` training steps. 8. Logs the training metrics as TensorBoard summaries. Args: pipeline_config_path: A path to a pipeline config file. model_dir: The directory to save checkpoints and summaries to. config_override: A pipeline_pb2.TrainEvalPipelineConfig text proto to override the config from `pipeline_config_path`. train_steps: Number of training steps. If None, the number of training steps is set from the `TrainConfig` proto. use_tpu: Boolean, whether training and evaluation should run on TPU. save_final_config: Whether to save final config (obtained after applying overrides) to `model_dir`. checkpoint_every_n: Checkpoint every n training steps. checkpoint_max_to_keep: int, the number of most recent checkpoints to keep in the model directory. **kwargs: Additional keyword arguments for configuration override. """ ## Parse the configs get_configs_from_pipeline_file = MODEL_BUILD_UTIL_MAP[ 'get_configs_from_pipeline_file'] merge_external_params_with_configs = MODEL_BUILD_UTIL_MAP[ 'merge_external_params_with_configs'] create_pipeline_proto_from_configs = MODEL_BUILD_UTIL_MAP[ 'create_pipeline_proto_from_configs'] configs = get_configs_from_pipeline_file(pipeline_config_path, config_override=config_override) kwargs.update({ 'train_steps': train_steps, 'use_bfloat16': configs['train_config'].use_bfloat16 and use_tpu }) configs = merge_external_params_with_configs(configs, None, kwargs_dict=kwargs) model_config = configs['model'] train_config = configs['train_config'] train_input_config = configs['train_input_config'] unpad_groundtruth_tensors = train_config.unpad_groundtruth_tensors add_regularization_loss = train_config.add_regularization_loss clip_gradients_value = None if train_config.gradient_clipping_by_norm > 0: clip_gradients_value = train_config.gradient_clipping_by_norm # update train_steps from config but only when non-zero value is provided if train_steps is None and train_config.num_steps != 0: train_steps = train_config.num_steps if kwargs['use_bfloat16']: tf.compat.v2.keras.mixed_precision.experimental.set_policy( 'mixed_bfloat16') if train_config.load_all_detection_checkpoint_vars: raise ValueError('train_pb2.load_all_detection_checkpoint_vars ' 'unsupported in TF2') config_util.update_fine_tune_checkpoint_type(train_config) fine_tune_checkpoint_type = train_config.fine_tune_checkpoint_type fine_tune_checkpoint_version = train_config.fine_tune_checkpoint_version # Write the as-run pipeline config to disk. if save_final_config: pipeline_config_final = create_pipeline_proto_from_configs(configs) config_util.save_pipeline_config(pipeline_config_final, model_dir) # Build the model, optimizer, and training input strategy = tf.compat.v2.distribute.get_strategy() with strategy.scope(): detection_model = model_builder.build(model_config=model_config, is_training=True) def train_dataset_fn(input_context): """Callable to create train input.""" # Create the inputs. train_input = inputs.train_input( train_config=train_config, train_input_config=train_input_config, model_config=model_config, model=detection_model, input_context=input_context) train_input = train_input.repeat() return train_input train_input = strategy.experimental_distribute_datasets_from_function( train_dataset_fn) global_step = tf.Variable( 0, trainable=False, dtype=tf.compat.v2.dtypes.int64, name='global_step', aggregation=tf.compat.v2.VariableAggregation.ONLY_FIRST_REPLICA) optimizer, (learning_rate, ) = optimizer_builder.build( train_config.optimizer, global_step=global_step) if callable(learning_rate): learning_rate_fn = learning_rate else: learning_rate_fn = lambda: learning_rate ## Train the model # Get the appropriate filepath (temporary or not) based on whether the worker # is the chief. summary_writer_filepath = get_filepath(strategy, os.path.join(model_dir, 'train')) summary_writer = tf.compat.v2.summary.create_file_writer( summary_writer_filepath) if use_tpu: num_steps_per_iteration = 100 else: # TODO(b/135933080) Explore setting to 100 when GPU performance issues # are fixed. num_steps_per_iteration = 1 with summary_writer.as_default(): with strategy.scope(): with tf.compat.v2.summary.record_if( lambda: global_step % num_steps_per_iteration == 0): # Load a fine-tuning checkpoint. if train_config.fine_tune_checkpoint: load_fine_tune_checkpoint( detection_model, train_config.fine_tune_checkpoint, fine_tune_checkpoint_type, fine_tune_checkpoint_version, train_input, unpad_groundtruth_tensors) ckpt = tf.compat.v2.train.Checkpoint(step=global_step, model=detection_model, optimizer=optimizer) manager_dir = get_filepath(strategy, model_dir) if not strategy.extended.should_checkpoint: checkpoint_max_to_keep = 1 manager = tf.compat.v2.train.CheckpointManager( ckpt, manager_dir, max_to_keep=checkpoint_max_to_keep) # We use the following instead of manager.latest_checkpoint because # manager_dir does not point to the model directory when we are running # in a worker. latest_checkpoint = tf.train.latest_checkpoint(model_dir) ckpt.restore(latest_checkpoint) def train_step_fn(features, labels): """Single train step.""" loss = eager_train_step( detection_model, features, labels, unpad_groundtruth_tensors, optimizer, learning_rate=learning_rate_fn(), add_regularization_loss=add_regularization_loss, clip_gradients_value=clip_gradients_value, global_step=global_step, num_replicas=strategy.num_replicas_in_sync) global_step.assign_add(1) return loss def _sample_and_train(strategy, train_step_fn, data_iterator): features, labels = data_iterator.next() per_replica_losses = strategy.run(train_step_fn, args=(features, labels)) # TODO(anjalisridhar): explore if it is safe to remove the ## num_replicas scaling of the loss and switch this to a ReduceOp.Mean return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) @tf.function def _dist_train_step(data_iterator): """A distributed train step.""" if num_steps_per_iteration > 1: for _ in tf.range(num_steps_per_iteration - 1): _sample_and_train(strategy, train_step_fn, data_iterator) return _sample_and_train(strategy, train_step_fn, data_iterator) train_input_iter = iter(train_input) if int(global_step.value()) == 0: manager.save() checkpointed_step = int(global_step.value()) logged_step = global_step.value() last_step_time = time.time() for _ in range(global_step.value(), train_steps, num_steps_per_iteration): loss = _dist_train_step(train_input_iter) time_taken = time.time() - last_step_time last_step_time = time.time() tf.compat.v2.summary.scalar('steps_per_sec', num_steps_per_iteration * 1.0 / time_taken, step=global_step) if global_step.value() - logged_step >= 100: tf.logging.info( 'Step {} per-step time {:.3f}s loss={:.3f}'.format( global_step.value(), time_taken / num_steps_per_iteration, loss)) logged_step = global_step.value() if ((int(global_step.value()) - checkpointed_step) >= checkpoint_every_n): manager.save() checkpointed_step = int(global_step.value()) # Remove the checkpoint directories of the non-chief workers that # MultiWorkerMirroredStrategy forces us to save during sync distributed # training. clean_temporary_directories(strategy, manager_dir) clean_temporary_directories(strategy, summary_writer_filepath)
def train_loop(pipeline_config_path, model_dir, config_override=None, train_steps=None, use_tpu=False, save_final_config=False, checkpoint_every_n=1000, checkpoint_max_to_keep=7, record_summaries=True, performance_summary_exporter=None, num_steps_per_iteration=NUM_STEPS_PER_ITERATION, **kwargs): get_configs_from_pipeline_file = MODEL_BUILD_UTIL_MAP[ "get_configs_from_pipeline_file"] merge_external_params_with_configs = MODEL_BUILD_UTIL_MAP[ "merge_external_params_with_configs"] create_pipeline_proto_from_configs = MODEL_BUILD_UTIL_MAP[ "create_pipeline_proto_from_configs"] steps_per_sec_list = [] configs = get_configs_from_pipeline_file(pipeline_config_path, config_override=config_override) kwargs.update({ "train_steps": train_steps, "use_bfloat16": configs["train_config"].use_bfloat16 and use_tpu, }) configs = merge_external_params_with_configs(configs, None, kwargs_dict=kwargs) model_config = configs["model"] train_config = configs["train_config"] train_input_config = configs["train_input_config"] unpad_groundtruth_tensors = train_config.unpad_groundtruth_tensors add_regularization_loss = train_config.add_regularization_loss clip_gradients_value = None if train_config.gradient_clipping_by_norm > 0: clip_gradients_value = train_config.gradient_clipping_by_norm if train_steps is None and train_config.num_steps != 0: train_steps = train_config.num_steps if kwargs["use_bfloat16"]: tf.compat.v2.keras.mixed_precision.set_global_policy("mixed_bfloat16") if train_config.load_all_detection_checkpoint_vars: raise ValueError( "train_pb2.load_all_detection_checkpoint_vars unsupported in TF2") config_util.update_fine_tune_checkpoint_type(train_config) fine_tune_checkpoint_type = train_config.fine_tune_checkpoint_type fine_tune_checkpoint_version = train_config.fine_tune_checkpoint_version strategy = tf.compat.v2.distribute.get_strategy() with strategy.scope(): detection_model = MODEL_BUILD_UTIL_MAP["detection_model_fn_base"]( model_config=model_config, is_training=True) def train_dataset_fn(input_context): train_input = inputs.train_input( train_config=train_config, train_input_config=train_input_config, model_config=model_config, model=detection_model, input_context=input_context, ) train_input = train_input.repeat() return train_input train_input = strategy.experimental_distribute_datasets_from_function( train_dataset_fn) global_step = tf.Variable( 0, trainable=False, dtype=tf.compat.v2.dtypes.int64, name="global_step", aggregation=tf.compat.v2.VariableAggregation.ONLY_FIRST_REPLICA, ) optimizer, (learning_rate, ) = optimizer_builder.build( train_config.optimizer, global_step=global_step) if train_config.optimizer.use_moving_average: _ensure_model_is_built(detection_model, train_input, unpad_groundtruth_tensors) optimizer.shadow_copy(detection_model) if callable(learning_rate): learning_rate_fn = learning_rate else: learning_rate_fn = lambda: learning_rate summary_writer_filepath = get_filepath(strategy, os.path.join(model_dir, "train")) # summary_writer = tf.compat.v2.summary.create_file_writer(summary_writer_filepath) summary_writer = tf2.summary.create_noop_writer() with summary_writer.as_default(): with strategy.scope(): with tf2.summary.record_if( lambda: global_step % num_steps_per_iteration == 0): if train_config.fine_tune_checkpoint: load_fine_tune_checkpoint( detection_model, train_config.fine_tune_checkpoint, fine_tune_checkpoint_type, fine_tune_checkpoint_version, train_config. run_fine_tune_checkpoint_dummy_computation, train_input, unpad_groundtruth_tensors, ) # ckpt = tf.compat.v2.train.Checkpoint(step=global_step, model=detection_model, optimizer=optimizer) # manager_dir = get_filepath(strategy, model_dir) # manager = tf.compat.v2.train.CheckpointManager(ckpt, manager_dir, max_to_keep=1) # latest_checkpoint = tf.train.latest_checkpoint(model_dir) # ckpt.restore(latest_checkpoint) def train_step_fn(features, labels): loss = eager_train_step( detection_model, features, labels, unpad_groundtruth_tensors, optimizer, learning_rate=learning_rate_fn(), add_regularization_loss=add_regularization_loss, clip_gradients_value=clip_gradients_value, global_step=global_step, num_replicas=strategy.num_replicas_in_sync, ) global_step.assign_add(1) return loss def _sample_and_train(strategy, train_step_fn, data_iterator): features, labels = data_iterator.next() if hasattr(tf.distribute.Strategy, "run"): per_replica_losses = strategy.run(train_step_fn, args=(features, labels)) else: per_replica_losses = strategy.experimental_run_v2( train_step_fn, args=(features, labels)) return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) @tf.function def _dist_train_step(data_iterator): if num_steps_per_iteration > 1: for _ in tf.range(num_steps_per_iteration - 1): with tf.name_scope(""): _sample_and_train(strategy, train_step_fn, data_iterator) return _sample_and_train(strategy, train_step_fn, data_iterator) train_input_iter = iter(train_input) checkpointed_step = int(global_step.value()) logged_step = global_step.value() last_step_time = time.time() for _ in range(global_step.value(), train_steps, num_steps_per_iteration): loss = _dist_train_step(train_input_iter) time_taken = time.time() - last_step_time last_step_time = time.time() steps_per_sec = num_steps_per_iteration * 1.0 / time_taken tf.compat.v2.summary.scalar("steps_per_sec", steps_per_sec, step=global_step) steps_per_sec_list.append(steps_per_sec) if global_step.value() - logged_step >= 100: tf.logging.info( "Step {} per-step time {:.3f}s loss={:.3f}".format( global_step.value(), time_taken / num_steps_per_iteration, loss)) logged_step = global_step.value() if (int(global_step.value()) - checkpointed_step) >= checkpoint_every_n: # manager.save() checkpointed_step = int(global_step.value()) # clean_temporary_directories(strategy, manager_dir) clean_temporary_directories(strategy, summary_writer_filepath)