Пример #1
0
def get_optimizer(params):
  """Get optimizer."""
  learning_rate = learning_rate_schedule(params)
  momentum = params['momentum']
  if params['optimizer'].lower() == 'sgd':
    logging.info('Use SGD optimizer')
    optimizer = tf.keras.optimizers.SGD(learning_rate, momentum=momentum)
  elif params['optimizer'].lower() == 'adam':
    logging.info('Use Adam optimizer')
    optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=momentum)
  else:
    raise ValueError('optimizers should be adam or sgd')

  moving_average_decay = params['moving_average_decay']
  if moving_average_decay:
    # TODO(tanmingxing): potentially add dynamic_decay for new tfa release.
    from tensorflow_addons import optimizers as tfa_optimizers  # pylint: disable=g-import-not-at-top
    optimizer = tfa_optimizers.MovingAverage(
        optimizer, average_decay=moving_average_decay, dynamic_decay=True)
  precision = utils.get_precision(params['strategy'], params['mixed_precision'])
  if precision == 'mixed_float16' and params['loss_scale']:
    optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
        optimizer,
        initial_scale=params['loss_scale'])
  return optimizer
Пример #2
0
    def __init__(self,
                 model_name: Text,
                 ckpt_path: Text = None,
                 batch_size: int = 1,
                 only_network: bool = False,
                 model_params: Dict[Text, Any] = None):
        """Initialize the inference driver.

    Args:
      model_name: target model name, such as efficientdet-d0.
      ckpt_path: checkpoint path, such as /tmp/efficientdet-d0/.
      batch_size: batch size for inference.
      only_network: only use the network without pre/post processing.
      model_params: model parameters for overriding the config.
    """
        super().__init__()
        self.model_name = model_name
        self.ckpt_path = ckpt_path
        self.batch_size = batch_size
        self.only_network = only_network

        self.params = hparams_config.get_detection_config(model_name).as_dict()

        if model_params:
            self.params.update(model_params)
        self.params.update(dict(is_training_bn=False))
        self.label_map = self.params.get('label_map', None)

        self._model = None

        mixed_precision = self.params.get('mixed_precision', None)
        precision = utils.get_precision(self.params.get('strategy', None),
                                        mixed_precision)
        policy = tf.keras.mixed_precision.experimental.Policy(precision)
        tf.keras.mixed_precision.experimental.set_policy(policy)
Пример #3
0
def build_model(model_name: Text, inputs: tf.Tensor, **kwargs):
    """Build model for a given model name.

  Args:
    model_name: the name of the model.
    inputs: an image tensor or a numpy array.
    **kwargs: extra parameters for model builder.

  Returns:
    (cls_outputs, box_outputs): the outputs for class and box predictions.
    Each is a dictionary with key as feature level and value as predictions.
  """
    mixed_precision = kwargs.get('mixed_precision', None)
    precision = utils.get_precision(kwargs.get('strategy', None),
                                    mixed_precision)

    if kwargs.get('use_keras_model', None):

        def model_arch(feats, model_name=None, **kwargs):
            """Construct a model arch for keras models."""
            config = hparams_config.get_efficientdet_config(model_name)
            config.override(kwargs)
            model = efficientdet_keras.EfficientDetNet(config=config)
            cls_out_list, box_out_list = model(feats, training=False)
            # convert the list of model outputs to a dictionary with key=level.
            assert len(cls_out_list) == config.max_level - config.min_level + 1
            assert len(box_out_list) == config.max_level - config.min_level + 1
            cls_outputs, box_outputs = {}, {}
            for i in range(config.min_level, config.max_level + 1):
                cls_outputs[i] = cls_out_list[i - config.min_level]
                box_outputs[i] = box_out_list[i - config.min_level]
            return cls_outputs, box_outputs

    else:
        model_arch = det_model_fn.get_model_arch(model_name)

    cls_outputs, box_outputs = utils.build_model_with_precision(
        precision, model_arch, inputs, model_name, **kwargs)

    if mixed_precision:
        # Post-processing has multiple places with hard-coded float32.
        # TODO(tanmingxing): Remove them once post-process can adpat to dtypes.
        cls_outputs = {
            k: tf.cast(v, tf.float32)
            for k, v in cls_outputs.items()
        }
        box_outputs = {
            k: tf.cast(v, tf.float32)
            for k, v in box_outputs.items()
        }

    return cls_outputs, box_outputs
Пример #4
0
  def __init__(self,
               model_name: str,
               uri: str,
               hparams: str = '',
               model_dir: Optional[str] = None,
               epochs: int = 50,
               batch_size: int = 64,
               steps_per_execution: int = 1,
               moving_average_decay: int = 0,
               var_freeze_expr: str = '(efficientnet|fpn_cells|resample_p6)',
               tflite_max_detections: int = 25,
               strategy: Optional[str] = None,
               tpu: Optional[str] = None,
               gcp_project: Optional[str] = None,
               tpu_zone: Optional[str] = None,
               use_xla: bool = False,
               profile: bool = False,
               debug: bool = False,
               tf_random_seed: int = 111111,
               verbose: int = 0) -> None:
    """Initialze an instance with model paramaters.

    Args:
      model_name: Model name.
      uri: TF-Hub path/url to EfficientDet module.
      hparams: Hyperparameters used to overwrite default configuration. Can be
        1) Dict, contains parameter names and values; 2) String, Comma separated
        k=v pairs of hyperparameters; 3) String, yaml filename which's a module
        containing attributes to use as hyperparameters.
      model_dir: The location to save the model checkpoint files.
      epochs: Default training epochs.
      batch_size: Training & Evaluation batch size.
      steps_per_execution: Number of steps per training execution.
      moving_average_decay: Float. The decay to use for maintaining moving
        averages of the trained parameters.
      var_freeze_expr: Expression to freeze variables.
      tflite_max_detections: The max number of output detections in the TFLite
        model.
      strategy:  A string specifying which distribution strategy to use.
        Accepted values are 'tpu', 'gpus', None. tpu' means to use TPUStrategy.
        'gpus' mean to use MirroredStrategy for multi-gpus. If None, use TF
        default with OneDeviceStrategy.
      tpu: The Cloud TPU to use for training. This should be either the name
        used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470
          url.
      gcp_project: Project name for the Cloud TPU-enabled project. If not
        specified, we will attempt to automatically detect the GCE project from
        metadata.
      tpu_zone: GCE zone where the Cloud TPU is located in. If not specified, we
        will attempt to automatically detect the GCE project from metadata.
      use_xla: Use XLA even if strategy is not tpu. If strategy is tpu, always
        use XLA, and this flag has no effect.
      profile: Enable profile mode.
      debug: Enable debug mode.
      tf_random_seed: Fixed random seed for deterministic execution across runs
        for debugging.
      verbose: verbosity mode for `tf.keras.callbacks.ModelCheckpoint`, 0 or 1.
    """
    self.model_name = model_name
    self.uri = uri
    self.batch_size = batch_size
    config = hparams_config.get_efficientdet_config(model_name)
    config.override(hparams)
    config.image_size = utils.parse_image_size(config.image_size)
    config.var_freeze_expr = var_freeze_expr
    config.moving_average_decay = moving_average_decay
    config.tflite_max_detections = tflite_max_detections
    if epochs:
      config.num_epochs = epochs

    if use_xla and strategy != 'tpu':
      tf.config.optimizer.set_jit(True)
      for gpu in tf.config.list_physical_devices('GPU'):
        tf.config.experimental.set_memory_growth(gpu, True)

    if debug:
      tf.config.experimental_run_functions_eagerly(True)
      tf.debugging.set_log_device_placement(True)
      os.environ['TF_DETERMINISTIC_OPS'] = '1'
      tf.random.set_seed(tf_random_seed)
      logging.set_verbosity(logging.DEBUG)

    if strategy == 'tpu':
      tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
          tpu, zone=tpu_zone, project=gcp_project)
      tf.config.experimental_connect_to_cluster(tpu_cluster_resolver)
      tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver)
      ds_strategy = tf.distribute.TPUStrategy(tpu_cluster_resolver)
      logging.info('All devices: %s', tf.config.list_logical_devices('TPU'))
      tf.config.set_soft_device_placement(True)
    elif strategy == 'gpus':
      ds_strategy = tf.distribute.MirroredStrategy()
      logging.info('All devices: %s', tf.config.list_physical_devices('GPU'))
    else:
      if tf.config.list_physical_devices('GPU'):
        ds_strategy = tf.distribute.OneDeviceStrategy('device:GPU:0')
      else:
        ds_strategy = tf.distribute.OneDeviceStrategy('device:CPU:0')

    self.ds_strategy = ds_strategy

    if model_dir is None:
      model_dir = tempfile.mkdtemp()
    params = dict(
        profile=profile,
        model_name=model_name,
        steps_per_execution=steps_per_execution,
        model_dir=model_dir,
        strategy=strategy,
        batch_size=batch_size,
        tf_random_seed=tf_random_seed,
        debug=debug,
        verbose=verbose)
    config.override(params, True)
    self.config = config

    # set mixed precision policy by keras api.
    precision = utils.get_precision(config.strategy, config.mixed_precision)
    policy = tf.keras.mixed_precision.experimental.Policy(precision)
    tf.keras.mixed_precision.experimental.set_policy(policy)
Пример #5
0
def main(_):
    # Parse and override hparams
    config = hparams_config.get_detection_config(FLAGS.model_name)
    config.override(FLAGS.hparams)
    if FLAGS.num_epochs:  # NOTE: remove this flag after updating all docs.
        config.num_epochs = FLAGS.num_epochs

    # Parse image size in case it is in string format.
    config.image_size = utils.parse_image_size(config.image_size)

    if FLAGS.use_xla and FLAGS.strategy != 'tpu':
        tf.config.optimizer.set_jit(True)
        for gpu in tf.config.list_physical_devices('GPU'):
            tf.config.experimental.set_memory_growth(gpu, True)

    if FLAGS.debug:
        tf.config.run_functions_eagerly(True)
        tf.debugging.set_log_device_placement(True)
        os.environ['TF_DETERMINISTIC_OPS'] = '1'
        tf.random.set_seed(FLAGS.tf_random_seed)
        logging.set_verbosity(logging.DEBUG)

    if FLAGS.strategy == 'tpu':
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
        tf.config.experimental_connect_to_cluster(tpu_cluster_resolver)
        tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver)
        ds_strategy = tf.distribute.TPUStrategy(tpu_cluster_resolver)
        logging.info('All devices: %s', tf.config.list_logical_devices('TPU'))
    elif FLAGS.strategy == 'gpus':
        ds_strategy = tf.distribute.MirroredStrategy()
        logging.info('All devices: %s', tf.config.list_physical_devices('GPU'))
    else:
        if tf.config.list_physical_devices('GPU'):
            ds_strategy = tf.distribute.OneDeviceStrategy('device:GPU:0')
        else:
            ds_strategy = tf.distribute.OneDeviceStrategy('device:CPU:0')

    steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.batch_size
    params = dict(profile=FLAGS.profile,
                  model_name=FLAGS.model_name,
                  steps_per_execution=FLAGS.steps_per_execution,
                  model_dir=FLAGS.model_dir,
                  steps_per_epoch=steps_per_epoch,
                  strategy=FLAGS.strategy,
                  batch_size=FLAGS.batch_size,
                  tf_random_seed=FLAGS.tf_random_seed,
                  debug=FLAGS.debug,
                  val_json_file=FLAGS.val_json_file,
                  eval_samples=FLAGS.eval_samples,
                  num_shards=ds_strategy.num_replicas_in_sync)
    config.override(params, True)
    # set mixed precision policy by keras api.
    precision = utils.get_precision(config.strategy, config.mixed_precision)
    policy = tf.keras.mixed_precision.Policy(precision)
    tf.keras.mixed_precision.set_global_policy(policy)

    def get_dataset(is_training, config):
        file_pattern = (FLAGS.train_file_pattern
                        if is_training else FLAGS.val_file_pattern)
        if not file_pattern:
            raise ValueError('No matching files.')

        return dataloader.InputReader(
            file_pattern,
            is_training=is_training,
            use_fake_data=FLAGS.use_fake_data,
            max_instances_per_image=config.max_instances_per_image,
            debug=FLAGS.debug)(config.as_dict())

    with ds_strategy.scope():
        if config.model_optimizations:
            tfmot.set_config(config.model_optimizations.as_dict())
        if FLAGS.hub_module_url:
            model = train_lib.EfficientDetNetTrainHub(
                config=config, hub_module_url=FLAGS.hub_module_url)
        else:
            model = train_lib.EfficientDetNetTrain(config=config)
        model = setup_model(model, config)
        if FLAGS.pretrained_ckpt and not FLAGS.hub_module_url:
            ckpt_path = tf.train.latest_checkpoint(FLAGS.pretrained_ckpt)
            util_keras.restore_ckpt(model, ckpt_path,
                                    config.moving_average_decay)
        init_experimental(config)
        if 'train' in FLAGS.mode:
            val_dataset = get_dataset(False,
                                      config) if 'eval' in FLAGS.mode else None
            model.fit(
                get_dataset(True, config),
                epochs=config.num_epochs,
                steps_per_epoch=steps_per_epoch,
                callbacks=train_lib.get_callbacks(config.as_dict(),
                                                  val_dataset),
                validation_data=val_dataset,
                validation_steps=(FLAGS.eval_samples // FLAGS.batch_size))
        else:
            # Continuous eval.
            for ckpt in tf.train.checkpoints_iterator(FLAGS.model_dir,
                                                      min_interval_secs=180):
                logging.info('Starting to evaluate.')
                # Terminate eval job when final checkpoint is reached.
                try:
                    current_epoch = int(os.path.basename(ckpt).split('-')[1])
                except IndexError:
                    current_epoch = 0

                val_dataset = get_dataset(False, config)
                logging.info('start loading model.')
                model.load_weights(tf.train.latest_checkpoint(FLAGS.model_dir))
                logging.info('finish loading model.')
                coco_eval = train_lib.COCOCallback(val_dataset, 1)
                coco_eval.set_model(model)
                eval_results = coco_eval.on_epoch_end(current_epoch)
                logging.info('eval results for %s: %s', ckpt, eval_results)

                try:
                    utils.archive_ckpt(eval_results, eval_results['AP'], ckpt)
                except tf.errors.NotFoundError:
                    # Checkpoint might be not already deleted by the time eval finished.
                    logging.info('Checkpoint %s no longer exists, skipping.',
                                 ckpt)

                if current_epoch >= config.num_epochs or not current_epoch:
                    logging.info('Eval epoch %d / %d', current_epoch,
                                 config.num_epochs)
                    break
Пример #6
0
def _model_fn(features, labels, mode, params, model, variable_filter_fn=None):
  """Model definition entry.

  Args:
    features: the input image tensor with shape [batch_size, height, width, 3].
      The height and width are fixed and equal.
    labels: the input labels in a dictionary. The labels include class targets
      and box targets which are dense label maps. The labels are generated from
      get_input_fn function in data/dataloader.py
    mode: the mode of TPUEstimator including TRAIN and EVAL.
    params: the dictionary defines hyperparameters of model. The default
      settings are in default_hparams function in this file.
    model: the model outputs class logits and box regression outputs.
    variable_filter_fn: the filter function that takes trainable_variables and
      returns the variable list after applying the filter rule.

  Returns:
    tpu_spec: the TPUEstimatorSpec to run training, evaluation, or prediction.

  Raises:
    RuntimeError: if both ckpt and backbone_ckpt are set.
  """
  is_tpu = params['strategy'] == 'tpu'
  if params['img_summary_steps']:
    utils.image('input_image', features, is_tpu)
  training_hooks = []
  params['is_training_bn'] = (mode == tf.estimator.ModeKeys.TRAIN)

  if params['use_keras_model']:

    def model_fn(inputs):
      model = efficientdet_keras.EfficientDetNet(
          config=hparams_config.Config(params))
      cls_out_list, box_out_list = model(inputs, params['is_training_bn'])
      cls_outputs, box_outputs = {}, {}
      for i in range(params['min_level'], params['max_level'] + 1):
        cls_outputs[i] = cls_out_list[i - params['min_level']]
        box_outputs[i] = box_out_list[i - params['min_level']]
      return cls_outputs, box_outputs
  else:
    model_fn = functools.partial(model, config=hparams_config.Config(params))

  precision = utils.get_precision(params['strategy'], params['mixed_precision'])
  cls_outputs, box_outputs = utils.build_model_with_precision(
      precision, model_fn, features, params['is_training_bn'])

  levels = cls_outputs.keys()
  for level in levels:
    cls_outputs[level] = tf.cast(cls_outputs[level], tf.float32)
    box_outputs[level] = tf.cast(box_outputs[level], tf.float32)

  # Set up training loss and learning rate.
  update_learning_rate_schedule_parameters(params)
  global_step = tf.train.get_or_create_global_step()
  learning_rate = learning_rate_schedule(params, global_step)

  # cls_loss and box_loss are for logging. only total_loss is optimized.
  det_loss, cls_loss, box_loss = detection_loss(
      cls_outputs, box_outputs, labels, params)
  reg_l2loss = reg_l2_loss(params['weight_decay'])
  total_loss = det_loss + reg_l2loss

  if mode == tf.estimator.ModeKeys.TRAIN:
    utils.scalar('lrn_rate', learning_rate, is_tpu)
    utils.scalar('trainloss/cls_loss', cls_loss, is_tpu)
    utils.scalar('trainloss/box_loss', box_loss, is_tpu)
    utils.scalar('trainloss/det_loss', det_loss, is_tpu)
    utils.scalar('trainloss/reg_l2_loss', reg_l2loss, is_tpu)
    utils.scalar('trainloss/loss', total_loss, is_tpu)
    train_epochs = tf.cast(global_step, tf.float32) / params['steps_per_epoch']
    utils.scalar('train_epochs', train_epochs, is_tpu)

  moving_average_decay = params['moving_average_decay']
  if moving_average_decay:
    ema = tf.train.ExponentialMovingAverage(
        decay=moving_average_decay, num_updates=global_step)
    ema_vars = utils.get_ema_vars()

  if mode == tf.estimator.ModeKeys.TRAIN:
    if params['optimizer'].lower() == 'sgd':
      optimizer = tf.train.MomentumOptimizer(
          learning_rate, momentum=params['momentum'])
    elif params['optimizer'].lower() == 'adam':
      optimizer = tf.train.AdamOptimizer(learning_rate)
    else:
      raise ValueError('optimizers should be adam or sgd')

    if is_tpu:
      optimizer = tf.tpu.CrossShardOptimizer(optimizer)

    # Batch norm requires update_ops to be added as a train_op dependency.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    var_list = tf.trainable_variables()
    if variable_filter_fn:
      var_list = variable_filter_fn(var_list)

    if params.get('clip_gradients_norm', None):
      logging.info('clip gradients norm by %f', params['clip_gradients_norm'])
      grads_and_vars = optimizer.compute_gradients(total_loss, var_list)
      with tf.name_scope('clip'):
        grads = [gv[0] for gv in grads_and_vars]
        tvars = [gv[1] for gv in grads_and_vars]
        # First clip each variable's norm, then clip global norm.
        clip_norm = abs(params['clip_gradients_norm'])
        clipped_grads = [
            tf.clip_by_norm(g, clip_norm) if g is not None else None
            for g in grads
        ]
        clipped_grads, _ = tf.clip_by_global_norm(clipped_grads, clip_norm)
        utils.scalar('gradient_norm', tf.linalg.global_norm(clipped_grads),
                     is_tpu)
        grads_and_vars = list(zip(clipped_grads, tvars))

      with tf.control_dependencies(update_ops):
        train_op = optimizer.apply_gradients(grads_and_vars, global_step)
    else:
      with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(
            total_loss, global_step, var_list=var_list)

    if moving_average_decay:
      with tf.control_dependencies([train_op]):
        train_op = ema.apply(ema_vars)

  else:
    train_op = None

  eval_metrics = None
  if mode == tf.estimator.ModeKeys.EVAL:

    def metric_fn(**kwargs):
      """Returns a dictionary that has the evaluation metrics."""
      if params['nms_configs'].get('pyfunc', True):
        detections_bs = []
        nms_configs = params['nms_configs']
        for index in range(kwargs['boxes'].shape[0]):
          detections = tf.numpy_function(
              functools.partial(nms_np.per_class_nms, nms_configs=nms_configs),
              [
                  kwargs['boxes'][index],
                  kwargs['scores'][index],
                  kwargs['classes'][index],
                  tf.slice(kwargs['image_ids'], [index], [1]),
                  tf.slice(kwargs['image_scales'], [index], [1]),
                  params['num_classes'],
                  nms_configs['max_output_size'],
              ], tf.float32)
          detections_bs.append(detections)
        detections_bs = postprocess.transform_detections(
            tf.stack(detections_bs))
      else:
        # These two branches should be equivalent, but currently they are not.
        # TODO(tanmingxing): enable the non_pyfun path after bug fix.
        nms_boxes, nms_scores, nms_classes, _ = postprocess.per_class_nms(
            params, kwargs['boxes'], kwargs['scores'], kwargs['classes'],
            kwargs['image_scales'])
        img_ids = tf.cast(
            tf.expand_dims(kwargs['image_ids'], -1), nms_scores.dtype)
        detections_bs = [
            img_ids * tf.ones_like(nms_scores),
            nms_boxes[:, :, 1],
            nms_boxes[:, :, 0],
            nms_boxes[:, :, 3] - nms_boxes[:, :, 1],
            nms_boxes[:, :, 2] - nms_boxes[:, :, 0],
            nms_scores,
            nms_classes,
        ]
        detections_bs = tf.stack(detections_bs, axis=-1, name='detnections')

      if params.get('testdev_dir', None):
        logging.info('Eval testdev_dir %s', params['testdev_dir'])
        eval_metric = coco_metric.EvaluationMetric(
            testdev_dir=params['testdev_dir'])
        coco_metrics = eval_metric.estimator_metric_fn(detections_bs,
                                                       tf.zeros([1]))
      else:
        logging.info('Eval val with groudtruths %s.', params['val_json_file'])
        eval_metric = coco_metric.EvaluationMetric(
            filename=params['val_json_file'], label_map=params['label_map'])
        coco_metrics = eval_metric.estimator_metric_fn(
            detections_bs, kwargs['groundtruth_data'])

      # Add metrics to output.
      cls_loss = tf.metrics.mean(kwargs['cls_loss_repeat'])
      box_loss = tf.metrics.mean(kwargs['box_loss_repeat'])
      output_metrics = {
          'cls_loss': cls_loss,
          'box_loss': box_loss,
      }
      output_metrics.update(coco_metrics)
      return output_metrics

    cls_loss_repeat = tf.reshape(
        tf.tile(tf.expand_dims(cls_loss, 0), [
            params['batch_size'],
        ]), [params['batch_size'], 1])
    box_loss_repeat = tf.reshape(
        tf.tile(tf.expand_dims(box_loss, 0), [
            params['batch_size'],
        ]), [params['batch_size'], 1])

    cls_outputs = postprocess.to_list(cls_outputs)
    box_outputs = postprocess.to_list(box_outputs)
    params['nms_configs']['max_nms_inputs'] = anchors.MAX_DETECTION_POINTS
    boxes, scores, classes = postprocess.pre_nms(params, cls_outputs,
                                                 box_outputs)
    metric_fn_inputs = {
        'cls_loss_repeat': cls_loss_repeat,
        'box_loss_repeat': box_loss_repeat,
        'image_ids': labels['source_ids'],
        'groundtruth_data': labels['groundtruth_data'],
        'image_scales': labels['image_scales'],
        'boxes': boxes,
        'scores': scores,
        'classes': classes,
    }
    eval_metrics = (metric_fn, metric_fn_inputs)

  checkpoint = params.get('ckpt') or params.get('backbone_ckpt')

  if checkpoint and mode == tf.estimator.ModeKeys.TRAIN:
    # Initialize the model from an EfficientDet or backbone checkpoint.
    if params.get('ckpt') and params.get('backbone_ckpt'):
      raise RuntimeError(
          '--backbone_ckpt and --checkpoint are mutually exclusive')

    if params.get('backbone_ckpt'):
      var_scope = params['backbone_name'] + '/'
      if params['ckpt_var_scope'] is None:
        # Use backbone name as default checkpoint scope.
        ckpt_scope = params['backbone_name'] + '/'
      else:
        ckpt_scope = params['ckpt_var_scope'] + '/'
    else:
      # Load every var in the given checkpoint
      var_scope = ckpt_scope = '/'

    def scaffold_fn():
      """Loads pretrained model through scaffold function."""
      logging.info('restore variables from %s', checkpoint)

      var_map = utils.get_ckpt_var_map(
          ckpt_path=checkpoint,
          ckpt_scope=ckpt_scope,
          var_scope=var_scope,
          skip_mismatch=params['skip_mismatch'])

      tf.train.init_from_checkpoint(checkpoint, var_map)
      return tf.train.Scaffold()
  elif mode == tf.estimator.ModeKeys.EVAL and moving_average_decay:

    def scaffold_fn():
      """Load moving average variables for eval."""
      logging.info('Load EMA vars with ema_decay=%f', moving_average_decay)
      restore_vars_dict = ema.variables_to_restore(ema_vars)
      saver = tf.train.Saver(restore_vars_dict)
      return tf.train.Scaffold(saver=saver)
  else:
    scaffold_fn = None

  if is_tpu:
    return tf.estimator.tpu.TPUEstimatorSpec(
        mode=mode,
        loss=total_loss,
        train_op=train_op,
        eval_metrics=eval_metrics,
        host_call=utils.get_tpu_host_call(global_step, params),
        scaffold_fn=scaffold_fn,
        training_hooks=training_hooks)
  else:
    # Profile every 1K steps.
    if params.get('profile', False):
      profile_hook = tf.estimator.ProfilerHook(
          save_steps=1000, output_dir=params['model_dir'], show_memory=True)
      training_hooks.append(profile_hook)

      # Report memory allocation if OOM; it will slow down the running.
      class OomReportingHook(tf.estimator.SessionRunHook):

        def before_run(self, run_context):
          return tf.estimator.SessionRunArgs(
              fetches=[],
              options=tf.RunOptions(report_tensor_allocations_upon_oom=True))

      training_hooks.append(OomReportingHook())

    logging_hook = tf.estimator.LoggingTensorHook(
        {
            'step': global_step,
            'det_loss': det_loss,
            'cls_loss': cls_loss,
            'box_loss': box_loss,
        },
        every_n_iter=params.get('iterations_per_loop', 100),
    )
    training_hooks.append(logging_hook)

    eval_metric_ops = (
        eval_metrics[0](**eval_metrics[1]) if eval_metrics else None)
    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=total_loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        scaffold=scaffold_fn() if scaffold_fn else None,
        training_hooks=training_hooks)