Ejemplo n.º 1
0
def train_loop(
    pipeline_config_path,
    model_dir,
    val_checkpoint_dir,
    config_override=None,
    train_steps=None,
    use_tpu=False,
    save_final_config=False,
    checkpoint_every_n=1000,
    checkpoint_max_to_keep=7,
    record_summaries=True,
    performance_summary_exporter=None,
    **kwargs):
  """Trains a model using eager + functions.

  This method:
    1. Processes the pipeline configs
    2. (Optionally) saves the as-run config
    3. Builds the model & optimizer
    4. Gets the training input data
    5. Loads a fine-tuning detection or classification checkpoint if requested
    6. Loops over the train data, executing distributed training steps inside
       tf.functions.
    7. Checkpoints the model every `checkpoint_every_n` training steps.
    8. Logs the training metrics as TensorBoard summaries.

  Args:
    pipeline_config_path: A path to a pipeline config file.
    model_dir:
      The directory to save checkpoints and summaries to.
    val_checkpoint_dir:
      The directory to save validation checkpoint.
    config_override: A pipeline_pb2.TrainEvalPipelineConfig text proto to
      override the config from `pipeline_config_path`.
    train_steps: Number of training steps. If None, the number of training steps
      is set from the `TrainConfig` proto.
    use_tpu: Boolean, whether training and evaluation should run on TPU.
    save_final_config: Whether to save final config (obtained after applying
      overrides) to `model_dir`.
    checkpoint_every_n:
      Checkpoint every n training steps.
    checkpoint_max_to_keep:
      int, the number of most recent checkpoints to keep in the model directory.
    record_summaries: Boolean, whether or not to record summaries.
    performance_summary_exporter: function for exporting performance metrics.
    **kwargs: Additional keyword arguments for configuration override.
  """

  print('START train looop function ========================')

  ## Parse the configs
  get_configs_from_pipeline_file = MODEL_BUILD_UTIL_MAP[
      'get_configs_from_pipeline_file']
  merge_external_params_with_configs = MODEL_BUILD_UTIL_MAP[
      'merge_external_params_with_configs']
  create_pipeline_proto_from_configs = MODEL_BUILD_UTIL_MAP[
      'create_pipeline_proto_from_configs']
  steps_per_sec_list = []

  configs = get_configs_from_pipeline_file(
      pipeline_config_path, config_override=config_override)
  kwargs.update({
      'train_steps': train_steps,
      'use_bfloat16': configs['train_config'].use_bfloat16 and use_tpu
  })
  configs = merge_external_params_with_configs(
      configs, None, kwargs_dict=kwargs)
  model_config = configs['model']
  train_config = configs['train_config']
  train_input_config = configs['train_input_config']

  unpad_groundtruth_tensors = train_config.unpad_groundtruth_tensors
  add_regularization_loss = train_config.add_regularization_loss
  clip_gradients_value = None
  if train_config.gradient_clipping_by_norm > 0:
    clip_gradients_value = train_config.gradient_clipping_by_norm

  # update train_steps from config but only when non-zero value is provided
  if train_steps is None and train_config.num_steps != 0:
    train_steps = train_config.num_steps

  if kwargs['use_bfloat16']:
    tf.compat.v2.keras.mixed_precision.experimental.set_policy('mixed_bfloat16')

  if train_config.load_all_detection_checkpoint_vars:
    raise ValueError('train_pb2.load_all_detection_checkpoint_vars '
                     'unsupported in TF2')

  config_util.update_fine_tune_checkpoint_type(train_config)
  fine_tune_checkpoint_type = train_config.fine_tune_checkpoint_type
  fine_tune_checkpoint_version = train_config.fine_tune_checkpoint_version

  # Write the as-run pipeline config to disk.
  if save_final_config:
    tf.logging.info('Saving pipeline config file to directory {}'.format(
        model_dir))
    pipeline_config_final = create_pipeline_proto_from_configs(configs)
    config_util.save_pipeline_config(pipeline_config_final, model_dir)

  # Build the model, optimizer, and training input
  strategy = tf.compat.v2.distribute.get_strategy()
  with strategy.scope():
    detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base'](
        model_config=model_config, is_training=True)

    def train_dataset_fn(input_context):
      """Callable to create train input."""
      # Create the inputs.
      train_input = inputs.train_input(
          train_config=train_config,
          train_input_config=train_input_config,
          model_config=model_config,
          model=detection_model,
          input_context=input_context)
      train_input = train_input.repeat()
      return train_input

    train_input = strategy.experimental_distribute_datasets_from_function(
        train_dataset_fn)


    global_step = tf.Variable(
        0, trainable=False, dtype=tf.compat.v2.dtypes.int64, name='global_step',
        aggregation=tf.compat.v2.VariableAggregation.ONLY_FIRST_REPLICA)
    optimizer, (learning_rate,) = optimizer_builder.build(
        train_config.optimizer, global_step=global_step)

    # We run the detection_model on dummy inputs in order to ensure that the
    # model and all its variables have been properly constructed. Specifically,
    # this is currently necessary prior to (potentially) creating shadow copies
    # of the model variables for the EMA optimizer.
    if train_config.optimizer.use_moving_average:
      _ensure_model_is_built(detection_model, train_input,
                             unpad_groundtruth_tensors)
      optimizer.shadow_copy(detection_model)

    if callable(learning_rate):
      learning_rate_fn = learning_rate
    else:
      learning_rate_fn = lambda: learning_rate

  ## Train the model
  # Get the appropriate filepath (temporary or not) based on whether the worker
  # is the chief.
  summary_writer_filepath = get_filepath(strategy,
                                         os.path.join(model_dir, 'train'))
  if record_summaries:
    summary_writer = tf.compat.v2.summary.create_file_writer(
        summary_writer_filepath)
  else:
    summary_writer = tf2.summary.create_noop_writer()

  if use_tpu:
    num_steps_per_iteration = 100
  else:
    # TODO(b/135933080) Explore setting to 100 when GPU performance issues
    # are fixed.
    num_steps_per_iteration = 1

  with summary_writer.as_default():
    with strategy.scope():
      with tf.compat.v2.summary.record_if(
          lambda: global_step % num_steps_per_iteration == 0):
        # Load a fine-tuning checkpoint.
        if train_config.fine_tune_checkpoint:
          load_fine_tune_checkpoint(
              detection_model, train_config.fine_tune_checkpoint,
              fine_tune_checkpoint_type, fine_tune_checkpoint_version,
              train_config.run_fine_tune_checkpoint_dummy_computation,
              train_input, unpad_groundtruth_tensors)

        ckpt = tf.compat.v2.train.Checkpoint(
            step=global_step, model=detection_model, optimizer=optimizer)
        val_ckpt = tf.compat.v2.train.Checkpoint(
            step=global_step, model=detection_model, optimizer=optimizer)

        manager_dir = get_filepath(strategy, model_dir)
        val_manager_dir = get_filepath(strategy, val_checkpoint_dir)



        # if not strategy.extended.should_checkpoint:
            # checkpoint_max_to_keep = 1
            
        checkpoint_max_to_keep = 1
        manager = tf.compat.v2.train.CheckpointManager(
            ckpt, manager_dir, max_to_keep=checkpoint_max_to_keep)
        val_manager = tf.compat.v2.train.CheckpointManager(
            val_ckpt, val_manager_dir, max_to_keep=checkpoint_max_to_keep)

        model_checkpoint_callback = tfc.ModelCheckpoint(val_manager)
        early_stopping_callback = tfc.EarlyStopping(min_delta=0.0001, patience=5, mode='min')
        train_logger_callback = tfc.TrainLogger(model_dir, 'logs.txt')
        cancellation_point = tfc.CancellationPoint()
        

        # We use the following instead of manager.latest_checkpoint because
        # manager_dir does not point to the model directory when we are running
        # in a worker.
        latest_checkpoint = tf.train.latest_checkpoint(model_dir)
        ckpt.restore(latest_checkpoint)
        val_ckpt.restore(latest_checkpoint)

        def train_step_fn(features, labels):
          """Single train step."""
          loss = eager_train_step(
              detection_model,
              features,
              labels,
              unpad_groundtruth_tensors,
              optimizer,
              learning_rate=learning_rate_fn(),
              add_regularization_loss=add_regularization_loss,
              clip_gradients_value=clip_gradients_value,
              global_step=global_step,
              num_replicas=strategy.num_replicas_in_sync)
          global_step.assign_add(1)
          return loss

        def _sample_and_train(strategy, train_step_fn, data_iterator):
          features, labels = data_iterator.next()
          if hasattr(tf.distribute.Strategy, 'run'):
            per_replica_losses = strategy.run(
                train_step_fn, args=(features, labels))
          else:
            per_replica_losses = strategy.experimental_run_v2(
                train_step_fn, args=(features, labels))
          # TODO(anjalisridhar): explore if it is safe to remove the
          ## num_replicas scaling of the loss and switch this to a ReduceOp.Mean
          return strategy.reduce(tf.distribute.ReduceOp.SUM,
                                 per_replica_losses, axis=None)

        @tf.function
        def _dist_train_step(data_iterator):
          """A distributed train step."""

          if num_steps_per_iteration > 1:
            for _ in tf.range(num_steps_per_iteration - 1):
              # Following suggestion on yaqs/5402607292645376
              with tf.name_scope(''):
                _sample_and_train(strategy, train_step_fn, data_iterator)

          return _sample_and_train(strategy, train_step_fn, data_iterator)

        train_input_iter = iter(train_input)

        if int(global_step.value()) == 0:
          manager.save()

        checkpointed_step = int(global_step.value())
        logged_step = global_step.value()

        # num_epochs = (train_steps - global_step.value()) // num_steps_per_iteration

        last_step_time = time.time()
        for epoch, _ in enumerate(range(global_step.value(), train_steps,
                       num_steps_per_iteration)):

          loss = _dist_train_step(train_input_iter)

          time_taken = time.time() - last_step_time
          last_step_time = time.time()
          steps_per_sec = num_steps_per_iteration * 1.0 / time_taken

          tf.compat.v2.summary.scalar(
              'steps_per_sec', steps_per_sec, step=global_step)

          steps_per_sec_list.append(steps_per_sec)

          if global_step.value() - logged_step >= 100:
            tf.logging.info(
                'Step {} per-step time {:.3f}s loss={:.3f}'.format(
                    global_step.value(), time_taken / num_steps_per_iteration,
                    loss))

            manager.save()
            checkpointed_step = int(global_step.value())

            log_metrics = eval_continuously(pipeline_config_path, model_dir=model_dir, checkpoint_dir=model_dir, timeout=20)
            log_metrics['train_total_loss'] = loss

            model_checkpoint_callback.step(epoch, log_metrics['Loss/total_loss'])
            stop_training = early_stopping_callback.step(epoch, log_metrics['Loss/total_loss'])
            train_logger_callback.log(log_metrics)

            if stop_training or cancellation_point.check():
                break
            
            print(log_metrics)
            logged_step = global_step.value()

    

  # Remove the checkpoint directories of the non-chief workers that
  # MultiWorkerMirroredStrategy forces us to save during sync distributed
  # training.
  clean_temporary_directories(strategy, manager_dir)
  clean_temporary_directories(strategy, summary_writer_filepath)
  # TODO(pkanwar): add accuracy metrics.
  if performance_summary_exporter is not None:
    metrics = {
        'steps_per_sec': np.mean(steps_per_sec_list),
        'steps_per_sec_p50': np.median(steps_per_sec_list),
        'steps_per_sec_max': max(steps_per_sec_list),
        'last_batch_loss': float(loss)
    }
    mixed_precision = 'bf16' if kwargs['use_bfloat16'] else 'fp32'
    performance_summary_exporter(metrics, mixed_precision)
def train_loop(pipeline_config_path,
               model_dir,
               config_override=None,
               train_steps=None,
               use_tpu=False,
               save_final_config=False,
               checkpoint_every_n=1000,
               checkpoint_max_to_keep=7,
               record_summaries=True,
               performance_summary_exporter=None,
               num_steps_per_iteration=NUM_STEPS_PER_ITERATION,
               **kwargs):
    # """Trains a model using eager + functions.

    config_override = None
    configs = config_util.get_configs_from_pipeline_file(
        pipeline_config_path, config_override=config_override)
    kwargs.update({
        'train_steps':
        train_steps,
        'use_bfloat16':
        configs['train_config'].use_bfloat16 and use_tpu
    })
    configs = config_util.merge_external_params_with_configs(
        configs, None, kwargs_dict=kwargs)
    model_config = configs['model']
    train_config = configs['train_config']
    train_input_config = configs['train_input_config']
    unpad_groundtruth_tensors = train_config.unpad_groundtruth_tensors  # False
    add_regularization_loss = train_config.add_regularization_loss  # True
    clip_gradients_value = None
    if train_config.gradient_clipping_by_norm > 0:  # Not run
        clip_gradients_value = train_config.gradient_clipping_by_norm

    # update train_steps from config but only when non-zero value is provided
    train_steps = num_train_steps
    if train_steps is None and train_config.num_steps != 0:
        train_steps = train_config.num_steps

    tf.compat.v2.keras.mixed_precision.set_global_policy('mixed_bfloat16')

    if train_config.load_all_detection_checkpoint_vars:
        raise ValueError('train_pb2.load_all_detection_checkpoint_vars '
                         'unsupported in TF2')

    config_util.update_fine_tune_checkpoint_type(train_config)
    fine_tune_checkpoint_type = train_config.fine_tune_checkpoint_type  # 'detection'
    fine_tune_checkpoint_version = train_config.fine_tune_checkpoint_version

    # Build the model, optimizer, and training input
    strategy = tf.compat.v2.distribute.get_strategy()
    from object_detection import inputs
    from object_detection.builders import optimizer_builder
    from object_detection.utils import variables_helper
    with strategy.scope():
        detection_model = model_builder.build(model_config=model_config,
                                              is_training=True)

        def train_dataset_fn(input_context):
            """Callable to create train input."""
            # Create the inputs.
            train_input = inputs.train_input(
                train_config=train_config,
                train_input_config=train_input_config,
                model_config=model_config,
                model=detection_model,
                input_context=input_context)
            train_input = train_input.repeat()
            return train_input

        train_input = strategy.experimental_distribute_datasets_from_function(
            train_dataset_fn)
        global_step = tf.Variable(
            0,
            trainable=False,
            dtype=tf.compat.v2.dtypes.int64,
            name='global_step',
            aggregation=tf.compat.v2.VariableAggregation.ONLY_FIRST_REPLICA)
        optimizer, (learning_rate, ) = optimizer_builder.build(
            train_config.optimizer, global_step=global_step)

        if callable(learning_rate):
            learning_rate_fn = learning_rate
        else:

            def learning_rate_fn():
                return learning_rate

    # Train the model
    # Get the appropriate filepath (temporary or not) based on whether the worker
    # is the chief.
    summary_writer_filepath = get_filepath(strategy,
                                           os.path.join(model_dir, 'train'))
    if record_summaries:
        summary_writer = tf.compat.v2.summary.create_file_writer(
            summary_writer_filepath)
    else:
        #summary_writer = tf2.summary.create_noop_writer()
        summary_writer = tf.summary.create_noop_writer()

    with summary_writer.as_default():
        with strategy.scope():
            with tf.compat.v2.summary.record_if(
                    lambda: global_step % num_steps_per_iteration == 0):
                # Load a fine-tuning checkpoint.
                if train_config.fine_tune_checkpoint:
                    variables_helper.ensure_checkpoint_supported(
                        train_config.fine_tune_checkpoint,
                        fine_tune_checkpoint_type, model_dir)
                    load_fine_tune_checkpoint(
                        detection_model, train_config.fine_tune_checkpoint,
                        fine_tune_checkpoint_type,
                        fine_tune_checkpoint_version, train_config.
                        run_fine_tune_checkpoint_dummy_computation,
                        train_input, unpad_groundtruth_tensors)

                ckpt = tf.compat.v2.train.Checkpoint(step=global_step,
                                                     model=detection_model,
                                                     optimizer=optimizer)

                manager_dir = get_filepath(strategy, model_dir)
                if not strategy.extended.should_checkpoint:
                    checkpoint_max_to_keep = 1
                manager = tf.compat.v2.train.CheckpointManager(
                    ckpt, manager_dir, max_to_keep=checkpoint_max_to_keep)

                # We use the following instead of manager.latest_checkpoint because
                # manager_dir does not point to the model directory when we are running
                # in a worker.
                latest_checkpoint = tf.train.latest_checkpoint(model_dir)
                ckpt.restore(latest_checkpoint)

                def train_step_fn(features, labels):
                    """Single train step."""
                    loss = eager_train_step(
                        detection_model,
                        features,
                        labels,
                        unpad_groundtruth_tensors,
                        optimizer,
                        learning_rate=learning_rate_fn(),
                        add_regularization_loss=add_regularization_loss,
                        clip_gradients_value=clip_gradients_value,
                        global_step=global_step,
                        num_replicas=strategy.num_replicas_in_sync)

                def _sample_and_train(strategy, train_step_fn, data_iterator):
                    features, labels = data_iterator.next()
                    if hasattr(tf.distribute.Strategy, 'run'):
                        per_replica_losses = strategy.run(train_step_fn,
                                                          args=(features,
                                                                labels))
                    else:
                        per_replica_losses = strategy.experimental_run_v2(
                            train_step_fn, args=(features, labels))
                    # TODO(anjalisridhar): explore if it is safe to remove the
                    # num_replicas scaling of the loss and switch this to a ReduceOp.Mean
                    return strategy.reduce(tf.distribute.ReduceOp.SUM,
                                           per_replica_losses,
                                           axis=None)

                @tf.function
                def _dist_train_step(data_iterator):
                    """A distributed train step."""

                    if num_steps_per_iteration > 1:
                        for _ in tf.range(num_steps_per_iteration - 1):
                            # Following suggestion on yaqs/5402607292645376
                            with tf.name_scope(''):
                                _sample_and_train(strategy, train_step_fn,
                                                  data_iterator)

                    return _sample_and_train(strategy, train_step_fn,
                                             data_iterator)

                train_input_iter = iter(train_input)

                if int(global_step.value()) == 0:
                    manager.save()

                checkpointed_step = int(global_step.value())
                logged_step = global_step.value()

                last_step_time = time.time()
                for _ in range(global_step.value(), train_steps,
                               num_steps_per_iteration):

                    loss = _dist_train_step(train_input_iter)

                    time_taken = time.time() - last_step_time
                    last_step_time = time.time()
                    steps_per_sec = num_steps_per_iteration * 1.0 / time_taken

                    tf.compat.v2.summary.scalar('steps_per_sec',
                                                steps_per_sec,
                                                step=global_step)

                    steps_per_sec_list.append(steps_per_sec)

                    if global_step.value() - logged_step >= 100:
                        tf.logging.info(
                            'Step {} per-step time {:.3f}s loss={:.3f}'.format(
                                global_step.value(),
                                time_taken / num_steps_per_iteration, loss))
                        logged_step = global_step.value()

                    if ((int(global_step.value()) - checkpointed_step) >=
                            checkpoint_every_n):
                        manager.save()
                        checkpointed_step = int(global_step.value())
Ejemplo n.º 3
0
def train_loop(config_path: str,
               model_dir: str,
               config_override: Optional[
                   pipeline_pb2.TrainEvalPipelineConfig] = None,
               train_steps: Optional[int] = None,
               use_tpu: bool = False,
               save_final_config: bool = False,
               log_every_n: int = 100,
               ckpt_every_n: int = 1000,
               ckpt_max_to_keep: int = 7,
               record_summaries: bool = True,
               **kwargs) -> None:
    """Trains a model using eager + functions.
    
    This method:
    1. Processes the pipeline configs
    2. (Optionally) saves the as-run config
    3. Builds the model & optimizer
    4. Gets the training input data
    5. Loads a fine-tuning detection or classification checkpoint if requested
    6. Loops over the train data, executing distributed training steps inside tf.functions.
    7. Checkpoints the model every `ckpt_every_n` training steps.
    8. Logs the training metrics as TensorBoard summaries.
    
    Args:
        config_path: A path to a pipeline config file.
        model_dir: The directory to save checkpoints and summaries to.
        config_override: A pipeline_pb2.TrainEvalPipelineConfig text proto to override the config from `config_path`.
        train_steps: Number of training steps. If None, training steps from `TrainConfig` proto will be adopted.
        use_tpu: Boolean, whether training and evaluation should run on TPU.
        save_final_config: Whether to save final config (obtained after applying overrides) to `model_dir`.
        log_every_n: Log total loss every n training steps.
        ckpt_every_n: Checkpoint every n training steps.
        ckpt_max_to_keep: int, the number of most recent checkpoints to keep in the model directory.
        record_summaries: Boolean, whether or not to record summaries.
        **kwargs: Additional keyword arguments for configuration override.
    """

    # parse config
    configs = config_util.get_configs_from_pipeline_file(
        config_path, config_override=config_override)
    kwargs.update({
        'train_steps':
        train_steps,
        'use_bfloat16':
        configs['train_config'].use_bfloat16 and use_tpu,
    })
    configs = config_util.merge_external_params_with_configs(
        configs, None, kwargs_dict=kwargs)

    model_config = configs['model']
    train_config = configs['train_config']
    train_input_config = configs['train_input_config']

    unpad_gt_tensors = train_config.unpad_groundtruth_tensors
    add_regularization_loss = train_config.add_regularization_loss
    clip_gradient_norm = None

    if train_config.gradient_clipping_by_norm > 0:
        clip_gradient_norm = train_config.gradient_clipping_by_norm

    if kwargs['use_bfloat16']:
        tf.keras.mixed_precision.experimental.set_policy('mixed_bfloat16')

    if train_config.load_all_detection_checkpoint_vars:
        raise ValueError(
            'train_pb2.load_all_detection_checkpoint_vars unsupported in TF2')

    # base checkpoint to fine-tune from
    config_util.update_fine_tune_checkpoint_type(train_config)
    base_ckpt = train_config.fine_tune_checkpoint
    base_ckpt_type = train_config.fine_tune_checkpoint_type
    base_ckpt_ver = train_config.fine_tune_checkpoint_version

    # write the as-run pipeline config to disk
    if save_final_config:
        pipeline_config_final = config_util.create_pipeline_proto_from_configs(
            configs)
        config_util.save_pipeline_config(pipeline_config_final, model_dir)

    # build model, input, optimizer
    strategy = tf.distribute.get_strategy()
    with strategy.scope():
        # build model
        model = model_builder.build(model_config=model_config,
                                    is_training=True)

        # build input
        def train_dataset_fn(
                input_context: tf.distribute.InputContext) -> tf.data.Dataset:
            """Callable to create train input."""
            train_input = inputs.train_input(
                train_config=train_config,
                train_input_config=train_input_config,
                model_config=model_config,
                model=model,
                input_context=input_context,
            )
            train_input = train_input.repeat()

            return train_input

        train_input = strategy.experimental_distribute_datasets_from_function(
            train_dataset_fn)

        # build optimizer
        global_step = tf.Variable(
            0,
            trainable=False,
            dtype=tf.int64,
            name='global_step',
            aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
        )
        optimizer, (learning_rate, ) = optimizer_builder.build(
            train_config.optimizer, global_step=global_step)

        if callable(learning_rate):
            learning_rate_fn = learning_rate
        else:
            learning_rate_fn = lambda: learning_rate

    # prepare for training

    # get appropriate filepath (temporary or not) based on whether the worker is the chief
    summary_log_path = get_filepath(strategy, os.path.join(model_dir, 'train'))

    if record_summaries:
        summary_writer = tf.summary.create_file_writer(summary_log_path)
    else:
        summary_writer = tf.summary.create_noop_writer()

    if use_tpu:
        num_steps_per_iteration = 100
    else:
        num_steps_per_iteration = 1

    with summary_writer.as_default():
        with strategy.scope():
            with tf.summary.record_if(
                    lambda: global_step % num_steps_per_iteration == 0):
                # prepare checkpoint manager
                # (do not use manager.latest_checkpoint as manager_dir is not model_dir while running in worker)
                ckpt = tf.train.Checkpoint(model=model,
                                           step=global_step,
                                           optimizer=optimizer)
                ckpt_max_to_keep = ckpt_max_to_keep if strategy.extended.should_checkpoint else 1
                manager_dir = get_filepath(strategy, model_dir)
                manager = tf.train.CheckpointManager(
                    ckpt, manager_dir, max_to_keep=ckpt_max_to_keep)
                latest_ckpt = tf.train.latest_checkpoint(model_dir)

                if latest_ckpt:
                    # load latest checkpoint being trained
                    ckpt.restore(latest_ckpt).expect_partial()
                elif base_ckpt:
                    # load a pre-trained checkpoint
                    load_base_ckpt(model, base_ckpt, base_ckpt_type,
                                   base_ckpt_ver, train_input,
                                   unpad_gt_tensors)

                # get trainable variables
                train_vars = get_train_vars(model, train_config)

                # define training step
                def train_step_fn(features: Dict, labels: Dict):
                    """Single train step."""
                    loss = eager_train_step(
                        model,
                        train_vars,
                        features,
                        labels,
                        unpad_gt_tensors,
                        optimizer,
                        learning_rate=learning_rate_fn(),
                        add_regularization_loss=add_regularization_loss,
                        clip_gradient_norm=clip_gradient_norm,
                        global_step=global_step,
                        num_replicas=strategy.num_replicas_in_sync,
                    )
                    global_step.assign_add(1)

                    return loss

                def _sample_and_train(strategy, train_step_fn, data_iterator):
                    features, labels = data_iterator.next()
                    per_replica_losses = strategy.run(train_step_fn,
                                                      args=(features, labels))

                    return strategy.reduce(tf.distribute.ReduceOp.SUM,
                                           per_replica_losses,
                                           axis=None)

                @tf.function
                def _dist_train_step(data_iterator):
                    """A distributed train step."""
                    if num_steps_per_iteration > 1:
                        for _ in tf.range(num_steps_per_iteration - 1):
                            with tf.name_scope(''):
                                _sample_and_train(strategy, train_step_fn,
                                                  data_iterator)

                    return _sample_and_train(strategy, train_step_fn,
                                             data_iterator)

                train_input_iter = iter(train_input)

                # save initialized version of checkpoint
                if int(global_step.value()) == 0:
                    manager.save()

                ckpt_step = int(global_step.value())
                logged_step = global_step.value()

                # proceed with training
                last_step_time = time.time()
                for _ in range(global_step.value(), train_config.num_steps,
                               num_steps_per_iteration):
                    # execute a step (forward pass + backward pass)
                    loss = _dist_train_step(train_input_iter)

                    # log time
                    curr_step = global_step.value()
                    time_taken = time.time() - last_step_time
                    last_step_time = time.time()

                    tf.summary.scalar(
                        'steps_per_sec',
                        num_steps_per_iteration * 1.0 / time_taken,
                        step=global_step,
                    )

                    # log loss
                    if curr_step - logged_step >= log_every_n:
                        step_time = time_taken / num_steps_per_iteration
                        step_msg = 'Step {} per-step time {:.3f}s loss={:.3f}'.format(
                            curr_step, step_time, loss)
                        v1.logging.info(step_msg)
                        logged_step = curr_step

                    # save checkpoint regularly
                    if (curr_step - ckpt_step) >= ckpt_every_n:
                        manager.save()
                        ckpt_step = curr_step

    # remove checkpoint directories of non-chief workers that MultiWorkerMirroredStrategy forces us to save during sync
    # distributed training.
    clean_temporary_directories(strategy, manager_dir)
    clean_temporary_directories(strategy, summary_log_path)
Ejemplo n.º 4
0
def train_loop(pipeline_config_path,
               model_dir,
               config_override=None,
               train_steps=None,
               use_tpu=False,
               save_final_config=False,
               checkpoint_every_n=1000,
               checkpoint_max_to_keep=7,
               **kwargs):
    """Trains a model using eager + functions.

  This method:
    1. Processes the pipeline configs
    2. (Optionally) saves the as-run config
    3. Builds the model & optimizer
    4. Gets the training input data
    5. Loads a fine-tuning detection or classification checkpoint if requested
    6. Loops over the train data, executing distributed training steps inside
       tf.functions.
    7. Checkpoints the model every `checkpoint_every_n` training steps.
    8. Logs the training metrics as TensorBoard summaries.

  Args:
    pipeline_config_path: A path to a pipeline config file.
    model_dir:
      The directory to save checkpoints and summaries to.
    config_override: A pipeline_pb2.TrainEvalPipelineConfig text proto to
      override the config from `pipeline_config_path`.
    train_steps: Number of training steps. If None, the number of training steps
      is set from the `TrainConfig` proto.
    use_tpu: Boolean, whether training and evaluation should run on TPU.
    save_final_config: Whether to save final config (obtained after applying
      overrides) to `model_dir`.
    checkpoint_every_n:
      Checkpoint every n training steps.
    checkpoint_max_to_keep:
      int, the number of most recent checkpoints to keep in the model directory.
    **kwargs: Additional keyword arguments for configuration override.
  """
    ## Parse the configs
    get_configs_from_pipeline_file = MODEL_BUILD_UTIL_MAP[
        'get_configs_from_pipeline_file']
    merge_external_params_with_configs = MODEL_BUILD_UTIL_MAP[
        'merge_external_params_with_configs']
    create_pipeline_proto_from_configs = MODEL_BUILD_UTIL_MAP[
        'create_pipeline_proto_from_configs']

    configs = get_configs_from_pipeline_file(pipeline_config_path,
                                             config_override=config_override)
    kwargs.update({
        'train_steps':
        train_steps,
        'use_bfloat16':
        configs['train_config'].use_bfloat16 and use_tpu
    })
    configs = merge_external_params_with_configs(configs,
                                                 None,
                                                 kwargs_dict=kwargs)
    model_config = configs['model']
    train_config = configs['train_config']
    train_input_config = configs['train_input_config']

    unpad_groundtruth_tensors = train_config.unpad_groundtruth_tensors
    add_regularization_loss = train_config.add_regularization_loss
    clip_gradients_value = None
    if train_config.gradient_clipping_by_norm > 0:
        clip_gradients_value = train_config.gradient_clipping_by_norm

    # update train_steps from config but only when non-zero value is provided
    if train_steps is None and train_config.num_steps != 0:
        train_steps = train_config.num_steps

    if kwargs['use_bfloat16']:
        tf.compat.v2.keras.mixed_precision.experimental.set_policy(
            'mixed_bfloat16')

    if train_config.load_all_detection_checkpoint_vars:
        raise ValueError('train_pb2.load_all_detection_checkpoint_vars '
                         'unsupported in TF2')

    config_util.update_fine_tune_checkpoint_type(train_config)
    fine_tune_checkpoint_type = train_config.fine_tune_checkpoint_type
    fine_tune_checkpoint_version = train_config.fine_tune_checkpoint_version

    # Write the as-run pipeline config to disk.
    if save_final_config:
        pipeline_config_final = create_pipeline_proto_from_configs(configs)
        config_util.save_pipeline_config(pipeline_config_final, model_dir)

    # Build the model, optimizer, and training input
    strategy = tf.compat.v2.distribute.get_strategy()
    with strategy.scope():
        detection_model = model_builder.build(model_config=model_config,
                                              is_training=True)

        def train_dataset_fn(input_context):
            """Callable to create train input."""
            # Create the inputs.
            train_input = inputs.train_input(
                train_config=train_config,
                train_input_config=train_input_config,
                model_config=model_config,
                model=detection_model,
                input_context=input_context)
            train_input = train_input.repeat()
            return train_input

        train_input = strategy.experimental_distribute_datasets_from_function(
            train_dataset_fn)

        global_step = tf.Variable(
            0,
            trainable=False,
            dtype=tf.compat.v2.dtypes.int64,
            name='global_step',
            aggregation=tf.compat.v2.VariableAggregation.ONLY_FIRST_REPLICA)
        optimizer, (learning_rate, ) = optimizer_builder.build(
            train_config.optimizer, global_step=global_step)

        if callable(learning_rate):
            learning_rate_fn = learning_rate
        else:
            learning_rate_fn = lambda: learning_rate

    ## Train the model
    # Get the appropriate filepath (temporary or not) based on whether the worker
    # is the chief.
    summary_writer_filepath = get_filepath(strategy,
                                           os.path.join(model_dir, 'train'))
    summary_writer = tf.compat.v2.summary.create_file_writer(
        summary_writer_filepath)

    if use_tpu:
        num_steps_per_iteration = 100
    else:
        # TODO(b/135933080) Explore setting to 100 when GPU performance issues
        # are fixed.
        num_steps_per_iteration = 1

    with summary_writer.as_default():
        with strategy.scope():
            with tf.compat.v2.summary.record_if(
                    lambda: global_step % num_steps_per_iteration == 0):
                # Load a fine-tuning checkpoint.
                if train_config.fine_tune_checkpoint:
                    load_fine_tune_checkpoint(
                        detection_model, train_config.fine_tune_checkpoint,
                        fine_tune_checkpoint_type,
                        fine_tune_checkpoint_version, train_input,
                        unpad_groundtruth_tensors)

                ckpt = tf.compat.v2.train.Checkpoint(step=global_step,
                                                     model=detection_model,
                                                     optimizer=optimizer)

                manager_dir = get_filepath(strategy, model_dir)
                if not strategy.extended.should_checkpoint:
                    checkpoint_max_to_keep = 1
                manager = tf.compat.v2.train.CheckpointManager(
                    ckpt, manager_dir, max_to_keep=checkpoint_max_to_keep)

                # We use the following instead of manager.latest_checkpoint because
                # manager_dir does not point to the model directory when we are running
                # in a worker.
                latest_checkpoint = tf.train.latest_checkpoint(model_dir)
                ckpt.restore(latest_checkpoint)

                def train_step_fn(features, labels):
                    """Single train step."""
                    loss = eager_train_step(
                        detection_model,
                        features,
                        labels,
                        unpad_groundtruth_tensors,
                        optimizer,
                        learning_rate=learning_rate_fn(),
                        add_regularization_loss=add_regularization_loss,
                        clip_gradients_value=clip_gradients_value,
                        global_step=global_step,
                        num_replicas=strategy.num_replicas_in_sync)
                    global_step.assign_add(1)
                    return loss

                def _sample_and_train(strategy, train_step_fn, data_iterator):
                    features, labels = data_iterator.next()
                    per_replica_losses = strategy.run(train_step_fn,
                                                      args=(features, labels))
                    # TODO(anjalisridhar): explore if it is safe to remove the
                    ## num_replicas scaling of the loss and switch this to a ReduceOp.Mean
                    return strategy.reduce(tf.distribute.ReduceOp.SUM,
                                           per_replica_losses,
                                           axis=None)

                @tf.function
                def _dist_train_step(data_iterator):
                    """A distributed train step."""

                    if num_steps_per_iteration > 1:
                        for _ in tf.range(num_steps_per_iteration - 1):
                            _sample_and_train(strategy, train_step_fn,
                                              data_iterator)

                    return _sample_and_train(strategy, train_step_fn,
                                             data_iterator)

                train_input_iter = iter(train_input)

                if int(global_step.value()) == 0:
                    manager.save()

                checkpointed_step = int(global_step.value())
                logged_step = global_step.value()

                last_step_time = time.time()
                for _ in range(global_step.value(), train_steps,
                               num_steps_per_iteration):

                    loss = _dist_train_step(train_input_iter)

                    time_taken = time.time() - last_step_time
                    last_step_time = time.time()

                    tf.compat.v2.summary.scalar('steps_per_sec',
                                                num_steps_per_iteration * 1.0 /
                                                time_taken,
                                                step=global_step)

                    if global_step.value() - logged_step >= 100:
                        tf.logging.info(
                            'Step {} per-step time {:.3f}s loss={:.3f}'.format(
                                global_step.value(),
                                time_taken / num_steps_per_iteration, loss))
                        logged_step = global_step.value()

                    if ((int(global_step.value()) - checkpointed_step) >=
                            checkpoint_every_n):
                        manager.save()
                        checkpointed_step = int(global_step.value())

    # Remove the checkpoint directories of the non-chief workers that
    # MultiWorkerMirroredStrategy forces us to save during sync distributed
    # training.
    clean_temporary_directories(strategy, manager_dir)
    clean_temporary_directories(strategy, summary_writer_filepath)
Ejemplo n.º 5
0
def train_loop(pipeline_config_path,
               model_dir,
               config_override=None,
               train_steps=None,
               use_tpu=False,
               save_final_config=False,
               checkpoint_every_n=1000,
               checkpoint_max_to_keep=7,
               record_summaries=True,
               performance_summary_exporter=None,
               num_steps_per_iteration=NUM_STEPS_PER_ITERATION,
               **kwargs):

    get_configs_from_pipeline_file = MODEL_BUILD_UTIL_MAP[
        "get_configs_from_pipeline_file"]
    merge_external_params_with_configs = MODEL_BUILD_UTIL_MAP[
        "merge_external_params_with_configs"]
    create_pipeline_proto_from_configs = MODEL_BUILD_UTIL_MAP[
        "create_pipeline_proto_from_configs"]
    steps_per_sec_list = []

    configs = get_configs_from_pipeline_file(pipeline_config_path,
                                             config_override=config_override)
    kwargs.update({
        "train_steps":
        train_steps,
        "use_bfloat16":
        configs["train_config"].use_bfloat16 and use_tpu,
    })
    configs = merge_external_params_with_configs(configs,
                                                 None,
                                                 kwargs_dict=kwargs)
    model_config = configs["model"]
    train_config = configs["train_config"]
    train_input_config = configs["train_input_config"]

    unpad_groundtruth_tensors = train_config.unpad_groundtruth_tensors
    add_regularization_loss = train_config.add_regularization_loss
    clip_gradients_value = None
    if train_config.gradient_clipping_by_norm > 0:
        clip_gradients_value = train_config.gradient_clipping_by_norm

    if train_steps is None and train_config.num_steps != 0:
        train_steps = train_config.num_steps

    if kwargs["use_bfloat16"]:
        tf.compat.v2.keras.mixed_precision.set_global_policy("mixed_bfloat16")

    if train_config.load_all_detection_checkpoint_vars:
        raise ValueError(
            "train_pb2.load_all_detection_checkpoint_vars unsupported in TF2")

    config_util.update_fine_tune_checkpoint_type(train_config)
    fine_tune_checkpoint_type = train_config.fine_tune_checkpoint_type
    fine_tune_checkpoint_version = train_config.fine_tune_checkpoint_version

    strategy = tf.compat.v2.distribute.get_strategy()
    with strategy.scope():
        detection_model = MODEL_BUILD_UTIL_MAP["detection_model_fn_base"](
            model_config=model_config, is_training=True)

        def train_dataset_fn(input_context):
            train_input = inputs.train_input(
                train_config=train_config,
                train_input_config=train_input_config,
                model_config=model_config,
                model=detection_model,
                input_context=input_context,
            )
            train_input = train_input.repeat()
            return train_input

        train_input = strategy.experimental_distribute_datasets_from_function(
            train_dataset_fn)

        global_step = tf.Variable(
            0,
            trainable=False,
            dtype=tf.compat.v2.dtypes.int64,
            name="global_step",
            aggregation=tf.compat.v2.VariableAggregation.ONLY_FIRST_REPLICA,
        )
        optimizer, (learning_rate, ) = optimizer_builder.build(
            train_config.optimizer, global_step=global_step)

        if train_config.optimizer.use_moving_average:
            _ensure_model_is_built(detection_model, train_input,
                                   unpad_groundtruth_tensors)
            optimizer.shadow_copy(detection_model)

        if callable(learning_rate):
            learning_rate_fn = learning_rate
        else:
            learning_rate_fn = lambda: learning_rate

    summary_writer_filepath = get_filepath(strategy,
                                           os.path.join(model_dir, "train"))
    # summary_writer = tf.compat.v2.summary.create_file_writer(summary_writer_filepath)
    summary_writer = tf2.summary.create_noop_writer()
    with summary_writer.as_default():
        with strategy.scope():
            with tf2.summary.record_if(
                    lambda: global_step % num_steps_per_iteration == 0):
                if train_config.fine_tune_checkpoint:
                    load_fine_tune_checkpoint(
                        detection_model,
                        train_config.fine_tune_checkpoint,
                        fine_tune_checkpoint_type,
                        fine_tune_checkpoint_version,
                        train_config.
                        run_fine_tune_checkpoint_dummy_computation,
                        train_input,
                        unpad_groundtruth_tensors,
                    )

                # ckpt = tf.compat.v2.train.Checkpoint(step=global_step, model=detection_model, optimizer=optimizer)

                # manager_dir = get_filepath(strategy, model_dir)
                # manager = tf.compat.v2.train.CheckpointManager(ckpt, manager_dir, max_to_keep=1)

                # latest_checkpoint = tf.train.latest_checkpoint(model_dir)
                # ckpt.restore(latest_checkpoint)

                def train_step_fn(features, labels):
                    loss = eager_train_step(
                        detection_model,
                        features,
                        labels,
                        unpad_groundtruth_tensors,
                        optimizer,
                        learning_rate=learning_rate_fn(),
                        add_regularization_loss=add_regularization_loss,
                        clip_gradients_value=clip_gradients_value,
                        global_step=global_step,
                        num_replicas=strategy.num_replicas_in_sync,
                    )
                    global_step.assign_add(1)
                    return loss

                def _sample_and_train(strategy, train_step_fn, data_iterator):
                    features, labels = data_iterator.next()
                    if hasattr(tf.distribute.Strategy, "run"):
                        per_replica_losses = strategy.run(train_step_fn,
                                                          args=(features,
                                                                labels))
                    else:
                        per_replica_losses = strategy.experimental_run_v2(
                            train_step_fn, args=(features, labels))
                    return strategy.reduce(tf.distribute.ReduceOp.SUM,
                                           per_replica_losses,
                                           axis=None)

                @tf.function
                def _dist_train_step(data_iterator):
                    if num_steps_per_iteration > 1:
                        for _ in tf.range(num_steps_per_iteration - 1):
                            with tf.name_scope(""):
                                _sample_and_train(strategy, train_step_fn,
                                                  data_iterator)

                    return _sample_and_train(strategy, train_step_fn,
                                             data_iterator)

                train_input_iter = iter(train_input)

                checkpointed_step = int(global_step.value())
                logged_step = global_step.value()

                last_step_time = time.time()
                for _ in range(global_step.value(), train_steps,
                               num_steps_per_iteration):
                    loss = _dist_train_step(train_input_iter)
                    time_taken = time.time() - last_step_time
                    last_step_time = time.time()
                    steps_per_sec = num_steps_per_iteration * 1.0 / time_taken
                    tf.compat.v2.summary.scalar("steps_per_sec",
                                                steps_per_sec,
                                                step=global_step)
                    steps_per_sec_list.append(steps_per_sec)
                    if global_step.value() - logged_step >= 100:
                        tf.logging.info(
                            "Step {} per-step time {:.3f}s loss={:.3f}".format(
                                global_step.value(),
                                time_taken / num_steps_per_iteration, loss))
                        logged_step = global_step.value()

                    if (int(global_step.value()) -
                            checkpointed_step) >= checkpoint_every_n:
                        # manager.save()
                        checkpointed_step = int(global_step.value())

    # clean_temporary_directories(strategy, manager_dir)
    clean_temporary_directories(strategy, summary_writer_filepath)